【问题标题】:Iterate over last axis of a numpy array迭代numpy数组的最后一个轴
【发布时间】:2021-09-08 17:32:18
【问题描述】:

假设我们有一个 (20, 5) 数组。我们可以非常Python地遍历每个

import numpy as np
xs = np.array(range(100)).reshape(20, 5)
for x in xs:
    print(x)

如果我们想迭代另一个轴(在此示例中,迭代列,但我正在为 ndarray 中的每个可能轴寻找解决方案),它不太直接,我们可以使用 @987654321 中的方法@:

for i in range(xs.shape[-1]):
    x = xs[..., i]
    print(x)

是否有更直接的方法来迭代另一个轴,例如(伪代码):

for x in xs.iterator(axis=-1):
    print(x) 

?

【问题讨论】:

  • 您可以将该轴移到前面。但我认为索引方法是最好的
  • @hpaulj 我正在寻找一个 n 维数组的解决方案,假设我们想在 4 轴 3 上进行迭代。
  • @basj 转置适用于 n 维数组。请注意,(惰性默认)转置以及直接索引在大数组上效率低下。如果您计划在目标数组上执行许多类似的操作,那么急切的转置可能是一个好主意。
  • 为什么不the second answer in the link you sharednp.rollaxis
  • @Basj 根据它创建视图的答案,这意味着没有额外的内存。至少我是这么理解的。但是,文档没有说明,所以我不确定。

标签: python arrays numpy multidimensional-array numpy-ndarray


【解决方案1】:

我认为 stride tricks 模块中的as_strided 应该在这里完成工作。

它在数组中创建一个视图,而不是一个副本(如文档所述)。

这是as_stided功能的简单演示:

from numpy.lib.stride_tricks import as_strided
import numpy as np
xs = np.array(range(3 *3 * 4)).reshape(3,3, 4)
for x in xs:
    print(x)

输出:

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
[[12 13 14 15]
 [16 17 18 19]
 [20 21 22 23]]
[[24 25 26 27]
 [28 29 30 31]
 [32 33 34 35]]

遍历数组特定轴的函数:

def iterate_over_axis(arr, axis=0):
    strides = arr.strides
    strides_ = [strides[axis], *strides[0:axis], *strides[(axis+1):]]
    shape = arr.shape
    shape_ = [shape[axis], *shape[0:axis], *shape[(axis+1):]]
    return as_strided(arr,  strides=strides_, shape=shape_)

for x in iterate_over_axis(xs, axis=1):
    print(x)

输出:

[[ 0  1  2  3]
 [12 13 14 15]
 [24 25 26 27]]
[[ 4  5  6  7]
 [16 17 18 19]
 [28 29 30 31]]
[[ 8  9 10 11]
 [20 21 22 23]
 [32 33 34 35]]

  

【讨论】:

    猜你喜欢
    • 2019-04-26
    • 1970-01-01
    • 2015-06-12
    • 1970-01-01
    • 1970-01-01
    • 2021-04-16
    • 2020-08-01
    • 2019-04-07
    相关资源
    最近更新 更多