【问题标题】:Slicing Tensorflow Tensor with Tensor用张量切片张量流张量
【发布时间】:2017-12-01 06:36:06
【问题描述】:

我正在尝试使用 PR 中添加的“高级”、numpy 样式切片,但是我遇到了 same issue as the user here

ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice_15' (op: 'StridedSlice') with input shapes: [3,2], [1,2], [1,2], [1].

也就是说,我想做与这个 numpy 操作等效的操作(在 numpy 中工作):

A = np.array([[1,2],[3,4],[5,6]]) 
id_rows = np.array([0,2])
A[id_rows]

但是对于上述错误,这在 TF 中不起作用:

A = tf.constant([[1,2],[3,4],[5,6]])
id_rows = tf.constant([0,2])
A[id_rows]

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    你正在寻找这样的东西:

    A = tf.constant([[1,2],[3,4],[5,6]])
    id_rows = tf.constant([[0],[2]]) #Notice the brackets
    out = tf.gather_nd(A,id_rows)
    

    【讨论】:

    • 好吧,这行得通,但是有没有办法使用“切片”来完成这个(即__getitem__ 使用gather_nd)?
    【解决方案2】:

    您可以按如下方式对张量进行切片。

    A = tf.constant([[1,2],[3,4],[5,6]])
    id_rows = tf.constant(np.array([0, 2]).reshape(-1, 1))
    out = tf.gather_nd(A,id_rows)
    with tf.Session() as session: 
        print(session.run(out))
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-03-15
      • 1970-01-01
      • 2019-11-05
      • 2016-07-04
      • 2019-05-30
      • 2018-01-31
      • 1970-01-01
      • 2017-11-05
      相关资源
      最近更新 更多