【问题标题】:NumPy - fast stable arg-sort of large array by frequencyNumPy - 按频率对大型数组进行快速稳定的 arg 排序
【发布时间】:2020-10-07 06:09:08
【问题描述】:

我有任何可比较的dtype 的大型一维NumPy 数组a,它的一些元素可能会重复。

我如何找到排序索引ix,它将按降序/升序的值频率进行稳定排序(a sense described here 中的稳定性)a

我想找到最快和最简单的方法来做到这一点。也许有现有的标准 numpy 函数可以做到这一点。

还有另一个相关的question here,但它专门要求删除数组重复项,即只输出唯一的排序值,我需要原始数组的所有值,包括重复项。

我已经编写了我的第一次试验来完成这项任务,但它不是最快的(使用 Python 的循环),并且可能不是最短/最简单的形式。如果相等元素的重复率不高且数组很大,则此 python 循环可能非常昂贵。如果在 NumPy 中可用(例如想象的np.argsort_by_freq()),那么有一个简短的函数来做这一切也很好。

Try it online!

import numpy as np
np.random.seed(1)
hi, n, desc = 7, 24, True
a = np.random.choice(np.arange(hi), (n,), p = (
    lambda p = np.random.random((hi,)): p / p.sum()
)())
us, cs = np.unique(a, return_counts = True)
af = np.zeros(n, dtype = np.int64)
for u, c in zip(us, cs):
    af[a == u] = c
if desc:
    ix = np.argsort(-af, kind = 'stable') # Descending sort
else:
    ix = np.argsort(af, kind = 'stable') # Ascending sort
print('rows: i_col(0) / original_a(1) / freqs(2) / sorted_a(3)')
print('    / sorted_freqs(4) / sorting_ix(5)')
print(np.stack((
    np.arange(n), a, af, a[ix], af[ix], ix,
), 0))

输出:

rows: i_col(0) / original_a(1) / freqs(2) / sorted_a(3)
    / sorted_freqs(4) / sorting_ix(5)
[[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
 [ 1  1  1  1  3  0  5  0  3  1  1  0  0  4  6  1  3  5  5  0  0  0  5  0]
 [ 7  7  7  7  3  8  4  8  3  7  7  8  8  1  1  7  3  4  4  8  8  8  4  8]
 [ 0  0  0  0  0  0  0  0  1  1  1  1  1  1  1  5  5  5  5  3  3  3  4  6]
 [ 8  8  8  8  8  8  8  8  7  7  7  7  7  7  7  4  4  4  4  3  3  3  1  1]
 [ 5  7 11 12 19 20 21 23  0  1  2  3  9 10 15  6 17 18 22  4  8 16 13 14]]

【问题讨论】:

  • 您当前的解决方案有什么问题?
  • @Nick 我在上面写了原因:1)它不是最快的(使用纯python循环)2)可能不是最短的3)重要的是降序不稳定(in this sense),但稳定我刚刚在一分钟前解决了here
  • 那么你对这个问题的预期输出是什么?
  • @Nick 在最后一个输出示例中,我需要在第 3 行末尾 4 6 而不是 6 4,因为它们的频率相同,并且它们按原顺序排列 4 6大批。但是稳定性不是问题,我已经像上面的评论一样解决了它,只是为了 argsort 否定值。我会解决我的问题,不要打扰人们关于稳定性的问题。重要的是我想要最快的解决方案(没有 Python 循环)并且尽可能最短。
  • @Nick 如果数组很大并且其中重复的元素很少,那么这个 Python 循环将花费很长时间。

标签: python arrays numpy sorting frequency


【解决方案1】:

我可能遗漏了一些东西,但似乎有了Counter,您可以根据每个元素的值的计数对每个元素的索引进行排序,使用元素值,然后使用索引来打破平局。例如:

from collections import Counter

a = [ 1,  1,  1,  1,  3,  0,  5,  0,  3,  1,  1,  0,  0,  4,  6,  1,  3,  5,  5,  0,  0,  0,  5,  0]
counts = Counter(a)

t = [(counts[v], v, i) for i, v in enumerate(a)]
t.sort()
print([v[2] for v in t])
t.sort(reverse=True)
print([v[2] for v in t])

输出:

[13, 14, 4, 8, 16, 6, 17, 18, 22, 0, 1, 2, 3, 9, 10, 15, 5, 7, 11, 12, 19, 20, 21, 23]
[23, 21, 20, 19, 12, 11, 7, 5, 15, 10, 9, 3, 2, 1, 0, 22, 18, 17, 6, 16, 8, 4, 14, 13]

如果您想保持具有相同计数的组的索引升序,您可以使用 lambda 函数进行降序:

t.sort(key = lambda x:(-x[0],-x[1],x[2]))
print([v[2] for v in t])

输出:

[5, 7, 11, 12, 19, 20, 21, 23, 0, 1, 2, 3, 9, 10, 15, 6, 17, 18, 22, 4, 8, 16, 14, 13]

如果你想保持元素的顺序与它们最初出现在数组中的顺序相同如果它们的计数相同,那么不要按值排序,而是按它们的索引排序数组中的第一次出现:

a = [ 1,  1,  1,  1,  3,  0,  5,  0,  3,  1,  1,  0,  0,  4,  6,  1,  3,  5,  5,  0,  0,  0,  5,  0]
counts = Counter(a)

idxs = {}
t = []
for i, v in enumerate(a):
    if not v in idxs:
        idxs[v] = i
    t.append((counts[v], idxs[v], i))

t.sort()
print([v[2] for v in t])
t.sort(key = lambda x:(-x[0],x[1],x[2]))
print([v[2] for v in t])

输出:

[13, 14, 4, 8, 16, 6, 17, 18, 22, 0, 1, 2, 3, 9, 10, 15, 5, 7, 11, 12, 19, 20, 21, 23]
[5, 7, 11, 12, 19, 20, 21, 23, 0, 1, 2, 3, 9, 10, 15, 6, 17, 18, 22, 4, 8, 16, 13, 14]

按照计数排序,然后在数组中定位,你根本不需要值或第一个索引:

from collections import Counter

a = [ 1,  1,  1,  1,  3,  0,  5,  0,  3,  1,  1,  0,  0,  4,  6,  1,  3,  5,  5,  0,  0,  0,  5,  0]
counts = Counter(a)

t = [(counts[v], i) for i, v in enumerate(a)]
t.sort()
print([v[1] for v in t])
t.sort(key = lambda x:(-x[0],x[1]))
print([v[1] for v in t])

这会为您的字符串数组生成与示例数据的先前代码相同的输出:

a = ['g',  'g',  'c',  'f',  'd',  'd',  'g',  'a',  'a',  'a',  'f',  'f',  'f',
     'g',  'f',  'c',  'f',  'a',  'e',  'b',  'g',  'd',  'c',  'b',  'f' ]

这会产生输出:

[18, 19, 23, 2, 4, 5, 15, 21, 22, 7, 8, 9, 17, 0, 1, 6, 13, 20, 3, 10, 11, 12, 14, 16, 24]
[3, 10, 11, 12, 14, 16, 24, 0, 1, 6, 13, 20, 7, 8, 9, 17, 2, 4, 5, 15, 21, 22, 19, 23, 18]

【讨论】:

  • 当数组很大时,我需要专门针对 NumPy 的解决方案。对于这种情况,它应该非常快。甚至使用标准库的 Counter 也可能很慢,因为它引用了诸如 intstr 之类的 python 对象。
  • 如果原始数组中唯一元素的数量不小于所有元素的数量(即重复级别低),那么计数器解决方案或python循环都会很慢。
  • 我很想看看你的解决方案与这个的速度比较。
  • @Arty 我还添加了另一个版本的代码,我认为它可以实现您正在寻找的稳定性。
  • 我几乎可以肯定,即使没有测量,对于几乎独特元素的情况,您的解决方案应该比我的解决方案更快。但更有效的面向 numpy 的矢量化解决方案可能是我想要的。顺便说一句,Counter 可用于任何类型的元素,包括任何可比较的类型,例如 str?
【解决方案2】:

我只是认为自己对于任何 dtype 都可能非常快速的解决方案,只使用没有 python 循环的 numpy 函数,它可以在 O(N log N) 时间工作。使用的 numpy 函数:np.uniquenp.argsort 和数组索引。

虽然在最初的问题中没有被问到,但我实现了额外的标志 equal_order_by_val,如果它是 False,那么具有相同频率的数组元素被排序为相等的稳定范围,这意味着可能有 c d d c d c 输出,如下面的输出转储,因为这是元素以相同频率进入原始数组的顺序。当 flag 为 True 时,这些元素还按原始数组的值排序,结果为 c c c d d d。换句话说,在 False 的情况下,我们仅按键 freq 稳定排序,当它为 True 时,我们按 (freq, value) 升序排序,(-freq, value) 降序排序。

Try it online!

import string, math
import numpy as np
np.random.seed(0)

# Generating input data

hi, n, desc = 7, 25, True
letters = np.array(list(string.ascii_letters), dtype = np.object_)[:hi]
a = np.random.choice(letters, (n,), p = (
    lambda p = np.random.random((letters.size,)): p / p.sum()
)())

for equal_order_by_val in [False, True]:
    # Solving task

    us, ui, cs = np.unique(a, return_inverse = True, return_counts = True)
    af = cs[ui]
    sort_key = -af if desc else af
    if equal_order_by_val:
        shift_bits = max(1, math.ceil(math.log(us.size) / math.log(2)))
        sort_key = ((sort_key.astype(np.int64) << shift_bits) +
            np.arange(us.size, dtype = np.int64)[ui])
    ix = np.argsort(sort_key, kind = 'stable') # Do sorting itself

    # Printing results

    print('\nequal_order_by_val:', equal_order_by_val)
    for name, val in [
        ('i_col', np.arange(n)),  ('original_a', a),
        ('freqs', af),            ('sorted_a', a[ix]),
        ('sorted_freqs', af[ix]), ('sorting_ix', ix),
    ]:
        print(name.rjust(12), ' '.join([str(e).rjust(2) for e in val]))

输出:

equal_order_by_val: False
       i_col  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
  original_a  g  g  c  f  d  d  g  a  a  a  f  f  f  g  f  c  f  a  e  b  g  d  c  b  f
       freqs  5  5  3  7  3  3  5  4  4  4  7  7  7  5  7  3  7  4  1  2  5  3  3  2  7
    sorted_a  f  f  f  f  f  f  f  g  g  g  g  g  a  a  a  a  c  d  d  c  d  c  b  b  e
sorted_freqs  7  7  7  7  7  7  7  5  5  5  5  5  4  4  4  4  3  3  3  3  3  3  2  2  1
  sorting_ix  3 10 11 12 14 16 24  0  1  6 13 20  7  8  9 17  2  4  5 15 21 22 19 23 18

equal_order_by_val: True
       i_col  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
  original_a  g  g  c  f  d  d  g  a  a  a  f  f  f  g  f  c  f  a  e  b  g  d  c  b  f
       freqs  5  5  3  7  3  3  5  4  4  4  7  7  7  5  7  3  7  4  1  2  5  3  3  2  7
    sorted_a  f  f  f  f  f  f  f  g  g  g  g  g  a  a  a  a  c  c  c  d  d  d  b  b  e
sorted_freqs  7  7  7  7  7  7  7  5  5  5  5  5  4  4  4  4  3  3  3  3  3  3  2  2  1
  sorting_ix  3 10 11 12 14 16 24  0  1  6 13 20  7  8  9 17  2 15 22  4  5 21 19 23 18

【讨论】:

  • @Nick 它不是按值排序,而是按频率键排序,即行sorted_freqscd 只是频率相同。并且对于相同频率的稳定顺序(与原始数组相同)应由初始任务完成。
  • 啊,我的代码不是这样工作的。它将所有cs 排序在所有ds 之前,因为第一个c 出现在第一个d 之前。我添加了另一个(更简单)的解决方案,它的排序方式与您的代码相同
  • @Nick Stop,不要更新。顺便说一句,这是一个很好的观点。尽管最初的任务与此无关,但最好在我和您的代码中作为额外标志来实现。基本上,如果像equal_freqs_sort_by_value = True 这样的标志,那么同样的频率也应该通过按值排序来分开。 IE。使用此标志进行排序应按元组(freq, a_val) 而不仅仅是freq
  • 我刚刚添加了另一个解决方案。原始代码仍然存在,将两个解决方案与标志合并以控制操作会相当容易。
  • 直接 python 解决方案大约快 3 倍:ideone.com/gp4DVj
最近更新 更多