这是 numba 的一个很好的用例,因为它可以让您将其表示为一个简单的双循环,而不会造成很大的性能损失,这反过来又可以让您避免使用 np.tile 复制数据时产生过多的额外内存第三个维度只是以矢量化的方式进行。
从另一个答案中借用标准矢量化 numpy 实现,我有这两个实现:
import numba
import numpy as np
def kmeans_assignment(centroids, points):
num_centroids, dim = centroids.shape
num_points, _ = points.shape
# Tile and reshape both arrays into `[num_points, num_centroids, dim]`.
centroids = np.tile(centroids, [num_points, 1]).reshape([num_points, num_centroids, dim])
points = np.tile(points, [1, num_centroids]).reshape([num_points, num_centroids, dim])
# Compute all distances (for all points and all centroids) at once and
# select the min centroid for each point.
distances = np.sum(np.square(centroids - points), axis=2)
return np.argmin(distances, axis=1)
@numba.jit
def kmeans_assignment2(centroids, points):
P, C = points.shape[0], centroids.shape[0]
distances = np.zeros((P, C), dtype=np.float32)
for p in range(P):
for c in range(C):
distances[p, c] = np.sum(np.square(centroids[c] - points[p]))
return np.argmin(distances, axis=1)
然后对于一些样本数据,我做了几个时序实验:
In [12]: points = np.random.rand(10000, 50)
In [13]: centroids = np.random.rand(30, 50)
In [14]: %timeit kmeans_assignment(centroids, points)
196 ms ± 6.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [15]: %timeit kmeans_assignment2(centroids, points)
127 ms ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
我不会说 numba 版本肯定比 np.tile 版本快,但显然它非常接近,同时不会产生 np.tile 的额外内存成本。
事实上,我在笔记本电脑上注意到,当我将形状变大并使用 (10000, 1000) 来表示 points 的形状时,使用 (200, 1000) 来表示 centroids 的形状,然后是 np.tile生成了一个MemoryError,同时numba函数在5秒内运行,没有内存错误。
另外,我实际上注意到在第一个版本(withnp.tile)上使用numba.jit 时速度变慢了,这可能是由于在 jitted 函数内部创建了额外的数组以及没有太多 numba 可以优化的事实当您已经在调用所有矢量化函数时。
当我尝试使用广播来缩短代码时,我也没有注意到第二个版本有任何显着改进。例如。将双循环缩短为
for p in range(P):
distances[p, :] = np.sum(np.square(centroids - points[p, :]), axis=1)
并没有真正帮助任何事情(并且在整个centroids 中重复广播points[p, :] 时会占用更多内存)。
这是 numba 的真正好处之一。您真的可以以一种非常直接、基于循环的方式编写算法,该方式符合算法的标准描述,并允许更好地控制语法如何分解为内存消耗或广播......所有这些都不会放弃运行时性能。