【问题标题】:Performance nested loop in numbanumba 中的性能嵌套循环
【发布时间】:2016-12-09 01:10:59
【问题描述】:

出于性能原因,除了 NumPy 之外,我还开始使用 Numba。我的 Numba 算法正在运行,但我觉得它应该更快。有一点正在减慢它的速度。这是代码sn-p:

@nb.njit
def rfunc1(ws, a, l):
    gn = a**l
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):
                    if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and
                    numpy.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1
                    if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and 
                    numpy.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1

在我看来,if 命令正在减慢它的速度。有没有更好的办法? (我在这里尝试实现的与之前发布的问题有关:Count possibilites for single crossoversws 是一个大小为 (gn, l) 的 NumPy 数组,其中包含 01

【问题讨论】:

  • 你意识到这与gn...的大小可怕地缩放?
  • 是的,当然,l 的最大大小为 9,a 始终为 2
  • 您使用的是 Python 2 还是 Python 3?
  • python 3 使用 anaconda

标签: python numpy numba


【解决方案1】:

鉴于想要确保所有项目相等的逻辑,您可以利用如果任何不相等的事实,您可以短路(即停止比较)计算。我稍微修改了您的原始函数,以便(1)您不会重复相同的比较两次,以及(2)对所有嵌套循环求和,因此可以比较返回:

@nb.njit
def rfunc1(ws, a, l):
    gn = a**l
    ysum = 0
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):
                    if np.all(ws[x1][0:i] == ws[x2][0:i]) and np.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1
                        ysum += 1

    return ysum


@nb.njit
def rfunc2(ws, a, l):
    gn = a**l
    ysum = 0
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):

                    incr_y = True
                    for j in range(i):
                        if ws[x1,j] != ws[x2,j]:
                            incr_y = False
                            break

                    if incr_y is True:
                        for j in range(i,l):
                            if ws[x1,j] != ws[x3,j]:
                                incr_y = False
                                break
                    if incr_y is True:
                        y += 1
                        ysum += 1
    return ysum

我不知道完整的功能是什么样的,但希望这可以帮助您走上正确的道路。

现在是一些时间:

l = 7
a = 2
gn = a**l
ws = np.random.randint(0,2,size=(gn,l))
In [23]:

%timeit rfunc1(ws, a , l)
1 loop, best of 3: 2.11 s per loop


%timeit rfunc2(ws, a , l)
1 loop, best of 3: 39.9 ms per loop

In [27]: rfunc1(ws, a , l)
Out[27]: 131919

In [30]: rfunc2(ws, a , l)
Out[30]: 131919

这使您的速度提高了 50 倍。

【讨论】:

  • 如何将jitnopython=True 一起使用?
  • njit 等价于jit(nopython=True)
  • 哦,我实际上得到了一个错误,当我尝试代码时:numba.errors.LoweringError: Failed at nopython (nopython mode backend) Internal error: NotImplementedError: No definition for lowering is(bool, bool) -> bool这可能是什么原因?
  • 代码行出现错误:if incr_y is True:
  • 代码有效,如果我使用整数而不是布尔值,但我不知道为什么
【解决方案2】:

为什么不概要分析你的代码并找到确切哪里?

分析的第一个目的是测试一个有代表性的系统,以确定哪些是慢的(或使用过多的 RAM,或导致过多的磁盘 I/O 或网络 I/O)。

分析通常会增加开销(通常会降低 10 到 100 倍的速度),并且您仍然希望代码的使用尽可能与实际情况相似。提取测试用例并隔离您需要测试的系统部分。最好,它已经被编写成它自己的一组模块。

基本技术包括 IPython 中的 %timeit 魔法、time.time(),timing decorator(参见下面的示例)。您可以使用这些技术来了解语句和函数的行为。

然后您有cProfile,它将为您提供问题的高级视图,以便您可以将注意力集中在关键功能上。

接下来,查看line_profiler,,它将逐行分析您选择的函数。结果将包括每行被调用的次数和在每行上花费的时间百分比。这正是您了解运行缓慢的原因和原因所需的信息。

perf stat 帮助您了解最终在 CPU 上执行的指令数量以及 CPU 缓存的利用率。这允许对矩阵运算进行高级调整。

heapy 可以跟踪 Python 内存中的所有对象。这对于寻找奇怪的内存泄漏非常有用。如果您使用的是长时间运行的系统, 那么dowser 会让您感兴趣:它允许您通过 Web 浏览器界面在长时间运行的过程中内省活动对象。

为了帮助您了解为什么您的 RAM 使用率很高,请查看 memory_profiler. 它对于在标记图表上随时间跟踪 RAM 使用情况特别有用,因此您可以向同事(或您自己)解释为什么某些功能使用更多 RAM超出预期。

示例:定义装饰器以自动进行计时测量

from functools import wraps

def timefn(fn):
    @wraps(fn)
    def measure_time(*args, **kwargs):
        t1 = time.time()
        result = fn(*args, **kwargs)
        t2 = time.time()
        print ("@timefn:" + fn.func_name + " took " + str(t2 - t1) + " seconds")
        return result
    return measure_time

@timefn
def your_func(var1, var2):
    ...

有关更多信息,我建议阅读以上内容的来源High performance Python(Micha Gorelick;Ian Ozsvald)。

【讨论】:

  • 这些都是很好的一般建议,但没有一个真正适用于这个问题。例如,您不能在 numba 函数内部使用line_profiler,也不能在nopython 模式下调用time.time。最初的问题是关于提高用 numba 编码的函数(可能已经确定为热点)的性能。通常情况下,您必须对 Numba 可以将什么转换为高性能 llvm 代码有直觉,而许多通用技术无法弄清楚这一点。
  • @JoshAdel:我想建议 OP 不必猜测瓶颈在哪里,但可以通过剖析来确定。为了未来读者的利益,我试图使分析选项有些完整(即使并非所有选项都适用于 OP 的情况)。
  • 这是 Google 图书中的章节链接:books.google.com/…
猜你喜欢
  • 1970-01-01
  • 2019-11-13
  • 2013-01-29
  • 2016-09-18
  • 2013-02-25
  • 1970-01-01
  • 1970-01-01
  • 2020-11-08
相关资源
最近更新 更多