【问题标题】:How to speed up numpy dot product with masks?如何使用掩码加速 numpy dot 积?
【发布时间】:2016-09-30 02:38:34
【问题描述】:

我有 2 个 numpy 数组,m1m2 其中m1 是大小 (nx1) 和 m2 是大小 (1xn),我想执行乘法 m1.dot(m2) 得到一个矩阵 @987654326 @ 大小 (nxn)

我想计算一个近似值m_approx,只使用m1m2 中的最高k 个元素并使所有其他元素为0(所有元素都是正数)。

我正在尝试加快乘法速度,因为 n 的大小对我来说很大(~10k)。我想选择一个小的k 说 100 并真正加快乘法速度。我尝试使用 numpy 稀疏矩阵,它确实使点积更快,但将 m1 和 m2 转换为稀疏向量非常慢。我怎样才能做到这一点?我觉得面具可能是实现这一目标的一种方式,但不确定如何?

【问题讨论】:

    标签: arrays performance numpy matrix-multiplication


    【解决方案1】:

    这可以通过使用np.argpartition 来获得最大k 元素的索引和np.ix_ 来解决,用于从m1m2 中选择和设置所选元素的点积。因此,我们将基本上有两个阶段来实现这一点,如下所述。

    首先,获取m1m2中最大的k元素对应的索引,像这样-

    m1_idx = np.argpartition(-m1,k,axis=0)[:k].ravel()
    m2_idx = np.argpartition(-m2,k)[:,:k].ravel()
    

    最后,设置输出数组。使用np.ix_ 分别沿行和列广播m1m2 索引,用于选择输出数组中要设置的元素。接下来,计算来自m1m2 的最高k 元素之间的点积,可以使用m1_idxm2_idx 的索引从m1m2 获得,就像这样 -

    out = np.zeros((n,n))
    out[np.ix_(m1_idx,m2_idx)] = np.dot(m1[m1_idx],m2[:,m2_idx])
    

    让我们通过一个示例来验证该实现,方法是针对另一个实现将较低的n-k 元素显式设置为0s 在m1m2 中,然后执行点积。这是执行检查的示例运行 -

    1) 输入:

    In [170]: m1
    Out[170]: 
    array([[ 0.26980423],
           [ 0.30698416],
           [ 0.60391089],
           [ 0.73246763],
           [ 0.35276247]])
    
    In [171]: m2
    Out[171]: array([[ 0.30523552, 0.87411242, 0.01071218, 0.81835438, 0.21693231]])
    
    In [172]: k = 2
    

    2) 运行建议的实现:

    In [173]: # Proposed solution code
         ...: m1_idx = np.argpartition(-m1,k,axis=0)[:k].ravel()
         ...: m2_idx = np.argpartition(-m2,k)[:,:k].ravel()
         ...: out = np.zeros((n,n))
         ...: out[np.ix_(m1_idx,m2_idx)] = np.dot(m1[m1_idx],m2[:,m2_idx])
         ...: 
    

    3) 使用替代实现来获取输出:

    In [174]: # Explicit setting of lower n-k elements to zeros for m1 and m2
         ...: m1[np.argpartition(-m1,k,axis=0)[k:]] = 0
         ...: m2[:,np.argpartition(-m2,k)[:,k:].ravel()] = 0
         ...: 
    
    In [175]: m1  # Verify m1 and m2 have lower n-k elements set to 0s
    Out[175]: 
    array([[ 0.        ],
           [ 0.        ],
           [ 0.60391089],
           [ 0.73246763],
           [ 0.        ]])
    
    In [176]: m2
    Out[176]: array([[ 0.       , 0.87411242, 0.        , 0.81835438, 0.        ]])
    
    In [177]: m1.dot(m2)  # Use m1.dot(m2) to directly get output. This is expensive.
    Out[177]: 
    array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
           [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
           [ 0.        ,  0.52788601,  0.        ,  0.49421312,  0.        ],
           [ 0.        ,  0.64025905,  0.        ,  0.59941809,  0.        ],
           [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])
    

    4) 验证我们提议的实现:

    In [178]: out   # Print output from proposed solution obtained earlier
    Out[178]: 
    array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
           [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
           [ 0.        ,  0.52788601,  0.        ,  0.49421312,  0.        ],
           [ 0.        ,  0.64025905,  0.        ,  0.59941809,  0.        ],
           [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])
    

    【讨论】:

    • 正是我要找的……不知道 np.ix_!
    • @Adi 很高兴为您提供帮助! :)
    猜你喜欢
    • 2017-02-23
    • 1970-01-01
    • 2021-02-25
    • 1970-01-01
    • 1970-01-01
    • 2021-03-01
    • 2018-06-01
    • 1970-01-01
    • 2013-05-19
    相关资源
    最近更新 更多