python scikit learn random forest classifier individual tree node elements command

Question:

I know that I can get an individual tree (for example tree 0) of a random forest object rf using
rf.estimators[0]

is the a way to get the list of events each terminal node of a specific tree and it classification?

using python 3, scikit learn,sciketlearn.ensemble.RandomForestClassifier

from sklearn.datasets import load_wine

wine = load_wine()
X = wine.data
y = wine.target

from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier(n_estimators=100, 
                            max_depth=3,
                            max_features='auto', 
                            min_samples_leaf=4,
                            bootstrap=True, 
                            n_jobs=-1, 
                            random_state=0)
rf.fit(X, y)

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

fig = plt.figure(figsize=(15, 10))
plot_tree(rf.estimators_[0], 
          feature_names=wine.feature_names,
          class_names=wine.target_names, 
          filled=True, impurity=True, 
          rounded=True)

*rf.treenodeelements[0,1]*
Asked By: NPHARD

||

Answers:

I am not sure if I get it correctly but based on my understanding, you should get the list of samples in each terminal node of a specific tree in a random forest classifier, by using apply method:

# Get the first tree in the random forest
tree = rf.estimators_[0]

# Get the node indices for each sample
node_indices = tree.apply(X)

# Create a dictionary to store the events in each terminal node
terminal_nodes = {}

# Iterate over the samples and add the events to the corresponding terminal node
for i, node_index in enumerate(node_indices):
    if node_index not in terminal_nodes:
        terminal_nodes[node_index] = []
    terminal_nodes[node_index].append(i)

# Print the events in each terminal node
for node_index, events in terminal_nodes.items():
    print(f'Node {node_index}: {events}')

Now, to get the classification of each terminal node, use predict method of the tree object and pass it the sample indices. This will return the predicted class for each sample.

# Get the predictions for each sample
predictions = tree.predict(X)

# Create a dictionary to store the classification of each terminal node
terminal_node_classification = {}

# Iterate over the samples and add the classification to the corresponding terminal node
for i, node_index in enumerate(node_indices):
    if node_index not in terminal_node_classification:
        terminal_node_classification[node_index] = []
    terminal_node_classification[node_index].append(predictions[i])

# Print the classification of each terminal node
for node_index, classification in terminal_node_classification.items():
    print(f'Node {node_index}: {classification}')
Answered By: Phoenix
Categories: questions Tags:
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.