【发布时间】:2020-09-26 18:48:20
【问题描述】:
我正在尝试使用索引张量对张量进行切片。为此,我尝试使用tf.gather。
但是,我很难理解 documentation 并且没有像我期望的那样让它工作:
我有两个张量。一个形状为[1,240,4] 的activations 张量和一个形状为[1,1,120] 的ids 张量。我想用ids 张量的第三维中提供的索引对activations 张量的第二维进行切片:
downsampled_activations = tf.gather(activations, ids, axis=1)
我给了它axis=1 选项,因为这是我要切片的activations 张量中的轴。
但是,这不会呈现预期的结果,只会给我以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,0,1] = 1 is not in [0, 1)
我尝试了axis 和batch_dims 选项的各种组合,但到目前为止都无济于事,并且文档并没有真正帮助我。任何愿意更详细地解释参数或在上面的示例中的人都会非常有帮助!
编辑: ID 在运行前预先计算并通过输入管道输入:
features = tf.io.parse_single_example(
serialized_example,
features={ 'featureIDs': tf.io.FixedLenFeature([], tf.string)}
然后它们被重新塑造成以前的格式:
feature_ids_raw = tf.decode_raw(features['featureIDs'], tf.int32)
feature_ids_shape = tf.stack([batch_size, (num_neighbours * 4)])
feature_ids = tf.reshape(feature_ids_raw, feature_ids_shape)
feature_ids = tf.expand_dims(feature_ids, 0)
之后它们具有前面提到的形状(batch_size = 1 和 num_neighbours = 30 -> [1,1,120]),我想用它们来切片 activations 张量。
Edit2:我希望输出为[1,120,4]。 (所以我想根据存储在我的ids张量中的ID,沿着activations张量的第二维收集条目。)
【问题讨论】:
-
你是如何生成张量 ID 的?
-
@qmeeus:我已经编辑了原帖来回答你的问题。
-
@budekatude 为了不让自己感到困惑,在这种情况下您期望的输出维度是多少?
-
@Daveldito:我在上面添加了另一个编辑来回答您的问题。
-
你确定使用axis=1的时候报错吗?该错误表示您提供的索引(= 1)超出了轴的大小(轴= 1大小为240 !!!)。我认为如果轴= 0,则会引发错误。
标签: python tensorflow