【发布时间】:2019-09-13 09:37:45
【问题描述】:
我正在尝试选择与零不同的元素并稍后使用它们。我的输入张量具有批量维度,因此我想保留它并且不要将数据混合到批次中。我认为tf.gather_nd() 对我有用,但首先我必须获取所需数据的索引,然后我找到了tf.where()。我尝试了以下方法:
img = tf.constant([[[1., 0., 0.],
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]
indexes = tf.where(tf.not_equal(img, 0.))
我希望indexes 保持批量维度,但它的形状为[7, 2]。我怀疑问题出在不同批次满足条件的点数不同。
有没有办法让索引保持批量维度?提前致谢。
编辑: indexes 的形状为 [7, 3],其中第一个 dim 指的是点数,第二个 dim 指的是点的位置(包括它属于哪个批次)。但是我需要indexes 来拥有特定的批次维度,因为稍后我想用它来收集来自img 的数据:
Y = tf.gather_nd(img, indexes)
我希望 Y 具有批次维度,但由于 indexes 没有,我得到一个扁平张量,其中混合了来自不同批次的数据。
【问题讨论】:
标签: python tensorflow