【问题标题】:Slicing tensor with list - TensorFlow使用列表切片张量 - TensorFlow
【发布时间】:2017-10-23 02:06:18
【问题描述】:

有没有办法在 Tensorflow 中完成这种切片方法(使用 numpy 显示的示例)?

z = np.random.random((3,7,7,12))
x = z[...,[0,5]]

这样

x_hat = np.concatenate([z[...,[0]], z[...,[5]]], 3)
assert np.all(x == x_hat)
x.shape # (3, 7, 7, 2)

在 TensorFlow 中,这个操作

tfz = tf.constant(z)
i = np.array([0,5] dtype=np.int32)
tfx = tfz[...,i]

抛出错误

ValueError: Shapes must be equal rank, but are 0 and 1
From merging shape 0 with other shapes. for 'strided_slice/stack_1' (op: 'Pack') with input shapes: [], [2].

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    怎么样:

    x = tf.stack([tfz[..., i] for i in [0,5]], axis=-1) 
    

    这对我有用:

    z = np.random.random((3,7,7,12))
    tfz = tf.constant(z)
    x = tf.stack([tfz[..., i] for i in [0,5]], axis=-1)
    
    x_hat = np.concatenate([z[...,[0]], z[...,[5]]], 3)
    
    with tf.Session() as sess:
        x_run = sess.run(x)
    
    assert np.all(x_run == x_hat)
    

    【讨论】:

      【解决方案2】:

      您需要重新整形以使连接的结果与原始形状(前 3 个维度)一致。

      z = np.arange(36)
      tfz = tf.reshape(tf.constant(z), [2, 3, 2, 3])
      slice1 = tf.reshape(tfz[:,:,:,1], [2, 3, -1, 1])
      slice2 = tf.reshape(tfz[:,:,:,2], [2, 3, -1, 1])
      slice = tf.concat([slice1, slice2], axis=3)
      
      with tf.Session() as sess:
        print sess.run([tfz, slice])
      
      
      > [[[[ 0,  1,  2],
           [ 3,  4,  5]],
      
          [[ 6,  7,  8],
           [ 9, 10, 11]],
      
          [[12, 13, 14],
           [15, 16, 17]]],
      
         [[[18, 19, 20],
           [21, 22, 23]],
      
          [[24, 25, 26],
           [27, 28, 29]],
      
          [[30, 31, 32],
           [33, 34, 35]]]]
      
        # Get the last two columns
      > [[[[ 1,  2],
           [ 4,  5]],
      
          [[ 7,  8],
           [10, 11]],
      
          [[13, 14],
           [16, 17]]],
      
         [[[19, 20],
           [22, 23]],
      
          [[25, 26],
           [28, 29]],
      
          [[31, 32],
           [34, 35]]]]
      

      【讨论】:

        【解决方案3】:

        就像格林斯所说的那样,是形状错误。不幸的是,似乎没有像我希望的那样简单的方法,但这是我想出的通用解决方案:

        def list_slice(tensor, indices, axis):
            """
            Args
            ----
            tensor (Tensor) : input tensor to slice
            indices ( [int] ) : list of indices of where to perform slices
            axis (int) : the axis to perform the slice on
            """
        
            slices = []   
        
            ## Set the shape of the output tensor. 
            # Set any unknown dimensions to -1, so that reshape can infer it correctly. 
            # Set the dimension in the slice direction to be 1, so that overall dimensions are preserved during the operation
            shape = tensor.get_shape().as_list()
            shape[shape==None] = -1
            shape[axis] = 1
        
            nd = len(shape)
        
            for i in indices:   
                _slice = [slice(None)]*nd
                _slice[axis] = slice(i,i+1)
                slices.append(tf.reshape(tensor[_slice], shape))
        
            return tf.concat(slices, axis=axis)
        
        
        
        z = np.random.random(size=(3, 7, 7, 12))
        x = z[...,[0,5]]
        tfz = tf.constant(z)
        tfx_hat = list_slice(tfz, [0, 5], axis=3)
        x_hat = tfx_hat.eval()
        
        assert np.all(x == x_hat)
        

        【讨论】:

        • 我喜欢你的概括。
        猜你喜欢
        • 2018-11-07
        • 2019-10-21
        • 2019-10-11
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2017-06-02
        相关资源
        最近更新 更多