【问题标题】:Confused by the behavior of `tf.cond`对 `tf.cond` 的行为感到困惑
【发布时间】:2016-05-06 03:54:53
【问题描述】:

我的图表中需要一个条件控制流。如果predTrue,则图应该调用一个更新变量然后返回它的操作,否则它返回变量不变。一个简化的版本是:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

但是,我发现pred=Truepred=False 导致相同的结果y=[2],这意味着当update_x_2 未被tf.cond 选择时,也会调用assign 操作。这要怎么解释?以及如何解决这个问题?

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    TL;DR:如果您希望tf.cond() 在其中一个分支中执行副作用(如赋值),您必须在内部创建执行副作用的操作 传递给tf.cond() 的函数。

    tf.cond() 的行为有点不直观。因为 TensorFlow 图中的执行流经图,所以您在 either 分支中引用的所有操作都必须在评估条件之前执行。这意味着 true 和 false 分支都接收到对 tf.assign() 操作的控制依赖,因此 y 总是设置为 2,即使 pred 是 False

    解决方案是在定义真正分支的函数内创建tf.assign() op。例如,您可以按如下方式构建代码:

    pred = tf.placeholder(tf.bool, shape=[])
    x = tf.Variable([1])
    def update_x_2():
      with tf.control_dependencies([tf.assign(x, [2])]):
        return tf.identity(x)
    y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
    with tf.Session() as session:
      session.run(tf.initialize_all_variables())
      print(y.eval(feed_dict={pred: False}))  # ==> [1]
      print(y.eval(feed_dict={pred: True}))   # ==> [2]
    

    【讨论】:

    • 是的,这也是让我感到困惑的一个。我的理解是,在执行tf.cond 之前,运行时会确保执行所有依赖项。 TrueFalse 分支中操作的依赖关系也是 cond 的依赖关系,所以即使分支中的操作可能永远不会执行,但它的所有依赖关系都会执行,听起来对吗?
    • 是的 - 图修剪考虑了所有潜在的依赖关系(任一分支的)执行,并且只有在它们被定义在其中一个分支内时才会禁止它们的执行,因为 CondContext adds a control dependency on the pivot 和如果它在未采用的分支中,则依赖项将是一个死张量(阻止操作执行)。
    • 这样做的原因是什么?为什么不修剪非活动分支后面的子图?
    • @LenarHoyt:修剪发生在 计算 pred 的值之前。这使 TensorFlow 能够基于一个简单的键(本质上是 Session.run() 的参数)缓存单个修剪后的图,并使条件执行的实现简单而轻量级。相同的机制用于实现tf.while_loop(),在此级别执行控制流的优势更加明显。
    【解决方案2】:
    pred = tf.constant(False)
    x = tf.Variable([1])
    
    def update_x_2():
        assign_x_2 = tf.assign(x, [2])
        with tf.control_dependencies([assign_x_2]):
            return tf.identity(x)
    y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
    with tf.Session() as session:
      session.run(tf.initialize_all_variables())
      print(y.eval())
    

    这将得到[1]的结果。

    这个答案和上面的答案完全一样。但我想分享的是你可以把你想使用的每一个操作放在它的分支函数中。因为,给定您的示例代码,张量 x 可以直接由 update_x_2 函数使用。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2013-02-17
      • 1970-01-01
      • 2019-04-14
      • 2014-10-21
      • 1970-01-01
      相关资源
      最近更新 更多