【问题标题】:Optimize Numba and Numpy function优化 Numba 和 Numpy 功能
【发布时间】:2019-06-14 09:35:30
【问题描述】:

我正在尝试让这段代码运行得更快,但我找不到更多可以加快速度的技巧。

我得到了大约 3 微秒的运行时间,问题是我调用了这个函数几百万次,而这个过程最终需要很长时间。我在 Java 中有相同的实现(只有基本的 for 循环),基本上,即使是大型训练数据,计算也是即时的(这是用于 ANN)

有没有办法加快速度?

我在 Windows 10 上运行 Python 2.7、numba 0.43.1 和 numpy 1.16.3

x = True
expected = 0.5
eligibility = np.array([0.1,0.1,0.1])
positive_weight = np.array([0.2,0.2,0.2])
total_sq_grad_positive = np.array([0.1,0.1,0.1])
learning_rate = 1

@nb.njit(fastmath= True, cache = True)
def update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate):
        if x:
            g = np.multiply(eligibility,(1-expected))
        else:
            g = np.negative(np.multiply(eligibility,expected))
        gg = np.multiply(g,g)
        total_sq_grad_positive = np.add(total_sq_grad_positive,gg)
        #total_sq_grad_positive = np.where(divide_by_zero,total_sq_grad_positive, tsgp_temp)

        temp = np.multiply(learning_rate, g)
        temp2 = np.sqrt(total_sq_grad_positive)
        #temp2 = np.where(temp2 == 0,1,temp2 )
        temp2[temp2 == 0] = 1
        temp = np.divide(temp,temp2)
        positive_weight = np.add(positive_weight, temp)
        return [positive_weight, total_sq_grad_positive]

【问题讨论】:

  • "我在 Java 中有相同的实现(只有基本的 for 循环)" 用你的 Python 代码做同样的事情。每个向量化操作都转换为一个带有不必要临时数组的 for 循环。由于 Numba (Python->LLVM-IR->LLVM-Backend) 与 Clang(C->LLVM-IR->LLVM-Backend) 非常相似,因此请像 C 一样编写代码。

标签: python optimization numba


【解决方案1】:

编辑:看来@max9111 是对的。不必要的临时数组是开销的来源。

对于您的函数的当前语义,似乎有两个无法避免的临时数组——返回值[positive_weight, total_sq_grad_positive]。但是,令我震惊的是,您可能正计划使用此函数来更新这两个输入数组。如果是这样,通过就地执行所有操作,我们可以获得最大的加速。像这样:

import numba as nb
import numpy as np

x = True
expected = 0.5
eligibility = np.array([0.1,0.1,0.1])
positive_weight = np.array([0.2,0.2,0.2])
total_sq_grad_positive = np.array([0.1,0.1,0.1])
learning_rate = 1

@nb.njit(fastmath= True, cache = True)
def update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate):
    for i in range(eligibility.shape[0]):
        if x:
            g = eligibility[i] * (1-expected)
        else:
            g = -(eligibility[i] * expected)
        gg = g * g
        total_sq_grad_positive[i] = total_sq_grad_positive[i] + gg

        temp = learning_rate * g
        temp2 = np.sqrt(total_sq_grad_positive[i])
        if temp2 == 0: temp2 = 1
        temp = temp / temp2
        positive_weight[i] = positive_weight[i] + temp

@nb.jit
def test(n, *args):
    for i in range(n): update_weight_from_post_post_jit(*args)

如果更新输入数组不是您想要的,您可以使用函数开始

positive_weight = positive_weight.copy()
total_sq_grad_positive = total_sq_grad_positive.copy()

并按照原始代码返回它们。这几乎没有那么快,但仍然更快。


我不确定是否可以优化为“即时”; Java 能做到这一点让我有点惊讶,因为这对我来说是一个相当复杂的函数,需要像sqrt 这样的耗时操作。

但是,您是否在调用此函数的函数上使用了nb.jit?像这样:

@nb.jit
def test(n):
    for i in range(n): update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate)

在我的计算机上,这将运行时间减少了一半,这是有道理的,因为 Python 函数调用的开销非常高。

【讨论】:

  • 您好,感谢您的回复。事实上,我确实从另一个 jit 函数中调用了该函数,但是,我在这里使用了类,我做了一个技巧,我基本上从类中提取了这个方法,所以在某种程度上我可以做到这一点,因为我可以#不要在课堂上这样做。也许有更好的方法来规避它?
  • 只是对编译器标志的一点评论。很多时候,但不是在这种情况下(只有 3 个除法),通过零检查进行除法会降低性能,但不是必需的。您可以使用 error_model="numpy" 禁用它。
  • 我实施了上述更改,我绝对可以看到速度加快,谢谢!对此进行改进,是否有某种方法可以跳过中间并在类外调用函数并让 jit 类返回临时矩阵?基本上更新“自我”变量,但在类之外?谢谢!
  • 天哪,我刚刚刷新了关于 python 传递引用的知识,只要参数是一个对象(例如一个 numpy 数组),它们总是通过引用传递!
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2019-01-13
  • 2014-01-07
  • 2021-03-24
  • 1970-01-01
  • 2021-04-27
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多