【发布时间】:2020-08-15 22:36:01
【问题描述】:
我正在尝试以特定的索引顺序获取 3D 张量的行。以下是输入:
import tensorflow as tf
matrix = tf.constant([
[[0, 1], [2, 3], [4, 5], [6, 7]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[16, 17], [18, 19], [20, 21], [22, 23]],
[[24, 25], [26, 27], [28, 29], [30, 31]],
[[32, 33], [34, 35], [36, 37], [38, 39]]
])
indx = tf.constant([[3,2,1,0], [0,1,2,3], [1,0,3,2], [0,3,1,2], [1,2,3,0]])
# required output tensor:
[[[6, 7], [4, 5], [2, 3], [0, 1]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[18, 19], [16, 17], [22, 23], [20, 21]],
[[24, 25], [30, 31], [26, 27], [28, 29]],
[[34, 35], [36, 37], [38, 39], [32, 33]]]
我正在为tf.gather_nd() 苦苦挣扎。有什么建议吗?我可以看到它在这里发生,但我不确定如何在不使用for 循环或tf.map_fn 的情况下应用于整个矩阵
print(tf.gather_nd(matrix[0], tf.expand_dims(indx, -1)[0]).numpy().tolist())
print(tf.gather_nd(matrix[1], tf.expand_dims(indx, -1)[1]).numpy().tolist())
print(tf.gather_nd(matrix[2], tf.expand_dims(indx, -1)[2]).numpy().tolist())
print(tf.gather_nd(matrix[3], tf.expand_dims(indx, -1)[3]).numpy().tolist())
print(tf.gather_nd(matrix[4], tf.expand_dims(indx, -1)[4]).numpy().tolist())
"""
[[6, 7], [4, 5], [2, 3], [0, 1]]
[[8, 9], [10, 11], [12, 13], [14, 15]]
[[18, 19], [16, 17], [22, 23], [20, 21]]
[[24, 25], [30, 31], [26, 27], [28, 29]]
[[34, 35], [36, 37], [38, 39], [32, 33]]
"""
编辑:我问了一个关于 numpy 的类似问题。一个聪明的索引答案确实解决了 numpy 版本,但很难将它应用于张量。随意看看这里接受的答案:How can I get elements from 3D matrix using specified indices in numpy?
【问题讨论】:
-
@Ehsan 是的,我问过 numpy,但它不相关,因为张量不允许像 numpy 数组那样索引
-
一旦答案解决了所提出的问题,通过接受答案来结束问题总是好的。您也可以参考它并提及您在 tensorflow 中特别需要它而不是 numpy(并且可能删除 numpy 标签)
-
是的,我正在关闭 numpy 版本,因为它在那里得到了 w.r.t numpy 数组的回答,并在此处添加了关于此的注释。谢谢
标签: python tensorflow tensorflow2.0