【发布时间】:2018-02-08 02:11:19
【问题描述】:
我在 Tensorflow 中做了一个简单的线性回归。我怎么知道回归的系数是多少? 我已经阅读了文档,但我无法在任何地方找到它! (https://www.tensorflow.org/api_docs/python/tf/estimator/LinearRegressor)
编辑代码示例
import numpy as np
import tensorflow as tf
# Declare list of features, we only have one real-valued feature
def model_fn(features, labels, mode):
# Build a linear model and predict values
W = tf.get_variable("W", [1], dtype=tf.float64)
b = tf.get_variable("b", [1], dtype=tf.float64)
y = W * features['x'] + b
# Loss sub-graph
loss = tf.reduce_sum(tf.square(y - labels))
# Training sub-graph
global_step = tf.train.get_global_step()
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = tf.group(optimizer.minimize(loss),
tf.assign_add(global_step, 1))
# EstimatorSpec connects subgraphs we built to the
# appropriate functionality.
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=y,
loss=loss,
train_op=train)
estimator = tf.estimator.Estimator(model_fn=model_fn)
# define our data sets
x_train = np.array([1., 2., 3., 4.])
y_train = np.array([0., -1., -2., -3.])
x_eval = np.array([2., 5., 8., 1.])
y_eval = np.array([-1.01, -4.1, -7, 0.])
input_fn = tf.estimator.inputs.numpy_input_fn(
{"x": x_train}, y_train, batch_size=4, num_epochs=None, shuffle=True)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
{"x": x_train}, y_train, batch_size=4, num_epochs=1000, shuffle=False)
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
{"x": x_eval}, y_eval, batch_size=4, num_epochs=1000, shuffle=False)
# train
estimator.train(input_fn=input_fn, steps=1000)
# Here we evaluate how well our model did.
train_metrics = estimator.evaluate(input_fn=train_input_fn)
eval_metrics = estimator.evaluate(input_fn=eval_input_fn)
print("train metrics: %r"% train_metrics)
print("eval metrics: %r"% eval_metrics)
【问题讨论】:
-
您能否发布您的代码,以便我们了解您到底做了什么?一般很难回答...
-
对不起。我正在关注tensorflow.org/get_started/get_started 中的教程示例中的示例
-
您没有使用
tf.estimator.LinearRegressor,而是实现了您自己的估算器,对吧? (相当于一个线性回归器)。 -
@jdehesa 我也尝试过使用该选项。在那种情况下,我怎样才能得到系数值?
标签: python machine-learning tensorflow