【发布时间】:2019-07-07 20:54:17
【问题描述】:
开始使用编码器和解码器训练一个简单的 NMT(神经机器翻译器),并在 Colab 上进行训练,
encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)
然后使用检查点保存模型,
# On loacl machine dir changed to 'training_checkpoints/' to fit the loaction
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
并在训练期间使用
保存checkpoint.save(file_prefix = checkpoint_prefix)
在 Colab 上训练恢复检查点后可以正常工作,即使将整个检查点文件夹保存在 Google 驱动器上并再次恢复它们,但是当尝试在我的本地计算机上恢复它们时,它会返回不同的垃圾结果, 训练前开始检查点使用
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
Colab 笔记本输出:
Input: <start> يلعبون الكرة <end>
Predicted translation: he played soccer . <end>
本地机器输出:
Input: <start> يلعبون الكرة <end>
Predicted translation: take either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either
Colab TensorFlow 版本:1.13.0-rc1
本地机器tensorflow版本:1.12.0
知道这个问题是由于tensorflow的不同版本造成的,如何在不遇到这个问题的情况下保存模型?
NMT 笔记本的附加链接 Neural Machine Translation with Attention
【问题讨论】:
标签: python tensorflow google-colaboratory checkpoint machine-translation