【问题标题】:Faster way to check if elements in numpy array windows are finite检查numpy数组窗口中的元素是否有限的更快方法
【发布时间】:2021-02-23 14:55:58
【问题描述】:

我有一个很长的 NumPy 数组,其中包含 1_000_000_000 元素,我想在数组中滑动一个 50 元素窗口,并询问窗口中的所有元素是否都是有限的。如果50 元素窗口内的所有元素都是有限的,则返回True(对于该窗口),否则,如果50 元素窗口内的一个或多个元素不是有限的,则返回False(对于该窗口)。继续此评估,直到评估所有窗口。一个很好的方法是:

import numpy as np

def rolling_window(a, window):
    a = np.asarray(a)
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)

    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

if __name__ == "__main__":
    a = np.random.rand(100_000_000)  # This is 10x shorter than my real data
    w = 50
    idx = np.random.randint(0, len(a), size=len(a)//10)  # Simulate having np.nan in my array
    a[idx] = np.nan
    print(np.all(rolling_window(np.isfinite(a), w), axis=1))

但是,当我的数组长度为 1_000_000_000 时,这会很慢。有没有更快的方法来完成这个,而且不需要大量的内存?

【问题讨论】:

  • 无限元素的频率/稀疏度是多少?如果它们是稀疏的,最好先取更大的块,只有在存在无限元素时才细分它们

标签: python arrays performance numpy


【解决方案1】:

方法 #1: 直接在 isfinite-mask 中滥用跨步窗口进行分配 -

def strided_allfinite(a, w):
    m = np.isfinite(a)
    p = rolling_window(m, w)
    nmW = ~m[:w]
    if nmW.any():
        m[:np.flatnonzero(nmW).max()] = False
    p[~m[w-1:]] = False
    return m[:-w+1]

给定样本数据的时间安排:

In [323]: N = 100_000_000
     ...: w = 50
     ...: 
     ...: np.random.seed(0)
     ...: a = np.random.rand(N)  # This is 10x shorter than my real data
     ...: idx = np.random.randint(0, len(a), size=len(a)//10)  # Simulate...
     ...: a[idx] = np.nan

# Original soln
In [324]: %timeit np.all(rolling_window(np.isfinite(a), w), axis=1)
1.61 s ± 14.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [325]: %timeit strided_allfinite(a, w)
556 ms ± 87.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

方法 #2

我们可以利用convolution -

np.convolve(np.isfinite(a), np.ones(w),'valid')==w

方法#3

binary-erosion -

from scipy.ndimage.morphology import binary_erosion

m = np.isfinite(a)
out = binary_erosion(m, np.ones(w, dtype=bool))[w//2:len(a)-w+1+w//2]

【讨论】:

  • 在我的测试中,as_strided 的速度是convolve 的两倍。
  • 我喝完了咖啡,无法理解strided_allfinite 背后的逻辑。但是np.all(strided_allfinite(a,w) == np.all(rolling_window(a,w), axis=1)) 给了我False
  • @QuangHoang 这是检查相等性的错误方法。试试np.array_equal(strided_allfinite(a, w), np.all(rolling_window(np.isfinite(a), w), axis=1))。也许更多的咖啡? :)
  • 我不明白,这可能不是最好的方法,但肯定是正确的方法,因为两个数组都不包含 nan。你能解释一下为什么吗?
  • @QuangHoang 你应该这样做np.all(strided_allfinite(a,w) == np.all(rolling_window(np.isfinite(a), w), axis=1))。现在,我需要咖啡。
猜你喜欢
  • 2011-10-20
  • 2012-12-16
  • 2019-02-22
  • 2014-02-07
  • 2017-11-01
  • 1970-01-01
  • 2023-02-10
  • 2011-04-01
相关资源
最近更新 更多