【问题标题】:Slicing a tensor with a tensor of indices and tf.gather使用索引张量和 tf.gather 对张量进行切片
【发布时间】:2020-09-26 18:48:20
【问题描述】:

我正在尝试使用索引张量对张量进行切片。为此,我尝试使用tf.gather。 但是,我很难理解 documentation 并且没有像我期望的那样让它工作:

我有两个张量。一个形状为[1,240,4]activations 张量和一个形状为[1,1,120]ids 张量。我想用ids 张量的第三维中提供的索引对activations 张量的第二维进行切片:

downsampled_activations = tf.gather(activations, ids, axis=1)

我给了它axis=1 选项,因为这是我要切片的activations 张量中的轴。

但是,这不会呈现预期的结果,只会给我以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,0,1] = 1 is not in [0, 1)

我尝试了axisbatch_dims 选项的各种组合,但到目前为止都无济于事,并且文档并没有真正帮助我。任何愿意更详细地解释参数或在上面的示例中的人都会非常有帮助!

编辑: ID 在运行前预先计算并通过输入管道输入:

features = tf.io.parse_single_example(
            serialized_example,
            features={ 'featureIDs': tf.io.FixedLenFeature([], tf.string)}

然后它们被重新塑造成以前的格式:

feature_ids_raw = tf.decode_raw(features['featureIDs'], tf.int32)
feature_ids_shape = tf.stack([batch_size, (num_neighbours * 4)])
feature_ids = tf.reshape(feature_ids_raw, feature_ids_shape)
feature_ids = tf.expand_dims(feature_ids, 0)

之后它们具有前面提到的形状(batch_size = 1num_neighbours = 30 -> [1,1,120]),我想用它们来切片 activations 张量。

Edit2:我希望输出为[1,120,4]。 (所以我想根据存储在我的ids张量中的ID,沿着activations张量的第二维收集条目。)

【问题讨论】:

  • 你是如何生成张量 ID 的?
  • @qmeeus:我已经编辑了原帖来回答你的问题。
  • @budekatude 为了不让自己感到困惑,在这种情况下您期望的输出维度是多少?
  • @Daveldito:我在上面添加了另一个编辑来回答您的问题。
  • 你确定使用axis=1的时候报错吗?该错误表示您提供的索引(= 1)超出了轴的大小(轴= 1大小为240 !!!)。我认为如果轴= 0,则会引发错误。

标签: python tensorflow


【解决方案1】:

你可以使用:

downsampled_activations =tf.gather(activations , tf.squeeze(ids) ,axis = 1)
downsampled_activations.shape #  [1,120,4]

在大多数情况下,tf.gather 方法需要 1d 索引,这在您的情况下是正确的,而不是 3d (1,1,120) 的索引,1d 就足够了 (120,)。 tf.gather 方法将查看轴(= 1)并返回索引张量提供的每个索引处的元素。

【讨论】:

  • 您可以将axis=[0, 1] 传递给tf.squeeze 以确保压缩前两个维度,并且只有那些维度被压缩,如果ids 没有预期的形状@987654325,这对于捕获错误很有用@.
【解决方案2】:

tf.gather 根据索引从paramsaxis 收集切片。

当然,文档不是最具表现力的,重点应该放在 slices 上(因为您从 axis 索引切片而不是元素,我想这是您错误的拿去)。

让我们举个小得多的例子:

activations_small = tf.convert_to_tensor([[[1, 2, 3, 4], [11, 22, 33, 44]]])
print(activations_small.shape) # [1, 2, 4]

让我们想象一下这个张量:

    XX 4  XX 44 XX XX
  XX  3 XX  33 X  XX
XXX 2 XX   22XX  XX
X-----X-----+X  XX
|  1  |  11 | XX
+-----+-----+X

tf.gather(activations1, [0, 0], axis=1) 将返回

<tf.Tensor: shape=(1, 2, 4), dtype=int32, numpy=
array([[[1, 2, 3, 4],
        [1, 2, 3, 4]]], dtype=int32)>

tf.gather 所做的是从轴 1 看,然后选择索引 0(ofc,两次,即[0, 0])。如果你运行tf.gather(activations1, [0, 0, 0, 0, 0], axis=1).shape,你会得到TensorShape([1, 5, 4])

您的错误 现在让我们尝试触发您遇到的错误。

tf.gather(activations1, [0, 2], axis=1)

InvalidArgumentError: indices[1] = 2 is not in [0, 2) [Op:GatherV2]

这里发生的情况是,当tf.gather 从轴 1 的角度看时,没有索引 = 2 的项目(如果你愿意的话,列)。

我猜这就是documentation 暗示的意思

param:&lt;indices&gt; 索引张量。必须是以下类型之一:int32、int64。必须在 [0, params.shape[axis]) 范围内。

您的(潜在)解决方案

indices的维度,以及你的问题的预期结果,我不确定以上对你来说是否很明显。

tf.gather(activations, indices=[0, 1, 2, 3], axis=2) 或任何索引在[0, activations.shape[2]) 范围内的索引,即[0, 4) 都可以。其他任何事情都会给您带来您遇到的错误。

下面有一个逐字逐句的答案,以防这是您的预期结果。

【讨论】:

  • 非常感谢您提供全面的答案和示例!
猜你喜欢
  • 2020-02-16
  • 1970-01-01
  • 1970-01-01
  • 2017-06-02
  • 2019-04-13
  • 2019-01-12
  • 2018-03-27
  • 2018-11-07
  • 1970-01-01
相关资源
最近更新 更多