【问题标题】:Tensorflow: Filtering 3D Index duplicates by their maximum ValuesTensorflow:按最大值过滤 3D 索引重复项
【发布时间】:2025-12-29 11:50:12
【问题描述】:

我正在尝试创建一个过滤器掩码,通过比较它们各自的哪个值更大来从向量中删除重复的索引。

我目前的做法是:

  1. 将 3-D 索引转换为 1-D
  2. 检查一维索引的唯一性
  3. 计算每个唯一索引的最大值
  4. 将最大值与原始值进行比较。如果存在相同的值,请保留该 3-D 索引。

我想获得一个过滤器数组,这样我也可以将boolean_mask 应用于其他张量。对于此示例,掩码应如下所示: [False True True True True].

除非值本身也重复,否则我当前的代码类型可以正常工作。但是,当我使用它时似乎就是这种情况,因此我需要找到更好的解决方案。

这是我的代码外观示例

import tensorflow as tf

# Dummy Input values with same Structure as the real
x_cells   = tf.constant([1,2,3,4,1], dtype=tf.int32)   # Index_1
y_cells   = tf.constant([4,4,4,4,4], dtype=tf.int32)   # Index_2
iou_index = tf.constant([1,2,3,4,1], dtype=tf.int32) # Index_3
iou_max   = tf.constant([1.,2.,3.,4.,5.], dtype=tf.float32) # Values

# my Output should be a mask that is [False True True True True]
# So if i filter this i get e.g. x_cells = [2,3,4,1] or iou_max = [2.,3.,4.,5.]

max_dim_y = tf.constant(10)
max_dim_x = tf.constant(20)
num_anchors = 5
stride = 32

# 1. Transforming the 3D-Index to 1D
tmp = tf.stack([x_cells, y_cells, iou_index], axis=1)
indices = tf.matmul(tmp, [[max_dim_y * num_anchors],     [num_anchors],[1]])

# 2. Looking for unique / duplicate indices
y, idx = tf.unique(tf.squeeze(indices))

# 3. Calculating the maximum values of each unique index.
# An function like unsorted_segment_argmax() would be awesome here
num_segments = tf.shape(y)[0]
ious = tf.unsorted_segment_max(iou_max, idx, num_segments)

iou_max_length = tf.shape(iou_max)[0]
ious_length = tf.shape(ious)[0]

# 4. Compare all max values to original values.
iou_max_tiled = tf.tile(iou_max, [ious_length])
iou_reshaped = tf.reshape(iou_max_tiled, [ious_length, iou_max_length])
iou_max_reshaped = tf.transpose(iou_reshaped)
filter_mask = tf.reduce_any(tf.equal(iou_max_reshaped, ious), -1)
filter_mask = tf.reshape(filter_mask, shape=[-1])

如果我们简单地将开头的iou_max 变量的值更改为:

x_cells = tf.constant([1,2,3,4,1], dtype=tf.int32)
y_cells = tf.constant([4,4,4,4,4], dtype=tf.int32)
iou_index = tf.constant([1,2,3,4,1], dtype=tf.int32)
iou_max = tf.constant([2.,2.,3.,4.,5.], dtype=tf.float32)

【问题讨论】:

    标签: python numpy tensorflow indexing tensor


    【解决方案1】:

    我当前的解决方法改变了我的问题的第 4 点:

    基本上我改变了我比较元组而不是单个值。这使我能够在逻辑上检查索引和值是否都在 3 的剩余值中。

    # 4. Compare a Max Value and Indices with original values
    rem_index_val_pair = tf.stack([ious, tf.cast(y, dtype=tf.float32)], axis=1)
    orig_val_index_pair = tf.stack([iou_max, tf.cast(indices, dtype=tf.float32)], axis=1)
    
    orig_val_index_pair_t = tf.tile(orig_val_index_pair, [1, ious_length])
    orig_val_index_pair_s = tf.reshape(orig_val_index_pair_t, [iou_max_length, ious_length, 2])
    filter_mask_1 = tf.equal(orig_val_index_pair_s, rem_index_val_pair)
    filter_mask_2 = tf.reduce_all(filter_mask_1, -1)
    filter_mask_3 = tf.reduce_any(filter_mask_2, -1)
    

    # The orig_val_index_pair_s looks like the following
    a =  [[[  2.  71.][  2.  71.][  2.  71.][  2.  71.]
         [[  2. 122.][  2. 122.][  2. 122.][  2. 122.]]
         [[  3. 173.][  3. 173.][  3. 173.][  3. 173.]]
         [[  4. 224.][  4. 224.][  4. 224.][  4. 224.]]
         [[  5.  71.][  5.  71.][  5.  71.][  5.  71.]]]
    # I then compare it to the rem_max_val_pair which looks like this.
    b =  [[  5.  71.][  2. 122.][  3. 173.][  4. 224.]]
    
    # Using equal(a,b) will now compare each of the values resulting in:
    c = [[[False  True][ True False][False False][False False]]
         [[False False][ True  True][False False][False False]]
         [[False False][False False][ True  True][False False]]
         [[False False][False False][False False][ True  True]]
         [[ True  True][False False][False False][False False]]]
    
    # Using tf.reduce_all(c, -1) I can filter the bool pairs with a logical And. 
    # (This kicks out my false positives from before).
    # Afterwards I can check if the line has any true value by tf.reduce_any().
    

    IMO 这个解决方案仍然是一个肮脏的解决方法。因此,如果您有任何更好的解决方案建议,请分享。 :)

    【讨论】: