Python递归实现梯度提升算法
梯度提升算法是一种集成学习算法,通过逐步优化基学习器的预测结果,最终得到一个性能更好的模型。其中,每一轮的优化都是通过学习当前模型在训练集上的负梯度来实现的。Python递归实现梯度提升算法的示例代码如下:
import numpy as np class GradientBoostingRegressor: def __init__(self, n_estimators=100, learning_rate=0.1, max_depth=3): self.n_estimators = n_estimators self.learning_rate = learning_rate self.max_depth = max_depth self.estimators = [] def fit(self, X, y): self._init_estimator(y) for i in range(self.n_estimators): r = self.residual(X, y) estimator = self._build_tree(X, r, 0) self.estimators.append(estimator) def _init_estimator(self, y): self.init_estimator = np.mean(y) def residual(self, X, y): pred_y = [self.predict(np.array(x).reshape(1,-1)) for x in X] return y - pred_y def _build_tree(self, X, y, depth): if depth == self.max_depth: return np.mean(y) else: estimator = DecisionTreeRegressor(max_depth=1) estimator.fit(X, y) return estimator def predict(self, X): raw_pred = self.init_estimator + sum(self.learning_rate * est.predict(X.reshape(1,-1)) for est in self.estimators) return raw_pred.flat[0]
其中,GradientBoostingRegressor类实现了梯度提升算法的主体框架,其中包括n_estimators个基学习器、每个学习器的学习率learning_rate和最大深度max_depth。在初始化函数中,我们定义了每个学习器是DecisionTreeRegressor,因此需要导入该类。
在fit函数中,我们首先通过_init_estimator函数计算了y的均值,作为初始的预测结果。然后,从0到n_estimators-1遍历每个基学习器,依次计算负梯度r,使用_build_tree函数构建决策树并将其添加到estimators列表中。
在predict函数中,我们通过求和每个基学习器的预测值来计算最终的预测结果。
最后,我们可以使用以下代码演示我们实现的梯度提升算法:
from sklearn.datasets import load_boston from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeRegressor from sklearn.metrics import mean_squared_error boston = load_boston() X_train, X_test, y_train, y_test = train_test_split(boston.data, boston.target, test_size=0.2, random_state=42) gb = GradientBoostingRegressor(n_estimators=20, learning_rate=0.1, max_depth=3) gb.fit(X_train, y_train) mse = mean_squared_error(y_test, [gb.predict(x) for x in X_test]) print('MSE:', mse)
在这个例子中,我们使用Boston房价数据集来训练一个梯度提升回归模型。我们将数据集分割成训练集和测试集,使用20个基学习器并将学习率设为0.1,最大深度为3。最后,我们计算测试集上的均方误差(MSE),并输出结果。
在实际使用中,梯度提升算法通常比单个决策树更有效,并且可以通过超参数调整进一步提高性能。
相关文章