【问题标题】:What is the alternative of tf.Variable.ref() in Tensorflow version 0.12?Tensorflow 0.12 版中 tf.Variable.ref() 的替代方法是什么?
【发布时间】:2016-12-01 01:58:42
【问题描述】:

我正在尝试运行 A3C 强化学习算法的开放代码以在 A3C code 中学习 A3C

但是,我遇到了几个错误,除了一个,我可以修复。 在代码中,使用了 tf.Variable 的成员函数 ref() (1,2),但在最近的 tensorflow 版本 0.12rc 中,该函数似乎已被弃用。 所以我不知道替换它的最佳方法是什么(我不明白作者为什么使用ref())。当我只是将它更改为变量本身时(例如v.ref()v),没有错误,但奖励没有改变。它似乎无法学习,我猜是因为变量没有正确更新。

请告诉我修改代码的正确方法是什么。

【问题讨论】:

    标签: python python-2.7 tensorflow


    【解决方案1】:

    新方法tf.Variable.read_value() 是TensorFlow 0.12 及更高版本中tf.Variable.ref() 的替代品。

    此方法的用例解释起来有些棘手,其动机是某些缓存行为会导致在不同设备上多次使用远程变量以使用缓存值。假设您有以下代码:

    with tf.device("/cpu:0")
      v = tf.Variable([[1.]])
    
    with tf.device("/gpu:0")
      # The value of `v` will be captured at this point and cached until `m2`
      # is computed.
      m1 = tf.matmul(v, ...)
    
    with tf.control_dependencies([m1])
      # The assign happens (on the GPU) after `m1`, but before `m2` is computed.
      assign_op = v.assign([[2.]])
    
    with tf.control_dependencies([assign_op]):
      with tf.device("/gpu:0"):
        # The initially read value of `v` (i.e. [[1.]]) will be used here,
        # even though `m2` is computed after the assign.
        m2 = tf.matmul(v, ...)
    
    sess.run(m2)
    

    您可以使用tf.Variable.read_value() 强制 TensorFlow 稍后再次读取该变量,它将受制于任何控制依赖项。所以如果你想在计算m2的时候看到赋值的结果,你可以修改程序的最后一个块如下:

    with tf.control_dependencies([assign_op]):
      with tf.device("/gpu:0"):
        # The `read_value()` call will cause TensorFlow to transfer the
        # new value of `v` from the CPU to the GPU before computing `m2`.
        m2 = tf.matmul(v.read_value(), ...)
    

    (请注意,目前,如果所有操作都在同一设备上,您将不需要使用read_value(),因为 TensorFlow 不会在它被用作同一设备上操作的输入。这可能会导致很多混乱——例如,当你将一个变量排入队列时!——这也是我们致力于增强内存模型的原因之一变量。)

    【讨论】:

    • 非常感谢您快速详细的回答。内容非常丰富,我可以很好地理解。
    猜你喜欢
    • 2019-06-20
    • 1970-01-01
    • 2022-11-13
    • 2015-01-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多