【发布时间】:2018-12-07 19:37:02
【问题描述】:
我正在 numpy.xml 中寻找以下代码的 tensorflow 等效项。给出了a 和idx_2。目标是构造b。
# A float Tensor obtained somehow
a = np.arange(3*5).reshape(3,5)
# An int Tensor obtained somehow
idx_2 = np.array([[1,2,3,4],[0,2,3,4],[0,2,3,4]])
# An int Tensor, constructed for indexing
idx_1 = np.arange(a.shape[0]).reshape(-1,1)
# The goal
b = a[idx_1, idx_2]
print(b)
>>> [[ 1 2 3 4]
[ 5 7 8 9]
[10 12 13 14]]
我尝试直接索引张量并使用tf.gather_nd,但我不断收到错误,所以我决定在这里询问如何做。我到处寻找人们使用tf.gather_nd(因此标题)来解决类似问题的答案,但是为了应用这个函数,我必须以某种方式重塑索引,以便它们可以用来分割第一个维度。我该怎么做呢?请帮忙。
【问题讨论】:
标签: python tensorflow