【发布时间】:2020-03-02 19:34:06
【问题描述】:
在 TF1 中以图形模式操作时,我相信我在使用函数式 API 时需要通过 feeddicts 连接 training=True 和 training=False。在 TF2 中执行此操作的正确方法是什么?
我相信这在使用tf.keras.Sequential 时会自动处理。例如,我不需要在docs的以下示例中指定training:
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10, activation='softmax')
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))
我是否也可以假设 keras 在使用功能 api 进行训练时会自动处理这个问题?这是相同的模型,使用函数 api 重写:
inputs = tf.keras.Input(shape=((28,28,1)), name="input_image")
hid = tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1))(inputs)
hid = tf.keras.layers.MaxPooling2D()(hid)
hid = tf.keras.layers.Flatten()(hid)
hid = tf.keras.layers.Dropout(0.1)(hid)
hid = tf.keras.layers.Dense(64, activation='relu')(hid)
hid = tf.keras.layers.BatchNormalization()(hid)
outputs = tf.keras.layers.Dense(10, activation='softmax')(hid)
model_fn = tf.keras.Model(inputs=inputs, outputs=outputs)
# Model is the full model w/o custom layers
model_fn.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model_fn.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model_fn.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))
我不确定hid = tf.keras.layers.BatchNormalization()(hid) 是否需要为hid = tf.keras.layers.BatchNormalization()(hid, training)?
可以在 here 找到这些模型的 colab。
【问题讨论】:
-
您是否有特定的理由想要控制训练标志,或者您是否在询问是否需要它?
-
我想我希望能够在
model_fn()(tf.keras.Model#call) 上将其设置为正向传递,以便 BatchNormalization 行为正确。我假设我需要对模型进行子类化并明确定义前向传递调用,以便我可以将training传递给 BN 调用,类似于tensorflow.org/api_docs/python/tf/keras/Model 中的示例。我还想知道在使用model_fn.fit()时是否需要它。 -
@cosentiyes:你提到了我相信这在使用
tf.keras.Sequential时会自动处理。你确定这是真的吗?你有什么可以证明这一点的参考资料吗?
标签: python tensorflow keras tensorflow2.0 tf.keras