【问题标题】:fastest way to filter a 2d numpy array过滤二维 numpy 数组的最快方法
【发布时间】:2021-09-22 21:35:50
【问题描述】:

我正在尝试过滤一个 numpy 数组,我已经完成了如下函数:

@nb.njit
def numpy_filter (npX):
    n = np.full (npX.shape[0], True)

    for npo_index in range(npX.shape[0]):
        n[npo_index] = npX[npo_index][0] < 2000 and npX[npo_index][1] < 4000 and npX[npo_index][2] < 5000

    return npX[n]

array = 600K 的 len 需要 1.75 秒(numba njit 模式),而如果 x[0]

有没有更好的实现可以有过滤功能让它运行得更快?

【问题讨论】:

  • 您到底想过滤什么?有很多方法可以从 numpy 数组中提取信息,但这取决于您要做什么?如果您使用一维 numpy 数组的列表,您真的应该将它们组合成一个二维数组。
  • 具有 randint(0,10000) 形状 (600000,3) 的 numpy 数组 aa[(a &lt; np.array([2000,4000,5000])).all(1)] 100 次循环,最好的 5 次:每个循环 19 毫秒
  • 我将重复已经说过的话。如果您发现自己在遍历一个 numpy 数组,那么您可能一开始就不应该使用 numpy。
  • 你预热了jitted函数吗?我得到了@njit 100 个循环,最好的 5:每个循环 3.65 毫秒list comprehension 1 个循环,最好的 5:每个循环 375 毫秒pure numpy 100 次循环,最好的 5 次:每个循环 19 毫秒,在 google colab 上有两个内核。 ~100x 速度高达 pythonnumba
  • @MichaelSzczesny,你是对的,我没有把它热身……热身后,它下降到 0.07 秒左右。这很酷……现在。我需要找出合并结果的最快方法

标签: python numpy list-comprehension numba


【解决方案1】:

通常使用 Pandas/NumPy 数组,如果您这样做,您将获得最佳性能

  • 避免遍历数组
  • 仅创建基本阵列的软拷贝或视图
  • 创建最少数量的中间 Python 对象

Pandas 可能是您的朋友,它允许您创建支持 NumPy 数组的视图并通过共享索引对每个数组的行进行操作

起始数据

这会创建一个与源数据形状相同的随机数组,其值范围为 0-10000

>>> import numpy as np
>>> arr = np.random.rand(600000, 3) * 10000
>>> arr
array([[8079.54193993,  925.74430028, 2031.45569251],
       [8232.74161149, 2347.42814063, 7571.21287502],
       [7435.52165567,  756.74380534, 1023.12181186],
       ...,
       [2176.36643662, 5374.36584708,  637.43482263],
       [2645.0737415 , 9059.42475818, 3913.32941652],
       [3626.54923011, 1494.57126083, 6121.65034039]])

创建一个 Pandas 数据框

这会在您的源数据上创建视图,以便您可以使用共享索引一起处理所有行

>>> import pandas as pd
>>> df = pd.DataFrame(arr)
>>> df
                  0            1            2
0       8079.541940   925.744300  2031.455693
1       8232.741611  2347.428141  7571.212875
2       7435.521656   756.743805  1023.121812
3       4423.799649  2256.125276  7591.732828
4       6892.019075  3170.699818  1625.226953
...             ...          ...          ...
599995   642.104686  3164.107206  9508.818253
599996   102.819102  3068.249711  1299.341425
599997  2176.366437  5374.365847   637.434823
599998  2645.073741  9059.424758  3913.329417
599999  3626.549230  1494.571261  6121.650340

[600000 rows x 3 columns]

过滤器

这将获取每列索引的过滤视图,并使用组合结果过滤 DataFrame

>>> df[(df[0] < 2000) & (df[1] < 4000) & (df[2] < 5000)]
                  0            1            2
35      1829.777633  1333.083450  1928.982210
38       653.584288  3129.089395  4753.734920
71      1354.736876   279.202816     5.793797
97      1381.531847   551.465381  3767.436640
115      183.112455  1573.272310  1973.143995
...             ...          ...          ...
599963  1895.537096  1695.569792  1866.575164
599970  1061.011239    51.534961  1014.290040
599988  1780.535714  2311.671494  1012.828410
599994   878.643910   352.858091  3014.505666
599996   102.819102  3068.249711  1299.341425

[24067 rows x 3 columns]

可能会跟随基准,但速度非常快

【讨论】:

  • 我有 0.18 秒...比 numpy 慢,但比 list 快
  • 我也标记为在 nb.njit 模式下,但速度有点慢(0.22s)
  • 轶事,自从我用 numba 尝试任何东西以来已经有一分钟了(尽管它在某些测试中速度非常快).. 但我相信它往往会受到影响,并且在使用任何非 numba 时往往性能更差-numpy 对象由于一些复制(?)
  • 更具体地说,始终设置 nopython=True 并且不允许 Pandas!
【解决方案2】:

jit函数没有预热,第一次运行后,结果显示只需要0.07s就可以完成任务。

【讨论】:

  • 有道理;在这里指向文档可能很好(因为不清楚您的意思是 numba JIT 缓存),还可以查看 Michael's comment 是否更快(应该是可编译的,因为它是全 NumPy )
  • 我确定不是缓存的结果..
  • 啊,编译后的 sn-ps 的缓存,而不是结果(否则你当然可以用 functools.cache 之类的东西来欺骗基准测试)
【解决方案3】:

让你的 jit 函数只返回掩码 n,不要发送 npX[n]。 由于 jit 编译器无法修复过滤后数组的返回大小,因此它可能会变慢。

在 jit 函数之外进行过滤,即npX[n]。这应该会加快速度。

另外,为了更好地使用装饰器添加签名,这将强制进行即时编译。

numpy 和 numba 中的优化方式是不同的,所以你总是尝试哪个会更快。但是当速度几乎相同时,您可以添加并行选项,这将使其更快(我想您已经知道了)

【讨论】:

    猜你喜欢
    • 2017-10-29
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-01-26
    • 2019-11-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多