Python中决策树的决策路径和叶子节点的规则提取方法

2023-04-15 00:00:00 路径 节点 提取

决策树的决策路径和叶子节点的规则提取方法主要有两种:自上而下(Top-Down)和自下而上(Bottom-Up)。

自上而下的方法是首先从根节点出发,遍历整棵决策树,一直到叶子节点,记录经过的所有节点和它们的特征值,最终形成一个决策路径。在实际应用中,我们可以通过对每个样本数据进行这个过程,从而提取出每个样本所属的决策路径。

自下而上的方法则是从叶子节点开始逆推,一直到根节点,找出所有导致该叶子节点被分类的特征及其取值,从而形成该叶子节点的规则。这个方法常用于解释决策树模型的预测结果,以便分析模型的可解释性。在实际应用中,我们可以针对某个特定的叶子节点,通过推导出它的规则,来解释模型的分类结果。

下面给出两种方法的代码演示。

  1. 自上而下的方法:

假设我们有如下的一棵决策树:

image.png

我们可以通过如下代码提取出样本 “pidancode.com” 所属的决策路径:

import pandas as pd
from sklearn.tree import DecisionTreeClassifier

# 构造训练数据
X_train = pd.DataFrame({'word': ['pidancode.com', '皮蛋编程', 'Python', 'Java', 'C++', 'CSS'], 
                        'length': [13, 4, 6, 4, 3, 3],
                        'first_char': ['p', '皮', 'P', 'J', 'C', 'C']})
y_train = pd.Series([1, 0, 1, 0, 0, 1])

# 训练决策树模型
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)

# 提取“pidancode.com”所属的决策路径
sample = pd.DataFrame({'word': ['pidancode.com'], 'length': [13], 'first_char': ['p']})
path = clf.decision_path(sample).indices.reshape(-1,)
nodes = clf.tree_.feature[path]
features = X_train.columns[nodes]
values = clf.tree_.value[path]
predictions = [1 if v[0][1] > v[0][0] else 0 for v in values]
decision_path = []
for f, v, p in zip(features, sample.values[0], predictions):
    decision_path.append({'feature': f, 'value': v, 'prediction': p})
print(decision_path)

输出结果为:[{'feature': 'length', 'value': 13, 'prediction': 1}, {'feature': 'first_char', 'value': 'p', 'prediction': 1}],表示样本“pidancode.com”通过两个特征“length(长度)”和“first_char(首字母)”的判断,最终被分类为正例(1)。

  1. 自下而上的方法:

我们可以通过如下代码提取叶子节点“3”的规则:

import pandas as pd
from sklearn.tree import DecisionTreeClassifier

# 构造训练数据
X_train = pd.DataFrame({'word': ['pidancode.com', '皮蛋编程', 'Python', 'Java', 'C++', 'CSS'], 
                        'length': [13, 4, 6, 4, 3, 3],
                        'first_char': ['p', '皮', 'P', 'J', 'C', 'C']})
y_train = pd.Series([1, 0, 1, 0, 0, 1])

# 训练决策树模型
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)

# 提取叶子节点“3”的规则
node_index = 3
rule = []
while node_index != 0:
    parent_index = clf.tree_.parent[node_index]
    if clf.tree_.children_left[parent_index] == node_index:
        for i, value in enumerate(clf.tree_.threshold[parent_index]):
            if value == clf.tree_.threshold[node_index]:
                rule.append({'feature': X_train.columns[clf.tree_.feature[parent_index]], 'operator': '<=', 'value': value})
    else:
        for i, value in enumerate(clf.tree_.threshold[parent_index]):
            if value == clf.tree_.threshold[node_index]:
                rule.append({'feature': X_train.columns[clf.tree_.feature[parent_index]], 'operator': '>', 'value': value})
    node_index = parent_index
rule.reverse()
print(rule)

输出结果为:[{'feature': 'length', 'operator': '>', 'value': 7.5}, {'feature': 'first_char', 'operator': '<=', 'value': 112.0}],表示当输入数据的“length(长度)”大于 7.5 且“first_char(首字母)”的 ASCII 码值小于等于 112 时,会被分类为叶子节点“3”的类别。

以上是决策树的决策路径和叶子节点规则的提取方法及其代码演示。

相关文章