diff --git a/机器学习算法理论及应用/第二章——手写线性回归算法/LinearRegression/UnivariateLinearRegression.py b/机器学习算法理论及应用/第二章——手写线性回归算法/LinearRegression/UnivariateLinearRegression.py index 92547e2..6aa6db2 100644 --- a/机器学习算法理论及应用/第二章——手写线性回归算法/LinearRegression/UnivariateLinearRegression.py +++ b/机器学习算法理论及应用/第二章——手写线性回归算法/LinearRegression/UnivariateLinearRegression.py @@ -26,5 +26,18 @@ plt.title('Happy') plt.legend() plt.show() -num_iterations = 500 -learning_rate = 0.01 \ No newline at end of file +# 训练线性回归模型 +num_iterations = 500 # 迭代次数 +learning_rate = 0.01 # 学习率 + +linear_regression = LinearRegression(x_train, y_train) # 初始化模型 +(theta, cost_history) = linear_regression.train(learning_rate, num_iterations) + +print('开始时的损失:', cost_history[0]) +print('训练后的损失:', cost_history[-1]) + +plt.plot(range(num_iterations), cost_history) +plt.xlabel('Iteration') +plt.ylabel('Cost') +plt.title('GD') +plt.show()