【问题标题】:Fast check for NaN in NumPy在 NumPy 中快速检查 NaN
【发布时间】:2023-03-07 15:01:01
【问题描述】:

我正在寻找最快的方法来检查 NumPy 数组 X 中是否出现 NaN (np.nan)。 np.isnan(X) 是不可能的,因为它构建了一个形状为 X.shape 的布尔数组,这可能是巨大的。

我试过np.nan in X,但这似乎不起作用,因为np.nan != np.nan。有没有一种快速且节省内存的方法来做到这一点?

(对于那些会问“多么巨大”的人:我说不出来。这是库代码的输入验证。)

【问题讨论】:

  • 在这种情况下验证用户输入不起作用吗?如在插入前检查 NaN
  • @Woot4Moo:不,该库将 NumPy 数组或 scipy.sparse 矩阵作为输入。
  • 如果你经常这样做,我听说过关于瓶颈的好消息 (pypi.python.org/pypi/Bottleneck)

标签: python performance numpy nan


【解决方案1】:

这里有两种通用方法:

  • 检查每个数组项是否有nan 并取any
  • 应用一些保留nans(如sum)的累积运算并检查其结果。

虽然第一种方法肯定是最干净的,但对一些累积操作(尤其是在 BLAS 中执行的操作,如 dot)的大量优化可以使这些操作变得非常快。请注意,dot 与其他一些 BLAS 操作一样,在某些条件下是多线程的。这解释了不同机器之间的速度差异。

import numpy as np
import perfplot


def min(a):
    return np.isnan(np.min(a))


def sum(a):
    return np.isnan(np.sum(a))


def dot(a):
    return np.isnan(np.dot(a, a))


def any(a):
    return np.any(np.isnan(a))


def einsum(a):
    return np.isnan(np.einsum("i->", a))


b = perfplot.bench(
    setup=np.random.rand,
    kernels=[min, sum, dot, any, einsum],
    n_range=[2 ** k for k in range(25)],
    xlabel="len(a)",
)
b.save("out.png")
b.show()

【讨论】:

    【解决方案2】:

    除了 @nico-schlömer 和 @mseifert 的答案之外,我还计算了带有提前停止的 numba-test has_nan 的性能,与解析整个数组的一些函数相比。

    在我的机器上,对于没有 nans 的数组,收支平衡发生在 ~10^4 个元素。

    
    import perfplot
    import numpy as np
    import numba
    import math
    
    def min(a):
        return np.isnan(np.min(a))
    
    def dot(a):
        return np.isnan(np.dot(a, a))
    
    def einsum(a):
        return np.isnan(np.einsum("i->", a))
    
    @numba.njit
    def has_nan(a):
        for i in range(a.size - 1):
            if math.isnan(a[i]):
                return True
        return False
    
    
    def array_with_missing_values(n, p):
        """ Return array of size n,  p : nans ( % of array length )
        Ex : n=1e6, p=1 : 1e4 nan assigned at random positions """
        a = np.random.rand(n)
        p = np.random.randint(0, len(a), int(p*len(a)/100))
        a[p] = np.nan
        return a
    
    
    #%%
    perfplot.show(
        setup=lambda n: array_with_missing_values(n, 0),
        kernels=[min, dot, has_nan],
        n_range=[2 ** k for k in range(20)],
        logx=True,
        logy=True,
        xlabel="len(a)",
    )
    
    

    如果数组有 nans 会发生什么?我调查了数组的 nan-coverage 的影响。

    对于长度为 1,000,000 的数组,has_nan 成为更好的选择,因为数组中有 ~10^-3 % nans(所以 ~10 nans)。

    
    #%%
    N = 1000000  # 100000
    perfplot.show(
        setup=lambda p: array_with_missing_values(N, p),
        kernels=[min, dot, has_nan],
        n_range=np.array([2 ** k for k in range(20)]) / 2**20 * 0.01, 
        logy=True,
        xlabel=f"% of nan in array (N = {N})",
    )
    

    如果在您的应用程序中大多数数组都有nan,而您正在寻找没有的数组,那么has_nan 是最好的方法。 别的; dot 似乎是最好的选择。

    【讨论】:

      【解决方案3】:
      1. 使用 .any()

        if numpy.isnan(myarray).any()

      2. numpy.isfinite 在检查方面可能比 isnan 更好

        if not np.isfinite(prop).all()

      【讨论】:

        【解决方案4】:

        如果您对 感到满意,它允许创建快速短路(一旦发现 NaN 就停止)功能:

        import numba as nb
        import math
        
        @nb.njit
        def anynan(array):
            array = array.ravel()
            for i in range(array.size):
                if math.isnan(array[i]):
                    return True
            return False
        

        如果没有NaN,该函数实际上可能比np.min 慢,我认为这是因为np.min 对大型数组使用了多处理:

        import numpy as np
        array = np.random.random(2000000)
        
        %timeit anynan(array)          # 100 loops, best of 3: 2.21 ms per loop
        %timeit np.isnan(array.sum())  # 100 loops, best of 3: 4.45 ms per loop
        %timeit np.isnan(array.min())  # 1000 loops, best of 3: 1.64 ms per loop
        

        但如果数组中有一个 NaN,特别是如果它的位置在低索引处,那么它会快得多:

        array = np.random.random(2000000)
        array[100] = np.nan
        
        %timeit anynan(array)          # 1000000 loops, best of 3: 1.93 µs per loop
        %timeit np.isnan(array.sum())  # 100 loops, best of 3: 4.57 ms per loop
        %timeit np.isnan(array.min())  # 1000 loops, best of 3: 1.65 ms per loop
        

        使用 Cython 或 C 扩展可能会获得类似的结果,这些结果稍微复杂一些(或者很容易通过 bottleneck.anynan 获得),但最终与我的 anynan 函数相同。

        【讨论】:

          【解决方案5】:

          与此相关的是如何找到第一次出现的 NaN 的问题。这是我所知道的最快的处理方法:

          index = next((i for (i,n) in enumerate(iterable) if n!=n), None)
          

          【讨论】:

            【解决方案6】:

            Ray 的解决方案很好。但是,在我的机器上,使用 numpy.sum 代替 numpy.min 大约快 2.5 倍:

            In [13]: %timeit np.isnan(np.min(x))
            1000 loops, best of 3: 244 us per loop
            
            In [14]: %timeit np.isnan(np.sum(x))
            10000 loops, best of 3: 97.3 us per loop
            

            min 不同,sum 不需要分支,这在现代硬件上往往相当昂贵。这可能是sum 更快的原因。

            edit 上面的测试是在数组中间使用单个 NaN 执行的。

            有趣的是,min 在存在 NaN 时比不存在时慢。随着 NaN 越来越接近数组的开头,它似乎也变慢了。另一方面,sum 的吞吐量似乎是恒定的,无论是否存在 NaN 以及它们位于何处:

            In [40]: x = np.random.rand(100000)
            
            In [41]: %timeit np.isnan(np.min(x))
            10000 loops, best of 3: 153 us per loop
            
            In [42]: %timeit np.isnan(np.sum(x))
            10000 loops, best of 3: 95.9 us per loop
            
            In [43]: x[50000] = np.nan
            
            In [44]: %timeit np.isnan(np.min(x))
            1000 loops, best of 3: 239 us per loop
            
            In [45]: %timeit np.isnan(np.sum(x))
            10000 loops, best of 3: 95.8 us per loop
            
            In [46]: x[0] = np.nan
            
            In [47]: %timeit np.isnan(np.min(x))
            1000 loops, best of 3: 326 us per loop
            
            In [48]: %timeit np.isnan(np.sum(x))
            10000 loops, best of 3: 95.9 us per loop
            

            【讨论】:

            • 当数组不包含 NaN 时,np.min 更快,这是我的预期输入。但我还是决定接受这个,因为它也捕获了infneginf
            • 只有在输入包含两者时才会捕获inf-inf,并且如果输入包含较大但有限的值相加时会溢出,则会出现问题。
            • min 和 max 不需要为支持 sse 的 x86 芯片上的浮点数据分支。因此,从 numpy 开始,1.8 分钟不会比 sum 慢,在我的 AMD 现象上,它甚至快 20%。
            • 在我的 Intel Core i5 上,在 OSX 上使用 numpy 1.9.2,np.sum 仍然比 np.min 快大约 30%。
            • np.isnan(x).any(0) 在我的机器上比np.sumnp.min 稍快,尽管可能会有一些不需要的缓存。
            【解决方案7】:

            即使存在公认的答案,我也想证明以下内容(在 Vista 上使用 Python 2.7.2 和 Numpy 1.6.0):

            In []: x= rand(1e5)
            In []: %timeit isnan(x.min())
            10000 loops, best of 3: 200 us per loop
            In []: %timeit isnan(x.sum())
            10000 loops, best of 3: 169 us per loop
            In []: %timeit isnan(dot(x, x))
            10000 loops, best of 3: 134 us per loop
            
            In []: x[5e4]= NaN
            In []: %timeit isnan(x.min())
            100 loops, best of 3: 4.47 ms per loop
            In []: %timeit isnan(x.sum())
            100 loops, best of 3: 6.44 ms per loop
            In []: %timeit isnan(dot(x, x))
            10000 loops, best of 3: 138 us per loop
            

            因此,真正有效的方法可能在很大程度上取决于操作系统。不管怎样,基于dot(.) 似乎是最稳定的。

            【讨论】:

            • 我怀疑它不太依赖于操作系统,而是依赖于底层 BLAS 实现和 C 编译器。谢谢,但是当x 包含大值时,点积更有可能溢出,我还想检查 inf。
            • 好吧,你总是可以用一个做点积并使用isfinite(.)。我只是想指出巨大的性能差距。谢谢
            • 我的机器上也是这样。
            • 聪明,不是吗? 正如Fred Foo 所暗示的,基于点积的方法的任何效率提升几乎肯定要归功于与优化的 BLAS 实施相关联的本地 NumPy 安装像 ATLAS、MKL 或 OpenBLAS。例如,Anaconda 就是这种情况。鉴于此,这个点积将在所有个可用内核上并行化。对于基于min- 或sum 的方法,同样的不能说是,它们仅限于单个内核运行。因此,性能差距。
            【解决方案8】:

            我认为np.isnan(np.min(X)) 应该做你想做的事。

            【讨论】:

            • 嗯...这总是 O(n) 而可能是 O(1)(对于某些数组)。
            猜你喜欢
            • 1970-01-01
            • 2020-09-29
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            • 2015-08-03
            • 1970-01-01
            相关资源
            最近更新 更多