Python中决策树的可解释性分析方法
- 特征重要性分析
决策树模型在拟合训练数据时,会根据特征对分类结果的影响程度进行排序,这种排序方式即为特征重要性分析,通常采用“Gini importance”或“Mean Decrease Impurity”作为特征重要性指标。
代码演示:
from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier import matplotlib.pyplot as plt # 加载数据集 iris = load_iris() X, y = iris.data, iris.target # 训练决策树模型 clf = DecisionTreeClassifier(random_state=42) clf.fit(X, y) # 获取特征重要性 importances = clf.feature_importances_ # 绘制特征重要性柱状图 plt.bar(range(X.shape[1]), importances) plt.xticks(range(X.shape[1]), iris.feature_names, rotation=90) plt.show()
结果图如下所示:
2. 决策路径分析
决策路径分析即为跟踪决策树上的每个决策路径,以了解每个样本的分类过程。可以通过解析决策树的规则,或利用Graphviz绘制决策树来实现。
代码演示:
解析规则方式:
# 解析决策树规则 def tree_to_string(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] value = tree.tree_.value def recurse(left, right, threshold, features, node): if (threshold[node] != -2): print("if ( " + features[node] + " <= " + "{:.6f}".format(threshold[node]) + " ) {") if left[node] != -1: recurse (left, right, threshold, features,left[node]) print("} else {") if right[node] != -1: recurse (left, right, threshold, features,right[node]) print("}") else: print("return " + str(value[node])) recurse(left, right, threshold, features, 0) # 调用函数进行解析 tree_to_string(clf, iris.feature_names)
结果输出如下所示:
if ( petal length (cm) <= 2.450000 ) { if ( petal width (cm) <= 0.800000 ) { return [[50. 0. 0.]] } else { if ( petal length (cm) <= 1.750000 ) { return [[ 0. 48. 0.]] } else { if ( sepal width (cm) <= 3.100000 ) { return [[0. 1. 0.]] } else { return [[0. 0. 2.]] } } } } else { if ( petal length (cm) <= 4.850000 ) { if ( petal width (cm) <= 1.650000 ) { if ( petal length (cm) <= 3.350000 ) { return [[ 0. 0. 46.]] } else { return [[ 0. 1. 39.]] } } else { if ( sepal length (cm) <= 6.000000 ) { return [[0. 0. 3.]] } else { return [[0. 0. 1.]] } } } else { if ( petal width (cm) <= 1.750000 ) { return [[0. 0. 3.]] } else { return [[0. 0. 43.]] } } } }
Graphviz方式:
from sklearn.tree import export_graphviz import graphviz dot_data = export_graphviz(clf, out_file=None, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True, special_characters=True) graph = graphviz.Source(dot_data) graph
结果图如下所示:
3. 局部可解释性分析
局部可解释性分析即为了解特定样本的分类原因,通过跟踪样本在决策树上的路径,以观察每个特征对分类结果的影响程度。
代码演示:
# 加载数据集 X, y = iris.data[:5, :], iris.target[:5] # 获取每个样本的决策路径 path = clf.decision_path(X) # 根据决策路径计算每个特征的影响程度 for i in range(X.shape[0]): print('\n样本%d' % (i+1)) for j in path[i].indices[:-1]: # 最后一位为叶子节点,不考虑 feature = iris.feature_names[clf.tree_.feature[j]] threshold = clf.tree_.threshold[j] if X[i, int(feature.split(' ')[-1])] <= threshold: print(feature, '<=', threshold) else: print(feature, '>', threshold)
结果如下所示:
样本1 petal length (cm) <= 2.450000 petal width (cm) <= 0.800000 样本2 petal length (cm) <= 2.450000 petal width (cm) <= 0.800000 样本3 petal length (cm) <= 2.450000 petal width (cm) > 0.800000 petal length (cm) <= 1.750000 样本4 petal length (cm) <= 2.450000 petal width (cm) > 0.800000 petal length (cm) > 1.750000 sepal width (cm) <= 3.100000 样本5 petal length (cm) <= 2.450000 petal width (cm) > 0.800000 petal length (cm) > 1.750000 sepal width (cm) > 3.100000
相关文章