在Python中使用决策树进行过拟合和欠拟合的检测

2023-04-14 00:00:00 检测 决策树 拟合

过拟合和欠拟合是机器学习中常见的问题,决策树也不例外。在Python中,我们可以使用sklearn库中的DecisionTreeRegressor或DecisionTreeClassifier类来训练决策树模型,然后通过可视化和评估模型误差等方法来检测过拟合和欠拟合。

下面是一个用DecisionTreeRegressor类训练决策树的例子,其中我们使用一个简单的二元分类数据集来演示。

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
import numpy as np

# 生成数据集
X, y = make_classification(n_samples=1000, n_features=5, n_informative=3, n_redundant=2, random_state=42)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 训练模型
dt = DecisionTreeRegressor(max_depth=10)
dt.fit(X_train, y_train)

# 绘制学习曲线
train_errors = []
test_errors = []
for i in range(1, 100):
    dt = DecisionTreeRegressor(max_depth=i)
    dt.fit(X_train, y_train)
    train_errors.append(dt.score(X_train, y_train))
    test_errors.append(dt.score(X_test, y_test))

plt.plot(np.arange(1, 100), train_errors, label='Training Score')
plt.plot(np.arange(1, 100), test_errors, label='Validation Score')
plt.xlabel('Tree Depth')
plt.ylabel('Score')
plt.legend()
plt.show()

在上面的代码中,我们使用make_classification函数生成一个具有5个特征的二元分类数据集,其中有3个特征是有信息量的,2个特征是冗余的。使用train_test_split函数将数据集划分为训练集和测试集,并使用DecisionTreeRegressor类训练了一个最大深度为10的决策树模型。

学习曲线的绘制是通过逐步增加树的深度来实现的。在每个深度上,我们都会记录模型在训练集和测试集上的性能得分,并将它们绘制成两条曲线。这样可以让我们更直观地看出模型的过拟合或欠拟合情况。

下面是绘制的学习曲线图:

决策树学习曲线

从图中可以看出,当树的深度较小时,模型在训练集和测试集上都表现不佳,即欠拟合。随着树的深度的增加,模型的性能得分得到了显著提高,在测试集上的得分也逐渐提高。但是到了某个点之后,测试集上的得分开始下降,即出现了过拟合的情况。在这个例子中,树的最大深度大约在15左右,过拟合开始出现。

如果将最大深度设置为20,决策树将会非常复杂,过度拟合数据集。在训练数据上,树的分数很高,但在测试数据上表现非常糟糕。

# 训练深度为20的决策树
dt = DecisionTreeRegressor(max_depth=20)
dt.fit(X_train, y_train)

# 检查过拟合
train_score = dt.score(X_train, y_train)
test_score = dt.score(X_test, y_test)

print("Training Score: {:.2f}".format(train_score))
print("Test Score: {:.2f}".format(test_score))

输出结果:

Training Score: 1.00
Test Score: 0.42

因此,我们需要选择一个合适的树深度,以避免过拟合和欠拟合。

决策树还可以可视化来检查过拟合和欠拟合。为了进行可视化,我们需要安装graphviz包和pydotplus包。如果你的环境中没有这些包,请使用pip install命令进行安装。

下面是一个可视化的决策树的示例。

from sklearn.tree import export_graphviz
import pydotplus
from IPython.display import Image

dot_data = export_graphviz(dt, out_file=None, feature_names=['Feature 1', 'Feature 2', 'Feature 3', 'Feature 4', 'Feature 5'])
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())

输出结果:

可视化决策树

通过分析输出结果,我们可以看到该决策树的规则并确定将数据分类为某个类别的依据。我们可以使用树的深度和数量和最佳拆分点来确定它是否过度拟合,但是需要注意过度拟合和欠拟合是相对的,因此,我们需要特别注意找到平衡的点,以避免出现任何一种情况。

以上就是Python中使用决策树进行过拟合和欠拟合的检测的实现方法。

相关文章