【发布时间】:2021-07-10 23:07:15
【问题描述】:
我正在尝试按照 Keras 官方演练为 TF2/Keras 编写自己的训练循环。 vanilla 版本就像一个魅力,但是当我尝试将 @tf.function 装饰器添加到我的训练步骤时,一些内存泄漏会占用我所有的内存并且我失去对我的机器的控制,有人知道发生了什么吗?。
代码的重要部分如下所示:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = siamese_network(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, siamese_network.trainable_weights)
optimizer.apply_gradients(zip(grads, siamese_network.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
val_logits = siamese_network(x, training=False)
val_acc_metric.update_state(y, val_logits)
val_prec_metric.update_state(y_batch_val, val_logits)
val_rec_metric.update_state(y_batch_val, val_logits)
for epoch in range(epochs):
step_time = 0
epoch_time = time.time()
print("Start of {} epoch".format(epoch))
for step, (x_batch_train, y_batch_train) in enumerate(train_ds):
if step > steps_epoch:
break
loss_value = train_step(x_batch_train, y_batch_train)
train_acc = train_acc_metric.result()
train_acc_metric.reset_states()
for val_step,(x_batch_val, y_batch_val) in enumerate(test_ds):
if val_step>validation_steps:
break
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_prec = val_prec_metric.result()
val_rec = val_rec_metric.result()
val_acc_metric.reset_states()
val_prec_metric.reset_states()
val_rec_metric.reset_states()
如果我评论@tf.function 行,则不会发生内存泄漏,但步骤时间慢了 3 倍。我的猜测是,不知何故,图表是在每个时期或类似的情况下再次创建的 bean,但我不知道如何解决它。
这是我正在学习的教程:https://keras.io/guides/writing_a_training_loop_from_scratch/
【问题讨论】:
-
您使用的是 GPU 吗?如果不是,则将其更改为 GPU。另外,尽量减少批量大小。
-
您的
train_ds和test_ds是如何创建的?当您枚举它们时,您会得到张量还是其他类型?
标签: python tensorflow keras custom-training