【发布时间】:2021-03-17 12:44:49
【问题描述】:
我有一个 n-by-m 矩阵 X 和一个 n-by-r 索引矩阵 I。我想知道哪些相关的 TensorFlow 运算符可以让我获得一个 n-by-r 矩阵 R 使得 R[ i,j] = X[i,I[i,j]]。例如,让我们说
X = tf.constant([[1,2,3],
[4,5,6],
[7,8,9]])
I = tf.constant([[1,2],
[1,0],
[0,2]])
期望的结果是张量
R = [[2, 3],
[5, 4],
[7, 9]]
我尝试使用矩阵 I 的每一列作为索引并执行 tf.diag_part(tf.gather(X', index)),如果我有相同的行数,这似乎给了我一列 R作为 X。例如,
idx = tf.transpose(I)[0] #[1,1,0]
res = tf.diag_part(tf.gather(tf.transpose(X), idx))
# res will be [2,5,7], i,e, first colum of R
另一个尝试:
res = tf.transpose(tf.gather(tf.transpose(X), I),[0,2,1])
print(res.eval())
array([[[2, 3],
[5, 6],
[8, 9]],
[[2, 1],
[5, 4],
[8, 7]],
[[3, 1],
[6, 4],
[7, 9]]], dtype=int32)
从这里我只需要能够选择“对角线条目”res[0,0]、res[1,1] 和 res[2,2] 来获得 R。不过我被困在这里......
【问题讨论】:
-
你能用
tf.gather提供你的尝试吗? -
@jakub 我尝试使用矩阵 I 的每一行作为索引并执行 tf.diag_part(tf.gather(X', index)),这似乎给了我一行 R。但我不知道如何将所有内容聚合在一起......
-
您能否将您的尝试添加到您的问题中?如果我们知道您的尝试,我们会更容易提供帮助。
-
@jakub 我做了一个编辑
标签: python tensorflow