【发布时间】:2017-11-01 03:04:04
【问题描述】:
我正在学习 TensorFlow,并尝试将其应用于一个简单的线性回归问题。 data 是形状为 [42x2] 的 numpy.ndarray。
我有点困惑,为什么在每个连续的时期之后,损失都在增加。是不是预计损失会随着每个连续的时期而下降!
这是我的代码(如果您希望我也分享输出,请告诉我!):(非常感谢您抽出宝贵时间来回答。)
1) 为因变量/自变量创建占位符
X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32,name='Y')
2) 为权重、偏差、总损失创建变量(在每个 epoch 之后)
w = tf.Variable(0.0,name='weights')
b = tf.Variable(0.0,name='bias')
3) 定义损失函数和优化器
Y_pred = X * w + b
loss = tf.reduce_sum(tf.square(Y - Y_pred), name = 'loss')
optimizer = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(loss)
4) 创建摘要事件和事件文件编写器
tf.summary.scalar(name = 'weight', tensor = w)
tf.summary.scalar(name = 'bias', tensor = b)
tf.summary.scalar(name = 'loss', tensor = loss)
merged = tf.summary.merge_all()
evt_file = tf.summary.FileWriter('def_g')
evt_file.add_graph(tf.get_default_graph())
5) 并在一个会话中执行所有操作
with tf.Session() as sess1:
sess1.run(tf.variables_initializer(tf.global_variables()))
for epoch in range(10):
summary, _,l = sess1.run([merged,optimizer,loss],feed_dict={X:data[:,0],Y:data[:,1]})
evt_file.add_summary(summary,epoch+1)
evt_file.flush()
print(" new_loss: {}".format(sess1.run(loss,feed_dict={X:data[:,0],Y:data[:,1]})))
干杯!
【问题讨论】:
-
您可以发布您的数据样本吗?至少 10 个。
-
: 20477196.0 : 8389799424.0 : 3440635019264.0 : 1410998762209280.0 : 5.786483681258373e+17 : 2.3730280954829e+20 : 9.731749171431463e+22 : 3.99097521643995e+25 : 1.6366924149690585e+28 : 6.712048633325622e+30
-
我注意到体重在 b/w -ve 和 + 区域内摆动到非常大的数字:e0:w 0.0,b 0.0 e1:w 45.9618034362793,b 2.828000068664551 e2:w -885.352294921875,b -43.58602905273 : w 17974.052734375, b 906.8658447265625 e4: w -363946.125, b -18330.419921875 e5: w 7370275.5, b 371251.5625 e6: w -149254576.0, b -7518119.0 e7: w 3022537472.0, b 152248672.0 e8: w -61209063424.0, b -3083169792.0 e9: w 1239537811456.0, b 62436925440.0
-
x 和 y 数据?
-
[6.2, 29.], [9.5, 44.], [10.5, 36.], [7.7, 37.], [8.6, 53.], [34.1, 68.], [11., 75.], [6.9, 18.], [7.3, 31.], [15.1, 25.], [29.1, 34.], [2.2, 14.], [5.7, 11.], [2., 11.], [2.5, 22.], [4., 16.], [5.4, 27.], [2.2, 9.], [7.2, 29.], [15.1, 30.] , [ 16.5, 40. ], [ 18.4, 32. ], [ 36.2, 41. ],
标签: tensorflow