【问题标题】:Possible to add numpy arrays to python sets?可以将 numpy 数组添加到 python 集中吗?
【发布时间】:2016-02-16 20:03:16
【问题描述】:

我知道,为了将元素添加到集合中,它必须是可散列的,而 numpy 数组似乎不是。这给我带来了一些问题,因为我有以下代码:

fill_set = set()
for i in list_of_np_1D:
    vecs = i + np_2D
    for j in range(N):
        tup = tuple(vecs[j,:])
        fill_set.add(tup)

# list_of_np_1D is a list of 1D numpy arrays
# np_2D is a 2D numpy array
# np_2D could also be converted to a list of 1D arrays if it helped.

我需要让它运行得更快,将近 50% 的运行时间用于将 2D numpy 数组的切片转换为元组,以便将它们添加到集合中。

所以我一直在尝试找出以下内容

  • 是否有任何方法可以使 numpy 数组或类似 numpy 数组(具有向量加法)功能的东西可散列化,以便将它们添加到集合中?
  • 如果没有,有什么方法可以加快元组转换的过程吗?

感谢您的帮助!

【问题讨论】:

  • NumPy 数组不仅不可散列,它们甚至都不是真正的 equatable。如果ab 中的任何一个是数组,并且set 不知道如何处理元素比较结果数组或如何拨打np.array_equal
  • 您真的需要将数组转换为 Python 集吗? Numpy 原生支持对数组的各种集合操作(​​参见numpy.lib.arraysetops)。
  • @ali_m 我不知道谢谢,我现在去看看。最终我有两个大的一维整数数组集合,我需要能够向这些集合添加更多数组并执行与集合具有的.difference_update 操作等效的操作。
  • 您可以使用tuple(vecs[j,:].tolist()) 来减少转换时间。如果您只想将数组保存在集合中,您甚至可以通过vecs[j, :].tobytes() 将数组转换为字节对象。
  • @HYRY 谢谢,我现在就去试试。

标签: python numpy casting set tuples


【解决方案1】:

先创建一些数据:

import numpy as np
np.random.seed(1)
list_of_np_1D = np.random.randint(0, 5, size=(500, 6))
np_2D = np.random.randint(0, 5, size=(20, 6))

运行您的代码:

%%time
fill_set = set()
for i in list_of_np_1D:
    vecs = i + np_2D
    for v in vecs:
        tup = tuple(v)
        fill_set.add(tup)
res1 = np.array(list(fill_set))

输出:

CPU times: user 161 ms, sys: 2 ms, total: 163 ms
Wall time: 167 ms

这是一个加速版,它使用广播,.view()方法将dtype转换为字符串,调用set()后将字符串转换回数组:

%%time
r = list_of_np_1D[:, None, :] + np_2D[None, :, :]
stype = "S%d" % (r.itemsize * np_2D.shape[1])
fill_set2 = set(r.ravel().view(stype).tolist())
res2 = np.zeros(len(fill_set2), dtype=stype)
res2[:] = list(fill_set2)
res2 = res2.view(r.dtype).reshape(-1, np_2D.shape[1])

输出:

CPU times: user 13 ms, sys: 1 ms, total: 14 ms
Wall time: 14.6 ms

检查结果:

np.all(res1[np.lexsort(res1.T), :] == res2[np.lexsort(res2.T), :])

您也可以使用lexsort() 删除重复数据:

%%time
r = list_of_np_1D[:, None, :] + np_2D[None, :, :]
r = r.reshape(-1, r.shape[-1])

r = r[np.lexsort(r.T)]
idx = np.where(np.all(np.diff(r, axis=0) == 0, axis=1))[0] + 1
res3 = np.delete(r, idx, axis=0)

输出:

CPU times: user 13 ms, sys: 3 ms, total: 16 ms
Wall time: 16.1 ms

检查结果:

np.all(res1[np.lexsort(res1.T), :] == res3)

【讨论】:

    猜你喜欢
    • 2022-01-16
    • 2021-07-16
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2011-04-22
    • 2011-02-10
    • 1970-01-01
    相关资源
    最近更新 更多