【问题标题】:How to use tf.cond for batch processing如何使用 tf.cond 进行批处理
【发布时间】:2017-11-20 06:07:28
【问题描述】:

我想使用tf.cond(pred, fn1, fn2, name=None) 进行条件分支。假设我有两个张量:x, y。每个张量是一批 0/1,我想使用这个张量压缩 x < y 作为源 tf.cond pred 参数:

pred:一个标量,决定是否返回 fn1 或 fn2 的结果。

但是,如果我正在处理批处理,那么看起来我需要遍历图中的源张量并为批处理中的每个项目制作切片并为每个项目应用 tf.cond。看起来对我来说很可疑。为什么 tf.cond 不接受批处理而只接受标量?您能否建议将它与批处理一起使用的正确方法是什么?

【问题讨论】:

    标签: tensorflow conditional


    【解决方案1】:

    tf.where 听起来像你想要的:张量之间的矢量化选择。

    tf.cond 是一个控制流修饰符:它决定执行哪些操作,因此很难想到有用的批处理语义。

    我们还可以将这些操作混合在一起:根据条件进行切片并将这些切片传递到两个分支的操作​​。

    import tensorflow as tf
    from tensorflow.python.util import nest
    
    def slicing_where(condition, full_input, true_branch, false_branch):
      """Split `full_input` between `true_branch` and `false_branch` on `condition`.
    
      Args:
        condition: A boolean Tensor with shape [B_1, ..., B_N].
        full_input: A Tensor or nested tuple of Tensors of any dtype, each with
          shape [B_1, ..., B_N, ...], to be split between `true_branch` and
          `false_branch` based on `condition`.
        true_branch: A function taking a single argument, that argument having the
          same structure and number of batch dimensions as `full_input`. Receives
          slices of `full_input` corresponding to the True entries of
          `condition`. Returns a Tensor or nested tuple of Tensors, each with batch
          dimensions matching its inputs.
        false_branch: Like `true_branch`, but receives inputs corresponding to the
          false elements of `condition`. Returns a Tensor or nested tuple of Tensors
          (with the same structure as the return value of `true_branch`), but with
          batch dimensions matching its inputs.
      Returns:
        Interleaved outputs from `true_branch` and `false_branch`, each Tensor
        having shape [B_1, ..., B_N, ...].
      """
      full_input_flat = nest.flatten(full_input)
      true_indices = tf.where(condition)
      false_indices = tf.where(tf.logical_not(condition))
      true_branch_inputs = nest.pack_sequence_as(
          structure=full_input,
          flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
                         for input_tensor in full_input_flat])
      false_branch_inputs = nest.pack_sequence_as(
          structure=full_input,
          flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
                         for input_tensor in full_input_flat])
      true_outputs = true_branch(true_branch_inputs)
      false_outputs = false_branch(false_branch_inputs)
      nest.assert_same_structure(true_outputs, false_outputs)
      def scatter_outputs(true_output, false_output):
        batch_shape = tf.shape(condition)
        scattered_shape = tf.concat(
            [batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
            0)
        true_scatter = tf.scatter_nd(
            indices=tf.cast(true_indices, tf.int32),
            updates=true_output,
            shape=scattered_shape)
        false_scatter = tf.scatter_nd(
            indices=tf.cast(false_indices, tf.int32),
            updates=false_output,
            shape=scattered_shape)
        return true_scatter + false_scatter
      result = nest.pack_sequence_as(
          structure=true_outputs,
          flat_sequence=[
              scatter_outputs(true_single_output, false_single_output)
              for true_single_output, false_single_output
              in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
      return result
    

    一些例子:

    vector_test = slicing_where(
        condition=tf.equal(tf.range(10) % 2, 0),
        full_input=tf.range(10, dtype=tf.float32),
        true_branch=lambda x: 0.2 + x,
        false_branch=lambda x: 0.1 + x)
    
    cross_range = (tf.range(10, dtype=tf.float32)[:, None]
                   * tf.range(10, dtype=tf.float32)[None, :])
    matrix_test = slicing_where(
        condition=tf.equal(tf.range(10) % 3, 0),
        full_input=cross_range,
        true_branch=lambda x: -x,
        false_branch=lambda x: x + 0.1)
    
    with tf.Session():
      print(vector_test.eval())
      print(matrix_test.eval())
    

    打印:

    [ 0.2         1.10000002  2.20000005  3.0999999   4.19999981  5.0999999
      6.19999981  7.0999999   8.19999981  9.10000038]
    [[  0.           0.           0.           0.           0.           0.
        0.           0.           0.           0.        ]
     [  0.1          1.10000002   2.0999999    3.0999999    4.0999999
        5.0999999    6.0999999    7.0999999    8.10000038   9.10000038]
     [  0.1          2.0999999    4.0999999    6.0999999    8.10000038
       10.10000038  12.10000038  14.10000038  16.10000038  18.10000038]
     [  0.          -3.          -6.          -9.         -12.         -15.
      -18.         -21.         -24.         -27.        ]
     [  0.1          4.0999999    8.10000038  12.10000038  16.10000038
       20.10000038  24.10000038  28.10000038  32.09999847  36.09999847]
     [  0.1          5.0999999   10.10000038  15.10000038  20.10000038
       25.10000038  30.10000038  35.09999847  40.09999847  45.09999847]
     [  0.          -6.         -12.         -18.         -24.         -30.
      -36.         -42.         -48.         -54.        ]
     [  0.1          7.0999999   14.10000038  21.10000038  28.10000038
       35.09999847  42.09999847  49.09999847  56.09999847  63.09999847]
     [  0.1          8.10000038  16.10000038  24.10000038  32.09999847
       40.09999847  48.09999847  56.09999847  64.09999847  72.09999847]
     [  0.          -9.         -18.         -27.         -36.         -45.
      -54.         -63.         -72.         -81.        ]]
    

    【讨论】:

    • 目标是“控制流量”。这就是我需要 tf.cond 的原因。但是您完全正确,对于当前的架构,“很难想到有用的批处理语义”。我只能使用 SGD。谢谢!现在我意识到了这一点。
    • 您能详细介绍一下您要解决的问题吗?很乐意帮助头脑风暴解决方案。
    • 让我们想象一下 tf.cond 之前的控制流上的三个分支(两个分支 - 检测器(d1,d2)和一个 - 数据源(ds)。在 tf.cond 之后还有两个分支(p1, p2). 让我们简化细节,如果第一个检测器的输出大于或等于第二个检测器的输出,那么来自数据源(ds)分支的数据应该由其他的 p1 分支处理case - 通过 p2 分支。我们不应该同时处理两个分支。
    • 处理是否有副作用(例如变量赋值)?如果没有, tf.where 应该按原样工作,并授予一些冗余计算。根据问题的具体情况,您可以通过根据条件对输入进行切片来稍微优化(产生两个张量,一个带有一批输入到第一个管道,一个带有一批输入到第二个)。
    • 是的,问题是p1和p2会将处理结果分配给两个分支末尾的同一个变量。该 var 将用作处理结果和损失计算。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2016-12-27
    • 1970-01-01
    • 2012-12-19
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多