【问题标题】:Why is this numba code to store and process pointcloud data slower than the pure Python version?为什么这个存储和处理点云数据的 numba 代码比纯 Python 版本慢?
【发布时间】:2021-09-23 18:38:03
【问题描述】:

我需要存储一些这样的数据结构:

{'x1,y1,z1': [[p11_x,p11_y,p11_z], [p12_x,p12_y,p12_z], ..., [p1n_x,p1n_y,p1n_z]],
 'x2,y2,z2': [[p21_x,p21_y,p21_z], [p22_x,p22_y,p22_z], ..., [p2n_x,p2n_y,p2n_z]],
 ...
 'xn,yn,zn': [[pn1_x,pn1_y,pn1_z], [pn2_x,pn2_y,pn2_z], ..., [pnm_x,pnm_y,pnm_z]]}

每个键都是网格单元格索引,值是分类点列表。该列表可以是可变长度,但我可以将其设置为静态,例如 1000 个元素。

现在我尝试了这样的事情:

np.zeros(shape=(100,100,100,50,3))

但如果我将numba.jit 与该函数一起使用,则执行时间比纯 Python 差几倍。

我想做的简单 Python 示例:

def split_into_grid_py(points: np.array):
    grid = {}
    for point in points:
        r_x = round(point[0])
        r_y = round(point[1])
        r_z = round(point[2])
        try:
            grid[(r_x, r_y, r_z)].append(point)
        except KeyError:
            grid[(r_x, r_y, r_z)] = [point]
    return grid

numba 有什么有效的方法吗? 每 10 次执行循环时间如下:

  • 号码:7.050494909286499
  • 纯 Python:1.0014197826385498

使用相同的数据集,所以这是垃圾优化。

我的 numba 代码:

@numba.jit(nopython=True)
def split_into_grid(points: np.array):
    grid = np.zeros(shape=(100,100,100,50,3))
    for point in points:
        r_x = round(point[0])
        r_y = round(point[1])
        r_z = round(point[2])
        i = 0
        for cell in grid[r_x][r_y][r_z]:
            if not np.sum(cell):
               grid[r_x][r_y][r_z][i] = point
               break
            i += 1
    return grid

【问题讨论】:

    标签: python numpy performance numba


    【解决方案1】:

    纯 Python 版本在 O(1) 时间追加项目,这要归功于字典容器,而 Numba 版本使用 O(n) 数组搜索(以 50 为界)。此外,np.zeros(shape=(100,100,100,50,3)) 分配了一个大约 1 GiB 的数组,这会导致在 RAM 中完成的计算过程中出现许多缓存未命中。同时,纯 Python 版本可能适合 CPU 缓存。有两种策略可以解决这个问题。

    第一个策略是使用 3 个容器。数组keyGrid 将每个网格单元映射到第二个数组valueGrid 中的偏移量,如果没有与该单元关联的点,则为-1。 valueGrid 包含给定网格单元的所有点。最后,countingGrid 计算每个网格单元的点数。这是一个未经测试的示例:

    @numba.jit(nopython=True)
    def split_into_grid(points: np.array):
        # Note: use np.uint16 if the actual number of filled grid cell is less than 65536
        keyGrid = np.full(shape=(100,100,100), -1, dtype=np.uint32)
        i = 0
        for point in points:
            r_x = round(point[0])
            r_y = round(point[1])
            r_z = round(point[2])
            if keyGrid[r_x,r_y,r_z] < 0:
                keyGrid[r_x,r_y,r_z] = i
                i += 1
        uniqueCloundPointCount = i
        # Note the number of points per grid cell is also bounded by the type
        countingGrid = np.zeros(uniqueCloundPointCount, dtype=np.uint8)
        valueGrid = np.full((uniqueCloundPointCount, 50, 3), -1, dtype=np.int32)
        for point in points:
            r_x = round(point[0])
            r_y = round(point[1])
            r_z = round(point[2])
            key = keyGrid[r_x,r_y,r_z]
            addingPos = countingGrid[key]
            valueGrid[key, addingPos] = point
            countingGrid[key] += 1
        return (keyGrid, valueGrid, countingGrid)
    

    请注意,只要不是所有网格单元都包含点,从而减少缓存未命中,数组就会非常小。此外,每个点的映射都是在(小)恒定时间内完成的,因此代码速度更快。

    第二种策略是使用与纯 Python 实现相同的方法,但使用 Numba 类型。事实上,Numba 实验性地supports dictionaries。您可以用字典检查 ((r_x, r_y, r_z) in grid) 替换异常,这将减少编译问题并可能加快生成的代码。请注意,Numba dict 通常与 CPython 一样快(如果不是更慢的话)。所以生成的代码可能不会快多少。

    【讨论】:

      猜你喜欢
      • 2014-02-23
      • 2021-03-10
      • 2018-11-12
      • 1970-01-01
      • 2014-08-14
      • 2017-12-22
      • 1970-01-01
      • 1970-01-01
      • 2017-12-06
      相关资源
      最近更新 更多