Python中决策树的剪枝方法和防止过拟合的方法

2023-04-15 00:00:00 方法 剪枝 拟合

Python中决策树的剪枝方法有两种,一种是预剪枝,在建树时就限制树的深度或节点的样本数量等条件,使得树的规模不会过大,避免过拟合;另一种是后剪枝,先生成完整的树,再通过降低节点不纯度来剪枝,使得树的规模减小,避免过拟合。
预剪枝的代码演示如下:

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(max_depth = 3)  # 限制深度为3
clf.fit(X_train, y_train)

后剪枝的代码演示如下:

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# 建树
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
# 绘制初始决策树
plt.figure(figsize=(10, 8))
plot_tree(clf)
# 后剪枝,降低节点不纯度
clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = clf.cost_complexity_pruning_path(X_train, y_train)["ccp_alphas"]
clfs = []
for ccp_alpha in ccp_alphas:
    temp = DecisionTreeClassifier(ccp_alpha=ccp_alpha)
    temp.fit(X_train, y_train)
    clfs.append(temp)
# 找出精度最高的决策树
train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]
best_clf = clfs[test_scores.index(max(test_scores))]
# 绘制剪枝后的决策树
plt.figure(figsize=(10, 8))
plot_tree(best_clf)

防止过拟合的方法有以下几种:
1. 数据增强:通过对训练集进行加噪声、旋转、裁剪等操作,增加训练数据的多样性,从而避免过拟合。
2. 正则化:在损失函数中加入正则化项,例如L1正则化、L2正则化等,使得模型的参数更加平滑,避免过于复杂。
3. Dropout:在训练时,以一定概率随机丢弃一些神经元,使得网络的结构更加随机,从而避免过拟合。
4. 交叉验证:将训练集划分成若干部分,每次取其中一部分作为验证集,其余部分作为训练集,重复直至所有部分都充当过验证集,得到多个模型的平均预测结果,从而避免对某些数据过分拟合。
5. 提前停止训练:当模型的训练误差和验证误差都达到一定的稳定水平时,提前停止训练,避免模型过拟合。

相关文章