【问题标题】:Slicing a numpy array based on argmax of another numpy array [duplicate]基于另一个numpy数组的argmax切片一个numpy数组[重复]
【发布时间】:2020-06-15 09:41:34
【问题描述】:

我有两个数组,分别如下:

import numpy as np

np.random.seed(42)
a = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int)
b = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int)

数组a的输出:

array([[[37, 95, 73],
        [59, 15, 15],
        [ 5, 86, 60],
        [70,  2, 96],
        [83, 21, 18]],

       [[18, 30, 52],
        [43, 29, 61],
        [13, 29, 36],
        [45, 78, 19],
        [51, 59,  4]]])

数组b的输出如下:

array([[[60, 17,  6],
        [94, 96, 80],
        [30,  9, 68],
        [44, 12, 49],
        [ 3, 90, 25]],

       [[66, 31, 52],
        [54, 18, 96],
        [77, 93, 89],
        [59, 92,  8],
        [19,  4, 32]]])

现在我可以使用以下代码获取数组aargmax

idx = np.argmax(a, axis=0)
print(idx)

输出:

array([[0, 0, 0],
       [0, 1, 1],
       [1, 0, 0],
       [0, 1, 0],
       [0, 1, 0]], dtype=int64)

是否可以使用数组a 的argmax 输出对数组b 进行切片,以便得到以下输出:

array([[60, 17,  6],
       [94, 18, 96],
       [77,  9, 68],
       [44, 92, 49],
       [ 3, 4, 25]])

我尝试了不同的方法,但没有成功。请帮忙。

【问题讨论】:

  • 相当肯定有一些欺骗,但你可以做np.take_along_axis(b, idx[None], axis=0)[0]

标签: python numpy


【解决方案1】:

使用 numpy 高级索引:

import numpy as np

np.random.seed(42)
a = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int)
b = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int)

idx = np.argmax(a, axis=0)

_, m, n = a.shape
b[idx, np.arange(m)[:,None], np.arange(n)]
array([[60, 17,  6],
       [94, 18, 96],
       [77,  9, 68],
       [44, 92, 49],
       [ 3,  4, 25]])

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2012-09-17
    • 2022-11-19
    • 1970-01-01
    • 2021-05-07
    • 1970-01-01
    • 1970-01-01
    • 2017-12-27
    • 1970-01-01
    相关资源
    最近更新 更多