【问题标题】:Fastest way to find the maximum minimum value of three 'connected' matrices找到三个“连接”矩阵的最大最小值的最快方法
【发布时间】:2021-11-26 09:39:38
【问题描述】:

question 给出了两个矩阵的答案,但我不确定如何将此逻辑应用于三个成对连接的矩阵,因为没有“免费”索引。我想最大化以下功能:

f(i, j, k) = min(A(i, j), B(j, k), C(i,k))

其中ABC 是矩阵,ijk 是范围高达矩阵各自维度的索引。我想找到(i, j, k) 使得f(i, j, k) 最大化。我目前这样做如下:

import numpy as np
import itertools

I = 100
J = 150
K = 200

A = np.random.rand(I, J)
B = np.random.rand(J, K)
C = np.random.rand(I, K)

# All the different i,j,k
combinations = itertools.product(np.arange(I), np.arange(J), np.arange(K))
combinations = np.asarray(list(combinations))

A_vals = A[combinations[:,0], combinations[:,1]]
B_vals = B[combinations[:,1], combinations[:,2]]
C_vals = C[combinations[:,0], combinations[:,2]]

f = np.min([A_vals,B_vals,C_vals],axis=0)

best_indices = combinations[np.argmax(f)]
print(best_indices)

[ 49 14 136]

这比遍历所有(i, j, k) 更快,但是很多(和大部分)时间都花在了构建_vals 矩阵上。这是不幸的,因为它们包含许多重复值,因为相同的ijk 多次出现。有没有办法做到这一点,(1) numpy 的矩阵计算速度可以保持,(2) 我不必构造内存密集型_vals 矩阵。

在其他语言中,您可以构造矩阵,以便它们包含指向 ABC 的指针,但我不知道如何在 Python 中实现这一点。

编辑:查看后续问题以获取更多索引here

【问题讨论】:

  • 您是否尝试过numba 并遍历矩阵?在小矩阵示例上编译它并在更大的矩阵上计时
  • @dankal444 我试过了,但是很多时间都花在了numba无法转换的itertools上。不过,我会尝试将它与下面的答案结合起来!
  • 是的,完全不需要itertools,只需要在numba修饰函数中使用三个for函数,分别对应i、j和k

标签: python numpy matrix optimization memory


【解决方案1】:

您可以使用重复和拼贴“构建”组合,而不是使用 itertools:

A_=np.repeat(A.reshape((-1,1)),K,axis=0).T
B_=np.tile(B.reshape((-1,1)),(I,1)).T
C_=np.tile(C,J).reshape((-1,1)).T

并将它们传递给 np.min:

print((t:=np.argmax(np.min([A_,B_,C_],axis=0)) , t//(K*J),(t//K)%J, t%K,))

使用 timeit 重复 10 次代码大约需要 18 秒,而使用 numpy 只需大约 1 秒。

【讨论】:

  • 谢谢,有没有办法扩展这个?所以假设我们没有f(i,j,k),但我们有f(i,j,k,l) 和6 个成对矩阵。是否需要更复杂的重复和拼贴?
【解决方案2】:

我们可以使用numpy 广播暴力破解它,或者尝试一些智能分支切割:

import numpy as np

def bf(A,B,C):
    I,J = A.shape
    J,K = B.shape
    return np.unravel_index((np.minimum(np.minimum(A[:,:,None],C[:,None,:]),B[None,:,:])).argmax(),(I,J,K))

def cut(A,B,C):
    gmx = min(A.min(),B.min(),C.min())
    I,J = A.shape
    J,K = B.shape
    Y,X = np.unravel_index(A.argsort(axis=None)[::-1],A.shape)
    for y,x in zip(Y,X):
        if A[y,x] <= gmx:
            return gamx
        curr = np.minimum(B[x,:],C[y,:])
        camx = curr.argmax()
        cmx = curr[camx]
        if cmx >= A[y,x]:
            return y,x,camx
        if gmx < cmx:
            gmx = cmx
            gamx = y,x,camx
    return gamx
            
from timeit import timeit

I = 100
J = 150
K = 200

for rep in range(4):
    print("trial",rep+1)
    A = np.random.rand(I, J)
    B = np.random.rand(J, K)
    C = np.random.rand(I, K)

    print("results identical",cut(A,B,C)==bf(A,B,C))
    print("brute force",timeit(lambda:bf(A,B,C),number=2)*500,"ms")
    print("branch cut",timeit(lambda:cut(A,B,C),number=10)*100,"ms")

事实证明,在给定的尺寸下,分支切割是非常值得的:

trial 1
results identical True
brute force 169.74265850149095 ms
branch cut 1.951422297861427 ms
trial 2
results identical True
brute force 180.37619898677804 ms
branch cut 2.1000938024371862 ms
trial 3
results identical True
brute force 181.6371419990901 ms
branch cut 1.999850495485589 ms
trial 4
results identical True
brute force 217.75578951928765 ms
branch cut 1.5871295996475965 ms

树枝切割是如何工作的?

我们选择一个数组(比如 A)并从大到小对它进行排序。然后,我们逐个遍历数组,将每个值与其他数组中的适当值进行比较,并跟踪最小值的运行最大值。只要最大值不小于 A 中的剩余值,我们就完成了。由于这通常会很快发生,因此我们会节省大量资金。

【讨论】:

  • 谢谢,这令人印象深刻!有没有办法让这个更普遍?因此,假设我们没有 f(i,j,k),但我们有 f(i,j,k,l) 和 6 个成对矩阵 (A(i,j), B(i,k), C(i,l), D(j,k), E(j,l) F(k,l))。切割函数可以泛化吗?
  • 可能有但不是在我的脑海中。我想你可以尝试一些递归方案。这很复杂,足以保证 IMO 提出一个新问题。
  • 暴力破解方法可以推广吗?我试过了,但到目前为止还没有成功
  • 您必须构建索引(如[:,None,None,:])。它们必须是元组,您必须使用slice(None) 来表示:。要获得您需要的所有模式,您可以使用 itertools.combinationsnp.triu_indices
  • 再次感谢您,我已经使蛮力技术适用于任意索引。我在这里发布了一个关于任意数量索引的后续问题:stackoverflow.com/questions/69476527/…
【解决方案3】:

great answer of loopy walt 为基础 - 使用 numba 可以获得轻微的加速 (~20%):

import numba
@numba.jit(nopython=True)
def find_gamx(A, B, C, X, Y, gmx):
    gamx = (0, 0, 0)
    for y, x in zip(Y, X):
        if A[y, x] <= gmx:
            return gamx
        curr = np.minimum(B[x, :], C[y, :])
        camx = curr.argmax()
        cmx = curr[camx]
        if cmx >= A[y, x]:
            return y, x, camx
        if gmx < cmx:
            gmx = cmx
            gamx = y, x, camx
    return gamx


def cut_numba(A, B, C):
    gmx = min(A.min(), B.min(), C.min())
    I, J = A.shape
    J, K = B.shape
    Y, X = np.unravel_index(A.argsort(axis=None)[::-1], A.shape)

    gamx = find_gamx(A, B, C, X, Y, gmx)
    return gamx

from timeit import timeit

I = 100
J = 150
K = 200

for rep in range(40):
    print("trial", rep + 1)
    A = np.random.rand(I, J)
    B = np.random.rand(J, K)
    C = np.random.rand(I, K)

    print("results identical", cut(A, B, C) == bf(A, B, C))
    print("results identical", cut_numba(A, B, C) == bf(A, B, C))
    print("brute force", timeit(lambda: bf(A, B, C), number=2) * 500, "ms")
    print("branch cut", timeit(lambda: cut(A, B, C), number=10) * 100, "ms")
    print("branch cut_numba", timeit(lambda: cut_numba(A, B, C), number=10) * 100, "ms")
trial 1
results identical True
results identical True
brute force 38.774325 ms
branch cut 1.7196750999999955 ms
branch cut_numba 1.3950291999999864 ms
trial 2
results identical True
results identical True
brute force 38.77167049999996 ms
branch cut 1.8655760999999993 ms
branch cut_numba 1.4977325999999902 ms
trial 3
results identical True
results identical True
brute force 39.69611449999999 ms
branch cut 1.8876490000000024 ms
branch cut_numba 1.421615300000001 ms
trial 4
results identical True
results identical True
brute force 44.338816499999936 ms
branch cut 1.614051399999994 ms
branch cut_numba 1.3842962000000014 ms

【讨论】:

    猜你喜欢
    • 2021-11-27
    • 2021-11-25
    • 1970-01-01
    • 2019-05-21
    • 1970-01-01
    • 2014-06-17
    • 2018-10-21
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多