【发布时间】:2021-09-22 21:35:50
【问题描述】:
我正在尝试过滤一个 numpy 数组,我已经完成了如下函数:
@nb.njit
def numpy_filter (npX):
n = np.full (npX.shape[0], True)
for npo_index in range(npX.shape[0]):
n[npo_index] = npX[npo_index][0] < 2000 and npX[npo_index][1] < 4000 and npX[npo_index][2] < 5000
return npX[n]
array = 600K 的 len 需要 1.75 秒(numba njit 模式),而如果 x[0]
有没有更好的实现可以有过滤功能让它运行得更快?
【问题讨论】:
-
您到底想过滤什么?有很多方法可以从 numpy 数组中提取信息,但这取决于您要做什么?如果您使用一维 numpy 数组的列表,您真的应该将它们组合成一个二维数组。
-
具有 randint(0,10000) 形状 (600000,3) 的
numpy数组a。a[(a < np.array([2000,4000,5000])).all(1)]100 次循环,最好的 5 次:每个循环 19 毫秒 -
我将重复已经说过的话。如果您发现自己在遍历一个 numpy 数组,那么您可能一开始就不应该使用 numpy。
-
你预热了jitted函数吗?我得到了
@njit100 个循环,最好的 5:每个循环 3.65 毫秒,list comprehension1 个循环,最好的 5:每个循环 375 毫秒,pure numpy100 次循环,最好的 5 次:每个循环 19 毫秒,在 google colab 上有两个内核。 ~100x 速度高达python和numba -
@MichaelSzczesny,你是对的,我没有把它热身……热身后,它下降到 0.07 秒左右。这很酷……现在。我需要找出合并结果的最快方法
标签: python numpy list-comprehension numba