【问题标题】:Alternatives to numpy einsumnumpy einsum 的替代品
【发布时间】:2016-07-15 13:30:34
【问题描述】:

当我用N 行和n 列计算矩阵X 的三阶矩时,我通常使用einsum

M3 = sp.einsum('ij,ik,il->jkl',X,X,X) /N

这通常可以正常工作,但现在我正在使用更大的值,即n = 120N = 100000,而einsum 返回以下错误:

ValueError: 迭代器太大

做3个嵌套循环的替代方案是不可行的,所以我想知道是否有任何替代方案。

【问题讨论】:

    标签: python numpy matrix


    【解决方案1】:

    请注意,计算它至少需要进行 ~n3 × N = 1730 亿次操作(不考虑对称性),因此除非 numpy 可以访问 GPU 或其他东西,否则它会很慢。在具有 ~3 GHz CPU 的现代计算机上,假设没有 SIMD/并行加速,整个计算预计需要大约 60 秒才能完成。


    对于测试,让我们从 N = 1000 开始。我们将使用它来检查正确性和性能:

    #!/usr/bin/env python3
    
    import numpy
    import time
    
    numpy.random.seed(0)
    
    n = 120
    N = 1000
    X = numpy.random.random((N, n))
    
    start_time = time.time()
    
    M3 = numpy.einsum('ij,ik,il->jkl', X, X, X)
    
    end_time = time.time()
    
    print('check:', M3[2,4,6], '= 125.401852515?')
    print('check:', M3[4,2,6], '= 125.401852515?')
    print('check:', M3[6,4,2], '= 125.401852515?')
    print('check:', numpy.sum(M3), '= 218028826.631?')
    print('total time =', end_time - start_time)
    

    这大约需要 8 秒。这是基线。

    让我们从 3 嵌套循环作为替代方案开始:

    M3 = numpy.zeros((n, n, n))
    for j in range(n):
        for k in range(n):
            for l in range(n):
                M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l])
    # ~27 seconds
    

    这大约需要半分钟,不好!一个原因是因为这实际上是四个嵌套循环:numpy.sum 也可以被认为是一个循环。

    我们注意到,可以将总和转换为点积来移除第 4 个循环:

    M3 = numpy.zeros((n, n, n))
    for j in range(n):
        for k in range(n):
            for l in range(n):
                M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l]
    # 14 seconds
    

    现在好多了,但仍然很慢。但是我们注意到点积可以变成矩阵乘法来去掉一个循环:

    M3 = numpy.zeros((n, n, n))
    for j in range(n):
        for k in range(n):
            M3[j,k] = X[:,j] * X[:,k] @ X
    # ~0.5 seconds
    

    嗯?现在这甚至比einsum 更有效率!我们还可以检查答案是否确实正确。

    我们可以更进一步吗?是的!我们可以通过以下方式消除k 循环:

    M3 = numpy.zeros((n, n, n))
    for j in range(n):
        Y = numpy.repeat(X[:,j], n).reshape((N, n))
        M3[j] = (Y * X).T @ X
    # ~0.3 seconds
    

    我们还可以使用广播(即 X 的每一行使用a * [b,c] == [a*b, a*c])来避免使用numpy.repeat(感谢@Divakar):

    M3 = numpy.zeros((n, n, n))
    for j in range(n):
        Y = X[:,j].reshape((N, 1))
        ## or, equivalently: 
        # Y = X[:, numpy.newaxis, j]
        M3[j] = (Y * X).T @ X
    # ~0.16 seconds
    

    如果我们将其缩放到 N = 100000,则程序预计需要 16 秒,这在理论限制范围内,因此消除 j 可能没有太大帮助(但这可能会使代码非常难以理解) .我们可以接受这是最终的解决方案。


    注意:如果您使用的是 Python 2,a @ b 等效于 a.dot(b)

    【讨论】:

    • 真是个好主意。如果我可以在这里添加一点广播,我们可以避免创建Y,直接得到迭代输出:(X[:,None,j]*X).T @ X。这应该会给我们带来一些进一步的性能提升。
    • @Divakar:谢谢!已更新。
    猜你喜欢
    • 2016-03-23
    • 2021-03-02
    • 2023-03-20
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-11-15
    • 1970-01-01
    • 2017-02-01
    相关资源
    最近更新 更多