Add. Loss and forecast module

pull/2/head
benjas 5 years ago
parent 21cbf93963
commit d9d0006c45

@ -32,7 +32,7 @@ class LinearRegression:
def train(self, alpha, num_iterations=500): def train(self, alpha, num_iterations=500):
""" """
训练模块执行梯度下降 训练模块执行梯度下降得到theta值和损失值loss
alpha: 学习率 alpha: 学习率
num_iterations: 迭代次数 num_iterations: 迭代次数
@ -46,10 +46,12 @@ class LinearRegression:
alpha: 学习率 alpha: 学习率
num_iterations: 迭代次数 num_iterations: 迭代次数
:return: 返回损失值 loss
""" """
cost_history = [] cost_history = [] # 收集每次的损失值
for _ in range(num_iterations): # 开始迭代 for _ in range(num_iterations): # 开始迭代
self.gradient_step(alpha) self.gradient_step(alpha) # 每次更新theta
cost_history.append(self.cost_function(self.data, self.labels)) cost_history.append(self.cost_function(self.data, self.labels))
return cost_history return cost_history
@ -68,6 +70,15 @@ class LinearRegression:
theta = theta - alpha * (1/num_examples)*(np.dot(delta.T, self.data)).T theta = theta - alpha * (1/num_examples)*(np.dot(delta.T, self.data)).T
self.theta = theta # 计算完theta后更新当前theta self.theta = theta # 计算完theta后更新当前theta
def cost_function(self, data, labels):
"""
损失计算方法计算平均的损失而不是每个数据的损失值
"""
num_examples = data.shape[0]
delta = LinearRegression.hypothesis(data, self.theta) - labels # 预测值-真实值 得到残差
cost = np.dot(delta, delta.T) # 损失值
return cost[0][0]
@staticmethod @staticmethod
def hypothesis(data, theta): def hypothesis(data, theta):
""" """
@ -79,3 +90,25 @@ class LinearRegression:
""" """
predictions = np.dot(data, theta) predictions = np.dot(data, theta)
return predictions return predictions
def get_cost(self, data, labels):
"""
得到当前损失
"""
data_processed = prepare_for_training.prepare_for_training(data,
self.polynomial_degree,
self.sinusoid_degree,
self.normalize_data)[0]
return self.cost_function(data_processed, labels)
def predict(self, data):
"""
用训练的参数模型预测得到回归值的结果
"""
data_processed = prepare_for_training.prepare_for_training(data,
self.polynomial_degree,
self.sinusoid_degree,
self.normalize_data)[0]
predictions = LinearRegression.hypothesis(data_processed, self.theta)
return predictions

Loading…
Cancel
Save