【问题标题】:Tensorflow: how to keep batch dimension when using tf.where()?Tensorflow:使用 tf.where() 时如何保持批量维度?
【发布时间】:2019-09-13 09:37:45
【问题描述】:

我正在尝试选择与零不同的元素并稍后使用它们。我的输入张量具有批量维度,因此我想保留它并且不要将数据混合到批次中。我认为tf.gather_nd() 对我有用,但首先我必须获取所需数据的索引,然后我找到了tf.where()。我尝试了以下方法:

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

indexes = tf.where(tf.not_equal(img, 0.))

我希望indexes 保持批量维度,但它的形状为[7, 2]。我怀疑问题出在不同批次满足条件的点数不同。

有没有办法让索引保持批量维度?提前致谢。

编辑: indexes 的形状为 [7, 3],其中第一个 dim 指的是点数,第二个 dim 指的是点的位置(包括它属于哪个批次)。但是我需要indexes 来拥有特定的批次维度,因为稍后我想用它来收集来自img 的数据:

Y = tf.gather_nd(img, indexes)

我希望 Y 具有批次维度,但由于 indexes 没有,我得到一个扁平张量,其中混合了来自不同批次的数据。

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    实际上,您可能做错了什么:当我运行您的代码时,indexes 的维度为 (7,3) 而不是 (7,2)3 对应于您的 3 个维度,而 7 对应于 img 中非零元素的数量。

    sess.run(indexes) 的完整结果:

    array([[0, 0, 0],
          [0, 1, 2],
          [0, 2, 1],
          [1, 0, 0],
          [1, 0, 1],
          [1, 0, 2],
          [1, 1, 2]])
    

    【讨论】:

    • 是的,你是对的,很抱歉。但即便如此,所有批次的数据都是混合的,尽管我有相应批次的信息。我想要的是indexes 有一个单独的批次维度:[2, X, 3]。那是因为我稍后将indexesY = gather_nd(img, indexes) 一起使用,并且需要Y 具有批量维度。
    • 正如您在问题中提到的,这是不可能的,因为批次中的所有样本都没有相同数量的非零元素。但是,您可能希望对 tf.where 的输出进行一些后处理,以获得您想要的格式
    猜你喜欢
    • 2021-02-07
    • 1970-01-01
    • 1970-01-01
    • 2019-08-17
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-08-08
    • 1970-01-01
    相关资源
    最近更新 更多