【问题标题】:How can I speed up python's dictionary with Numba如何使用 Numba 加速 python 的字典
【发布时间】:2019-12-26 16:09:33
【问题描述】:

我需要在布尔值数组中存储一些单元格。起初我使用 numpy,但是当数组开始占用大量内存时,我有了一个想法,即在字典中以元组作为键存储非零元素(因为它是可散列类型)。例如: {(0, 0, 0): True, (1, 2, 3): True}(这是“3D 数组”中的两个单元格,索引分别为 0,0,0 和 1,2,3,但维数是预先未知的,并在我运行算法时定义)。 这很有帮助,因为非零单元格只填充了数组的一小部分。

为了从这个字典中写入和获取值,我需要使用循环:

def fill_cells(indices, area_dict):
    for i in indices:
        area_dict[tuple(i)] = 1

def get_cells(indices, area_dict):
    n = len(indices)
    out = np.zeros(n, dtype=np.bool)
    for i in range(n):
        out[i] = tuple(indices[i]) in area_dict.keys()
    return out

现在我需要使用 Numba 加快速度。 Numba 不支持原生 Python 的 dict(),所以我使用了 numba.typed.Dict。 问题是 Numba 在定义函数阶段想知道键的大小,所以我什至无法创建字典(键的长度事先未知,在调用函数时定义):

@njit
def make_dict(n):
    out = {(0,)*n:True}
    return out

Numba 无法正确推断字典键的类型并返回错误:

Compilation is falling back to object mode WITH looplifting enabled because Function "make_dict" failed type inference due to: Invalid use of Function(<built-in function mul>) with argument(s) of type(s): (tuple(int64 x 1), int64)

如果我在函数中将 n 更改为具体数字,它会起作用。我用这个技巧解决了它:

n = 10
s = '@njit\ndef make_dict():\n\tout = {(0,)*%s:True}\n\treturn out' % n
exec(s)

但我认为这是错误的低效方式。而且我需要将我的 fill_cells 和 get_cells 函数与 @njit 装饰器一起使用,但 Numba 返回相同的错误,因为我试图在此函数中从 numpy 数组创建元组。

我了解 Numba 的基本限制(以及一般的编译),但也许有一些方法可以加快功能,或者,也许您对我的单元存储问题有另一种解决方案?

【问题讨论】:

  • 你考虑过稀疏矩阵吗?
  • @Marat 是的,我自己实现了基于键字典的稀疏矩阵(函数 fill_cells 和 get_cells 是此实现的一部分)。我意识到这是稀疏矩阵非常常见的解决方案。问题是我需要加快这个实现。另外,我不需要对其进行矩阵运算,只需存储和获取值,也许它可以扩展可能的解决方案集。
  • 像dicts这样的原生数据结构效率很低。 scipy.sparse 提供了 C 实现,它的性能可能会比原生结构高一个数量级。
  • @Marat 是的,我发现 scipy.sparse 比我的解决方案更快,但它只适用于 2D 矩阵。我需要使用任意维度。我没有找到比自己编写并使用 Numba 加快速度更好的解决方案(这是我现在正在尝试做的事情,我遇到了我在问题中描述的问题)。
  • 你以前见过gist.github.com/sklam/830fe01343ba95828c3b24c391855c86吗?当我想使用数组作为矩阵的索引时,我遇到了同样的问题。它只需要在顶部稍作调整,因为dicts没有ndim。

标签: python arrays python-3.x dictionary numba


【解决方案1】:

最终解决方案:

主要问题是 Numba 在定义创建元组的函数时需要知道元组的长度。诀窍是每次都重新定义功能。我需要使用定义函数的代码生成字符串并使用 exec() 运行它:

n = 10
s = '@njit\ndef arr_to_tuple(a):\n\treturn (' + ''.join('a[%i],' % i for i in range(n)) + ')'
exec(s)

之后我可以调用 arr_to_tuple(a) 来创建可以在另一个 @njit - 修饰函数中使用的元组。

例如,创建元组键的空字典,需要解决问题:

@njit
def make_empty_dict():
    tpl = arr_to_tuple(np.array([0]*5))
    out = {tpl:True}
    del out[tpl]
    return out

我在字典中写了一个元素,因为它是 Numba 推断类型的一种方式。

另外,我需要使用问题中描述的 fill_cellsget_cells 函数。这就是我用 Numba 重写它们的方式:

书写元素。刚刚把 tuple() 改成了 arr_to_tuple():

@njit
def fill_cells_nb(indices, area_dict):
    for i in range(len(indices)):
        area_dict[arr_to_tuple(indices[i])] = True

从字典中获取元素需要一些令人毛骨悚然的代码:

@njit
def get_cells_nb(indices, area_dict):
    n = len(indices)
    out = np.zeros(n, dtype=np.bool_)
    for i in range(n):
        new_len = len(area_dict)
        tpl = arr_to_tuple(indices[i])
        area_dict[tpl] = True
        old_len = len(area_dict)
        if new_len == old_len:
            out[i] = True
        else:
            del area_dict[tpl]
    return out

我的 Numba (0.46) 版本不支持 .contains (in) 运算符和 try-except 构造。如果您有支持它的版本,您可以为其编写更多“常规”解决方案。

所以当我想检查字典中是否存在具有某个索引的元素时,我会记住它的长度,然后我在字典中写一些带有提到索引的东西。如果长度改变,我得出结论该元素不存在。否则元素存在。看起来很慢的解决方案,但事实并非如此。

速度测试:

解决方案运行速度惊人。我用 %timeit 对它进行了测试,并与本机 Python 优化代码进行了比较:

  1. arr_to_tuple() 比常规 tuple() 函数快 5 倍
  2. get_cells with numba原生 Python 编写的 get_cells 相比,一个元素的速度提高 3 倍,大型元素数组的速度提高 40 倍
  3. 原生 Python 编写的 fill_cells 相比,使用 numba 的 fill_cells 一个元素快 4 倍,大型元素数组快 40 倍

【讨论】:

  • 您是否将性能与键入的 List 进行了比较?看来您实际上并不需要存储“True”,因为存储索引已经暗示了这一点。您还可以考虑编写一个类似于 Numpy 的 unravel_indexravel_multi_index 函数,使索引始终存储一维。
  • @RutgerKassies 我在您发表评论后进行了比较。使用类型化列表要慢得多,因为需要在循环中检查列表的元素才能进行检查。函数的执行时间取决于列表的大小,与由于散列键而具有恒定时间的 Dict 相反。
  • 顺便说一句,我发现对于这样的问题,使用本机 Python hash() 函数并在循环中使用它而不是 ravel_multi_index 可能是合理的(更快)。这个循环可以用@njit 修饰,而不需要对代码进行重大更改。
猜你喜欢
  • 1970-01-01
  • 2018-07-28
  • 2015-09-27
  • 2018-12-23
  • 2022-09-23
  • 2023-04-04
  • 1970-01-01
  • 2019-11-08
  • 1970-01-01
相关资源
最近更新 更多