线性回归的原理及Python实现

2023-07-04 09:48:01 python 原理 线性 回归

提到线性回归相信大家应该都不会觉得陌生(不陌生你点进来干嘛[捂脸]),本文就线性回归的基本原理进行讲解,并手把手、肩并肩地带您实现这一算法。

完整实现代码请参考本人的p...哦不是...github:
regression_base.py
linear_regression.py
linear_regression_example.py

1. 原理篇

我们用人话而不是大段的数学公式来讲讲线性回归是怎么一回事。

1.1 线性方程组

上小学或者中学的时候,很多人就接触过线性方程组了。举个栗子,如果x + y = 2且2x + y = 3,那么3x + 4y = ?。我们可以轻松地得出结论,解线性方程组得到x = 1且y = 1,所以3x + 4y = 3 + 4 = 7。

1.2 超定方程组

对于方程组Ra=y,R为n×m矩阵,如果R列满秩,且n>m。则方程组没有解,此时称方程组为超定方程组。翻译成人话就是方程组里方程的个数n太多了,比要求解的变量数m还多,这个方程是没办法求出解的。比如x + y = 2, 2x + y = 3且x + 2y = 4,那么我们是无法求出x和y能够同时满足这三个等式的。

1.3 线性回归问题

我们假设公司有n个同事(n = 10000),他们的年龄为A = [a1, a2...an],职级为B = [b1, b2...bn],工资为C = [c1, c2...cn],满足方程组Ax + By + z = C,我们想求出x, y 和z的值从而预测同事的工资,这样的问题就是典型的线性回归问题。我们有3个未知数x, y, z要求解,却有10000个方程,这显然是一个超定方程组。

1.4 小二乘法

如何求解这个超定方程组呢?当当当当,小二乘法闪亮登场了。假设n个同事有m个特征(年龄、职级等),收集这些特征组成m行n列的矩阵X,同事的工资为m行1列的矩阵Y,且满足m > n。我们要求解n个未知数W = [w1, w2...wn]和1个未知数b,满足方程组W * X + b = Y。
令预测值为 \hat Y ,那么有
MSE = \large\frac{1}{m}\normalsize\sum_{1}^{m}(Y_{i} - \hat Y_{i})^2

当我们的预测值完全等于真实值的时候,MSE等于0。根据上面的讲解,显然我们不太可能找到满足方程的解W,也就不可能准确地预测出Y,所以MSE不可能为0。但是我们想办法找出方程的近似解让MSE小,这就是小二乘法。

1.5 求近似解

如何求让MSE为零的近似解W呢?前方小段数学公式低能预警。

1. 使用MSE作为损失函数L
L = \large\frac{1}{m}\normalsize\sum_{1}^{m}(Y_{i} - \hat Y_{i})^2
2. 已知
\hat Y=WX + b

3. 对w求偏导,得
\large\frac{\mathrm{d}L}{\mathrm{d}W}\normalsize= -\large\frac{2}{m}\normalsize\sum_{1}^{m}(Y_{i} - WX_{i} - b)X_{i}

4. 对b求偏导,得
\large\frac{\mathrm{d}L}{\mathrm{d}b}\normalsize= -\large\frac{2}{m}\normalsize\sum_{1}^{m}(Y_{i} - WX_{i} - b)

所以,参数W的梯度就是式3,参数b的梯度就是式4。

1.6 梯度下降法

请参考我的另一篇文章,在这里就不赘述。链接如下:

1.7 批量梯度下降

遍历数据集中所有的样本,计算梯度并更新参数,记做1个epoch。经过若干个epochs之后,算法收敛或终止,计算量较大。

1.8 随机梯度下降

使用数据集中随机的一个样本,计算梯度并更新参数,直至算法收敛或终止,计算量较小。

2. 实现篇

本人用全宇宙简单的编程语言——Python实现了线性回归算法,没有依赖任何第三方库,便于学习和使用。简单说明一下实现过程,更详细的注释请参考本人github上的代码。

2.1 创建RegressionBase类

初始化,存储权重weights和偏置项bias。

class RegressionBase(object):
    def __init__(self):
        self.bias = None
        self.weights = None

相关文章