【问题标题】:Tensorflow: cross index slicing of a tensorTensorflow:张量的交叉索引切片
【发布时间】:2017-11-30 10:09:41
【问题描述】:

我有两个如下形状的张量:

tensor1 => shape(10, 99, 106)
tensor2 => shape(10, 99)

tensor2 包含范围从0 - 105 的值,我希望用它来分割tensor1 的最后一个维度并获得形状的tensor3

tensor3 => shape(10, 99, 99)

我尝试过使用:

tensor4 = tf.gather(tensor1, tensor2)
# this causes tensor4 to be of shape (10, 99, 99, 106)

另外,使用

tensor4 = tf.gather_nd(tensor1, tensor2)
# gives the error: last dimension of tensor2 (which is 99) must be 
# less than the rank of the tensor1 (which is 3).

我正在寻找类似于 numpy 的 cross_indexing 的东西。

【问题讨论】:

  • 你确定tensor3的形状吗?不应该是简单的 (10,99) 吗?
  • 是的。我希望使用来自tensor2 的 99 维向量,仅使用来自tensor1 的第三维 (106) 的 99 值。

标签: python tensorflow tensor


【解决方案1】:

你可以使用tf.map_fn:

 tensor3 = tf.map_fn(lambda u: tf.gather(u[0],u[1],axis=1),[tensor1,tensor2],dtype=tensor1.dtype)

你可以把这条线想象成一个循环,它在tensor1tensor2 的第一个维度上运行,对于第一个维度中的每个索引i,它在tensor1[i,:,:] 和@ 上应用tf.gather 987654329@.

【讨论】:

  • 谢谢。这行得通!我仍然需要了解 map_fn 是如何做到的。干杯!
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2017-06-02
  • 2019-10-21
  • 2019-05-30
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2017-05-21
相关资源
最近更新 更多