【发布时间】:2019-12-29 14:28:39
【问题描述】:
我正在尝试从头开始用 numpy 编写一个神经网络来识别 hand_written_digits。但我在更新权重和偏差方面有点困惑
这是我的代码
class NeuralNetwork():
learning_rate = 0.0001
ephochs = 10000
nodes_in_input_layer = 784 # features
nodes_in_hidden_layer = 100
nodes_in_output_layer = 10 # classes
np.random.seed(3)
def __init__(self):
self.hidden_layer = {'weights': np.random.rand(self.nodes_in_input_layer, self.nodes_in_hidden_layer)*0.1,
'biases': np.random.rand(self.nodes_in_hidden_layer)*0.1 }
self.output_layer = {'weights': np.random.rand(self.nodes_in_hidden_layer, self.nodes_in_output_layer)*0.1,
'biases': np.random.rand(self.nodes_in_output_layer)*0.1 }
print('self.hidden_layer: ',self.hidden_layer['weights'].shape)
print('self.output_layer: ',self.output_layer['weights'].shape)
def fit(self, x, y, ephochs= ephochs):
for i in range(ephochs):
# feed forword
z_hidden_layer = np.dot(x[i], self.hidden_layer['weights']) + self.hidden_layer['biases']
o_hidden_layer = sigmoid(z_hidden_layer)
z_output_layer = np.dot(o_hidden_layer, self.output_layer['weights']) + self.output_layer['biases']
o_output_layer = sigmoid(z_output_layer)
# back propagation
error = o_output_layer - y[i]
'''
## at output layer
derror_dweights = derror/do * do/dz * dz/dw
derror/do = error
do/dz = derivative of sigmoid(x[i])
dz/dw = o_hidden_layer
'''
derror_do = error
do_dz = sigmoid(z_output_layer, derivative=True)
dz_dw = o_hidden_layer
nw_output_layer = derror_do * do_dz
nw_output_layer = np.dot(nw_output_layer, dz_dw.T)
nb_output_layer = error
# updating new weights and biases
self.output_layer['weights'] = self.output_layer['weights'] - (self.learning_rate * nw_output_layer)
self.output_layer['biases'] = self.output_layer['biases'] - (self.learning_rate * nb_output_layer)
## update remain weights and biases
我在运行时遇到了这个错误
nw_output_layer = np.dot(nw_output_layer, dz_dw.T)
ValueError: shapes (10,) and (100,) not aligned: 10 (dim 0) != 100 (dim 0)
谁能逐步解释更新这个神经网络的权重和偏差的过程?
【问题讨论】:
-
我认为你问错了问题。首先,您可能想要更正错误(数组的尺寸不匹配)。然后,您可能想更深入地了解深度学习概念。
-
谢谢!是的,这就是我要问的。我哪里出错了
标签: python-3.x numpy neural-network deep-learning