当你写得高效时,我假设你在这里想要的实际上是快。
我将尝试简要讨论渐近效率。
在这种情况下,我们将N 称为输入大小,将K 称为唯一值的数量。
我的方法解决方案是结合使用 np.argsort() 和专门针对 NumPy 输入优化的定制 groupby_np():
import numpy as np
def groupby_np(arr, both=True):
n = len(arr)
extrema = np.nonzero(arr[:-1] != arr[1:])[0] + 1
if both:
last_i = 0
for i in extrema:
yield last_i, i
last_i = i
yield last_i, n
else:
yield 0
yield from extrema
yield n
def labeling_groupby_np(values, labels):
slicing = labels.argsort()
sorted_labels = labels[slicing]
sorted_values = values[slicing]
del slicing
result = {}
for i, j in groupby_np(sorted_labels, True):
result[sorted_labels[i]] = sorted_values[i:j]
return result
这具有复杂性O(N log N + K)。
N log N 来自排序步骤,K 来自最后一个循环。
有趣的是,N-dependent 和 K-dependent 步骤都很快,因为 N-dependent 部分是在低级别执行的,K-dependent 部分是 O(1) 和也很快。
类似以下的解决方案(非常类似于@theEpsilon 的答案):
import numpy as np
def labeling_loop(values, labels):
labeled = {}
for x, l in zip(values, labels):
if l not in labeled:
labeled[l] = [x]
else:
labeled[l].append(x)
return {k: np.array(v) for k, v in labeled.items()}
使用两个循环并具有O(N + K)。我认为您不能轻易避免第二个循环(没有明显的速度损失)。至于第一个循环,这是在 Python 中执行的,它本身会带来很大的速度损失。
另一种可能性是使用np.unique() 将主循环 带到较低的级别。然而,这带来了其他挑战,因为一旦提取了唯一值,没有一些NumPy advanced indexing,即O(N),就没有有效的方法来提取信息来构造你想要的数组。这些解决方案的总体复杂度为O(K * N),但由于 NumPy 高级索引是在较低级别完成的,因此可以实现相对较快的解决方案,尽管其渐近复杂度比替代方案更差。
可能的实现包括(类似于@AjayVerma's 和@AKX's 的答案):
import numpy as np
def labeling_unique_bool(values, labels):
return {l: values[l == labels] for l in np.unique(labels)}
import numpy as np
def labeling_unique_nonzero(values, labels):
return {l: values[np.nonzero(l == labels)] for l in np.unique(labels)}
此外,可以考虑预先排序步骤,然后通过避免 NumPy 高级索引来加速切片部分。
然而,排序步骤可能比高级索引更昂贵,而且一般而言,对于我测试的输入,所提出的方法往往更快。
import numpy as np
def labeling_unique_argsort(values, labels):
uniques, counts = np.unique(labels, return_counts=True)
sorted_values = values[labels.argsort()]
bound = 0
result = {}
for x, c in zip(uniques, counts):
result[x] = sorted_values[bound:bound + c]
bound += c
return result
另一种方法,原则上很简洁(与我提出的方法相同),但在实践中很慢是使用排序和itertools.groupby():
import itertools
from operator import itemgetter
def labeling_groupby(values, labels):
slicing = labels.argsort()
sorted_labels = labels[slicing]
sorted_values = values[slicing]
del slicing
result = {}
for x, g in itertools.groupby(zip(sorted_labels, sorted_values), itemgetter(0)):
result[x] = np.fromiter(map(itemgetter(1), g), dtype=sorted_values.dtype)
return result
最后,一种基于 Pandas 的方法,对于较大的输入非常简洁且相当快,但对于较小的输入则表现不佳(类似于 @Ehsan's answer):
def labeling_groupby_pd(values, labels):
df = pd.DataFrame({'values': values, 'labels': labels})
return df.groupby('labels').values.apply(lambda x: x.values).to_dict()
现在,说话很便宜,所以让我们将一些数字附加到 fast 和 slow 并为不同的输入大小生成一些图。 K 的值上限为 52(英文字母的大小写字母)。当N远大于K时,达到封顶值的概率很高。
输入是通过以下方式以编程方式生成的:
def gen_input(n, p, labels=string.ascii_letters):
k = len(labels)
values = np.arange(n)
labels = np.array([string.ascii_letters[i] for i in np.random.randint(0, int(k * p), n)])
return values, labels
基准是针对p 的值从(1.0, 0.5, 0.1, 0.05) 生成的,这会改变K 的最大值。下面的图表按该顺序引用了p 值。
p=1.0(最多K = 52)
...并以最快的速度放大
p=0.5(最多K = 26)
p=0.1(最多K = 5)
p=0.05(最多K = 2)
...并以最快的速度放大
我们可以看到,除了非常小的输入外,所提出的方法如何优于迄今为止针对测试输入提出的其他方法。
(提供完整的基准测试here)。
也可以考虑将循环的某些部分移至 Numba / Cython,但我会将其留给感兴趣的读者。