【发布时间】:2016-11-25 16:02:56
【问题描述】:
有谁知道如何提取 2 阶张量每行的前 n 个最大值?
例如,如果我想要形状为 [2,4] 的张量的前 2 个值,其值:
[[40, 30, 20, 10], [10, 20, 30, 40]]
所需的条件矩阵如下所示: [[真,真,假,假],[假,假,真,真]]
一旦有了条件矩阵,我就可以使用 tf.select 来选择实际值。
感谢您的帮助!
【问题讨论】:
有谁知道如何提取 2 阶张量每行的前 n 个最大值?
例如,如果我想要形状为 [2,4] 的张量的前 2 个值,其值:
[[40, 30, 20, 10], [10, 20, 30, 40]]
所需的条件矩阵如下所示: [[真,真,假,假],[假,假,真,真]]
一旦有了条件矩阵,我就可以使用 tf.select 来选择实际值。
感谢您的帮助!
【问题讨论】:
您可以使用内置的tf.nn.top_k 函数来做到这一点:
a = tf.convert_to_tensor([[40, 30, 20, 10], [10, 20, 30, 40]])
b = tf.nn.top_k(a, 2)
print(sess.run(b))
TopKV2(values=array([[40, 30],
[40, 30]], dtype=int32), indices=array([[0, 1],
[3, 2]], dtype=int32))
print(sess.run(b).values))
array([[40, 30],
[40, 30]], dtype=int32)
要获取布尔值True/False,可以先获取第k个值,然后使用tf.greater_equal:
kth = tf.reduce_min(b.values)
top2 = tf.greater_equal(a, kth)
print(sess.run(top2))
array([[ True, True, False, False],
[False, False, True, True]], dtype=bool)
【讨论】:
你也可以使用tf.contrib.framework.argsort
a = [[40, 30, 20, 10], [10, 20, 30, 40]]
idx = tf.contrib.framework.argsort(a, direction='DESCENDING') # sorted indices
ranks = tf.contrib.framework.argsort(idx, direction='ASCENDING') # ranks
b = ranks < 2
# [[ True True False False] [False False True True]]
此外,您可以将2 替换为一维张量,以便每行/列可以具有不同的n 值。
【讨论】: