I am building a decision tree using
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)
This all works fine. However, how do I then explore the decision tree?
For example, how do I find which entries from X_train appear in a particular leaf?
You need to use the predict method.
After training the tree, you feed the X values to predict their output.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
tree = clf.fit(iris.data, iris.target)
tree.predict(iris.data)
output:
>>> tree.predict(iris.data)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
To get details on the tree structure, we can use tree_.__getstate__()
Tree structure translated into an "ASCII art" picture
0
_____________
1 2
______________
3 12
_______ _______
4 7 13 16
___ ______ _____
5 6 8 9 14 15
_____
10 11
tree structure as an array.
In [38]: tree.tree_.__getstate__()['nodes']
Out[38]:
array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),
(-1, -1, -2, -2.0, 0.0, 50, 50.0),
(3, 12, 3, 1.75, 0.5, 100, 100.0),
(4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),
(5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),
(-1, -1, -2, -2.0, 0.0, 47, 47.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),
(-1, -1, -2, -2.0, 0.0, 3, 3.0),
(10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),
(-1, -1, -2, -2.0, 0.0, 2, 2.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),
(14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),
(-1, -1, -2, -2.0, 0.0, 2, 2.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(-1, -1, -2, -2.0, 0.0, 43, 43.0)],
dtype=[('left_child', '<i8'), ('right_child', '<i8'),
('feature', '<i8'), ('threshold', '<f8'),
('impurity', '<f8'), ('n_node_samples', '<i8'),
('weighted_n_node_samples', '<f8')])
Where:
The first node [0] is the root node.
internal nodes have left_child and right_child refering to nodes with positive values, and greater than the current node.
leaves have -1 value for the left and right child nodes.
nodes 1,5,6, 8,10,11,14,15,16 are leaves.
the node structure is built using the Depth First Search Algorithm.
the feature field tells us which of the iris.data features was used in the node to determine the path for this sample.
the threshold tells us the value used to evaluate the direction based on the feature.
impurity reaches 0 at the leaves... since all the samples are in the same class once you reach the leaf.
n_node_samples tells us how many samples reach each leaf.
Using this information we could trivially track each sample X to the leaf where it eventually lands by following the classification rules and thresholds on a script. Additionally, the n_node_samples would allow us to perform unit tests ensuring that each node gets the correct number of samples.Then using the output of tree.predict, we could map each leaf to the associated class.
NOTE: This is not an answer, only a hint on possible solutions.
I encountered a similar problem recently in my project. My goal is to extract the corresponding chain of decisions for some particular samples. I think your problem is a subset of mine, since you just need to record the last step in the decision chain.
Up to now, it seems the only viable solution is to write a custom predict method in Python to keep track of the decisions along the way. The reason is that the predict method provided by scikit-learn cannot do this out-of-box (as far as I know). And to make it worse, it is a wrapper for C implementation which is pretty hard to customize.
Customization is fine for my problem, since I'm dealing with a unbalanced dataset, and the samples I care about (positive ones) are rare. So I can filter them out first using sklearn predict and then get the decision chain using my customization.
However, this may not work for you if you have a large dataset. Because if you parse the tree and do predict in Python, it will run slow in Python speed and will not (easily) scale. You may have to fallback to customizing the C implementation.
I've changed a bit what Dr. Drew posted.
The following code, given a data frame and the decision tree after being fitted, returns:
rules_list: a list of rules
values_path: a list of entries (entries for each class going through the path)
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
def get_rules(dtc, df):
rules_list = []
values_path = []
values = dtc.tree_.value
def RevTraverseTree(tree, node, rules, pathValues):
'''
Traverase an skl decision tree from a node (presumably a leaf node)
up to the top, building the decision rules. The rules should be
input as an empty list, which will be modified in place. The result
is a nested list of tuples: (feature, direction (left=-1), threshold).
The "tree" is a nested list of simplified tree attributes:
[split feature, split threshold, left node, right node]
'''
# now find the node as either a left or right child of something
# first try to find it as a left node
try:
prevnode = tree[2].index(node)
leftright = '<='
pathValues.append(values[prevnode])
except ValueError:
# failed, so find it as a right node - if this also causes an exception, something's really f'd up
prevnode = tree[3].index(node)
leftright = '>'
pathValues.append(values[prevnode])
# now let's get the rule that caused prevnode to -> node
p1 = df.columns[tree[0][prevnode]]
p2 = tree[1][prevnode]
rules.append(str(p1) + ' ' + leftright + ' ' + str(p2))
# if we've not yet reached the top, go up the tree one more step
if prevnode != 0:
RevTraverseTree(tree, prevnode, rules, pathValues)
# get the nodes which are leaves
leaves = dtc.tree_.children_left == -1
leaves = np.arange(0,dtc.tree_.node_count)[leaves]
# build a simpler tree as a nested list: [split feature, split threshold, left node, right node]
thistree = [dtc.tree_.feature.tolist()]
thistree.append(dtc.tree_.threshold.tolist())
thistree.append(dtc.tree_.children_left.tolist())
thistree.append(dtc.tree_.children_right.tolist())
# get the decision rules for each leaf node & apply them
for (ind,nod) in enumerate(leaves):
# get the decision rules
rules = []
pathValues = []
RevTraverseTree(thistree, nod, rules, pathValues)
pathValues.insert(0, values[nod])
pathValues = list(reversed(pathValues))
rules = list(reversed(rules))
rules_list.append(rules)
values_path.append(pathValues)
return (rules_list, values_path)
It follows an example:
df = pd.read_csv('df.csv')
X = df[df.columns[:-1]]
y = df['classification']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
dtc = DecisionTreeClassifier(max_depth=2)
dtc.fit(X_train, y_train)
The Decision Tree fitted has generated the following tree: Decision Tree with width 2
At this point, just calling the function:
get_rules(dtc, df)
This is what the function returns:
rules = [
['first <= 63.5', 'first <= 43.5'],
['first <= 63.5', 'first > 43.5'],
['first > 63.5', 'second <= 19.700000762939453'],
['first > 63.5', 'second > 19.700000762939453']
]
values = [
[array([[ 1568., 1569.]]), array([[ 636., 241.]]), array([[ 284., 57.]])],
[array([[ 1568., 1569.]]), array([[ 636., 241.]]), array([[ 352., 184.]])],
[array([[ 1568., 1569.]]), array([[ 932., 1328.]]), array([[ 645., 620.]])],
[array([[ 1568., 1569.]]), array([[ 932., 1328.]]), array([[ 287., 708.]])]
]
Obviously, in values, for each path, there is the leaf values too.
The below code should produce a plot of your top ten features:
import numpy as np
import matplotlib.pyplot as plt
importances = clf.feature_importances_
std = np.std(clf.feature_importances_,axis=0)
indices = np.argsort(importances)[::-1]
# Print the feature ranking
print("Feature ranking:")
for f in range(10):
print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))
# Plot the feature importances of the forest
plt.figure()
plt.title("Feature importances")
plt.bar(range(10), importances[indices],
color="r", yerr=std[indices], align="center")
plt.xticks(range(10), indices)
plt.xlim([-1, 10])
plt.show()
Taken from here and modified slightly to fit the DecisionTreeClassifier.
This doesn't exactly help you explore the tree, but it does tell you about the tree.
This code will do exactly what you want. Here, n is the number observations in X_train. At the end, the (n,number_of_leaves)-sized array leaf_observations holds in each column boolean values for indexing into X_train to get the observations in each leaf. Each columns of leaf_observations corresponds to an element in leaves, which has the node IDs for the leaves.
# get the nodes which are leaves
leaves = clf.tree_.children_left == -1
leaves = np.arange(0,clf.tree_.node_count)[leaves]
# loop through each leaf and figure out the data in it
leaf_observations = np.zeros((n,len(leaves)),dtype=bool)
# build a simpler tree as a nested list: [split feature, split threshold, left node, right node]
thistree = [clf.tree_.feature.tolist()]
thistree.append(clf.tree_.threshold.tolist())
thistree.append(clf.tree_.children_left.tolist())
thistree.append(clf.tree_.children_right.tolist())
# get the decision rules for each leaf node & apply them
for (ind,nod) in enumerate(leaves):
# get the decision rules in numeric list form
rules = []
RevTraverseTree(thistree, nod, rules)
# convert & apply to the data by sequentially &ing the rules
thisnode = np.ones(n,dtype=bool)
for rule in rules:
if rule[1] == 1:
thisnode = np.logical_and(thisnode,X_train[:,rule[0]] > rule[2])
else:
thisnode = np.logical_and(thisnode,X_train[:,rule[0]] <= rule[2])
# get the observations that obey all the rules - they are the ones in this leaf node
leaf_observations[:,ind] = thisnode
This needs the helper function defined here, which recursively traverses the tree starting from a specified node to build the decision rules.
def RevTraverseTree(tree, node, rules):
'''
Traverase an skl decision tree from a node (presumably a leaf node)
up to the top, building the decision rules. The rules should be
input as an empty list, which will be modified in place. The result
is a nested list of tuples: (feature, direction (left=-1), threshold).
The "tree" is a nested list of simplified tree attributes:
[split feature, split threshold, left node, right node]
'''
# now find the node as either a left or right child of something
# first try to find it as a left node
try:
prevnode = tree[2].index(node)
leftright = -1
except ValueError:
# failed, so find it as a right node - if this also causes an exception, something's really f'd up
prevnode = tree[3].index(node)
leftright = 1
# now let's get the rule that caused prevnode to -> node
rules.append((tree[0][prevnode],leftright,tree[1][prevnode]))
# if we've not yet reached the top, go up the tree one more step
if prevnode != 0:
RevTraverseTree(tree, prevnode, rules)
I think an easy option would be to use the apply method of the trained decision tree. Train the tree, apply the traindata and build a lookup table from the returned indices:
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
# apply training data to decision tree
leaf_indices = clf.apply(iris.data)
lookup = {}
# build lookup table
for i, leaf_index in enumerate(leaf_indices):
try:
lookup[leaf_index].append(iris.data[i])
except KeyError:
lookup[leaf_index] = []
lookup[leaf_index].append(iris.data[i])
# test
unkown_sample = [[4., 3.1, 6.1, 1.2]]
index = clf.apply(unkown_sample)
print(lookup[index[0]])
Have you tried dumping your DecisionTree into a graphviz' .dot file [1] and then load it with graph_tool [2].:
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from graph_tool.all import *
iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
tree.export_graphviz(clf,out_file='tree.dot')
#load graph with graph_tool and explore structure as you please
g = load_graph('tree.dot')
for v in g.vertices():
for e in v.out_edges():
print(e)
for w in v.out_neighbours():
print(w)
[1] http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html
[2] https://graph-tool.skewed.de/