【问题标题】:Elementwise operations in mpmath slow compared to numpy and its solution与 numpy 及其解决方案相比,mpmath 中的元素操作速度较慢
【发布时间】:2014-10-24 12:57:44
【问题描述】:

我有一些计算涉及快速爆炸的阶乘,因此我决定使用任意精度库mpmath

我的代码如下所示:

import numpy as np
import mpmath as mp
import time

a    = np.linspace( 0, 100e-2, 100 )
b    = np.linspace( 0, np.pi )
c    = np.arange( 30 )

t    = time.time()
M    = np.ones( [ len(a), len(b), len(c) ] )
A, B = np.meshgrid( a, b, indexing = 'ij' )
temp = A**2 + B
temp = np.reshape( temp, [ len(a), len(b), 1 ] )
temp = np.repeat( temp, len(c), axis = 2 )
M   *= temp
print 'part1:      ', time.time() - t
t    = time.time()

temp = np.array( [ mp.fac(x) for x in c ] )
temp = np.reshape( temp, [ 1, 1, len(c) ] )
temp = np.repeat(  temp, len(a), axis = 0 )
temp = np.repeat(  temp, len(b), axis = 1 )
print 'part2 so far:', time.time() - t
M   *= temp
print 'part2 finally', time.time() - t
t    = time.time()

似乎花费最多时间的是最后一行,我怀疑这是因为M 有一堆floats 和temp 有一堆mp.mpfs。我尝试使用mp.mpfs 初始化M,但随后一切都变慢了。

这是我得到的输出:

part1:        0.00429606437683
part2 so far: 0.00184297561646
part2 finally 1.9477159977

有什么想法可以加快速度吗?

【问题讨论】:

    标签: python numpy factorial mpmath


    【解决方案1】:

    gmpy2 在这种类型的计算中明显快于mpmath。以下代码在我的机器上运行速度大约快 12 倍。

    import numpy as np
    import gmpy2 as mp
    import time
    
    a = np.linspace(0, 100e-2, 100)
    b = np.linspace(0, np.pi)
    c = np.arange(30)
    
    t = time.time()
    M = np.ones([len(a), len(b), len(c)])
    A, B = np.meshgrid( a, b, indexing = 'ij' )
    temp = A**2+B
    temp = np.reshape(temp, [len(a), len(b), 1])
    temp = np.repeat(temp, len(c), axis=2)
    M *= temp
    print 'part1:', time.time() - t
    t = time.time()
    
    temp = np.array([mp.factorial(x) for x in c])
    temp = np.reshape(temp, [1, 1, len(c)])
    temp = np.repeat(temp, len(a), axis=0)
    temp = np.repeat(temp, len(b), axis=1)
    print 'part2 so far:', time.time() - t
    M *= temp
    print 'part2:', time.time() - t
    t = time.time()
    

    mpmath 是用 Python 编写的,通常使用 Python 的本机整数进行计算。如果gmpy2 可用,它将使用gmpy2 提供的更快的整数类型。如果你只需要gmpy2直接提供的功能之一,那么直接使用gmpy2通常会更快。

    更新

    我进行了一些实验。实际发生的事情可能不是你所期望的。计算 temp 时,值可以是整数(math.factorialgmpy.facgmpy2.fac)或浮点值(gmpy2.factorialmpmath.fac)。当numpy 计算M *= temp 时,temp 中的所有值都将转换为 64 位浮点数。如果该值是整数,则转换会引发溢出错误。如果该值是浮点数,则转换返回无穷大。您可以通过将c 更改为np.arange(300) 并在最后打印M 来看到这一点。如果你使用gmpy.facmath.factorial,你会得到OverflowError。如果您使用mpmath.factorialgmpy2.factorial,您将不会得到OverflowError,但生成的M 将包含无穷大。

    如果您试图避免 OverflowError,则需要使用浮点值计算 temp,以便转换为 64 位浮点数将导致无穷大。

    如果您没有遇到OverflowError,那么math.factorial 是最快的选择。

    如果您试图同时避免OverflowError 和无穷大,那么您将需要始终使用mpmath.mpfgmpy2.mpfr 浮点类型。 (不要尝试使用gmpy.mpf。)

    更新 #2

    这是一个使用 gmpy2.mpfr 的示例,精度为 200 位。使用c=np.arange(30),它比原始示例快约 5 倍。我使用c = np.arange(300) 展示它,因为这会生成OverflowError 或无穷大。较大范围的总运行时间与您的原始代码大致相同。

    import numpy as np
    import gmpy2
    import time
    
    from gmpy2 import mpfr
    
    gmpy2.get_context().precision = 200
    
    a = np.linspace(mpfr(0), mpfr(1), 100)
    b = np.linspace(mpfr(0), gmpy2.const_pi())
    c = np.arange(300)
    
    t = time.time()
    M = np.ones([len(a), len(b), len(c)], dtype=object)
    A, B = np.meshgrid( a, b, indexing = 'ij' )
    temp = A**2+B
    temp = np.reshape(temp, [len(a), len(b), 1])
    temp = np.repeat(temp, len(c), axis=2)
    M *= temp
    print 'part1:', time.time() - t
    t = time.time()
    
    temp = np.array([gmpy2.factorial(x) for x in c], dtype=object)
    temp = np.reshape(temp, [1, 1, len(c)])
    temp = np.repeat(temp, len(a), axis=0)
    temp = np.repeat(temp, len(b), axis=1)
    print 'part2 so far:', time.time() - t
    M *= temp
    print 'part2:', time.time() - t
    t = time.time()
    

    免责声明:我维护gmpy2

    【讨论】:

    • 哇! 12 倍!我用 gmpy 运行它(我之前设置过MPMATH_NOGMPY),但它最多只能加快 40% 的速度。还有,我的mp.libmp.BACKENDgmpy不是gmpy2有区别吗?
    • 糟糕。我粘贴了错误的代码。我直接使用gmpy2。我会在一分钟内编辑答案。 gmpy2 提供对 MPFR(正确舍入的实浮点)和 MPC(正确舍入的复数浮点)任意精度库的完全访问。旧的gmpy 不支持 MPFR 和 MPC。
    • hm... 我刚刚注意到您所做的唯一更改是使用了阶乘函数。使我的代码变慢的是最后一个 M*=temp 而不是 mp.fac(x) 操作。在这种情况下,后端有这么重要吗?
    • 请仔细检查导入语句...我根本没有使用 mpmath。这不仅仅是改变后端:它使用 gmpy2 而不是 mpmath。
    • 我想我明白了...从某种意义上说,您是说gmpy2 元素会更快地执行M *= temp 操作,对吗?
    最近更新 更多