【发布时间】:2019-03-28 21:40:56
【问题描述】:
我从零开始为神经网络实现添加了学习率和动量,我发现于:https://towardsdatascience.com/how-to-build-your-own-neural-network-from-scratch-in-python-68998a08e4f6
但是我对我的实现有一些疑问:
- 正确吗?有什么改进建议吗?它似乎通常可以输出足够的结果,但非常感谢外部建议。
-
当学习率 0.9 时,网络往往会陷入损失 = ~1 的局部最优。我认为这是因为步长不足以逃避这个问题,但有没有办法克服这个问题?或者这是与正在解决的数据的性质所固有的并且不可避免的。
import numpy as np import matplotlib.pyplot as plt def sigmoid(x): return 1 / (1 + np.exp(-x)) def sigmoid_derivative(x): sig = 1 / (1 + np.exp(-x)) return sig * (1 - sig) class NeuralNetwork: def __init__(self, x, y): self.input = x self.weights1 = np.random.rand(self.input.shape[1], 4) self.weights2 = np.random.rand(4, 1) self.y = y self.output = np.zeros(self.y.shape) self.v_dw1 = 0 self.v_dw2 = 0 self.alpha = 0.5 self.beta = 0.5 def feedforward(self): self.layer1 = sigmoid(np.dot(self.input, self.weights1)) self.output = sigmoid(np.dot(self.layer1, self.weights2)) def backprop(self, alpha, beta): # application of the chain rule to find derivative of the loss function with respect to weights2 and weights1 d_weights2 = np.dot(self.layer1.T, (2*(self.y - self.output) * sigmoid_derivative(self.output))) d_weights1 = np.dot(self.input.T, (np.dot(2*(self.y - self.output) * sigmoid_derivative(self.output), self.weights2.T) * sigmoid_derivative(self.layer1))) # adding effect of momentum self.v_dw1 = (beta * self.v_dw1) + ((1 - beta) * d_weights1) self.v_dw2 = (beta * self.v_dw2) + ((1 - beta) * d_weights2) # update the weights with the derivative (slope) of the loss function self.weights1 = self.weights1 + (self.v_dw1 * alpha) self.weights2 = self.weights2 + (self.v_dw2 * alpha) if __name__ == "__main__": X = np.array([[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]]) y = np.array([[0], [1], [1], [0]]) nn = NeuralNetwork(X, y) total_loss = [] for i in range(10000): nn.feedforward() nn.backprop(nn.alpha, nn.beta) total_loss.append(sum((nn.y-nn.output)**2)) iteration_num = list(range(10000)) plt.plot(iteration_num, total_loss) plt.show() print(nn.output)
【问题讨论】:
-
"•正确吗?有什么改进建议吗?" Code Review 是一个更好的提问地方。 (虽然他们想要已经可以工作的代码,所以我不确定“这是否正确”是否是主题......)
-
@JETM 它确实有效(通常),谢谢。
-
欢迎来到 SO; @JETM 想说的是,这类问题(“它是否正确?任何建议的改进?”)对于 SO 来说可能是题外话,您应该认真考虑将其移至代码审查...
-
@desertnaut 会的,谢谢
标签: python python-3.x numpy machine-learning neural-network