【发布时间】:2019-06-14 09:35:30
【问题描述】:
我正在尝试让这段代码运行得更快,但我找不到更多可以加快速度的技巧。
我得到了大约 3 微秒的运行时间,问题是我调用了这个函数几百万次,而这个过程最终需要很长时间。我在 Java 中有相同的实现(只有基本的 for 循环),基本上,即使是大型训练数据,计算也是即时的(这是用于 ANN)
有没有办法加快速度?
我在 Windows 10 上运行 Python 2.7、numba 0.43.1 和 numpy 1.16.3
x = True
expected = 0.5
eligibility = np.array([0.1,0.1,0.1])
positive_weight = np.array([0.2,0.2,0.2])
total_sq_grad_positive = np.array([0.1,0.1,0.1])
learning_rate = 1
@nb.njit(fastmath= True, cache = True)
def update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate):
if x:
g = np.multiply(eligibility,(1-expected))
else:
g = np.negative(np.multiply(eligibility,expected))
gg = np.multiply(g,g)
total_sq_grad_positive = np.add(total_sq_grad_positive,gg)
#total_sq_grad_positive = np.where(divide_by_zero,total_sq_grad_positive, tsgp_temp)
temp = np.multiply(learning_rate, g)
temp2 = np.sqrt(total_sq_grad_positive)
#temp2 = np.where(temp2 == 0,1,temp2 )
temp2[temp2 == 0] = 1
temp = np.divide(temp,temp2)
positive_weight = np.add(positive_weight, temp)
return [positive_weight, total_sq_grad_positive]
【问题讨论】:
-
"我在 Java 中有相同的实现(只有基本的 for 循环)" 用你的 Python 代码做同样的事情。每个向量化操作都转换为一个带有不必要临时数组的 for 循环。由于 Numba (Python->LLVM-IR->LLVM-Backend) 与 Clang(C->LLVM-IR->LLVM-Backend) 非常相似,因此请像 C 一样编写代码。
标签: python optimization numba