Python中决策树的剪枝方法和防止过拟合的方法
Python中决策树的剪枝方法有两种,一种是预剪枝,在建树时就限制树的深度或节点的样本数量等条件,使得树的规模不会过大,避免过拟合;另一种是后剪枝,先生成完整的树,再通过降低节点不纯度来剪枝,使得树的规模减小,避免过拟合。
预剪枝的代码演示如下:
from sklearn.tree import DecisionTreeClassifier clf = DecisionTreeClassifier(max_depth = 3) # 限制深度为3 clf.fit(X_train, y_train)
后剪枝的代码演示如下:
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import plot_tree import matplotlib.pyplot as plt # 建树 clf = DecisionTreeClassifier() clf.fit(X_train, y_train) # 绘制初始决策树 plt.figure(figsize=(10, 8)) plot_tree(clf) # 后剪枝,降低节点不纯度 clf.cost_complexity_pruning_path(X_train, y_train) ccp_alphas = clf.cost_complexity_pruning_path(X_train, y_train)["ccp_alphas"] clfs = [] for ccp_alpha in ccp_alphas: temp = DecisionTreeClassifier(ccp_alpha=ccp_alpha) temp.fit(X_train, y_train) clfs.append(temp) # 找出精度最高的决策树 train_scores = [clf.score(X_train, y_train) for clf in clfs] test_scores = [clf.score(X_test, y_test) for clf in clfs] best_clf = clfs[test_scores.index(max(test_scores))] # 绘制剪枝后的决策树 plt.figure(figsize=(10, 8)) plot_tree(best_clf)
防止过拟合的方法有以下几种:
1. 数据增强:通过对训练集进行加噪声、旋转、裁剪等操作,增加训练数据的多样性,从而避免过拟合。
2. 正则化:在损失函数中加入正则化项,例如L1正则化、L2正则化等,使得模型的参数更加平滑,避免过于复杂。
3. Dropout:在训练时,以一定概率随机丢弃一些神经元,使得网络的结构更加随机,从而避免过拟合。
4. 交叉验证:将训练集划分成若干部分,每次取其中一部分作为验证集,其余部分作为训练集,重复直至所有部分都充当过验证集,得到多个模型的平均预测结果,从而避免对某些数据过分拟合。
5. 提前停止训练:当模型的训练误差和验证误差都达到一定的稳定水平时,提前停止训练,避免模型过拟合。
相关文章