【问题标题】:Removing rows from a multi dimensional numpy array从多维 numpy 数组中删除行
【发布时间】:2015-01-25 00:20:09
【问题描述】:

我有一个相当大的 3 维 numpy (2000,2500,32) 数组需要操作。有些行不好,所以我需要删除几行。 为了检测哪一行是“坏的”,我使用以下函数

def badDetect(x):
  for i in xrange(10,19):
    ptp = np.ptp(x[i*100:(i+1)*100])
    if ptp < 0.01:
      return True
  return False

这将 2000 的任何序列标记为坏,该序列具有 100 个值的范围,且峰峰值小于 0.01。 在这种情况下,我想删除 2000 个值的序列(可以使用 a[:,x,y] 从 numpy 中选择) Numpy delete 似乎接受索引,但仅适用于二维数组。

【问题讨论】:

  • 请更具体一点:这些行被认为是坏的标准是什么? 3D 上下文中的“行”也没有多大意义。您是指最后一个轴,即 32 个元素的“行”吗?
  • 我添加了检测“坏”的功能,并澄清了我的意思是2000行的值。提前感谢您的关注!
  • 您添加的有关程序应该执行的操作的信息与实际算法不符:x[i*100:(i+1)*100] 将采用(对于i 的第一次迭代)值x[1000:1100]。所以它甚至不是从零开始的。你确定这就是你的想法吗?
  • 是的,你是对的。这就是我想要的,因为只有 1000 到 2000 之间的值对我的分析很重要。你可以假设函数是正确的。不过很好,再次感谢!
  • 删除a[:,x,y]后你期望留下什么?这是数组的一部分,而不是一行。

标签: python numpy multidimensional-array


【解决方案1】:

您肯定需要重新调整输入数组的形状,因为从 3D 立方体中删除“行”会留下无法正确处理的结构。

由于我们没有您的数据,我将首先使用不同的示例来解释这种可能的解决方案的工作原理:

>>> import numpy as np
>>> from numpy.lib.stride_tricks import as_strided
>>> 
>>> threshold = 18
>>> a = np.arange(5*3*2).reshape(5,3,2)  # your dataset of 2000x2500x32
>>> # Taint the data:
... a[0,0,0] = 5
>>> a[a==22]=20
>>> print(a)
[[[ 5  1]
  [ 2  3]
  [ 4  5]]

 [[ 6  7]
  [ 8  9]
  [10 11]]

 [[12 13]
  [14 15]
  [16 17]]

 [[18 19]
  [20 21]
  [20 23]]

 [[24 25]
  [26 27]
  [28 29]]]
>>> a2 = a.reshape(-1, np.prod(a.shape[1:]))
>>> print(a2)  # Will prove to be much easier to work with!
[[ 5  1  2  3  4  5]
 [ 6  7  8  9 10 11]
 [12 13 14 15 16 17]
 [18 19 20 21 20 23]
 [24 25 26 27 28 29]]

如您所见,从上面的表示中,现在您想要计算峰峰值的窗口已经变得更加清晰了。如果您要从此数据结构中删除“行”(现在它们已转换为列),您将需要此表单,这是您在 3 维中无法做到的!

>>> isize = a.itemsize  # More generic, in case you have another dtype
>>> slice_size = 4  # How big each continuous slice is over which the Peak2Peak value is calculated
>>> slices = as_strided(a2,
...     shape=(a2.shape[0] + 1 - slice_size, slice_size, a2.shape[1]),
...     strides=(isize*a2.shape[1], isize*a2.shape[1], isize))
>>> print(slices)
[[[ 5  1  2  3  4  5]
  [ 6  7  8  9 10 11]
  [12 13 14 15 16 17]
  [18 19 20 21 20 23]]

 [[ 6  7  8  9 10 11]
  [12 13 14 15 16 17]
  [18 19 20 21 20 23]
  [24 25 26 27 28 29]]]

因此,我以 4 个元素的窗口大小为例:如果这 4 个元素切片中的任何一个(每个数据集,因此每列)的峰峰值小于某个阈值,我想排除它。可以这样做:

>>> mask = np.all(slices.ptp(axis=1) >= threshold, axis=0) # These are the ones that are of interest
>>> print(a2[:,mask])
[[ 1  2  3  5]
 [ 7  8  9 11]
 [13 14 15 17]
 [19 20 21 23]
 [25 26 27 29]]

您现在可以清楚地看到受污染的数据已被删除。但请记住,您不能简单地从 3D 数组中删除该数据(但您可以在那时屏蔽它)。

显然,您必须在用例中将threshold 设置为.01,并将slice_size 设置为100

请注意,虽然 as_strided 表单非常节省内存,但在您的情况下计算此数组的峰峰值并存储该结果确实需要大量内存:在完整情况下为 1901x(2500x32) ,所以当你不要忽略前 1000 个切片时。在您的情况下,您只对来自1000:1900 的切片感兴趣,您必须将其添加到代码中,如下所示:

mask = np.all(slices[1000:1900,:,:].ptp(axis=1) >= threshold, axis=0)

这会减少将此掩码存储为“仅”900x(2500x32) 值(您使用的任何数据类型)所需的内存。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2017-04-11
    • 1970-01-01
    • 2015-06-08
    • 2012-06-26
    • 1970-01-01
    • 1970-01-01
    • 2021-03-26
    相关资源
    最近更新 更多