【问题标题】:How to check if all elements of a numpy array are in another numpy array如何检查一个numpy数组的所有元素是否在另一个numpy数组中
【发布时间】:2018-12-17 12:27:34
【问题描述】:

我有两个 2D numpy 数组,例如:

A = numpy.array([[1, 2, 4, 8], [16, 32, 32, 8], [64, 32, 16, 8]])

B = numpy.array([[1, 2], [32, 32]])

我想拥有来自A 的所有行,我可以在其中找到来自B 的任何行的所有元素。如果B 的一行中有 2 个相同元素,则来自A 的行也必须至少包含 2 个。在我的例子中,我想实现这个:

A_filtered = [[1, 2, 4, 8], [16, 32, 32, 8]]

我可以控制值表示,所以我选择了二进制表示只有一个位置的数字 1(例如:0b000000010b00000010 等...)这样我可以轻松检查是否所有通过使用np.logical_or.reduce() 函数,值的类型在行中,但我无法检查A 行中相同元素的数量是否更大或相等。我真的希望我可以避免简单的for 循环和数组的深层副本,因为性能对我来说是一个非常重要的方面。

如何在 numpy 中以有效的方式做到这一点?


更新:

here 的解决方案可能有效,但我认为性能对我来说是一个大问题,A 可以非常大(>300000 行),B 可以适中(>30):

[set(row).issuperset(hand) for row in A.tolist() for hand in B.tolist()]

更新 2:

set() 解决方案不起作用,因为set() 删除了所有重复值。

【问题讨论】:

  • B中相同元素的2个B的任意行?
  • “至少 2”重要吗?如果 A 和 B 中的行需要 B 中的每个标记的数量相等,我想我知道一个优雅的解决方案。
  • 任意行的同一元素的 2 个。 “至少2”非常重要。在A中,接受[32, 32, 32,8]这一行
  • 我不认为集合完全符合你的描述,因为单一和多次出现之间的区别会丢失。

标签: python arrays numpy vectorization


【解决方案1】:

我希望你的问题是正确的。至少它适用于您在问题中描述的问题。如果输出的顺序应该与输入的顺序相同,请更改就地排序。

代码看起来很丑,但性能应该不错,应该不难理解。

代码

import time
import numba as nb
import numpy as np

@nb.njit(fastmath=True,parallel=True)
def filter(A,B):
  iFilter=np.zeros(A.shape[0],dtype=nb.bool_)

  for i in nb.prange(A.shape[0]):
    break_loop=False

    for j in range(B.shape[0]):
      ind_to_B=0
      for k in range(A.shape[1]):
        if A[i,k]==B[j,ind_to_B]:
          ind_to_B+=1

        if ind_to_B==B.shape[1]:
          iFilter[i]=True
          break_loop=True
          break

      if break_loop==True:
        break

  return A[iFilter,:]

衡量绩效

####First call has some compilation overhead####
A=np.random.randint(low=0, high=60, size=300_000*4).reshape(300_000,4)
B=np.random.randint(low=0, high=60, size=30*2).reshape(30,2)

t1=time.time()
#At first sort the arrays
A.sort()
B.sort()
A_filtered=filter(A,B)
print(time.time()-t1)

####Let's measure the second call too####
A=np.random.randint(low=0, high=60, size=300_000*4).reshape(300_000,4)
B=np.random.randint(low=0, high=60, size=30*2).reshape(30,2)

t1=time.time()
#At first sort the arrays
A.sort()
B.sort()
A_filtered=filter(A,B)
print(time.time()-t1)

结果

46ms after the first run on a dual-core Notebook (sorting included)
32ms (sorting excluded)

【讨论】:

  • 我试过了,它确实提升了我的代码,谢谢!我是numba 的新手,我只是读了一点。在这种情况下我可以使用nopython 模式还是numpy 需要object 模式?
  • @doodoroma (at)njit 是 (at)jit(nopython=True) 的快捷方式
【解决方案2】:

我认为这应该可行:

首先,按如下方式对数据进行编码(假设“令牌”数量有限,您的二进制方案似乎也暗示了这一点):

制作一个形状 [n_rows, n_tokens],dtype int8,其中每个元素计算令牌的数量。以同样的方式编码 B,形状为 [n_hands, n_tokens]

这允许您的输出的单个矢量化表达式;匹配 = (A[None, :, :] >= B[:, None, :]).all(axis=-1)。 (究竟如何将此匹配数组映射到所需的输出格式留给读者作为练习,因为该问题未定义多个匹配项)。

但我们在这里讨论的是每个令牌 > 10Mbyte 的内存。即使有 32 个令牌也不应该是不可想象的;但在这种情况下,最好不要对 n_tokens 或 n_hands 或两者上的循环进行矢量化; for 循环适用于较小的 n,或者如果主体中有足够的工作要做,那么循环开销是微不足道的。

只要 n_tokens 和 n_hands 保持适度,我认为这将是最快的解决方案,如果保持在纯 python 和 numpy 的范围内。

【讨论】:

  • 你说的每个元素计算token的数量是什么意思
  • 您的示例 A 数组有 7 个唯一标记。为每个标记分配一列,并为每个元素分配该标记在行中的计数,给出 A = numpy.array([[1, 1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 1, 2, 0], [0, 0, 0, 1, 1, 1, 1]])
  • 是的,这就是我一直在寻找的解决方案。我只能离线转换AB每次都要转换,我刚数了一下,我有71个token,好像很多。
  • 尽管我正在考虑类似这个答案的解决方案,但由于内存限制,我选择了@max9111 的想法。拥有 71 个令牌需要太大的矩阵来处理 (300000*71)。总之,真的很有用!
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2021-06-04
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多