【问题标题】:Python numba.jit typesPython numba.jit 类型
【发布时间】:2013-12-21 01:26:00
【问题描述】:

我整天都在尝试从 numba 文档中推断类型是如何设置的。我已经掌握了一些方法,但现在我想创建一个函数,它返回一个一维数组和一个二维数组,并采用一堆 args,我很难再进一步:

@jit
class name(object)
    @double[:,:], double[:](double[:], double, double, int64)
    def solve(self, u0, a, b, n):
        self.t = linspace(a, b, n+1)
        dt = abs((b-a)/float(n))
        u = zeros(n+1, len([u0]))
        u[0] = u0
        u = advance(u, t, n, dt)
        return u.transpose(), t.transpose()   

上面抛出了这些异常:

Traceback (most recent call last):
  File "/home/marius/dev/python/inf1100/test_ODE.py", line 2, in <module>
    from DE import *
  File "/home/marius/dev/python/inf1100/DE.py", line 13
    @double[:,:], double[:](double[:], double, double, int64)
           ^
SyntaxError: invalid syntax

如果你能告诉我出了什么问题会很好,但是如果你能推荐一个一劳永逸地严格解释这些语法的文档,那就更好了。

感谢您的宝贵时间。

亲切的问候, 马吕斯

【问题讨论】:

    标签: python numba memoryview


    【解决方案1】:

    这是一个返回元组的方法的简单版本。这适用于我在 OS X 上使用 Numba 0.11.1:

    import numba
    import numpy as np
    
    @numba.jit
    class name(object):
        @numba.object_(numba.double[:], numba.double)
        def solve(self, x, a):
            y = np.empty(x.shape[0], dtype=np.float64)
            z = np.empty(x.shape[0], dtype=np.float64)
            for k in xrange(x.shape[0]):
                y[k] = x[k] * a
                z[k] = x[k] + a
    
            return y, z 
    

    然后使用它:

    C = name()
    a, b = C.solve(np.arange(5, dtype=np.float64), 3.0)
    

    ab 分别是:

    In [24]: a
    Out[24]:
    array([  0.,   3.,   6.,   9.,  12.])
    In [22]: b
    Out[22]:
    array([ 3.,  4.,  5.,  6.,  7.])
    

    【讨论】:

    • 从 0.24.0 版开始,这似乎已过时。 numba.object_ 不存在,我不确定它被替换为什么。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2015-12-21
    • 2020-09-22
    • 1970-01-01
    • 1970-01-01
    • 2020-12-26
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多