【问题标题】:Run tensorflow model before training, otherweise it won't train?在训练之前运行 tensorflow 模型,否则它不会训练?
【发布时间】:2022-01-13 17:56:07
【问题描述】:

今天在运行一个tensorflow代码的时候发现了一件很有意思的事情:

import matplotlib
from matplotlib import pyplot as plt
import tensorflow as tf
matplotlib.rcParams['figure.figsize'] = [9, 6]

x = tf.linspace(-2., 2., 201)
def f(x):
  y = x**2 + 2*x - 5
  return y
y = f(x) + tf.random.normal(shape=[201])

class Model(tf.keras.Model):
  def __init__(self, units):
    super().__init__()
    self.dense1 = tf.keras.layers.Dense(units=units,
                                        activation=tf.nn.relu,
                                        kernel_initializer=tf.random.normal,
                                        bias_initializer=tf.random.normal)
    self.dense2 = tf.keras.layers.Dense(1)

  def call(self, x, training=True):
    # For Keras layers/models, implement `call` instead of `__call__`.
    x = x[:, tf.newaxis]
    x = self.dense1(x)
    x = self.dense2(x)
    return tf.squeeze(x, axis=1)

model = Model(64)

test = model(x) ################## model couldn't train without this line ####

variables = model.variables
optimizer = tf.optimizers.SGD(learning_rate=0.001)

for step in range(1000):
  with tf.GradientTape() as tape:
    prediction = model(x)
    error = (y-prediction)**2
    mean_error = tf.reduce_mean(error)
  gradient = tape.gradient(mean_error, variables)
  optimizer.apply_gradients(zip(gradient, variables))

  if step % 100 == 0:
    print(f'Mean squared error: {mean_error.numpy():0.3f}')

模型本身非常简单。有趣的是注释行。如果不调用模型一次,比如通过test = model(x),模型根本不会训练!!!例如,如果我删除这一行。结果将是:

Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782

为什么需要这条线?

【问题讨论】:

  • 在 for 循环中,如果将 variables 替换为 model.variables,则代码将起作用。因为您传递的variables 只是一个空列表。

标签: python tensorflow machine-learning keras


【解决方案1】:

在训练模型之前,您必须构建并编译它。

构建模型会根据您的训练数据的input_shape 创建模型的所有变量。

编译模型会设置您希望在训练期间使用的优化器和损失函数。

当您调用模型时,它会自动按照您插入的数据的形状构建。因此,您可以在调用模型后对其进行训练。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-01-20
    • 1970-01-01
    • 2020-01-03
    • 1970-01-01
    • 2023-03-15
    • 1970-01-01
    相关资源
    最近更新 更多