【问题标题】:Saving and Restoring a trained LSTM in Tensor Flow在 TensorFlow 中保存和恢复经过训练的 LSTM
【发布时间】:2016-11-05 19:05:01
【问题描述】:

我使用 BasicLSTMCell 训练了 LSTM 分类器。如何保存我的模型并恢复它以供以后分类使用?

【问题讨论】:

    标签: tensorflow recurrent-neural-network lstm


    【解决方案1】:

    我们发现了同样的问题。我们不确定是否保存了内部变量。我们发现您必须在创建/定义 BasicLSTMCell 后创建保护程序。否则不保存。

    【讨论】:

    • 如何检查内部变量是否被保存?加载保存的 RNN 模型后,我不确定将什么用作 initial_state?
    【解决方案2】:

    保存和恢复模型的最简单方法是使用tf.train.Saverobject。构造函数为图中的所有变量或指定列表添加保存和恢复操作到图中。 saver 对象提供了运行这些操作的方法,指定要写入或读取的检查点文件的路径。

    参考:

    https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html

    检查点文件

    变量保存在二进制文件中,大致包含从变量名称到张量值的映射。

    当您创建 Saver 对象时,您可以选择为检查点文件中的变量选择名称。默认情况下,它使用每个变量的 Variable.name 属性的值。

    要了解检查点中的变量,您可以使用 inspect_checkpoint 库,尤其是 print_tensors_in_checkpoint_file 函数。

    保存变量

    使用 tf.train.Saver() 创建一个 Saver 来管理模型中的所有变量。

    # Create some variables.
    v1 = tf.Variable(..., name="v1")
    v2 = tf.Variable(..., name="v2")
    ...
    # Add an op to initialize the variables.
    init_op = tf.initialize_all_variables()
    
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    # Later, launch the model, initialize the variables, do some work, save the
    # variables to disk.
    with tf.Session() as sess:
      sess.run(init_op)
      # Do some work with the model.
      ..
      # Save the variables to disk.
      save_path = saver.save(sess, "/tmp/model.ckpt")
      print("Model saved in file: %s" % save_path)
    

    恢复变量

    相同的 Saver 对象用于恢复变量。请注意,当您从文件中恢复变量时,您不必事先初始化它们。

    # Create some variables.
    v1 = tf.Variable(..., name="v1")
    v2 = tf.Variable(..., name="v2")
    ...
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    # Later, launch the model, use the saver to restore variables from disk, and
    # do some work with the model.
    with tf.Session() as sess:
      # Restore variables from disk.
      saver.restore(sess, "/tmp/model.ckpt")
      print("Model restored.")
      # Do some work with the model
      ...
    

    【讨论】:

    • 谢谢@JoeMurf,但是如何保存 LSTM 单元的门值?
    • 有人可以回答马塞洛的问题吗?
    【解决方案3】:

    我自己也在想这个。正如其他人指出的那样,在 TensorFlow 中保存模型的常用方法是使用 tf.train.Saver(),但我相信这会保存 tf.Variables 的值。 我不确定BasicLSTMCell 实现中是否有tf.Variables 在您执行此操作时会自动保存,或者是否可能需要采取另一个步骤,但如果所有其他方法都失败,BasicLSTMCell可以很容易地保存并加载到泡菜文件中。

    【讨论】:

    • 你是对的!我试过了,好像 Tensorflow 做了所有的“拯救魔法”
    【解决方案4】:

    是的,LSTM 单元内部存在权重和偏差变量(实际上,所有神经网络单元都必须在某处具有权重变量)。正如其他答案中已经指出的那样,使用 Saver 对象似乎是要走的路……以一种相当方便的方式保存您的变量和您的(元)图。如果您想恢复整个模型,您将需要元图,而不仅仅是一些孤立地坐在那里的 tf.Variables。它确实需要知道它必须保存的所有变量,所以在创建图表后创建保存器。

    在处理任何“是否存在变量?”/“它是否正确地重用权重?”/“我如何才能真正查看我的 LSTM 中的权重,它没有绑定到任何 python var? “/ETC。情况就是这个小sn-p:

    for i in tf.global_variables():
        print(i)
    

    对于变量和

    for i in my_graph.get_operations():
        print (i)
    

    用于操作。如果要查看未绑定到 python var 的张量,

    tf.Graph.get_tensor_by_name('name_of_op:N')
    

    其中 op 的名称是生成张量的操作的名称,N 是您所追求的(可能是多个)输出张量的索引。

    如果您的图表有大量操作...大多数倾向于...

    【讨论】:

      【解决方案5】:

      我已经为 LSTM 保存和恢复制作了示例代码。 我也花了很多时间来解决这个问题。 参考这个网址:https://github.com/MareArts/rnn_save_restore_test 希望对这段代码有所帮助。

      【讨论】:

        【解决方案6】:

        您可以实例化一个tf.train.Saver 对象并在训练期间调用save 传递当前会话和输出检查点文件(*.ckpt) 路径。您可以在您认为合适的任何时候致电save(例如,每隔几个时期,当验证错误下降时):

        # Create some variables.
        v1 = tf.Variable(..., name="v1")
        v2 = tf.Variable(..., name="v2")
        ...
        # Add an op to initialize the variables.
        init_op = tf.initialize_all_variables()
        
        # Add ops to save and restore all the variables.
        saver = tf.train.Saver()
        
        # Later, launch the model, initialize the variables, do some work, save the
        # variables to disk.
        with tf.Session() as sess:
          sess.run(init_op)
          # Do some work with the model.
          ..
          # Save the variables to disk.
          save_path = saver.save(sess, "/tmp/model.ckpt")
          print("Model saved in file: %s" % save_path)
        

        在分类/推理期间,您实例化另一个 tf.train.Saver 并调用 restore 传递当前会话和要恢复的检查点文件。您可以在使用模型进行分类之前调用​​restore,方法是调用session.run

        # Create some variables.
        v1 = tf.Variable(..., name="v1")
        v2 = tf.Variable(..., name="v2")
        ...
        # Add ops to save and restore all the variables.
        saver = tf.train.Saver()
        
        # Later, launch the model, use the saver to restore variables from disk, and
        # do some work with the model.
        with tf.Session() as sess:
          # Restore variables from disk.
          saver.restore(sess, "/tmp/model.ckpt")
          print("Model restored.")
          # Do some work with the model
          ...
        

        参考:https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html#saving-and-restoring

        【讨论】:

        • 如果我的回答有问题,我将不胜感激。
        • 我相信提问者试图找出在训练后保存 BasicLSTMCell 对象是否有特殊考虑。我想知道自己,因为 TensorFlow 文档只演示了如何保存 tf.Variable's。
        • 在推理过程中第一次需要零作为 LSTM 单元的初始状态吗?
        猜你喜欢
        • 1970-01-01
        • 2018-04-25
        • 1970-01-01
        • 2018-02-08
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2020-08-15
        相关资源
        最近更新 更多