【发布时间】:2022-01-13 17:56:07
【问题描述】:
今天在运行一个tensorflow代码的时候发现了一件很有意思的事情:
import matplotlib
from matplotlib import pyplot as plt
import tensorflow as tf
matplotlib.rcParams['figure.figsize'] = [9, 6]
x = tf.linspace(-2., 2., 201)
def f(x):
y = x**2 + 2*x - 5
return y
y = f(x) + tf.random.normal(shape=[201])
class Model(tf.keras.Model):
def __init__(self, units):
super().__init__()
self.dense1 = tf.keras.layers.Dense(units=units,
activation=tf.nn.relu,
kernel_initializer=tf.random.normal,
bias_initializer=tf.random.normal)
self.dense2 = tf.keras.layers.Dense(1)
def call(self, x, training=True):
# For Keras layers/models, implement `call` instead of `__call__`.
x = x[:, tf.newaxis]
x = self.dense1(x)
x = self.dense2(x)
return tf.squeeze(x, axis=1)
model = Model(64)
test = model(x) ################## model couldn't train without this line ####
variables = model.variables
optimizer = tf.optimizers.SGD(learning_rate=0.001)
for step in range(1000):
with tf.GradientTape() as tape:
prediction = model(x)
error = (y-prediction)**2
mean_error = tf.reduce_mean(error)
gradient = tape.gradient(mean_error, variables)
optimizer.apply_gradients(zip(gradient, variables))
if step % 100 == 0:
print(f'Mean squared error: {mean_error.numpy():0.3f}')
模型本身非常简单。有趣的是注释行。如果不调用模型一次,比如通过test = model(x),模型根本不会训练!!!例如,如果我删除这一行。结果将是:
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
为什么需要这条线?
【问题讨论】:
-
在 for 循环中,如果将
variables替换为model.variables,则代码将起作用。因为您传递的variables只是一个空列表。
标签: python tensorflow machine-learning keras