【问题标题】:Slicing a matrix by an index matrix in TensorFlow在 TensorFlow 中通过索引矩阵对矩阵进行切片
【发布时间】:2021-03-17 12:44:49
【问题描述】:

我有一个 n-by-m 矩阵 X 和一个 n-by-r 索引矩阵 I。我想知道哪些相关的 TensorFlow 运算符可以让我获得一个 n-by-r 矩阵 R 使得 R[ i,j] = X[i,I[i,j]]。例如,让我们说

X = tf.constant([[1,2,3],
                [4,5,6],
                [7,8,9]])

I = tf.constant([[1,2],
                 [1,0],
                 [0,2]])

期望的结果是张量

R = [[2, 3], 
     [5, 4],
     [7, 9]]

我尝试使用矩阵 I 的每一列作为索引并执行 tf.diag_part(tf.gather(X', index)),如果我有相同的行数,这似乎给了我一列 R作为 X。例如,

idx = tf.transpose(I)[0] #[1,1,0]
res = tf.diag_part(tf.gather(tf.transpose(X), idx))

# res will be [2,5,7], i,e, first colum of R

另一个尝试:

res = tf.transpose(tf.gather(tf.transpose(X), I),[0,2,1])
print(res.eval()) 
        array([[[2, 3],
                [5, 6],
                [8, 9]],

               [[2, 1],
                [5, 4],
                [8, 7]],

               [[3, 1],
                [6, 4],
                [7, 9]]], dtype=int32)

从这里我只需要能够选择“对角线条目”res[0,0]、res[1,1] 和 res[2,2] 来获得 R。不过我被困在这里......

【问题讨论】:

  • 你能用tf.gather提供你的尝试吗?
  • @jakub 我尝试使用矩阵 I 的每一行作为索引并执行 tf.diag_part(tf.gather(X', index)),这似乎给了我一行 R。但我不知道如何将所有内容聚合在一起......
  • 您能否将您的尝试添加到您的问题中?如果我们知道您的尝试,我们会更容易提供帮助。
  • @jakub 我做了一个编辑

标签: python tensorflow


【解决方案1】:

使用tf.gatherbatch_dims 参数:

res = tf.gather(X, I, batch_dims=1)

【讨论】:

  • 太棒了!我能够通过 (1) res = tf.gather(tf.transpose(X), I) (2) res = tf.transpose(res, [1,2,0]) (3) res = tf.transpose(tf.linalg.diag_part(res))
猜你喜欢
  • 2015-06-14
  • 2015-07-10
  • 2021-06-29
  • 1970-01-01
  • 1970-01-01
  • 2013-03-18
  • 1970-01-01
  • 2022-07-10
  • 2021-06-17
相关资源
最近更新 更多