【问题标题】:TensorFlow: Split a tensor into `batch_size` slicesTensorFlow:将张量拆分为“batch_size”切片
【发布时间】:2021-05-31 20:21:46
【问题描述】:

我有一个名为 tensor 的 rank-3 张量,形状为 [batch_size, axis_1, axis_2],并希望将其沿第一个轴拆分为 batch_size 切片,如下所示:

batch_size = tf.shape(tensor)[0]

batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)

很遗憾,这不起作用,因为batch_size 的值在构建图的过程中是未知的。

我该如何解决这个问题?

我收到此错误:

TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.

奇怪的是,尝试在其他 TensorFlow 函数中使用 batch_size 似乎有效:

tensor = tf.reshape(tensor, [batch_size, -1])

尽管batch_size 的值在图形构建过程中是未知的,但工作正常。

问题是tf.split() 的问题吗?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    解决方法是:

    batch_items = tf.map_fn(fn=lambda k: tensor[...,k],
                            elems=tf.range(batch_size),
                            dtype=tf.float32)
    

    不过,我仍然对更好的解决方案感兴趣。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-04-28
      • 2019-10-21
      • 2019-06-14
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-02-11
      相关资源
      最近更新 更多