【发布时间】:2020-06-15 09:41:34
【问题描述】:
我有两个数组,分别如下:
import numpy as np
np.random.seed(42)
a = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int)
b = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int)
数组a的输出:
array([[[37, 95, 73],
[59, 15, 15],
[ 5, 86, 60],
[70, 2, 96],
[83, 21, 18]],
[[18, 30, 52],
[43, 29, 61],
[13, 29, 36],
[45, 78, 19],
[51, 59, 4]]])
数组b的输出如下:
array([[[60, 17, 6],
[94, 96, 80],
[30, 9, 68],
[44, 12, 49],
[ 3, 90, 25]],
[[66, 31, 52],
[54, 18, 96],
[77, 93, 89],
[59, 92, 8],
[19, 4, 32]]])
现在我可以使用以下代码获取数组a 的argmax:
idx = np.argmax(a, axis=0)
print(idx)
输出:
array([[0, 0, 0],
[0, 1, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0]], dtype=int64)
是否可以使用数组a 的argmax 输出对数组b 进行切片,以便得到以下输出:
array([[60, 17, 6],
[94, 18, 96],
[77, 9, 68],
[44, 92, 49],
[ 3, 4, 25]])
我尝试了不同的方法,但没有成功。请帮忙。
【问题讨论】:
-
相当肯定有一些欺骗,但你可以做
np.take_along_axis(b, idx[None], axis=0)[0]。