【发布时间】:2020-04-27 20:58:52
【问题描述】:
我致力于优化 python 模块中的一些代码。我已经确定了瓶颈,并且是一个代码 sn-p,它在numpy 中进行了一些计算。即如下代码:
xh = np.multiply(K_Rinv[0, 0], x )
xh += np.multiply(K_Rinv[0, 1], y)
xh += np.multiply(K_Rinv[0, 2], h)
yh = np.multiply(K_Rinv[1, 0], x)
yh += np.multiply(K_Rinv[1, 1], y)
yh += np.multiply(K_Rinv[1, 2], h)
q = np.multiply(K_Rinv[2, 0], x)
q += np.multiply(K_Rinv[2, 1], y)
q += np.multiply(K_Rinv[2, 2], h)
其中 x,y 和 h 是形状为 (4206,5749) 的 np.ndarray,K_Rinv 是形状为 (3,3) 的 np.ndarray。
这段代码 sn-p 被多次调用,占用了整个代码 50% 以上的时间。
有没有办法加快速度?还是就这样,无法加速。
编辑1:
感谢您的回答。在遇到 numba 问题后(请参阅最后的错误消息),我尝试了使用 numexpr 的建议。但是,使用此解决方案时我的代码中断了。所以我检查了结果是否相同而它们不同。这是我正在使用的代码:
xh_1 = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[0, 0], 'b1': x,
'a2': K_Rinv[0, 1], 'b2': y,
'a3': K_Rinv[0, 2], 'b3': h})
yh_1 = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[1, 0], 'b1': x,
'a2': K_Rinv[1, 1], 'b2': y,
'a3': K_Rinv[1, 2], 'b3': h})
q_1 = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[2, 0], 'b1': x,
'a2': K_Rinv[2, 1], 'b2': y,
'a3': K_Rinv[2, 2], 'b3': h})
xh_2 = np.multiply(K_Rinv[0, 0], x )
xh_2 += np.multiply(K_Rinv[0, 1], y)
xh_2 += np.multiply(K_Rinv[0, 2], h)
yh_2 = np.multiply(K_Rinv[1, 0], x)
yh_2 += np.multiply(K_Rinv[1, 1], y)
yh_2 += np.multiply(K_Rinv[1, 2], h)
q_2 = np.multiply(K_Rinv[2, 0], x)
q_2 += np.multiply(K_Rinv[2, 1], y)
q_2 += np.multiply(K_Rinv[2, 2], h)
check1 = xh_1.all() == xh_2.all()
check2 = yh_1.all() == yh_2.all()
check3 = q_2.all() == q_1.all()
print ( " Check1 :{} , Check2: {} , Check3:{}" .format (check1,check2,check3))
我对 numexpr 没有任何经验,它们通常不一样吗?
来自 numba 的错误:
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 420, in _compile_for_args
raise e
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 353, in _compile_for_args
return self.compile(tuple(argtypes))
File "/usr/local/lib/python3.6/dist-packages/numba/compiler_lock.py", line 32, in _acquire_compile_lock
return func(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 768, in compile
cres = self._compiler.compile(args, return_type)
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 77, in compile
status, retval = self._compile_cached(args, return_type)
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 91, in _compile_cached
retval = self._compile_core(args, return_type)
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 109, in _compile_core
pipeline_class=self.pipeline_class)
File "/usr/local/lib/python3.6/dist-packages/numba/compiler.py", line 551, in compile_extra
return pipeline.compile_extra(func)
File "/usr/local/lib/python3.6/dist-packages/numba/compiler.py", line 327, in compile_extra
raise e
File "/usr/local/lib/python3.6/dist-packages/numba/compiler.py", line 321, in compile_extra
ExtractByteCode().run_pass(self.state)
File "/usr/local/lib/python3.6/dist-packages/numba/untyped_passes.py", line 67, in run_pass
bc = bytecode.ByteCode(func_id)
File "/usr/local/lib/python3.6/dist-packages/numba/bytecode.py", line 215, in __init__
self._compute_lineno(table, code)
File "/usr/local/lib/python3.6/dist-packages/numba/bytecode.py", line 237, in _compute_lineno
known = table[_FIXED_OFFSET].lineno
KeyError: 2
编辑2 cmets和答案的坦克。 因此,在再次查看代码后,numexpr 解决方案有效。非常感谢。我仍然在一个单独的 python 文件中进行了一些测试,以查看 numba 代码是否在那里工作并且它可以工作,但速度很慢。请参阅下面我使用的代码:
import numpy as np
import numba as nb
import numexpr
from datetime import datetime
def calc(x,y,h,K_Rinv):
xh_2 = np.multiply(K_Rinv[0, 0], x )
xh_2 += np.multiply(K_Rinv[0, 1], y)
xh_2 += np.multiply(K_Rinv[0, 2], h)
yh_2 = np.multiply(K_Rinv[1, 0], x)
yh_2 += np.multiply(K_Rinv[1, 1], y)
yh_2 += np.multiply(K_Rinv[1, 2], h)
q_2 = np.multiply(K_Rinv[2, 0], x)
q_2 += np.multiply(K_Rinv[2, 1], y)
q_2 += np.multiply(K_Rinv[2, 2], h)
return xh_2, yh_2, q_2
def calc_numexpr(x,y,h,K_Rinv):
xh = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[0, 0], 'b1': x,
'a2': K_Rinv[0, 1], 'b2': y,
'a3': K_Rinv[0, 2], 'b3': h})
yh = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[1, 0], 'b1': x,
'a2': K_Rinv[1, 1], 'b2': y,
'a3': K_Rinv[1, 2], 'b3': h})
q = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[2, 0], 'b1': x,
'a2': K_Rinv[2, 1], 'b2': y,
'a3': K_Rinv[2, 2], 'b3': h})
return xh,yh,q
@nb.njit(fastmath=True,parallel=True)
def calc_nb(x,y,h,K_Rinv):
xh=np.empty_like(x)
yh=np.empty_like(x)
q=np.empty_like(x)
for i in nb.prange(x.shape[0]):
for j in range(x.shape[1]):
xh[i,j]=K_Rinv[0, 0]*x[i,j]+K_Rinv[0, 1]* y[i,j]+K_Rinv[0, 2]*h[i,j]
yh[i,j]=K_Rinv[1, 0]*x[i,j]+K_Rinv[1, 1]* y[i,j]+K_Rinv[1, 2]*h[i,j]
q[i,j] =K_Rinv[2, 0]*x[i,j]+K_Rinv[2, 1]* y[i,j]+K_Rinv[2, 2]*h[i,j]
return xh,yh,q
x = np.random.random((4206, 5749))
y = np.random.random((4206, 5749))
h = np.random.random((4206, 5749))
K_Rinv = np.random.random((3, 3))
start = datetime.now()
x_calc,y_calc,q_calc = calc(x,y,h,K_Rinv)
end = datetime.now()
print("Calc took: {} ".format(end - start))
start = datetime.now()
x_numexpr,y_numexpr,q_numexpr = calc_numexpr(x,y,h,K_Rinv)
end = datetime.now()
print("Calc_numexpr took: {} ".format(end - start))
start = datetime.now()
x_nb,y_nb,q_nb = calc_nb(x,y,h,K_Rinv)
end = datetime.now()
print("Calc nb took: {} ".format(end - start))
check_nb_q = (q_calc==q_nb).all()
check_nb_y = (y_calc==y_nb).all()
check_nb_x = (x_calc==x_nb).all()
check_numexpr_q = (q_calc==q_numexpr).all()
check_numexpr_y = (y_calc==y_numexpr).all()
check_numexpr_x = (x_calc==x_numexpr).all()
print("Checks for numexpr: {} , {} ,{} \nChecks for nb: {} ,{}, {}" .format(check_numexpr_x,check_numexpr_y,check_numexpr_q,check_nb_x,check_nb_y,check_nb_q))
其输出如下:
Calc took: 0:00:01.944150
Calc_numexpr took: 0:00:00.616224
Calc nb took: 0:00:01.553058
Checks for numexpr: True , True ,True
Checks for nb: False ,False, False
因此 numba 版本无法按预期工作。知道我做错了什么吗?希望让 numba 解决方案也能正常工作。
附言。 nb.版本是'0.47.0'
【问题讨论】:
-
@LucaNeri,
numpy是一个特殊的标签,因为矢量化计算是整个包的重点——所以我们对 CodeReview 没有太多兴趣。 -
您使用的是哪个 Numba 版本
nb.__version__?您是否尝试过完全复制我的示例? -
这是一个时间问题。 Numba 是一个 jit 编译器,编译需要 aprox。 1s,但仅在第一次调用时。所有进一步的调用都快得多。
-
@max9111 再次感谢。认为是时候让我谷歌一下 jit 编译器是什么了 :) 。附言。第二次调用确实更快。 0.8s ,仍然不如 numexpr 快。顺便说一下结果不一样,有什么想法吗?
-
fastmath 标志 clang.llvm.org/docs/… 会稍微改变结果。详细地说,由于 SIMD 向量化(FMA3 指令)en.wikipedia.org/wiki/…,我预计会产生一些不同的结果,如果您使用 np.allclose() 进行检查,您将得到 True,但您也可以禁用 fastmath。关于计时:1)检查 fastmath=False,2)您在开始计时之前是否添加了额外的一行
x_nb,y_nb,q_nb = calc_nb(x,y,h,K_Rinv)?
标签: python performance numpy optimization multiplication