【问题标题】:How to speed up a python function with numba如何使用 numba 加速 python 函数
【发布时间】:2021-10-28 05:23:30
【问题描述】:

我正在尝试使用 numba 加快我对 Floyd-Steinberg's dithering algorithm 的实施。在阅读完初学者指南后,我将@jit 装饰器添加到我的代码中:

def findClosestColour(pixel):
    colors = np.array([[255, 255, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [0, 128, 0], [253, 134, 18]])
    distances = np.sum(np.abs(pixel[:, np.newaxis].T - colors), axis=1)
    shortest = np.argmin(distances)
    closest_color = colors[shortest]
    return closest_color

@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def floydDither(img_array):
    height, width, _ = img_array.shape
    for y in range(0, height-1):
        for x in range(1, width-1):
            old_pixel = img_array[y, x, :]
            new_pixel = findClosestColour(old_pixel)
            img_array[y, x, :] = new_pixel
            quant_error = new_pixel - old_pixel
            img_array[y, x+1, :] =  img_array[y, x+1, :] + quant_error * 7/16
            img_array[y+1, x-1, :] =  img_array[y+1, x-1, :] + quant_error * 3/16
            img_array[y+1, x, :] =  img_array[y+1, x, :] + quant_error * 5/16
            img_array[y+1, x+1, :] =  img_array[y+1, x+1, :] + quant_error * 1/16
    return img_array

但是,我收到以下错误:

Untyped global name 'findClosestColour': Cannot determine Numba type of <class 'function'>

我想我理解 numba 不知道 findClosestColour 的类型,但我刚刚开始使用 numba,不知道如何处理错误。

这是我用来测试函数的代码:

image = cv2.imread('logo.jpeg')
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
im_out = floydDither(img)

这是我使用的测试图像。

【问题讨论】:

  • 您是否尝试将@jit 装饰器也应用于findClosestColour
  • 您是否尝试应用 @jit 装饰器并将 nopython 参数设置为 False?
  • @folgerit 我将@jit 装饰器应用到findClosestColour 并收到此错误:numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
  • @yannziselman 我尝试将 @jit 装饰器参数设置为 false,但我收到了一些弃用警告。与使用普通 python 相比,整体性能没有提高。
  • 请提供一个小型可运行的img_array 以供进一步调查。

标签: python numpy numba


【解决方案1】:

首先,不可能从 Numba nopython jitted 函数调用纯 Python 函数(又名 njit 函数)。这是因为 Numba 需要在编译时跟踪类型以生成高效的二进制文件。

此外,Numba 无法编译表达式 pixel[:, np.newaxis].T,因为 np.newaxis 似乎尚不受支持(可能是因为 np.newaxisNone)。您可以改用pixel.reshape(3, -1).T

请注意,您应该注意类型,因为当两个变量的类型均为 np.uint8 时执行 a - b 会导致可能的溢出(例如 0 - 1 == 255,甚至更令人惊讶: 0 - 256 = 65280b 是字面整数且a 类型为np.uint8)。请注意,数组是就地计算的,并且像素是在之前写入的


尽管 Numba 做得很好,但生成的代码效率不会很高。您可以使用循环自己迭代颜色以找到最小索引。这要好一些,因为它不会生成许多小的临时数组。您还可以指定类型,以便 Numba 提前编译函数。话虽如此。这也使代码级别更低,因此更冗长/更难维护。

这是一个优化的实现

@nb.njit('int32[::1](uint8[::1])')
def nb_findClosestColour(pixel):
    colors = np.array([[255, 255, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [0, 128, 0], [253, 134, 18]], dtype=np.int32)
    r,g,b = pixel.astype(np.int32)
    r2,g2,b2 = colors[0]
    minDistance = np.abs(r-r2) + np.abs(g-g2) + np.abs(b-b2)
    shortest = 0
    for i in range(1, colors.shape[0]):
        r2,g2,b2 = colors[i]
        distance = np.abs(r-r2) + np.abs(g-g2) + np.abs(b-b2)
        if distance < minDistance:
            minDistance = distance
            shortest = i
    return colors[shortest]

@nb.njit('uint8[:,:,::1](uint8[:,:,::1])')
def nb_floydDither(img_array):
    assert(img_array.shape[2] == 3)
    height, width, _ = img_array.shape
    for y in range(0, height-1):
        for x in range(1, width-1):
            old_pixel = img_array[y, x, :]
            new_pixel = nb_findClosestColour(old_pixel)
            img_array[y, x, :] = new_pixel
            quant_error = new_pixel - old_pixel
            img_array[y, x+1, :] =  img_array[y, x+1, :] + quant_error * 7/16
            img_array[y+1, x-1, :] =  img_array[y+1, x-1, :] + quant_error * 3/16
            img_array[y+1, x, :] =  img_array[y+1, x, :] + quant_error * 5/16
            img_array[y+1, x+1, :] =  img_array[y+1, x+1, :] + quant_error * 1/16
    return img_array

naive 版本快 14 倍,而最后一个版本快 19 倍

【讨论】:

  • 干得好,先生。起首!
猜你喜欢
  • 2018-07-28
  • 1970-01-01
  • 2015-09-27
  • 2021-10-17
  • 2018-12-23
  • 2022-09-23
  • 2023-04-04
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多