【问题标题】:Skip every nth index of numpy array跳过numpy数组的每个第n个索引
【发布时间】:2017-04-17 04:59:24
【问题描述】:

为了进行 K 折验证,我想对一个 numpy 数组进行切片,以便生成原始数组的视图,但删除每个第 n 个元素。

例如:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

如果n = 4 那么结果是

[1, 2, 4, 5, 6, 8, 9]

注意:numpy 的要求是因为它被用于固定依赖关系的机器学习任务。

【问题讨论】:

  • 对于交叉验证的用例,这种方法看起来很可怕。然后有一些关于数据顺序的隐藏假设。一般来说,我更喜欢一些基于 shuffle/random_permutation 的方法,但也会坚持scikit-learn 中可用的功能,因为还有更强大的东西,比如分层采样(如果需要)。旁注:清理你的标签,因为fold(函数式编程)和k(编程语言)是错误的。
  • 我同意萨沙。特别是,看看交叉验证迭代器。 scikit-learn.org/stable/modules/…
  • @sascha 我同意使用现有的库会更好,但是我应该提到我只能使用 numpy 作为依赖项,因为这是机器学习任务对不起!为了实现随机性,我使用np.random.shuffle 对行进行洗牌。
  • 我明白了。但是在洗牌之后,你是每 4 次取还是前 N/4 取值都没有关系。后者可能更容易实现。

标签: python numpy slice


【解决方案1】:

使用modulus的方法#1

a[np.mod(np.arange(a.size),4)!=0]

示例运行 -

In [255]: a
Out[255]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [256]: a[np.mod(np.arange(a.size),4)!=0]
Out[256]: array([1, 2, 3, 5, 6, 7, 9])

使用masking 的方法#2:要求为view

考虑到视图要求,如果想法是节省内存,我们可以存储等效的布尔数组,在 Linux 系统上占用 8 倍的内存。因此,这种基于掩码的方法就像这样 -

# Create mask
mask = np.ones(a.size, dtype=bool)
mask[::4] = 0

这是内存需求统计信息 -

In [311]: mask.itemsize
Out[311]: 1

In [312]: a.itemsize
Out[312]: 8

然后,我们可以使用布尔索引作为视图 -

In [313]: a
Out[313]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [314]: a[mask] = 10

In [315]: a
Out[315]: array([ 0, 10, 10, 10,  4, 10, 10, 10,  8, 10])

使用NumPy array strides 的方法#3:要求为view

如果输入数组的长度是n 的倍数,您可以使用np.lib.stride_tricks.as_strided 创建这样的视图。如果它不是倍数,它仍然可以工作,但不是安全的做法,因为我们将超出为输入数组分配的内存。请注意,这样创建的视图将是2D

因此,获得这样一个视图的实现将是 -

def skipped_view(a, n):
    s = a.strides[0]
    strided = np.lib.stride_tricks.as_strided
    return strided(a,shape=((a.size+n-1)//n,n),strides=(n*s,s))[:,1:]

示例运行 -

In [50]: a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) # Input array

In [51]: a_out = skipped_view(a, 4)

In [52]: a_out
Out[52]: 
array([[ 1,  2,  3],
       [ 5,  6,  7],
       [ 9, 10, 11]])

In [53]: a_out[:] = 100 # Let's prove output is a view indeed

In [54]: a
Out[54]: array([  0, 100, 100, 100,   4, 100, 100, 100,   8, 100, 100, 100])

【讨论】:

  • 很好的答案谢谢@Divakar #2 看起来对我来说是最好的解决方案
  • @BenHazelwood 我同意,这是一个通用的解决方案。
【解决方案2】:

numpy.delete

In [18]: arr = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [19]: arr = np.delete(arr, np.arange(0, arr.size, 4))

In [20]: arr
Out[20]: array([1, 2, 3, 5, 6, 7, 9])

【讨论】:

  • 这看起来不像是视图
  • 我同意@sascha 的观点,如果存在内存效率更高的方法会更好
猜你喜欢
  • 1970-01-01
  • 2022-09-23
  • 1970-01-01
  • 1970-01-01
  • 2022-07-24
  • 2023-01-11
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多