【问题标题】:Saving tensorflow encoder, decoder and attention保存 tensorflow 编码器、解码器和注意力
【发布时间】:2019-07-07 20:54:17
【问题描述】:

开始使用编码器和解码器训练一个简单的 NMT(神经机器翻译器),并在 Colab 上进行训练,

encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)

然后使用检查点保存模型,

# On loacl machine dir changed to 'training_checkpoints/' to fit the loaction
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

并在训练期间使用

保存
checkpoint.save(file_prefix = checkpoint_prefix)

在 Colab 上训练恢复检查点后可以正常工作,即使将整个检查点文件夹保存在 Google 驱动器上并再次恢复它们,但是当尝试在我的本地计算机上恢复它们​​时,它会返回不同的垃圾结果, 训练前开始检查点使用

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

Colab 笔记本输出:

Input: <start> يلعبون الكرة <end>
Predicted translation: he played soccer . <end> 

本地机器输出:

Input: <start> يلعبون الكرة <end>
Predicted translation: take either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either

Colab TensorFlow 版本:1.13.0-rc1

本地机器tensorflow版本:1.12.0

知道这个问题是由于tensorflow的不同版本造成的,如何在不遇到这个问题的情况下保存模型?

NMT 笔记本的附加链接 Neural Machine Translation with Attention

【问题讨论】:

    标签: python tensorflow google-colaboratory checkpoint machine-translation


    【解决方案1】:

    TF 仅提供前向兼容性保证:https://www.tensorflow.org/guide/version_compat#compatibility_of_graphs_and_checkpoints 1.13保存了1.12无法恢复的文件也就不足为奇了。 升级本地机器的 tensorflow?

    【讨论】:

    • 好的,我知道了,但是如何在没有检查点和使用 tf.train.Saver 的情况下保存模型及其编码器和解码器权重?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2019-05-20
    • 2019-06-27
    • 2022-01-04
    • 2022-07-05
    • 2020-01-28
    • 2020-09-26
    • 1970-01-01
    相关资源
    最近更新 更多