【问题标题】:Tensorflow 1.0 : Retreive hidden state from restored RNNTensorflow 1.0:从恢复的 RNN 中检索隐藏状态
【发布时间】:2018-04-27 14:53:28
【问题描述】:

我想恢复一个 RNN 并获得隐藏状态。

我这样做是为了保存 RNN:

loc="path/to/save/rnn"
with tf.variable_scope("lstm") as scope:
    outputs, state = tf.nn.dynamic_rnn(..)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
save_path = saver.save(sess,loc)

现在我想检索state

graph = tf.Graph()
sess = tf.Session(graph=graph)
with graph.as_default():
      saver = tf.train.import_meta_graph(loc + '.meta', clear_devices=True)
      saver.restore(sess, loc)
      state= ...

【问题讨论】:

    标签: tensorflow recurrent-neural-network restore


    【解决方案1】:

    您可以使用tf.add_to_collectionstate 张量添加到图collection 中,该图基本上是用于跟踪张量的键值存储,稍后使用tf.get_collection 检索它。例如:

    loc="path/to/save/rnn"
    with tf.variable_scope("lstm") as scope:
        outputs, state = tf.nn.dynamic_rnn(..)
        tf.add_to_collection('state', state)
    
    
    graph = tf.Graph()
    with graph.as_default():
          saver = tf.train.import_meta_graph(loc + '.meta', clear_devices=True)
          state = tf.get_collection('state')[0]  # Note: tf.get_collection returns a list.
    

    【讨论】:

    • 它有效,谢谢!但是必须知道“状态”不能保存为元组,必须在添加到集合之前将其分成两个对象。然后必须恢复这两个对象以重建状态元组。也许你会编辑你的答案。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2017-02-04
    • 2021-05-29
    • 2019-12-05
    • 1970-01-01
    • 1970-01-01
    • 2019-07-17
    • 2019-05-05
    相关资源
    最近更新 更多