【问题标题】:Numba, strange type modificationNumba,奇怪的类型修改
【发布时间】:2020-01-27 21:53:43
【问题描述】:

我正在尝试进一步加快一些用 python 编写的代码的速度,这些代码使用 Numba 编译。在查看 numba 生成的程序集时,我注意到正在生成双精度操作,我觉得这很奇怪,因为输入和输出都应该是 float32。

我在 jitter 循环之外将变量/数组类型声明为 float32 并将它们传递给函数。奇怪的是,我发现在运行我的测试之后,变量“scalarout”被转换为python float,这实际上是一个64位的值。

我的代码:

from scipy import ndimage, misc
import matplotlib.pyplot as plt
import numpy.fft
from timeit import default_timer as timer
import numba
# numba.config.DUMP_ASSEMBLY = 1
from numba import float32
from numba import jit, njit, prange
from numba import cuda
import numpy as np
import scipy as sp

# import llvmlite.binding as llvm
# llvm.set_option('', '--debug-only=loop-vectorize')

@njit(fastmath=True, parallel=False)
def mydot(a, b, xlen, ylen, scalarout):
    scalarout = (np.float32)(0.0)
    for y in prange(ylen):
        for x in prange(xlen):
            scalarout += a[y, x] * b[y, x]
    return scalarout

# ======================================== TESTS ========================================

print()
xlen = 100000
ylen = 16
a = np.random.rand(ylen, xlen).astype(np.float32)
b = np.random.rand(ylen, xlen).astype(np.float32)
print("a type = ", type(a[1,1]))
scalarout = (np.float32)(0.0)
print("scalarout type, before execution = ", type(scalarout))
iters=1000

time = 100.0
for n in range(iters):
    start = timer()
    scalarout = mydot(a, b, xlen, ylen, scalarout)
    end = timer()
    if(end-start < time):
        time = end-start
print("Numba njit function time, in us = %16.10f" % ((end-start)*10**6))
print("function output = %f" % scalarout)
print("scalarout type, after execution = ", type(scalarout))

【问题讨论】:

  • 您可以使用mydot.inspect_types()获取内部类型。 scalarout = np.float32(0.) 或定义本地人,例如。 @njit(fastmath=True, parallel=False,locals={"scalarout": numba.types.float32})。只有最大值。 5% 的性能提升(问题主要受内存带宽限制),由于精度较低导致的结果差异更明显。

标签: python numba


【解决方案1】:

这更像是一个扩展评论而不是一个答案。如果将 scalarout 更改为长度为 1 的 float32 数组并对其进行修改,则输出为 float32。

@njit(fastmath=True, parallel=False)
def mydot(a, b, xlen, ylen):
    scalarout = np.array([0.0], dtype=np.float32)
    for y in prange(ylen):
        for x in prange(xlen):
            scalarout[0] += a[y, x] * b[y, x]
    return scalarout

如果将return scalarout 更改为return scalarout[0],那么输出又是一个python 浮点数。

mydot 的原始代码中,即使您编写 return np.float32(scalarout),结果也是一个 python 浮点数。

【讨论】:

  • 这是一个有趣的想法,它确实将输出类型强制为 np.float32。不幸的是,它导致函数执行速度慢了大约 5 倍。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2014-05-31
  • 2013-08-15
  • 2017-11-29
相关资源
最近更新 更多