虽然不是很优雅,但实现此目的的一种方法是使用 broadcasting 和 fancy/advanced indexing:
import numpy as np
arr = np.array([[1,2,3,8], [3,0,2,1],[5, 4, 25, 67], [11, 1, 6, 10]])
首先得到按列总和排序的中间数组。
arr1 = arr[:, arr.sum(axis = 0).argsort()]
print(arr1)
# array([[ 2, 1, 3, 8],
# [ 0, 3, 2, 1],
# [ 4, 5, 25, 67],
# [ 1, 11, 6, 10]])
接下来获取每列中出现最大值的位置。
idx = arr1.argmax(axis = 0)
print(idx)
# array([2, 3, 2, 2])
现在准备行和列索引数组以从arr1 切片。请注意,计算rows 的行本质上对上面idx 中的每个元素执行{0, 1, 2, 3} 的集合差异(通常是arr 中的行数),并将它们存储在rows 矩阵的列。
k = np.arange(arr1.shape[0]) # original number of rows
rows = np.nonzero(k != idx[:, None])[1].reshape(-1, arr1.shape[0] - 1).T
cols = np.arange(arr1.shape[1])
print(rows)
# array([[0, 0, 0, 0],
# [1, 1, 1, 1],
# [3, 2, 3, 3]])
请注意,cols 将被广播成rows 的形状,同时由它们索引arr1。为了您的理解,cols 看起来与rows 兼容:
print(np.broadcast_to(cols, rows.shape))
# array([[0, 1, 2, 3],
# [0, 1, 2, 3],
# [0, 1, 2, 3]])
基本上,当您(fancy) 索引arr1 时,您将获得第 0、1 和 3 行的第 0 列;第 0、1 和 2 行的第一列,依此类推。希望你能明白。
arr2 = arr1[rows, cols]
print(arr2)
# array([[ 2, 1, 3, 8],
# [ 0, 3, 2, 1],
# [ 1, 5, 6, 10]])
您可以编写一个简单的函数来组合这些步骤,以方便执行乘法运算。