从 GradientBoostingClassifier 中提取决策规则
问题描述
我已经解决了以下问题:
I have gone through the below questions:
如何提取 GradientBosstingClassifier 的决策规则
如何从中提取决策规则scikit-learn 决策树?
但是以上两个并没有解决我的目的.以下是我的查询:
However the above two does not solve my purpose. Below is my query:
我需要使用 gradientboostingclassifer 在 Python 中构建一个模型,并在 SAS 平台中实现这个模型.为此,我需要从 gradientboostingclassifer 中提取决策规则.
I need to build a model in Python using gradientboostingclassifer and implement this model in SAS platform. To do this I need to extract decision rules from the gradientboostingclassifer .
以下是我目前尝试过的:
Below is what I have tried so far:
在 IRIS 数据上构建模型:
Build the model on the IRIS data:
# import the most common dataset
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image
X, y = load_iris(return_X_y=True)
# there are 150 observations and 4 features
print(X.shape) # (150, 4)
# let's build a small model = 5 trees with depth no more than 2
model = GradientBoostingClassifier(n_estimators=5, max_depth=3, learning_rate=1.0)
model.fit(X, y==2) # predict 2nd class vs rest, for simplicity
# we can access individual trees
trees = model.estimators_.ravel()
def plot_tree(clf):
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data, node_ids=True,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data([enter image description here][3]dot_data.getvalue())
return Image(graph.create_png())
# now we can plot the first tree
plot_tree(trees[0])
绘制图表后,我检查了第一棵树的图表源代码,并使用以下代码写入文本文件:
After the plotting of the graph, I have checked the source code of the graph for the 1st tree and write to text file using the below code:
with open("C:\UsersXXXXDesktopPythoninput_tree.txt", "w") as wrt:
wrt.write(export_graphviz(trees[0], out_file=None, node_ids=True,
filled=True, rounded=True,
special_characters=True))
以下是输出文件:
digraph Tree {
node [shape=box, style="filled, rounded", color="black", fontname=helvetica] ;
edge [fontname=helvetica] ;
0 [label=<node #0<br/>X<SUB>3</SUB> ≤ 1.75<br/>friedman_mse = 0.222<br/>samples = 150<br/>value = 0.0>, fillcolor="#e5813955"] ;
1 [label=<node #1<br/>X<SUB>2</SUB> ≤ 4.95<br/>friedman_mse = 0.046<br/>samples = 104<br/>value = -0.285>, fillcolor="#e5813945"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label=<node #2<br/>X<SUB>3</SUB> ≤ 1.65<br/>friedman_mse = 0.01<br/>samples = 98<br/>value = -0.323>, fillcolor="#e5813943"] ;
1 -> 2 ;
3 [label=<node #3<br/>friedman_mse = 0.0<br/>samples = 97<br/>value = -1.5>, fillcolor="#e5813900"] ;
2 -> 3 ;
4 [label=<node #4<br/>friedman_mse = -0.0<br/>samples = 1<br/>value = 3.0>, fillcolor="#e58139ff"] ;
2 -> 4 ;
5 [label=<node #5<br/>X<SUB>3</SUB> ≤ 1.55<br/>friedman_mse = 0.222<br/>samples = 6<br/>value = 0.333>, fillcolor="#e5813968"] ;
1 -> 5 ;
6 [label=<node #6<br/>friedman_mse = 0.0<br/>samples = 3<br/>value = 3.0>, fillcolor="#e58139ff"] ;
5 -> 6 ;
7 [label=<node #7<br/>friedman_mse = 0.222<br/>samples = 3<br/>value = 0.0>, fillcolor="#e5813955"] ;
5 -> 7 ;
8 [label=<node #8<br/>X<SUB>2</SUB> ≤ 4.85<br/>friedman_mse = 0.021<br/>samples = 46<br/>value = 0.645>, fillcolor="#e581397a"] ;
0 -> 8 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
9 [label=<node #9<br/>X<SUB>1</SUB> ≤ 3.1<br/>friedman_mse = 0.222<br/>samples = 3<br/>value = 0.333>, fillcolor="#e5813968"] ;
8 -> 9 ;
10 [label=<node #10<br/>friedman_mse = 0.0<br/>samples = 2<br/>value = 3.0>, fillcolor="#e58139ff"] ;
9 -> 10 ;
11 [label=<node #11<br/>friedman_mse = -0.0<br/>samples = 1<br/>value = -1.5>, fillcolor="#e5813900"] ;
9 -> 11 ;
12 [label=<node #12<br/>friedman_mse = -0.0<br/>samples = 43<br/>value = 3.0>, fillcolor="#e58139ff"] ;
8 -> 12 ;
}
为了从输出文件中提取决策规则,我尝试了以下 python RegEX 代码来转换为 SAS 代码:
To extract the decision rules from the output file I have tried the below python RegEX code to translate to SAS code:
import re
with open("C:\UsersXXXXDesktopPythoninput_tree.txt") as f:
with open("C:\UsersXXXXDesktopPythonoutput.txt", "w") as f1:
result0 = 'value = 0;'
f1.write(result0)
for line in f:
result1 = re.sub(r'^(d+)s+.*<br/>([A-Z]+)<SUB>(d+)</SUB>s+(.+?)([-d.]+)<br/>friedman_mse.*;$',r"if 23 4 5 then do;",line)
result2 = re.sub(r'^(d+).*(?!SUB).*(values+=)s([-d.]+).*;$',r"2 value + 3; end;",result1)
result3 = re.sub(r'^(d+s+->s+d+s+);$',r'1',result2)
result4 = re.sub(r'^digraph.+|^node.+|^edge.+','',result3)
result5 = re.sub(r'&(w{2});',r'1',result4)
result6 = re.sub(r'}','end;',result5)
f1.write(result6)
以下是上述代码的输出 SAS:
below is the output SAS from the above code:
value = 0;
if X3 le 1.75 then do;
if X2 le 4.95 then do;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
if X3 le 1.65 then do;
1 -> 2
value = value + -1.5; end;
2 -> 3
value = value + 3.0; end;
2 -> 4
if X3 le 1.55 then do;
1 -> 5
value = value + 3.0; end;
5 -> 6
value = value + 0.0; end;
5 -> 7
if X2 le 4.85 then do;
0 -> 8 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
if X1 le 3.1 then do;
8 -> 9
value = value + 3.0; end;
9 -> 10
value = value + -1.5; end;
9 -> 11
value = value + 3.0; end;
8 -> 12
end;
如您所见,输出文件中缺少一块,即我无法正确打开/关闭 do-end 块.为此,我需要使用节点号,但我没有这样做,因为我在这里找不到任何模式.
As you can see there is a missing piece in the output file i.e. I am not able to open/close the do-end block properly. For this I need to make use of the node numbers but I am failing to so as I am unable to find any pattern here.
谁能帮我解决这个问题.
Could anyone of you please help me with this query.
除此之外,像决策树分类器一样,我不能提取上面第二个链接中提到的 children_left、children_right、阈值.我已经成功提取了GBM的每一棵树
Apart from this, like decisiontreeclassifier can I not extract the children_left, children_right, threshold value as mentioned in the above 2nd link. I have successfully extracted each tree of GBM
trees = model.estimators_.ravel()
但是我没有找到任何有用的函数可以用来提取每棵树的值和规则.如果我能以与 DecisionTreeclassifier 类似的方式使用 grapviz 对象,请提供帮助.
but I didn't find any useful function which I can use to extract the value and rules of each tree. Kindly help if I can use the grapviz object in a similar way of DecisionTreeclassifier.
或
用任何其他可以解决我的目的的方法来帮助我.
Help me with any other method which can solve my purpose.
解决方案
不需要使用graphviz导出来访问决策树数据.model.estimators_
包含模型所包含的所有单个分类器.对于 GradientBoostingClassifier,这是一个形状为 (n_estimators, n_classes) 的 2D numpy 数组,每个项目都是一个 DecisionTreeRegressor.
There is no need to use the graphviz export to access the decision tree data. model.estimators_
contains all the individual classifiers that the model consists of. In the case of a GradientBoostingClassifier, this is a 2D numpy array with shape (n_estimators, n_classes), and each item is a DecisionTreeRegressor.
每个决策树都有一个属性 _tree
和 了解决策树结构 展示了如何从该对象中取出节点、阈值和子对象.
Each decision tree has a property _tree
and Understanding the decision tree structure shows how to get out the nodes, thresholds and children from that object.
import numpy
import pandas
from sklearn.ensemble import GradientBoostingClassifier
est = GradientBoostingClassifier(n_estimators=4)
numpy.random.seed(1)
est.fit(numpy.random.random((100, 3)), numpy.random.choice([0, 1, 2], size=(100,)))
print('s', est.estimators_.shape)
n_classes, n_estimators = est.estimators_.shape
for c in range(n_classes):
for t in range(n_estimators):
dtree = est.estimators_[c, t]
print("class={}, tree={}: {}".format(c, t, dtree.tree_))
rules = pandas.DataFrame({
'child_left': dtree.tree_.children_left,
'child_right': dtree.tree_.children_right,
'feature': dtree.tree_.feature,
'threshold': dtree.tree_.threshold,
})
print(rules)
为每棵树输出如下内容:
Outputs something like this for each tree:
class=0, tree=0: <sklearn.tree._tree.Tree object at 0x7f18a697f370>
child_left child_right feature threshold
0 1 2 0 0.020702
1 -1 -1 -2 -2.000000
2 3 6 1 0.879058
3 4 5 1 0.543716
4 -1 -1 -2 -2.000000
5 -1 -1 -2 -2.000000
6 7 8 0 0.292586
7 -1 -1 -2 -2.000000
8 -1 -1 -2 -2.000000
相关文章