【问题标题】:How to force bias to zero in tensorflow LinearRegressor?如何在张量流 LinearRegressor 中强制偏差为零?
【发布时间】:2021-08-26 19:09:17
【问题描述】:

我正在使用 tensorflow LinearRegressor API 来解决回归问题 (https://www.tensorflow.org/api_docs/python/tf/estimator/LinearRegressor)。我知道我的模型中的偏差正好是 0。

我如何强制 LinearRegressor 学习 0 的偏差?

这是一个最小的例子:

import tensorflow as tf
import numpy as np
from sklearn.linear_model import SGDRegressor

用 2 个特征模拟一些数据(+ 偏差为 0) y = 0 + 2*x1 + 3*x2 + 噪声

np.random.seed(5332)
n = 1000
weights = np.array([
    [2],
    [3],
])

bias = 0

x = np.random.randn(n, np.shape(weights)[0])
y = (bias + np.matmul(x, weights) + np.random.randn(n, 1)).ravel()

在 sklearn 中,我会使用 fit_intercept=False 将偏差强制为 0:

ols = SGDRegressor(tol=0.000001, fit_intercept=False)
ols.fit(x, y)

print("True weights: {}".format(weights.ravel()))
print("Learned weights: {}".format(np.round(ols.coef_), 3))
print("True bias: {}".format([bias]))
print("Learned bias: {}".format(np.round(ols.intercept_), 3))

输出:

True weights: [2 3]
Learned weights: [2. 3.]
True bias: [0]
Learned bias: [0.]

在张量流中我做了以下事情:

column =  tf.feature_column.numeric_column('x', shape=np.shape(x)[1])
ols = tf.estimator.LinearRegressor(
    feature_columns=[column],
    optimizer=tf.train.GradientDescentOptimizer(0.0001)
)


train_input = tf.estimator.inputs.numpy_input_fn(
    x={"x": x},
    y=y,
    shuffle=False,
    num_epochs=100,
    batch_size=int(len(y) / 20)
)

ols.train(train_input)

print("True weights: {}".format(weights.ravel()))
print("Learned weights: {}".format(np.round(ols.get_variable_value('linear/linear_model/x/weights').flatten(), 3)))
print("True bias: {}".format([bias]))
print("Learned bias: {}".format(np.round(ols.get_variable_value('linear/linear_model/bias_weights').flatten(), 3)))

输出:

True weights: [2 3]
Learned weights: [1.993 2.998]
True bias: [0]
Learned bias: [-0.067]

但是学习到的偏差应该是:[0],我该如何执行呢?

【问题讨论】:

  • 老问题,但 tf.keras.constraints 是您要搜索的,对吧?

标签: python tensorflow linear-regression


【解决方案1】:

我猜 tf.keras.constraints 就是您要搜索的内容。

【讨论】:

    猜你喜欢
    • 2017-06-08
    • 2016-08-10
    • 2018-07-27
    • 2018-10-17
    • 1970-01-01
    • 1970-01-01
    • 2017-09-14
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多