decision tree repeating class names

Question:

I have a very simple sample of data/labels, the problem I’m having is that the decision tree generated (pdf) is repeating the class name:

from sklearn import tree
from sklearn.externals.six import StringIO  
import pydotplus

features_names = ['weight', 'texture']
features = [[140, 1], [130, 1], [150, 0], [110, 0]]
labels = ['apple', 'apple', 'orange', 'orange']

clf = tree.DecisionTreeClassifier()
clf.fit(features, labels)

dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data, 
                         feature_names=features_names,  
                         class_names=labels,  
                         filled=True, rounded=True,  
                         special_characters=True,
                         impurity=False)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) 
graph.write_pdf("apples_oranges.pdf")

The resulting pdf looks like:

enter image description here

So, the problem is pretty obvious, it’s apple for both possibilities. What am I doing wrong?

From the DOCS:

list of strings, bool or None, optional (default=None)
Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. If True, shows a symbolic representation of the class name.

“…ascending numerical order” this doesn’t make much sense for me, if I change the kwarg to:

class_names=sorted(labels)

The result is the same (obvious in this case).

Asked By: Hula Hula

||

Answers:

The class names are literally just that, the name of the classes. It’s not the labels for each example.

So one class is ‘apple’ and the other is ‘orange’, so you just need to pass in ['apple', 'orange'].

Regarding order, to get it properly consistent, you could use a LabelEncoder to convert your target to an integer int_labels = labelEncoder.fit_transform(labels), use int_labels to fit your decision tree, then use the labelEncoder.classes_ attribute to pass into your graph viz.

Answered By: Ken Syme

Class names should be a "set" of names of your labels & passed in an ascending order.
You can do it directly like this

labels_set = sorted(labels.unique())
Answered By: Carol