Decision Tree

A simple but potentially powerful machine learning technique, the decision tree has some very appealing properties:

  • it is easy to implement
  • results are easily explained

Here is some sample code to pre-process the text in the sentiment file:

In [338]:
import numpy as np

data = [ line.split() for line in open('imdb_labelled.txt').readlines() ]

#  '.' at end of last word in each sentence, remove with strip()
sents = [ [ w.strip('.,').lower() for w in line[:-1] ] for line in data ]
y = np.asarray([ int(line[-1]) for line in data ])

# remove stop words
stop = ['a','-','the','it','in','and','or','with','to','of','as','was','is',"it's",'from','for','this','on','at']
sents = [ [ w for w in s if w not in stop ] for s in sents ]

print(sents[:5])
print('mean sentence length:', np.mean([len(s) for s in sents]))
print(y[:5])
print(len(y), sum(y))
[['very', 'very', 'very', 'slow-moving', 'aimless', 'movie', 'about', 'distressed', 'drifting', 'young', 'man'], ['not', 'sure', 'who', 'more', 'lost', 'flat', 'characters', 'audience', 'nearly', 'half', 'whom', 'walked', 'out'], ['attempting', 'artiness', 'black', '&', 'white', 'clever', 'camera', 'angles', 'movie', 'disappointed', 'became', 'even', 'more', 'ridiculous', 'acting', 'poor', 'plot', 'lines', 'almost', 'non-existent'], ['very', 'little', 'music', 'anything', 'speak'], ['best', 'scene', 'movie', 'when', 'gerardo', 'trying', 'find', 'song', 'that', 'keeps', 'running', 'through', 'his', 'head']]
mean sentence length: 10.185
[0 0 0 0 1]
1000 500

Now we encode simply be membership in a small vocabulary:

In [339]:
from collections import defaultdict

cnt = defaultdict(int)

for s in sents:
    for w in s:
        cnt[w] += 1
        
voc = [ w for w in cnt if cnt[w] > 20 ]
print('voc len:', len(voc))
X = [ [ int(w in sent) for w in voc ] for sent in sents ]
voc len: 77

This is enough to start training and evaluating our decision tree:

In [340]:
from sklearn import tree
from sklearn import metrics 
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

clf = tree.DecisionTreeClassifier(max_depth=3)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)

print("Accuracy:", metrics.accuracy_score(y_test, y_pred))
Accuracy: 0.59

Right away we achieve a moderately reasonable performance.

In addition, we can now visualize the process:

In [341]:
from sklearn.tree import export_graphviz
import graphviz

export_graphviz(clf, feature_names=voc, out_file="mytree.dot" )
with open("mytree.dot") as f:
    dot_graph = f.read()
graphviz.Source(dot_graph)
Out[341]:
Tree 0 bad <= 0.5 gini = 0.5 samples = 900 value = [453, 447] 1 great <= 0.5 gini = 0.499 samples = 850 value = [406, 444] 0->1 True 8 not <= 0.5 gini = 0.113 samples = 50 value = [47, 3] 0->8 False 2 plot <= 0.5 gini = 0.5 samples = 815 value = [402, 413] 1->2 5 all <= 0.5 gini = 0.202 samples = 35 value = [4, 31] 1->5 3 gini = 0.499 samples = 795 value = [384, 411] 2->3 4 gini = 0.18 samples = 20 value = [18, 2] 2->4 6 gini = 0.117 samples = 32 value = [2, 30] 5->6 7 gini = 0.444 samples = 3 value = [2, 1] 5->7 9 about <= 0.5 gini = 0.043 samples = 45 value = [44, 1] 8->9 12 also <= 0.5 gini = 0.48 samples = 5 value = [3, 2] 8->12 10 gini = 0.0 samples = 42 value = [42, 0] 9->10 11 gini = 0.444 samples = 3 value = [2, 1] 9->11 13 gini = 0.375 samples = 4 value = [3, 1] 12->13 14 gini = 0.0 samples = 1 value = [0, 1] 12->14

This gives us the chance to quickly understand how the machine learning method arrives at its result.

Here is a version of the decision tree with nicer layout; however, it involves some coding to get just what we want.

In [546]:
data = [ line.split() for line in open('imdb_labelled.txt').readlines() ]

#  '.' at end of last word in each sentence, remove with strip()
sents = [ [ w.strip('.,').lower() for w in line[:-1] ] for line in data ]
y = np.asarray([ int(line[-1]) for line in data ])

cnt = defaultdict(int)

for s in sents:
    for w in s:
        cnt[w] += 1
        
# remove stop words
stop = 'a,-,the,it,in,and,or,with,to,of,as,was,is,it\'s,from,for,this,on,at,do,been,has,her'
stop += ',after,part,90\'s,we,i\'ve,were,its'
sents = [ [ w for w in s if w not in stop.split(',') ] for s in sents ]

voc = [ w for w in cnt if cnt[w] > 2 and cnt[w] < 50 ]
print('voc len:', len(voc))
print(voc[:50])

# 0-->1 for graph labels
X = [ [ int(w not in sent) for w in voc ] for sent in sents ]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1)
print('sum(y_test):', sum(y_test))

clf = tree.DecisionTreeClassifier(max_depth=13)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

print("Accuracy:", metrics.accuracy_score(y_test, y_pred))

export_graphviz(clf, feature_names=voc, class_names=['neg','pos'], out_file="mytree.dot" )
g = open("mytree.dot").read()
print(g[:300])
voc len: 662
['about', 'young', 'man', 'sure', 'who', 'more', 'lost', '-', 'characters', 'or', 'audience', 'half', 'out', 'black', '&', 'white', 'clever', 'camera', 'disappointed', 'even', 'acting', 'poor', 'plot', 'lines', 'almost', 'little', 'music', 'anything', 'speak', 'best', 'scene', 'when', 'trying', 'find', 'song', 'through', 'his', 'rest', 'lacks', 'art', 'meaning', 'if', 'works', 'guess', 'because', 'wasted', 'two', 'hours', 'saw', 'today']
sum(y_test): 51
Accuracy: 0.57
digraph Tree {
node [shape=box] ;
0 [label="great <= 0.5\ngini = 0.5\nsamples = 900\nvalue = [451, 449]\nclass = neg"] ;
1 [label="would <= 0.5\ngini = 0.17\nsamples = 32\nvalue = [3, 29]\nclass = pos"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="gini = 0.0\nsamples = 1

It turns out that the input to the graphviz module is really a string, so we can manipulate that content to change the labels into just word, number of samples, and class. This makes the tree more readable and also smaller, so we can go to a higer value for the depth and still fit it nicely on the page.

In [547]:
import re
g = re.sub(r" <= \d\.\d+\\n", "", g)
g = re.sub(r"gini = \d\.\d+\\n", "\\n", g)
g = re.sub(r"value = \[\d+, \d+\]\\n", "", g)
g = re.sub(r"samples = ", "", g)
g = re.sub(r"\\nclass = ", " ", g)
#print(g)
graphviz.Source(g)
Out[547]:
Tree 0 great 900 neg 1 would 32 pos 0->1 True 10 wonderful 868 neg 0->10 False 2 1 neg 1->2 3 people 31 pos 1->3 4 1 neg 3->4 5 original 30 pos 3->5 6 ever 2 neg 5->6 9 28 pos 5->9 7 1 pos 6->7 8 1 neg 6->8 11 17 pos 10->11 12 love 851 neg 10->12 13 never 20 pos 12->13 18 even 831 neg 12->18 14 1 neg 13->14 15 watch 19 pos 13->15 16 1 neg 15->16 17 18 pos 15->17 19 though 34 neg 18->19 38 excellent 797 neg 18->38 20 2 pos 19->20 21 bit 32 neg 19->21 22 1 pos 21->22 23 felt 31 neg 21->23 24 1 pos 23->24 25 child 30 neg 23->25 26 1 pos 25->26 27 don't 29 neg 25->27 28 4 neg 27->28 29 by 25 neg 27->29 30 4 neg 29->30 31 can't 21 neg 29->31 32 3 neg 31->32 33 who 18 neg 31->33 34 2 neg 33->34 35 worse 16 neg 33->35 36 2 neg 35->36 37 14 neg 35->37 39 10 pos 38->39 40 beautiful 787 neg 38->40 41 10 pos 40->41 42 loved 777 neg 40->42 43 9 pos 42->43 44 stupid 768 neg 42->44 45 12 neg 44->45 46 didn't 756 neg 44->46 47 most 19 neg 46->47 52 no 737 neg 46->52 48 1 pos 47->48 49 screen 18 neg 47->49 50 1 pos 49->50 51 17 neg 49->51 53 see 19 neg 52->53 58 would 718 neg 52->58 54 1 pos 53->54 55 think 18 neg 53->55 56 1 pos 55->56 57 17 neg 55->57 59 10 14 neg 58->59 62 awful 704 neg 58->62 60 1 pos 59->60 61 13 neg 59->61 63 10 neg 62->63 64 best 694 neg 62->64 65 19 pos 64->65 66 675 neg 64->66
In [ ]:
 
In [ ]: