十三、普通最小二乘
普通最小二乘法(OLS)意味着最小化模型做出的预测与观测数据之间的平方和的误差。
在上查找有关 OLS 的更多信息,请查看这个很酷的互动工具和/或查看在 Python 中执行 OLS 的。
普通最小二乘
通常,我们希望最小化此误差项。因此,例如线性模型的 OLS 解,是具有最小平方误差值的模型,其被计算为数据点的模型预测与数据点本身之间的差的平方。
在这里,我们将创建一个最小数据集,并使用 OLS 探索将简单线性模型拟合到他。
观察上面的数据,我们可以看到和y
之间存在某种关系,但我们想要一种方法来衡量这种关系是什么。OLS 是这样做的过程:找到最小化每个观测数据点与模型预测之间的平方距离的模型(在本例中为直线)。
# 重塑数据来适配 NumPy
x = np.reshape(x, [len(x), 1])
y = np.reshape(y, [len(y), 1])
请注意,我们在这里没有拟合截距(没有b
值,如果你想到y = ax + b
)。 在这个简单的模型中,我们隐含地假设截距值为零。你可以使用 OLS 调整截距(以及具有更多参数的线性模型),你只需将它们添加到其中即可。
# 使用 numpy 拟合(普通)最小二乘最佳直线
# 这给了我们拟合值(theta)和残差(我们在这个拟合中有多少误差)
theta, residuals, _, _ = np.linalg.lstsq(x, y)
# 从数组中拉出 theta 值
theta = theta[0][0]
# 检查 OLS 产生的 θ 解什么:
print(theta)
# 1.98695402961
print('The true relationship between y & x is: \t', true_rel)
print('OLS calculated relationship between y & x is: \t', theta)
'''
The true relationship between y & x is: 2
OLS calculated relationship between y & x is: 1.98695402961
'''
# 检查残差是什么
residuals[0]
# 1.3701226131131277
# 绘制原始数据,具有真实的基础关系,以及 OLS 拟合
fig, ax = plt.subplots(1)
ax.plot(x, y, 'x', markersize=10, label='Data')
ax.plot(x, theta*x, label='True Fit')
ax.legend();
# 我们还可以看到,我们观察到的所有点,模型预测了什么
preds = theta * x
# 残差是模型拟合与观测数据点之间的平方和
# Re-calculate the residuals 'by hand'
error = np.sum(np.subtract(preds, y) ** 2)
# 检查我们的残差计算是否与 scipy 实现相匹配
print('Error from :', residuals[0])
print('Error from :', error)
'''
Error from : 1.37012261311
Error from : 1.37012261311
注意:在实践中,你不会将 numpy 用于 OLS。其他模块,如,拥有更明确用于线性建模的 OLS 实现。