【问题标题】:TensorFlow: Efficient way to get the index of the smallest element in a tensor that isn't zero?TensorFlow:获取非零张量中最小元素索引的有效方法?
【发布时间】:2019-05-19 09:40:20
【问题描述】:

我使用 TensorFlow 1.12。我有一个一维张量tag_mask_sizes,它主要包含零,但也包含一些正整数。如何有效地获取不为零的最小元素的索引?我尝试了以下方法:

tag_mask_sizes_suppressed = tf.map_fn(lambda x: x if tf.not_equal(x, tf.constant(0, dtype=tf.uint8)) else 9999999, tag_mask_sizes)
        smallest_mask_index = tf.argmin(tag_mask_sizes_suppressed)

但是,tf.not_equal() 会产生一个布尔张量,我无法在 lambda 内的 if-else 条件下有效地评估它。还有其他类似的优雅解决方案吗?

虽然我通常急切地执行,但这个问题发生在我在tf.Dataset.map() 中使用的一个函数中,它没有急切地执行。

【问题讨论】:

    标签: python tensorflow boolean tensor


    【解决方案1】:

    其实你的代码等价于下面的代码。

    tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
    smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)
    

    矢量化方法将明显快于tf.map_fn()。此外,还有一些矢量化方法可以获取一维张量中不为零的最小元素的索引。一个例子:

    import tensorflow as tf
    # tf.enable_eager_execution()
    
    tag_mask_sizes = tf.constant([2,0,1,3,1,32,0,0,0], dtype=tf.int32)
    
    # approach 1, the disadvantage is that the maximum must be specified and only the first minimum can be found.
    tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
    smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)
    
    # approach 2, only the first minimum can be found.
    tag_mask_sizes_nozeroidx = tf.where(tf.not_equal(tag_mask_sizes, 0))
    tag_mask_sizes_suppressed = tf.gather_nd(tag_mask_sizes,tag_mask_sizes_nozeroidx)
    smallest_mask_index2 = tag_mask_sizes_nozeroidx[tf.argmin(tag_mask_sizes_suppressed)]
    
    # approach 3, find all minimum
    tag_mask_sizes_suppressed = tf.boolean_mask(tag_mask_sizes,tf.not_equal(tag_mask_sizes, 0))
    smallest_mask_index3 = tf.squeeze(tf.where(tf.equal(tag_mask_sizes,tf.reduce_min(tag_mask_sizes_suppressed))))
    
    with tf.Session() as sess:
        print(sess.run(smallest_mask_index1))
        print(sess.run(smallest_mask_index2))
        print(sess.run(smallest_mask_index3))
    
    # print
    2
    [2]
    [2 4]
    

    【讨论】:

    • 感谢您的回复!现在使用fill() 时出现此错误:File "/home/.../tensorflow/python/ops/gen_array_ops.py", line 2979, in fill "Fill", dims=dims, value=value, name=name) File "/home/.../tensorflow/python/framework/op_def_library.py", line 528, in _apply_op_helper (input_name, err)) ValueError: Tried to convert 'dims' to a tensor and failed. Error: Cannot convert a partially known TensorShape to a Tensor: (?,) 我事先不知道tag_mask_sizes 的长度。您碰巧知道解决方法吗?也许复制数组并覆盖它?
    • 没关系,我得到了你的第二个解决方案!非常感谢!
    • @EmielBoss 你可以使用tag_mask_sizes+9999999替换tf.fill()来避免这个错误。
    猜你喜欢
    • 1970-01-01
    • 2011-08-28
    • 1970-01-01
    • 1970-01-01
    • 2021-01-26
    • 2017-01-06
    • 2018-05-22
    • 1970-01-01
    • 2017-12-13
    相关资源
    最近更新 更多