【问题标题】:The loss in training is always zero训练中的损失始终为零
【发布时间】:2020-06-27 19:11:47
【问题描述】:

我想复制吴恩达课程中的房价预测模型,但是在生成训练结果后,损失总是为零,权重变量也为零。有人可以帮忙吗?

import pandas as pd
import numpy as np

def normalize_feature(df):
    return df.apply(lambda column: (column - column.mean()) / column.std())


df = normalize_feature(pd.read_csv('data1.csv',
                                   names=['square', 'bedrooms', 'price']))

ones = pd.DataFrame({'ones': np.ones(len(df))})
df = pd.concat([ones, df], axis=1)
df.head()X_data = np.array(df[df.columns[0:3]])
y_data = np.array(df[df.columns[-1]]).reshape(len(df), 1)

print(X_data.shape, type(X_data))
print(y_data.shape, type(y_data))import tensorflow as tf

alpha = 0.1 
epoch = 500 

X = tf.placeholder(tf.float32, X_data.shape)
y = tf.placeholder(tf.float32, y_data.shape)

W = tf.get_variable("weights", (X_data.shape[1], 1), initializer=tf.constant_initializer())
y_pred = tf.matmul(X, W)


loss_op = 1 / (2 * len(X_data)) * tf.matmul((y_pred - y), (y_pred - y), transpose_a=True)
opt = tf.train.GradientDescentOptimizer(learning_rate=alpha)
train_op = opt.minimize(loss_op)with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())

    for e in range(1, epoch + 1):
        sess.run(train_op, feed_dict={X:X_data, y:y_data})
        if e % 10 == 0:
            loss, w = sess.run([loss_op, W], feed_dict={X: X_data, y: y_data})
            log_str = "Epoch %d \t Loss=%.4g \t Model: y = %.4gx1 + %.4gx2 + %.4g"
            print(log_str % (e, loss, w[1], w[2], w[0]))

结果是

Epoch 10     Loss=0      Model: y = 0x1 + 0x2 + 0
Epoch 20     Loss=0      Model: y = 0x1 + 0x2 + 0
Epoch 30     Loss=0      Model: y = 0x1 + 0x2 + 0
Epoch 40     Loss=0      Model: y = 0x1 + 0x2 + 0
Epoch 50     Loss=0      Model: y = 0x1 + 0x2 + 0
...
Epoch 500    Loss=0      Model: y = 0x1 + 0x2 + 0

【问题讨论】:

    标签: python tensorflow machine-learning


    【解决方案1】:

    loss_op 变量的结果很可能是自动四舍五入的。尝试用显式浮点数替换分子,如下所示

    loss_op = 1.0 / (2 * len(X_data)) * tf.matmul((y_pred - y), (y_pred - y), transpose_a=True)
    

    通过这种方式,您可以显式强制输出类型为 float

    【讨论】:

    猜你喜欢
    • 2022-12-10
    • 1970-01-01
    • 1970-01-01
    • 2018-05-17
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-06-14
    相关资源
    最近更新 更多