【问题标题】:TensorFlow: using a tensor to index another tensorTensorFlow:使用一个张量来索引另一个张量
【发布时间】:2016-06-20 22:27:38
【问题描述】:

我有一个关于如何在 TensorFlow 中进行索引的基本问题。

在 numpy 中:

x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
#numpy 
print x * e[x]

我可以得到

[1 0 3 3 0 5 0 7 1 3]

如何在 TensorFlow 中做到这一点?

x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
x_t = tf.constant(x)
e_t = tf.constant(e)
with tf.Session():
    ????

谢谢!

【问题讨论】:

标签: python numpy tensorflow


【解决方案1】:

幸运的是,tf.gather() TensorFlow 支持您所询问的确切情况:

result = x_t * tf.gather(e_t, x_t)

with tf.Session() as sess:
    print sess.run(result)  # ==> 'array([1, 0, 3, 3, 0, 5, 0, 7, 1, 3])'

tf.gather() 操作不如NumPy's advanced indexing 强大:它只支持在第 0 维提取张量的完整切片。已请求支持更通用的索引,并在this GitHub issue 中进行跟踪。

【讨论】:

  • Tensorflow 现在拥有更强大的tf.gather_nd() op.
  • 另外,tf.gather 现在支持任何轴,不仅是第 0 维,参数为 axis
猜你喜欢
  • 1970-01-01
  • 2018-03-03
  • 2022-01-21
  • 2020-09-23
  • 1970-01-01
  • 1970-01-01
  • 2018-03-13
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多