【问题标题】:Slicing a numpy array along a dynamically specified axis沿动态指定的轴切片 numpy 数组
【发布时间】:2014-08-15 10:16:59
【问题描述】:

我想沿特定轴动态切片一个 numpy 数组。鉴于此:

axis = 2
start = 5
end = 10

我想达到和这个一样的结果:

# m is some matrix
m[:,:,5:10]

使用这样的东西:

slc = tuple(:,) * len(m.shape)
slc[axis] = slice(start,end)
m[slc]

但是: 值不能放在元组中,所以我不知道如何构建切片。

【问题讨论】:

标签: python numpy


【解决方案1】:

我认为一种方法是使用slice(None):

>>> m = np.arange(2*3*5).reshape((2,3,5))
>>> axis, start, end = 2, 1, 3
>>> target = m[:, :, 1:3]
>>> target
array([[[ 1,  2],
        [ 6,  7],
        [11, 12]],

       [[16, 17],
        [21, 22],
        [26, 27]]])
>>> slc = [slice(None)] * len(m.shape)
>>> slc[axis] = slice(start, end)
>>> np.allclose(m[slc], target)
True

我有一种模糊的感觉,我以前用过这个函数,但现在我似乎找不到它..

【讨论】:

  • 谢谢——这解决了问题。 slice(None) 显然等同于 :
  • 很好的解决方案,尽管使用 m[slc] 中的列表进行索引现在在 Numpy 中已被弃用并抛出 FutureWarningFutureWarning 中的建议修复方法是将列表转换为元组,而不是 m[tuple(slc)]
  • 不错。当使用 netCDF4 从 netCDF 数据集中提取时,此答案也适用,其中 numpy.take 不可用。
  • len(m.shape)代替m.ndim
【解决方案2】:

这有点晚了,但默认的 Numpy 方法是numpy.take。然而,一个总是复制数据(因为它支持花哨的索引,它总是假设这是可能的)。为避免这种情况(在许多情况下,您需要数据的 视图,而不是副本),回退到另一个答案中已经提到的 slice(None) 选项,可能将其包装在一个不错的函数中:

def simple_slice(arr, inds, axis):
    # this does the same as np.take() except only supports simple slicing, not
    # advanced indexing, and thus is much faster
    sl = [slice(None)] * arr.ndim
    sl[axis] = inds
    return arr[tuple(sl)]

【讨论】:

  • 如果您能阐明您期望 inds 参数是什么数据类型会有所帮助。
【解决方案3】:

由于没有足够清楚地提及(我也在寻找它):

相当于:

a = my_array[:, :, :, 8]
b = my_array[:, :, :, 2:7]

是:

a = my_array.take(indices=8, axis=3)
b = my_array.take(indices=range(2, 7), axis=3)

【讨论】:

  • 这应该是答案。
  • 使用 np.take 创建一个从原始数组复制数据的新数组。这可能不是您想要的(额外的内存使用对于大型数组可能很重要)
【解决方案4】:

有一种访问数组x 的任意轴n 的优雅方法:使用numpy.moveaxis¹ 将感兴趣的轴移到前面。

x_move = np.moveaxis(x, n, 0)  # move n-th axis to front
x_move[start:end]              # access n-th axis

要注意的是,您可能必须将moveaxis 应用于与x_move[start:end] 的输出一起使用的其他数组,以保持轴顺序一致。数组x_move 只是一个视图,因此您对其前轴所做的每一次更改都对应于n-th 轴中x 的更改(即您可以读取/写入x_move)。


1) 你也可以使用swapaxes 来不用担心n0 的顺序,这与moveaxis(x, n, 0) 相反。我更喜欢moveaxis 而不是swapaxes,因为它只会改变与n 相关的顺序。

【讨论】:

    【解决方案5】:

    非常晚了,但我有一个替代切片功能,其性能略好于其他答案:

    def array_slice(a, axis, start, end, step=1):
        return a[(slice(None),) * (axis % a.ndim) + (slice(start, end, step),)]
    

    这是测试每个答案的代码。每个版本都标有发布答案的用户的姓名:

    import numpy as np
    from timeit import timeit
    
    def answer_dms(a, axis, start, end, step=1):
        slc = [slice(None)] * len(a.shape)
        slc[axis] = slice(start, end, step)
        return a[slc]
    
    def answer_smiglo(a, axis, start, end, step=1):
        return a.take(indices=range(start, end, step), axis=axis)
    
    def answer_eelkespaak(a, axis, start, end, step=1):
        sl = [slice(None)] * m.ndim
        sl[axis] = slice(start, end, step)
        return a[tuple(sl)]
    
    def answer_clemisch(a, axis, start, end, step=1):
        a = np.moveaxis(a, axis, 0)
        a = a[start:end:step]
        return np.moveaxis(a, 0, axis)
    
    def answer_leland(a, axis, start, end, step=1):
        return a[(slice(None),) * (axis % a.ndim) + (slice(start, end, step),)]
    
    if __name__ == '__main__':
        m = np.arange(2*3*5).reshape((2,3,5))
        axis, start, end = 2, 1, 3
        target = m[:, :, 1:3]
        for answer in (answer_dms, answer_smiglo, answer_eelkespaak,
                       answer_clemisch, answer_leland):
            print(answer.__name__)
            m_copy = m.copy()
            m_slice = answer(m_copy, axis, start, end)
            c = np.allclose(target, m_slice)
            print('correct: %s' %c)
            t = timeit('answer(m, axis, start, end)',
                       setup='from __main__ import answer, m, axis, start, end')
            print('time:    %s' %t)
            try:
                m_slice[0,0,0] = 42
            except:
                print('method:  view_only')
            finally:
                if np.allclose(m, m_copy):
                    print('method:  copy')
                else:
                    print('method:  in_place')
            print('')
    

    结果如下:

    answer_dms
    
    Warning (from warnings module):
      File "C:\Users\leland.hepworth\test_dynamic_slicing.py", line 7
        return a[slc]
    FutureWarning: Using a non-tuple sequence for multidimensional indexing is 
    deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be 
    interpreted as an array index, `arr[np.array(seq)]`, which will result either in an 
    error or a different result.
    correct: True
    time:    2.2048302
    method:  in_place
    
    answer_smiglo
    correct: True
    time:    5.9013344
    method:  copy
    
    answer_eelkespaak
    correct: True
    time:    1.1219435999999998
    method:  in_place
    
    answer_clemisch
    correct: True
    time:    13.707583699999999
    method:  in_place
    
    answer_leland
    correct: True
    time:    0.9781496999999995
    method:  in_place
    
    • DSM's answer 包含一些改进 cmets 的建议。
    • EelkeSpaak's answer 应用了这些改进,从而避免了警告并且速度更快。
    • 涉及np.takeŚmigło's answer 会产生更差的结果,虽然它不是仅供查看的,但它确实会创建一个副本。
    • 涉及np.moveaxisclemisch's answer 完成时间最长,但令人惊讶的是它引用了前一个数组的内存位置。
    • 我的回答消除了对中间切片列表的需要。当切片轴朝向开头时,它还使用较短的切片索引。这会产生最快的结果,并且随着轴接近 0 会得到额外的改进。

    我还在每个版本中添加了一个step 参数,以防您需要。

    【讨论】:

      【解决方案6】:

      这确实是非常非常晚了!但我得到了Leland's answer 并对其进行了扩展,使其适用于多个轴和切片参数。这是函数的详细版本

      from numpy import *
      
      def slicer(a, axis=None, slices=None):
          if not hasattr(axis, '__iter__'):
              axis = [axis]
          if not hasattr(slices, '__iter__') or len(slices) != len(axis):
              slices = [slices]
          slices = [ sl if isinstance(sl,slice) else slice(*sl) for sl in slices ]
          mask = []
          fixed_axis = array(axis) % a.ndim
          case = dict(zip(fixed_axis, slices))
          for dim, size in enumerate(a.shape):
              mask.append( case[dim] if dim in fixed_axis else slice(None) )
          return a[tuple(mask)]
      

      它适用于可变数量的轴和切片元组作为输入

      >>> a = array( range(10**4) ).reshape(10,10,10,10)
      >>> slicer( a, -2, (1,3) ).shape
      (10, 10, 2, 10)
      >>> slicer( a, axis=(-1,-2,0), slices=((3,), s_[:5], slice(3,None)) ).shape
      (7, 10, 5, 3)
      

      稍微紧凑的版本

      def slicer2(a, axis=None, slices=None):
          ensure_iter = lambda l: l if hasattr(l, '__iter__') else [l]
          axis = array(ensure_iter(axis)) % a.ndim
          if len(ensure_iter(slices)) != len(axis):
              slices = [slices]
          slice_selector = dict(zip(axis, [ sl if isinstance(sl,slice) else slice(*sl) for sl in ensure_iter(slices) ]))
          element = lambda dim_: slice_selector[dim_] if dim_ in slice_selector.keys() else slice(None)
          return a[( element(dim) for dim in range(a.ndim) )]
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2016-12-20
        • 1970-01-01
        • 2011-06-29
        • 2019-09-18
        • 2019-06-25
        • 1970-01-01
        • 2022-10-30
        • 1970-01-01
        相关资源
        最近更新 更多