【问题标题】:Flatten batch in tensorflow在张量流中展平批次
【发布时间】:2016-04-16 19:13:09
【问题描述】:

我有一个形状为[None, 9, 2] 的张量流的输入(其中None 是批处理)。

要对其执行进一步的操作(例如 matmul),我需要将其转换为 [None, 18] 形状。怎么做?

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    您可以使用 tf.reshape() 轻松完成此操作,而无需知道批量大小。

    x = tf.placeholder(tf.float32, shape=[None, 9,2])
    shape = x.get_shape().as_list()        # a list: [None, 9, 2]
    dim = numpy.prod(shape[1:])            # dim = prod(9,2) = 18
    x2 = tf.reshape(x, [-1, dim])           # -1 means "all"
    

    最后一行中的-1 表示整个列,无论运行时的批处理大小是多少。你可以在tf.reshape()看到它。


    更新:shape = [None, 3, None]

    感谢@kbrose。对于多于一维未定义的情况,我们可以使用tf.shape()tf.reduce_prod() 交替使用。

    x = tf.placeholder(tf.float32, shape=[None, 3, None])
    dim = tf.reduce_prod(tf.shape(x)[1:])
    x2 = tf.reshape(x, [-1, dim])
    

    tf.shape() 返回一个可以在运行时评估的形状张量。 tf.get_shape()和tf.shape()的区别可以看in the doc

    我还尝试了 tf.contrib.layers.flatten() 在另一个 .第一种情况最简单,但第二种情况就不行了。

    【讨论】:

    • 如果您知道所有其他维度的大小,这很有效,但如果其他维度的大小未知,则不会。例如。 x = tf.placeholder(tf.float32, shape=[None, 9, None])
    • 感谢@kbrose。我已经更新了这个案例的答案。
    • @weitang114 太棒了!
    • 我被 reshape 卡住了,tf.reduce_prod 为我解决了问题。非常感谢!
    • 如果稍后将x2 传递给dynamic_rnn,这似乎不起作用。产生ValueError: Input size (depth of inputs) must be accessible via shape inference, but saw value None.
    【解决方案2】:
    flat_inputs = tf.layers.flatten(inputs)
    

    【讨论】:

      【解决方案3】:

      您可以在运行时通过tf.batch使用动态整形来获取批量维度的值,将整组新维度计算到tf.reshape中。这是一个在不知道列表长度的情况下将平面列表重塑为方阵的示例。

      tf.reset_default_graph()
      sess = tf.InteractiveSession("")
      a = tf.placeholder(dtype=tf.int32)
      # get [9]
      ashape = tf.shape(a)
      # slice the list from 0th to 1st position
      ashape0 = tf.slice(ashape, [0], [1])
      # reshape list to scalar, ie from [9] to 9
      ashape0_flat = tf.reshape(ashape0, ())
      # tf.sqrt doesn't support int, so cast to float
      ashape0_flat_float = tf.to_float(ashape0_flat)
      newshape0 = tf.sqrt(ashape0_flat_float)
      # convert [3, 3] Python list into [3, 3] Tensor
      newshape = tf.pack([newshape0, newshape0])
      # tf.reshape doesn't accept float, so convert back to int
      newshape_int = tf.to_int32(newshape)
      a_reshaped = tf.reshape(a, newshape_int)
      sess.run(a_reshaped, feed_dict={a: np.ones((9))})
      

      你应该看到

      array([[1, 1, 1],
             [1, 1, 1],
             [1, 1, 1]], dtype=int32)
      

      【讨论】:

      • 我在这个解决方案或 Tensorflow 中没有看到任何方法 tf.batch...
      猜你喜欢
      • 2016-06-13
      • 2021-02-10
      • 1970-01-01
      • 1970-01-01
      • 2018-04-04
      • 2021-08-08
      • 2021-12-25
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多