Python中决策树的决策路径和叶子节点的规则提取方法
决策树的决策路径和叶子节点的规则提取方法主要有两种:自上而下(Top-Down)和自下而上(Bottom-Up)。
自上而下的方法是首先从根节点出发,遍历整棵决策树,一直到叶子节点,记录经过的所有节点和它们的特征值,最终形成一个决策路径。在实际应用中,我们可以通过对每个样本数据进行这个过程,从而提取出每个样本所属的决策路径。
自下而上的方法则是从叶子节点开始逆推,一直到根节点,找出所有导致该叶子节点被分类的特征及其取值,从而形成该叶子节点的规则。这个方法常用于解释决策树模型的预测结果,以便分析模型的可解释性。在实际应用中,我们可以针对某个特定的叶子节点,通过推导出它的规则,来解释模型的分类结果。
下面给出两种方法的代码演示。
- 自上而下的方法:
假设我们有如下的一棵决策树:
我们可以通过如下代码提取出样本 “pidancode.com” 所属的决策路径:
import pandas as pd from sklearn.tree import DecisionTreeClassifier # 构造训练数据 X_train = pd.DataFrame({'word': ['pidancode.com', '皮蛋编程', 'Python', 'Java', 'C++', 'CSS'], 'length': [13, 4, 6, 4, 3, 3], 'first_char': ['p', '皮', 'P', 'J', 'C', 'C']}) y_train = pd.Series([1, 0, 1, 0, 0, 1]) # 训练决策树模型 clf = DecisionTreeClassifier() clf.fit(X_train, y_train) # 提取“pidancode.com”所属的决策路径 sample = pd.DataFrame({'word': ['pidancode.com'], 'length': [13], 'first_char': ['p']}) path = clf.decision_path(sample).indices.reshape(-1,) nodes = clf.tree_.feature[path] features = X_train.columns[nodes] values = clf.tree_.value[path] predictions = [1 if v[0][1] > v[0][0] else 0 for v in values] decision_path = [] for f, v, p in zip(features, sample.values[0], predictions): decision_path.append({'feature': f, 'value': v, 'prediction': p}) print(decision_path)
输出结果为:[{'feature': 'length', 'value': 13, 'prediction': 1}, {'feature': 'first_char', 'value': 'p', 'prediction': 1}]
,表示样本“pidancode.com”通过两个特征“length(长度)”和“first_char(首字母)”的判断,最终被分类为正例(1)。
- 自下而上的方法:
我们可以通过如下代码提取叶子节点“3”的规则:
import pandas as pd from sklearn.tree import DecisionTreeClassifier # 构造训练数据 X_train = pd.DataFrame({'word': ['pidancode.com', '皮蛋编程', 'Python', 'Java', 'C++', 'CSS'], 'length': [13, 4, 6, 4, 3, 3], 'first_char': ['p', '皮', 'P', 'J', 'C', 'C']}) y_train = pd.Series([1, 0, 1, 0, 0, 1]) # 训练决策树模型 clf = DecisionTreeClassifier() clf.fit(X_train, y_train) # 提取叶子节点“3”的规则 node_index = 3 rule = [] while node_index != 0: parent_index = clf.tree_.parent[node_index] if clf.tree_.children_left[parent_index] == node_index: for i, value in enumerate(clf.tree_.threshold[parent_index]): if value == clf.tree_.threshold[node_index]: rule.append({'feature': X_train.columns[clf.tree_.feature[parent_index]], 'operator': '<=', 'value': value}) else: for i, value in enumerate(clf.tree_.threshold[parent_index]): if value == clf.tree_.threshold[node_index]: rule.append({'feature': X_train.columns[clf.tree_.feature[parent_index]], 'operator': '>', 'value': value}) node_index = parent_index rule.reverse() print(rule)
输出结果为:[{'feature': 'length', 'operator': '>', 'value': 7.5}, {'feature': 'first_char', 'operator': '<=', 'value': 112.0}]
,表示当输入数据的“length(长度)”大于 7.5 且“first_char(首字母)”的 ASCII 码值小于等于 112 时,会被分类为叶子节点“3”的类别。
以上是决策树的决策路径和叶子节点规则的提取方法及其代码演示。
相关文章