Python中决策树的可解释性分析方法

2023-04-14 00:00:00 分析 方法 解释性
  1. 特征重要性分析
    决策树模型在拟合训练数据时,会根据特征对分类结果的影响程度进行排序,这种排序方式即为特征重要性分析,通常采用“Gini importance”或“Mean Decrease Impurity”作为特征重要性指标。
    代码演示:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target
# 训练决策树模型
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X, y)
# 获取特征重要性
importances = clf.feature_importances_
# 绘制特征重要性柱状图
plt.bar(range(X.shape[1]), importances)
plt.xticks(range(X.shape[1]), iris.feature_names, rotation=90)
plt.show()

结果图如下所示:
feature_importance
2. 决策路径分析
决策路径分析即为跟踪决策树上的每个决策路径,以了解每个样本的分类过程。可以通过解析决策树的规则,或利用Graphviz绘制决策树来实现。
代码演示:
解析规则方式:

# 解析决策树规则
def tree_to_string(tree, feature_names):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value
    def recurse(left, right, threshold, features, node):
            if (threshold[node] != -2):
                    print("if ( " + features[node] + " <= " + "{:.6f}".format(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node])
                    print("} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node])
                    print("}")
            else:
                    print("return " + str(value[node]))
    recurse(left, right, threshold, features, 0)
# 调用函数进行解析
tree_to_string(clf, iris.feature_names)

结果输出如下所示:

if ( petal length (cm) <= 2.450000 ) {
if ( petal width (cm) <= 0.800000 ) {
return [[50.  0.  0.]]
} else {
if ( petal length (cm) <= 1.750000 ) {
return [[ 0. 48.  0.]]
} else {
if ( sepal width (cm) <= 3.100000 ) {
return [[0. 1. 0.]]
} else {
return [[0. 0. 2.]]
}
}
}
} else {
if ( petal length (cm) <= 4.850000 ) {
if ( petal width (cm) <= 1.650000 ) {
if ( petal length (cm) <= 3.350000 ) {
return [[ 0. 0. 46.]]
} else {
return [[ 0.  1. 39.]]
}
} else {
if ( sepal length (cm) <= 6.000000 ) {
return [[0. 0. 3.]]
} else {
return [[0. 0. 1.]]
}
}
} else {
if ( petal width (cm) <= 1.750000 ) {
return [[0. 0.  3.]]
} else {
return [[0. 0. 43.]]
}
}
}
}

Graphviz方式:

from sklearn.tree import export_graphviz
import graphviz
dot_data = export_graphviz(clf, out_file=None, 
                     feature_names=iris.feature_names,  
                     class_names=iris.target_names,  
                     filled=True, rounded=True,  
                     special_characters=True)  
graph = graphviz.Source(dot_data)  
graph

结果图如下所示:
decision_tree
3. 局部可解释性分析
局部可解释性分析即为了解特定样本的分类原因,通过跟踪样本在决策树上的路径,以观察每个特征对分类结果的影响程度。
代码演示:

# 加载数据集
X, y = iris.data[:5, :], iris.target[:5]
# 获取每个样本的决策路径
path = clf.decision_path(X)
# 根据决策路径计算每个特征的影响程度
for i in range(X.shape[0]):
    print('\n样本%d' % (i+1))
    for j in path[i].indices[:-1]:   # 最后一位为叶子节点,不考虑
        feature = iris.feature_names[clf.tree_.feature[j]]
        threshold = clf.tree_.threshold[j]
        if X[i, int(feature.split(' ')[-1])] <= threshold:
            print(feature, '<=', threshold)
        else:
            print(feature, '>', threshold)

结果如下所示:

样本1
petal length (cm) <= 2.450000
petal width (cm) <= 0.800000
样本2
petal length (cm) <= 2.450000
petal width (cm) <= 0.800000
样本3
petal length (cm) <= 2.450000
petal width (cm) > 0.800000
petal length (cm) <= 1.750000
样本4
petal length (cm) <= 2.450000
petal width (cm) > 0.800000
petal length (cm) > 1.750000
sepal width (cm) <= 3.100000
样本5
petal length (cm) <= 2.450000
petal width (cm) > 0.800000
petal length (cm) > 1.750000
sepal width (cm) > 3.100000

相关文章