Make your Decision Tree and Random Forest models explainable. Learn impurity-based importance, permutation importance, partial dependence plots, and quick tree visualization.
Uses decrease in Gini/Entropy across splits. Fast but can be biased toward features with many unique values.
Randomly shuffles a feature to measure drop in performance. Model-agnostic, less biased; computed on a validation set.
Shows the average effect of one (or two) features on predictions while marginalizing others.
# Imports
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance, PartialDependenceDisplay
from sklearn.metrics import accuracy_score
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
import numpy as np
# Data & split
X, y = load_breast_cancer(return_X_y=True)
feature_names = load_breast_cancer().feature_names
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# Train RF
rf = RandomForestClassifier(n_estimators=300, random_state=42, n_jobs=-1)
rf.fit(X_tr, y_tr)
print("Test Acc:", round(accuracy_score(y_te, rf.predict(X_te)), 3))
# 1) Impurity-based feature importance
imp = rf.feature_importances_
top = np.argsort(imp)[-8:][::-1]
print("Top features (impurity):")
for i in top:
print(f"{feature_names[i]} -> {imp[i]:.4f}")
# 2) Permutation importance (on test set)
perm = permutation_importance(rf, X_te, y_te, n_repeats=10, random_state=42, n_jobs=-1)
perm_idx = perm.importances_mean.argsort()[-8:][::-1]
print("\\nTop features (permutation):")
for i in perm_idx:
print(f"{feature_names[i]} -> {perm.importances_mean[i]:.4f}")
# 3) Partial Dependence (1D and 2D)
fig, ax = plt.subplots(figsize=(6,4))
PartialDependenceDisplay.from_estimator(rf, X_tr, [top[0]], feature_names=feature_names, ax=ax)
plt.tight_layout(); plt.show()
fig, ax = plt.subplots(figsize=(6,4))
PartialDependenceDisplay.from_estimator(rf, X_tr, [(top[0], top[1])], feature_names=feature_names, ax=ax)
plt.tight_layout(); plt.show()
# Bonus: Visualize a single tree from the forest
plt.figure(figsize=(10,6))
plot_tree(rf.estimators_[0], max_depth=3, filled=True, feature_names=feature_names, class_names=['neg','pos'], fontsize=6)
plt.tight_layout(); plt.show()
Tip: Compare impurity vs permutation rankings. If they diverge, trust permutation more (it reflects true predictive impact on held-out data).
plot_tree depth for readable diagrams in reports.