【问题标题】:Replacing chunks of numpy array on condition根据条件替换大块 numpy 数组
【发布时间】:2025-12-06 01:30:01
【问题描述】:

假设我有以下 numpy 数组,仅包含 1 和 0:

import numpy as np

example = np.array([0,1,1,0,1,0,0,1,1], dtype=np.uint8)

我想将所有元素分组为 3 的块,并根据条件将这些块替换为单个值。 假设我希望 [0,1,1] 变为 5[0,1,0] 变为 10 。 因此,所需的输出将是:

[5,10,5]

一个块中所有可能的 1 和 0 组合都有一个对应的唯一值,应该替换该块。最快的方法是什么?

【问题讨论】:

  • 也许您可以有一个 dict 将 0 和 1 的 8 个可能值映射到您想要的任意数字,然后根据一次读取数组三个索引并从中获取值进行翻译字典
  • 我可以,但这对于我的用例来说太慢了。我正在寻找一种矢量化的方式来实现这一点。
  • np.packbits 做你想做的事

标签: python numpy numpy-ndarray


【解决方案1】:

我建议您将数组重塑为 3 by something 数组。现在我们可以将每一行视为一个二进制数,它是您想要的值列表的索引。您将其转换为该数字并索引到值中。

arr = np.array([0,1,1,0,1,0,0,1,1], dtype=np.uint8).reshape(-1,3)

idx = 2**0*arr[:,0]+2**1*arr[:,1]+2**2*arr[:,2]

values = np.zeros(2**3)
values[0 *2**0+ 1 *2**1+ 1 *2**2] = 5
values[0 *2**0+ 1 *2**1+ 0 *2**2] = 10

values[idx]

这给了

array([ 5., 10.,  5.])

或者,如果您希望更简洁地编写转换,尽管有点不那么基本(感谢@mozway 的想法):

def bin_vect_to_int(arr):
    bin_units = 2**np.arange(arr.shape[1])
    return np.dot(arr,bin_units)


arr = np.array([0,1,1,0,1,0,0,1,1,0,1,1], dtype=np.uint8).reshape(-1,3)
idx = binVecToInt(arr)

values = np.zeros(2**3)
values[bin_vect_to_int(np.array([[0,1,1]]))] = 5
values[bin_vect_to_int(np.array([[0,1,0]]))] = 10

values[idx]

【讨论】:

  • 不错 (+1),更简单的二进制计算方法:(example.reshape(3,-1)*2**np.arange(3)).sum(1)
  • @mozway 好吧,当您已经在进行乘法和总结时,您不妨使用np.dot 否?
  • 是的,非常正确,你可以做到np.dot(example.reshape(3,-1),2**np.arange(3)) ;)
  • @mozway 我很好
  • 我看到了(抱歉,小屏幕,在手机上)
【解决方案2】:

正如其他答案所示,您可以从重塑数组开始(实际上,您可能应该从一开始就使用正确的形状生成它,但这是另一个问题):

example = np.array([0, 1, 1, 0, 1, 0, 0, 1, 1], dtype=np.uint8)
data = example.reshape(-1, 3)

现在在数组上运行自定义 python 函数会很慢,但幸运的是 numpy 支持你。您可以使用np.packbits 将每一行直接转换为数字:

data = np.packbits(data, axis=1, bitorder='little').ravel() # [6, 2, 6]

如果您希望 101 映射到 5110 映射到 6,那么您的工作就完成了。否则,您将需要提出一个映射。由于您有三位,因此映射数组中只需要 8 个数字:

mapping = np.array([7, 4, 3, 8, 124, 1, 5, 0])

您可以将data 用作直接指向mapping 的索引。输出的类型为mapping,但形状为data

result = mapping[data]  # [5, 3, 5]

你可以在一行中做到这一点:

mapping[np.packbits(example.reshape(-1, 3), axis=1, bitorder='little').ravel()]

【讨论】:

    【解决方案3】:

    您可以使用shape(3, -1) 的连续数组视图,查找唯一出现的位置并在这些位置替换它们:

    def view_ascontiguous(a): # a is array
        a = np.ascontiguousarray(a)
        void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
        return a.view(void_dt).ravel()
    
    def replace(x, args, subs, viewer):
        u, inv = np.unique(viewer(x), return_inverse=True)
        idx = np.searchsorted(viewer(args), u)
        return subs[idx][inv]
    
    >>> replace(x=np.array([1, 0, 1, 0, 0, 1, 1, 0, 1]).reshape(-1, 3),
            args=np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]),
            subs=np.array([ 5, 57, 58, 44, 67, 17, 77,  1]),
            viewer=view_ascontiguous)
    array([17, 57, 17])
    

    注意这里idx表示Power Set{0, 1}^N中长度为N的唯一contiguous块的位置。

    如果viewer(args)args 映射到np.searchsorted 方法内的[0, 1, 2, 3, ...],则将其替换为np.arange(len(args)) 有助于提高性能。

    这个算法也可以用于更一般的问题:


    您将获得 dtype=np.uint8 的数组,其中 M*N 的值为 0 和 1。您还获得了 Power Set [0, 1]^N(所有可能的长度为 0 和 1 的 N 的块)和一些标量值之间的映射.按照以下步骤查找 M 值的数组:

    • 将您分配到M 长度为N 的连续块中的拆分数组
    • 使用给定的映射将每个块替换为标量值

    现在,有趣的部分:您可以使用自己的viewer。需要将您传入 args 的数组映射到任何类型的升序索引,如下所示:

    viewer=lambda arr: np.ravel_multi_index(arr.T, (2,2,2)) #0, 1, 2, 3, 4, 5, 6, 7
    viewer=lambda arr: np.sum(arr * [4, 2, 1], axis=1) #0, 1, 2, 3, 4, 5, 6, 7
    viewer=lambda arr: np.dot(arr, [4, 2, 1]) #0, 1, 2, 3, 4, 5, 6, 7
    

    或者更有趣:

    viewer=lambda arr: 2*np.dot(arr, [4, 2, 1]) + 1 #1, 3, 5, 7, 9, 11, 13, 15
    viewer=lambda arr: np.vectorize(chr)(97+np.dot(arr, [4, 2, 1])) #a b c d e f g h
    

    因为你也可以映射

    [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]

    到任何你能想到的升序,比如[1, 3, 5, 7, 9, 11, 13, 15]['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] 结果还是一样。

    补充说明

    感谢@MadPhysicist,

    np.packbits(example.reshape(-1, N), axis=1, bitorder='little').ravel()

    也能做到这一点。它假装是最快的解决方案,因为np.packbitsnumpy 中得到了很好的优化。

    【讨论】:

    • @mathfux 真的“它假装是最快的解决方案”?所以它实际上不是最快的?我认为这应该改写。但这是这个想法的一个很酷的实现。
    • 我认为我们的前提略有不同。您解决了将可能的模式(此时实际上不必是 0-1)映射到其他任意数量的一般问题。我假设 0 和 1 的所有组合都是可能的,这 A) 在大多数系统上将 N 限制为 64,并且 B) 意味着用户必须在输出上提供完整的 2**N 大小的映射。我的解决方案更快,因为它解决了一个更简单的问题。您使用排序是因为 numpy 不支持哈希表,我在数组上使用了几个线性传递。
    • 就复杂性而言,我会使用哈希表来解决您正在解决的问题。实际上,它在 python 中总是会更慢,但在复杂性方面,它会提供一个 O(n) 解决方案,但由于 numpy 的限制,你不得不使用 O(n log n)。
    • @mathfux。请记住,np.unique 是穿着风衣的np.sort, np.diff, np.flatnonzero
    • @mathfux。字面意思可能不是这样。重要的部分是它是基于排序的,这使它成为O(n log n)。之后的步骤相当于np.diffnp.flatnonzero,但它们可能在C 中实现。我从未检查过,因为O(n) 部分在做什么并不重要。跨度>