【发布时间】:2009-12-21 10:02:03
【问题描述】:
我正在尝试执行以下操作
>> from numpy import *
>> x = array([[3,2,3],[4,4,4]])
>> y = set(x)
TypeError: unhashable type: 'numpy.ndarray'
如何轻松高效地创建包含 Numpy 数组中所有元素的集合?
【问题讨论】:
我正在尝试执行以下操作
>> from numpy import *
>> x = array([[3,2,3],[4,4,4]])
>> y = set(x)
TypeError: unhashable type: 'numpy.ndarray'
如何轻松高效地创建包含 Numpy 数组中所有元素的集合?
【问题讨论】:
如果你想要一组元素,这里有另一种可能更快的方法:
y = set(x.flatten())
PS:在对一个 10x100 阵列进行x.flat、x.flatten() 和x.ravel() 之间的比较后,我发现它们的运行速度大致相同。对于 3x3 数组,最快的版本是迭代器版本:
y = set(x.flat)
我会推荐它,因为它是内存成本较低的版本(它可以很好地随着数组的大小而扩展)。
PPS:还有一个 NumPy 函数可以做类似的事情:
y = numpy.unique(x)
这确实会生成一个 NumPy 数组,其元素与 set(x.flat) 相同,但作为 NumPy 数组。这非常快(几乎快 10 倍),但如果您需要 set,那么执行 set(numpy.unique(x)) 会比其他过程慢一些(构建集合需要很大的开销)。
【讨论】:
数组的不可变对应物是元组,因此,尝试将数组数组转换为元组数组:
>> from numpy import *
>> x = array([[3,2,3],[4,4,4]])
>> x_hashable = map(tuple, x)
>> y = set(x_hashable)
set([(3, 2, 3), (4, 4, 4)])
【讨论】:
如果您想从 ndarray 中包含的 元素 创建一个集合,上述答案有效,但如果您想创建一组 ndarray 对象 - 或使用 @987654323 @ 对象作为字典中的键——然后你必须为它们提供一个可散列的包装器。一个简单的例子见下面的代码:
from hashlib import sha1
from numpy import all, array, uint8
class hashable(object):
r'''Hashable wrapper for ndarray objects.
Instances of ndarray are not hashable, meaning they cannot be added to
sets, nor used as keys in dictionaries. This is by design - ndarray
objects are mutable, and therefore cannot reliably implement the
__hash__() method.
The hashable class allows a way around this limitation. It implements
the required methods for hashable objects in terms of an encapsulated
ndarray object. This can be either a copied instance (which is safer)
or the original object (which requires the user to be careful enough
not to modify it).
'''
def __init__(self, wrapped, tight=False):
r'''Creates a new hashable object encapsulating an ndarray.
wrapped
The wrapped ndarray.
tight
Optional. If True, a copy of the input ndaray is created.
Defaults to False.
'''
self.__tight = tight
self.__wrapped = array(wrapped) if tight else wrapped
self.__hash = int(sha1(wrapped.view(uint8)).hexdigest(), 16)
def __eq__(self, other):
return all(self.__wrapped == other.__wrapped)
def __hash__(self):
return self.__hash
def unwrap(self):
r'''Returns the encapsulated ndarray.
If the wrapper is "tight", a copy of the encapsulated ndarray is
returned. Otherwise, the encapsulated ndarray itself is returned.
'''
if self.__tight:
return array(self.__wrapped)
return self.__wrapped
使用包装类很简单:
>>> from numpy import arange
>>> a = arange(0, 1024)
>>> d = {}
>>> d[a] = 'foo'
Traceback (most recent call last):
File "<input>", line 1, in <module>
TypeError: unhashable type: 'numpy.ndarray'
>>> b = hashable(a)
>>> d[b] = 'bar'
>>> d[b]
'bar'
【讨论】:
如果你想要一组元素:
>> y = set(e for r in x
for e in r)
set([2, 3, 4])
对于一组行:
>> y = set(tuple(r) for r in x)
set([(3, 2, 3), (4, 4, 4)])
【讨论】:
我喜欢xperroni's idea。但我认为可以使用从 ndarray 直接继承而不是包装它来简化实现。
from hashlib import sha1
from numpy import ndarray, uint8, array
class HashableNdarray(ndarray):
def __hash__(self):
if not hasattr(hasattr, '__hash'):
self.__hash = int(sha1(self.view(uint8)).hexdigest(), 16)
return self.__hash
def __eq__(self, other):
if not isinstance(other, HashableNdarray):
return super(HashableNdarray, self).__eq__(other)
return super(HashableNdarray, self).__eq__(super(HashableNdarray, other)).all()
NumPy ndarray 可以被视为派生类并用作可散列对象。 view(ndarray) 可用于反向转换,但在大多数情况下甚至不需要它。
>>> a = array([1,2,3])
>>> b = array([2,3,4])
>>> c = array([1,2,3])
>>> s = set()
>>> s.add(a.view(HashableNdarray))
>>> s.add(b.view(HashableNdarray))
>>> s.add(c.view(HashableNdarray))
>>> print(s)
{HashableNdarray([2, 3, 4]), HashableNdarray([1, 2, 3])}
>>> d = next(iter(s))
>>> print(d == a)
[False False False]
>>> import ctypes
>>> print(d.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
<__main__.LP_c_double object at 0x7f99f4dbe488>
【讨论】:
添加到@Eric Lebigot 和他的精彩帖子。
以下是构建张量查找表的诀窍:
a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]])
np.unique(a, axis=0)
输出:
array([[1, 0, 0], [2, 3, 4]])
【讨论】: