【问题标题】:Compiling njit nopython version of function fails due to data types由于数据类型,编译 njit nopython 版本的函数失败
【发布时间】:2017-07-29 04:03:11
【问题描述】:

我正在 njit 中编写一个函数来加速非常慢的水库操作优化代码。该函数根据水库水位和闸门可用性返回溢出释放的最大值。我传入了一个参数大小,它指定要计算的流数(在某些调用中它是一个,在某些调用中它是多个)。我还传入了一个 numpy.zeros 数组,然后我可以用函数输出填充该数组。函数的简化版写法如下:

import numpy as np
from numba import njit

@njit(cache=True)
def fncMaxFlow(elev, flag, size, MaxQ):
    if (flag == 1): # SPOG2 running
        if size==0:
            if (elev>367.28):
                return 861.1 
            else: return 0
        else:
            for i in range(size):
                if((elev[i]>367.28) & (elev[i]<385)):
                    MaxQ[i]=861.1
            return MaxQ
    else:
        if size==0: return 0
        else: return MaxQ

fncMaxFlow(np.random.randint(368, 380, 3), 1, 3, np.zeros(3))

我得到的错误:

Can't unify return type from the following types: array(float64, 1d, C), float64, int32

这是什么原因?是否有任何解决方法或我缺少某些步骤,以便我可以使用 numba 来加快速度?这个函数和其他类似函数被调用了数百万次,因此它们是计算效率的主要因素。任何建议都会有所帮助 - 我对 python 很陌生。

【问题讨论】:

    标签: python performance jit numba


    【解决方案1】:

    numba 函数中的变量必须具有一致的类型,包括返回变量。在您的代码中,您可以返回 MaxQ(一个数组)、861.1(一个浮点数)或 0(一个 int)。

    您需要重构此代码,使其始终返回一致的类型,而不管代码路径如何。

    还请注意,在您将 numpy 数组与标量 (elev &gt; 367.28) 进行比较的几个地方,您得到的是一个布尔值数组,这会导致您出现问题。因此,您的示例函数不能作为纯 python 函数(删除 numba 装饰器)运行。

    【讨论】:

    • 谢谢。我在 else 循环中为大于 1 值的大小添加了 elev 的索引值。如何在不使用 numpy 的情况下将 int 和 float 返回转换为数组?我试图将 MaxQ 作为零数组传入,然后将 MaxQ[0] 的值设置为 861.1,将 MaxQ 的值设置为 861.1 或将数组保留为零,但我仍然收到错误
    猜你喜欢
    • 1970-01-01
    • 2016-03-27
    • 1970-01-01
    • 1970-01-01
    • 2015-06-14
    • 2020-02-23
    • 1970-01-01
    • 2021-10-16
    相关资源
    最近更新 更多