【问题标题】:Extracting/gathering elements from a tensor according to two index vectors along with one dimension in Tensorflow根据两个索引向量以及Tensorflow中的一维从张量中提取/收集元素
【发布时间】:2026-02-17 08:55:01
【问题描述】:

很抱歉,问题标题冗长乏味,难以理解。基本上,我想在 tensorflow 中实现一个函数:

例如 对于维度为 [10, 10, 7, 1] 的张量 A 和索引矩阵 B = array([[1,3,5],[2,4,6]])。我想根据B每行的索引提取A中的元素以及axis = 2(遵循Python约定,A有0,1,2,3四个轴)。

所以示例的结果应该是一个维度为 [10, 10, 3, 2] 的张量 C,其中第三个维度是由于根据索引 [1,3,5] 沿轴 = 2 选择 A 中的元素] 或 [2,4,6],并且第四维等于 B 的第一个维度(即此处 B 的行数),因为我们在这里沿该维度进行了两次选择。

在张量流中实现这一点的任何“张量青睐”线索,而不是分两步完成?我没有看到使用 tf.gather_nd() 或 tf.gather() 的方法。任何想法?非常感谢!

另一个例子:

A = [[[1],     # A is (3, 5, 1)
       [2],
       [3],
       [4],
       [5]]],
     [[[10],
       [20],
       [30],
       [40],
       [50]]],
     [[[100],
       [200],
       [300],
       [400],
       [500]]]

B = [[1,4,3],     # B is (2,3)
     [2,3,5]]

C = [[[1, 2],     # C is (3, 3, 2)
       [4, 3],
       [3, 5]]],
     [[[10, 20],
       [40, 30],
       [30, 50]]],
     [[[100, 200],
       [400, 300],
       [300, 500]]]

【问题讨论】:

  • "根据索引 [1,3,5] 沿轴 = 2..."... 这是什么意思? axis=2 索引的形式应该是(开始、停止、步进),例如A[...,1:4:2,:] 的形状为 (10,10,2,1)。轴 2 上的 [1,3,5] 索引是多少?
  • 对不起,我的意思是这句话,例如,考虑 A[0,0],它应该是一个 7×1 矩阵,并通过引用索引 [1,3 ,5],我想提取 A[0,0,1]、A[0,0,3] 和 A[0,0,5]。我不是要使用包含所有奇数的[1,3,5],索引矩阵B也可以像数组([[1,4,5],[2,3,5]])。所以像 (start, stop, step) 这样的东西都行不通。谢谢!
  • 这很有帮助。还有一个问题:在您的示例中,A 的第四维度在哪里?是否只是为了“容纳”B 的索引?即。如果A 的形状为 (10,10,7,5),C 的形状是什么?
  • 嗨,是的,最后一个维度仅用于“容纳”B,因为 B 有两行,就像 A 的两个“查询”一样。如果 A 是 (10,10,7,5) , 那么 A 将首先被重新整形为 (10,10,7,1,5),最终的 C 将是 (10,10,3,2,5)。

标签: python tensorflow tensor


【解决方案1】:

B 张量的形状看起来不对,你的问题很难解析。但无论如何,TF 在这个问题上并不是很优雅。它需要一个非常特定的 B 形状。尝试类似于

import tensorflow as tf
import numpy as np

A = np.random.randn(10, 10, 7, 1).astype(np.float32)
A[0, 0, 1, 0] = 100001
A[0, 0, 3, 0] = 100002
A[0, 0, 5, 0] = 100003
A[0, 0, 2, 0] = 100004
A[0, 0, 4, 0] = 100005
A[0, 0, 6, 0] = 100006
A = tf.convert_to_tensor(A)

sess = tf.InteractiveSession()

B = np.array([
    [1, 3, 5],
    [2, 4, 6]
])

B = tf.convert_to_tensor(B)
B = tf.reshape(B, [-1])
B = tf.concat([tf.zeros_like(B), tf.zeros_like(B), B, tf.zeros_like(B)], axis=-1)
B = tf.reshape(B, [4, -1])
B = tf.transpose(B, [1, 0])
B = tf.reshape(B, [1, 2, 3, -1])


C = tf.gather_nd(A, B)
C = sess.run(C)
print C.shape
print C

输出是

[[[100001. 100002. 100003.]
  [100004. 100005. 100006.]]]

【讨论】:

  • 您好,感谢您的帮助,如有任何困惑,我们深表歉意。您的代码段根据 B 中的索引成功地“提取”了 A[0,0] 中的六个元素,但是,对于其他 A[i,j] 没有做任何事情。我想对所有 A[i,j] 进行这种提取和收集,并用 (10,10,3,2) 获得 C,其中 3 是由于 B 的列号,而 2 是由于B的行号,最后一个维度从A中的1扩展到C中的2。假设我不想为所有i,j循环A,我认为你的答案并没有真正达到我想要的。谢谢!
  • 尝试发布一个最小示例:输入数据和预期输出。这样可以更轻松地为您提供帮助。
  • 您好,请参考新示例。