【问题标题】:TensorFlow: Linear Regression with multiple inputs returns NaNsTensorFlow:具有多个输入的线性回归返回 NaN
【发布时间】:2017-04-13 06:41:47
【问题描述】:

这是我在 TensorFlow 的第一次尝试:我正在构建一个具有多个输入线性回归模型。

问题是结果总是NaN,我怀疑这是因为我是一个使用numpy和tensorflow的矩阵运算的完全菜鸟(matlab背景呵呵)。

代码如下:

import numpy as np
import tensorflow as tf

N_INP = 2
N_OUT = 1

# Model params
w = tf.Variable(tf.zeros([1, N_INP]), name='w')
b = tf.Variable(tf.zeros([1, N_INP]), name='b')

# Model input and output
x = tf.placeholder(tf.float32, [None, N_INP], name='x')
y = tf.placeholder(tf.float32, [None, N_OUT], name='y')
linear_model = tf.reduce_sum(x * w + b, axis=1, name='out')

# Loss as sum(error^2)
loss = tf.reduce_sum(tf.square(linear_model - y), name='loss')

# Create optimizer
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss, name='train')

# Define training data
w_real = np.array([-1, 4])
b_real = np.array([1, -5])
x_train = np.array([[1, 2, 3, 4], [0, 0.5, 1, 1.5]]).T
y_train = np.sum(x_train * w_real + b_real, 1)[np.newaxis].T
print('Real X:\n', x_train)
print('Real Y:\n', y_train)

# Create session and init parameters
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Training loop
train_data = {x: x_train, y: y_train}
for i in range(1000):
    sess.run(train, train_data)

# Eval solution
w_est, b_est, curr_loss, y_pred = sess.run([w, b, loss, linear_model], train_data)
print("w: %s b: %s loss: %s" % (w_est, b_est, curr_loss))
print("y_pred: %s" % (y_pred,))

这是输出:

Real X:
 [[ 1.   0. ]
 [ 2.   0.5]
 [ 3.   1. ]
 [ 4.   1.5]]
Real Y:
 [[-5.]
 [-4.]
 [-3.]
 [-2.]]

w: [[ nan  nan]] b: [[ nan  nan]] loss: nan
y_pred: [ nan  nan  nan  nan]

【问题讨论】:

    标签: python numpy tensorflow


    【解决方案1】:

    您需要在linear_model 的定义中添加keep_dims=True。也就是说,

    linear_model = tf.reduce_sum(x * w + b, axis=1, name='out',keep_dims=True)
    

    原因是否则结果是“扁平化”的,你不能从中减去y

    例如,

    'x' is [[1,2,3],
            [4,5,6]]   
    tf.reduce_sum(x, axis=1) is [6, 15]   
    tf.reduce_sum(x, axis=1, keep_dims=True) is [[6], [15]]
    

    【讨论】:

      猜你喜欢
      • 2019-04-12
      • 2019-12-31
      • 1970-01-01
      • 2017-04-10
      • 2017-12-16
      • 2019-10-14
      • 1970-01-01
      • 2021-06-12
      • 2013-03-09
      相关资源
      最近更新 更多