【问题标题】:numpy vectorized way to count non-zero bits in array of integersnumpy矢量化方法来计算整数数组中的非零位
【发布时间】:2021-01-05 07:05:28
【问题描述】:

我有一个整数数组:

[int1, int2, ..., intn]

我想计算这些整数的二进制表示中有多少非零位。

例如:

bin(123) -> 0b1111011, there are 6 non-zero bits

当然,我可以遍历整数列表,使用 bin()count('1') 函数,但我正在寻找矢量化的方式来做到这一点。

【问题讨论】:

标签: python arrays numpy vectorization


【解决方案1】:

假设你的数组是a,你可以这样做:

np.unpackbits(a.view('uint8')).sum()

示例:

a = np.array([123, 44], dtype=np.uint8)
#bin(a) is [0b1111011, 0b101100]
np.unpackbits(a.view('uint8')).sum()
#9

比较使用benchit

#@Ehsan's solution
def m1(a):
  return np.unpackbits(a.view('uint8')).sum()

#@Valdi_Bo's solution
def m2(a):
  return sum([ bin(n).count('1') for n in a ])

in_ = [np.random.randint(100000,size=(n)) for n in [10,100,1000,10000,100000]]

m1 明显更快。

【讨论】:

  • 您能把我的方法添加到您的精美图表中吗?它在大尺寸下表现出良好的性能。
【解决方案2】:

似乎 np.unpackbits 运行比计算总和 bin(n).count('1') 在源数组的每个元素上。

%timeit 衡量的执行时间是:

  • 8.35 µs 对于np.unpackbits(a.view('uint8')).sum()
  • 对于sum([ bin(n).count('1') for n in a ])2.45 µs(快 3 倍以上)。

所以也许你应该坚持最初的概念。

【讨论】:

  • a 这里有多大?
  • 我使用了与 Ehsan 的答案相同的数组(2 个元素)。也许对于更大的数组,执行时间的关系会有所不同。
  • @Valdi_Bo 请查看我的帖子添加的比较。谢谢。
【解决方案3】:

在 Forth 中,我使用了一个查找表来计算每个字节的位数。 我正在寻找是否有一个 numpy 函数来计算位数并找到了这个答案。

256 字节查找比这里的两种方法快。 16 位(65536 字节查找)再次更快。我用完了 32 位查找 4.3G 的空间 :-)

这可能对找到此答案的其他人有用。其他答案中的一个衬线打字速度要快得多。

import numpy as np

def make_n_bit_lookup( bits = 8 ):
    """ Creates a lookup table of bits per byte ( or per 2 bytes for bits = 16).
        returns a count function that uses the table generated.
    """
    try:
        dtype = { 8: np.uint8, 16: np.uint16 }[ bits ]
    except KeyError:
        raise ValueError( 'Parameter bits must be 8, 16.')

    bits_per_byte = np.zeros( 2**bits, dtype = np.uint8 )

    i = 1
    while i < 2**bits:
        bits_per_byte[ i: i*2 ] = bits_per_byte[ : i ] + 1
        i += i
        # Each power of two adds one bit set to the bit count in the 
        # corresponding index from zero.
        #  n       bits   ct  derived from   i
        #  0       0000   0                  
        #  1       0001   1 = bits[0] + 1    1
        #  2       0010   1 = bits[0] + 1    2
        #  3       0011   2 = bits[1] + 1    2
        #  4       0100   1 = bits[0] + 1    4
        #  5       0101   2 = bits[1] + 1    4
        #  6       0110   2 = bits[2] + 1    4
        #  7       0111   3 = bits[3] + 1    4
        #  8       1000   1 = bits[0] + 1    8
        #  9       1001   2 = bits[1] + 1    8
        #  etc...

    def count_bits_set( arr ):
        """ The function using the lookup table. """
        a = arr.view( dtype )
        return bits_per_byte[ a ].sum()

    return count_bits_set

count_bits_set8  = make_n_bit_lookup( 8 )
count_bits_set16 = make_n_bit_lookup( 16 )

# The two original answers as functions.
def text_count( arr ):
    return sum([ bin(n).count('1') for n in arr ])  

def unpack_count( arr ):
    return np.unpackbits(arr.view('uint8')).sum()   


np.random.seed( 1234 )

max64 = 2**64
arr = np.random.randint( max64, size = 100000, dtype = np.uint64 )

count_bits_set8( arr ), count_bits_set16( arr ), text_count( arr ), unpack_count( arr )                                             
# (3199885, 3199885, 3199885, 3199885) - All the same result

%timeit n_bits_set8( arr )                                                                                                          
# 3.63 ms ± 17.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit n_bits_set16( arr )                                                                                                         
# 1.78 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit text_count( arr )                                                                                                           
# 83.9 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit unpack_count( arr )                                                                                                         
# 8.73 ms ± 87.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

【讨论】:

    【解决方案4】:

    由于 numpy 与 python 不同,整数大小有限,因此您可以将Óscar López 提出的比特旋转解决方案调整为Fast way of counting non-zero bits in positive integer(最初来自here),以获得可移植、快速的解决方案:

    def bit_count(arr):
         # Make the values type-agnostic (as long as it's integers)
         t = arr.dtype.type
         mask = t(-1)
         s55 = t(0x5555555555555555 & mask)  # Add more digits for 128bit support
         s33 = t(0x3333333333333333 & mask)
         s0F = t(0x0F0F0F0F0F0F0F0F & mask)
         s01 = t(0x0101010101010101 & mask)
    
         arr = arr - ((arr >> 1) & s55)
         arr = (arr & s33) + ((arr >> 2) & s33)
         arr = (arr + (arr >> 4)) & s0F
         return (arr * s01) >> (8 * (arr.itemsize - 1))
    

    函数的第一部分将数量 0x5555...、0x3333... 等截断为 arr 实际包含的整数类型。剩下的只是做一组位旋转操作。

    对于 100000 个元素的数组,此函数比 Ehsan 的方法快 4.5 倍,比 Valdi Bo 的方法快 60 倍:

    a = np.random.randint(0, 0xFFFFFFFF, size=100000, dtype=np.uint32)
    %timeit bit_count(a).sum()
    # 846 µs ± 16.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    %timeit m1(a)
    # 3.81 ms ± 24 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    %timeit m2(a)
    # 49.8 ms ± 97.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    m1m2 照原样取自 @Ehsan's answer

    【讨论】:

      猜你喜欢
      • 2017-09-28
      • 1970-01-01
      • 2018-01-09
      • 2015-07-21
      • 1970-01-01
      • 2013-12-23
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多