【问题标题】:How do I quickly find the last number of elements matching some criterion in numpy?如何在 numpy 中快速找到匹配某个条件的最后一个元素?
【发布时间】:2018-05-05 07:47:01
【问题描述】:

假设我有一些数组:

x = numpy.array([1, 2, 3, 4, 5, 1, 1, 1])

我想在数组末尾找到连续的1s 的数量。一种方法是使用循环:

i = len(x) - 1
while x[i] == 1:
    i = i - 1

现在我可以查看i 并计算出x 后面的1s 的数量。但是,在我的真实示例中,x 可能非常大,1s 的数量也可能非常大,所以我想要一个解决方案:

  1. 不使用循环,并且
  2. 不遍历整个数组

【问题讨论】:

  • 编写一个 C 扩展(使用 Cython)或使用 numba。两者都会使用循环。
  • 如果数组的最后一个元素不是1,根据您的要求,您需要得到什么值?

标签: python numpy


【解决方案1】:

我同意使用 Cython 或 numba 来加快循环并仅遍历数组尾部,但如果您想尝试使用纯 numpy 进行尝试,我会说类似以下的操作:

np.argwhere(x[::-1] != 1).ravel()[0]

反转数组并取第一个非 1 值。虽然它正在遍历整个数组......所以它可能无法满足您的需求。

编辑:这是一个numba 的完整性方法

from numba import jit
@jit
def count_trailing_ones(array):
    count = 0
    a = array[::-1]
    for i in range(array.shape[0]):
        if a[i] == 1:
            count += 1
        else:
            return count

这是一个基准测试,还包括 @J...S 和 @Kasramvd 解决方案,用于 800MB 阵列和几百万个尾随阵列。 numba 显然会赢,但如果你要选择 numpy,我会说 @J...S 的 argmax 是最好的。

In [102]: %timeit np.argwhere(x[::-1] != 1).ravel()[0]
631 ms ± 1.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [103]: %timeit np.argmax(x[::-1] != 1)
117 ms ± 417 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [104]: %timeit kas(x)
915 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [105]: %timeit count_trailing_ones(x)
4.62 ms ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

【讨论】:

  • array([1, 2, 3, 4, 5, 1, 1, 1, 4, 5])测试。
  • @Kasramvd 请阅读整个答案,有关于假设的评论
  • 这个答案仍然不完全正确。请更新并告诉我们该检查是如何工作的,谢谢。
  • 另外,不确定你在暗示什么,我的两个解决方案都为你的测试数组提供0
  • @Kasramvd 结果证明根本不需要检查,它与0 一起工作太完美了......将更新答案
【解决方案2】:

如果最后一个元素始终是1,您可以先反转数组,然后使用argmax,如

np.argmax(x[::-1]!=1)

对于给定的数组会给出哪个

3

您可以先使用检查来确定最后一个元素是否为 1 之类的

if(x[-1] == 1):
    print(np.argmax(x[::-1]!=1))
else:
    print(0)

【讨论】:

  • array([1, 2, 3, 4, 5, 1, 1, 1, 4, 5])测试。
  • @Kasramvd 你的意思是不是最后的1s 也应该被考虑?如果是这样,我认为 OP 给出的循环只考虑数组末尾的 1s。
  • 是的,这是我第一次想到的,然后我更新了我的答案。
【解决方案3】:

这是一种矢量化方法:

In [50]: mask = x == 1

In [51]: T_inds = np.where(mask)[0]

In [52]: F_inds = np.where(~mask)[0]

In [53]: last_f_ind = np.where(T_inds[-1] > F_inds)[0][-1]

# x = np.array([1, 2, 3, 4, 5, 1, 1, 1, 4, 5])
In [54]: T_inds[-1] - F_inds[last_f_ind]
Out[54]: 3

诀窍是在低于最新一项的索引的非一项中找到最新项的索引。

另请注意,这种方法适用于偶数不在数组尾部的所有情况(最后 1 秒后没有其他数字)。但是,如果您想检查 1 位于数组末尾的特殊情况,这是一种更简洁的方法:

x.size - np.where(x != 1)[0][-1] - 1
Out[27]: 3

# x != 1 will give you a mask of the indices where their value is not
# equal to one. Then you can use np.where() to find the index of last 
# occurrence of not one items. By subtracting it from size of array you 
# can get the number of consecutive ones.

【讨论】:

  • @downvoter 请解释下投票的原因是什么,如果有的话我可以解决它。
【解决方案4】:

看看searchsorted:https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.searchsorted.html#numpy.searchsorted

在最后一个元素上与 np.where 结合使用,如果 != 1 则返回 0。

注意如果您的数组包含 0,这将不起作用,因为它会尝试在该点插入值。

import numpy as np

x = np.array([1, 2, 3, 4, 5, 1, 1, 1])
y = np.array([1, 2, 3, 4, 5, 1, 1, 1, 4, 5])

# Create a lambda function that accepts x(array)
f = lambda x: np.where(x[-1] != 1, 0, np.searchsorted(x[::-1], 1, side='right'))

print(f(x)) # 3
print(f(y)) # 0

【讨论】:

  • 你不是假设x >= 1吗?
  • array([1, 2, 3, 4, 5, 1, 1, 1, 4, 5])测试。
猜你喜欢
  • 1970-01-01
  • 2020-02-23
  • 2013-02-22
  • 2010-12-23
  • 2012-05-07
  • 2012-05-07
  • 2013-04-11
  • 2016-04-28
  • 1970-01-01
相关资源
最近更新 更多