【问题标题】:accessing elements of a tensor with another list of indices in tensorflow使用 tensorflow 中的另一个索引列表访问张量的元素
【发布时间】:2017-05-02 10:05:29
【问题描述】:

我需要使用我拥有的另一个索引列表来访问张量的元素,但目前使用简单的语法似乎是不可能的。我不确定它是否是一个错误,所以我把它贴在这里希望能修复我的语法。我的代码是:

import tensorflow as tf
import numpy as np

sess = tf.Session()
input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
idx_list = np.array([0,2])
output = input[:, idx_list]

print(sess.run(output))

但我得到了错误:

ValueError: Shapes must be equal rank, but are 0 and 1 From 将形状 0 与其他形状合并。对于'strided_slice/stack_1'(操作: 'Pack') 输入形状:[], [2]。

我安装的tensorflow版本是tensorflow-1.1.0-cp35(pip安装)。

更新:

我通过 tf.fn_map 执行此操作,但我真的怀疑这是进行索引的正确方法:

output = tf.transpose(tf.map_fn(lambda x: input[:,x], idx_list),perm=[1,0])

更新:

对此有一个特定的issue registered,在最新的 cmets 中有一个不错的 sn-p,可能会有所帮助。同时这个操作并不像numpy那么简单......

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    您可以使用tf.gathertf.transpose 执行此操作,如下所示:

    import tensorflow as tf
    import numpy as np
    
    sess = tf.Session()
    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    idx_list = np.array([0,2])
    output = tf.transpose(tf.gather(tf.transpose(input),idx_list))
    output.eval(session=sess)
    

    打印出来

    array([[1, 3],
           [4, 6],
           [7, 9]])
    

    【讨论】:

    • 谢谢!如果您还没有在 github 上看到问题,请查看我的更新。
    猜你喜欢
    • 2018-07-04
    • 2016-06-20
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-03-03
    • 2020-09-23
    • 1970-01-01
    • 2023-03-15
    相关资源
    最近更新 更多