【问题标题】:How to use tf.gather_nd to slice a tensor in tensorflow?如何使用 tf.gather_nd 在张量流中对张量进行切片?
【发布时间】:2018-12-07 19:37:02
【问题描述】:

我正在 numpy.xml 中寻找以下代码的 tensorflow 等效项。给出了aidx_2。目标是构造b

# A float Tensor obtained somehow
a = np.arange(3*5).reshape(3,5)                    

# An int Tensor obtained somehow
idx_2 = np.array([[1,2,3,4],[0,2,3,4],[0,2,3,4]])  

# An int Tensor, constructed for indexing
idx_1 = np.arange(a.shape[0]).reshape(-1,1)        

# The goal
b = a[idx_1, idx_2]

print(b)
>>> [[ 1  2  3  4]
     [ 5  7  8  9]
     [10 12 13 14]]

我尝试直接索引张量并使用tf.gather_nd,但我不断收到错误,所以我决定在这里询问如何做。我到处寻找人们使用tf.gather_nd(因此标题)来解决类似问题的答案,但是为了应用这个函数,我必须以某种方式重塑索引,以便它们可以用来分割第一个维度。我该怎么做呢?请帮忙。

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    当涉及到 NumPy 中非常简单和 Pythonic 的东西时,Tensorflow 可能会非常难看。以下是我如何使用 tf.gather_nd 在 TensorFlow 中重新创建您的问题。不过可能有更好的方法。

    import tensorflow as tf
    import numpy as np
    
    with tf.Session() as sess:
        # Define 'a'
        a = tf.reshape(tf.range(15),(3,5))
        # Define both index tensors 
        idx_1 = tf.reshape(tf.range(a.get_shape().as_list()[0]),(-1,1)).eval()
        idx_2 = tf.constant([[1,2,3,4],[0,2,3,4],[0,2,3,4]]).eval()
        # get indices for use with gather_nd
        gather_idx = tf.constant([(x[0],y) for (i,x) in enumerate(idx_1) for y in idx_2[i]])
        # extract elements and reshape to desired dimensions
        b = tf.gather_nd(a, gather_idx)
        b = tf.reshape(b,(idx_1.shape[0], idx_2.shape[1]))
        print(sess.run(b))
    
    [[ 1  2  3  4]
    [ 5  7  8  9]
    [10 12 13 14]]
    

    【讨论】:

      猜你喜欢
      • 2018-12-28
      • 2017-12-01
      • 2019-04-13
      • 1970-01-01
      • 2019-08-06
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-11-02
      相关资源
      最近更新 更多