【发布时间】:2018-09-13 10:32:17
【问题描述】:
我遇到了一个问题,我认为这与自定义类 __eq__/__hash__ 函数的不当实现有关。
我创建了一个自定义 Line 类,其中一条线包含一个斜率和 y 截距,它们是从 2 个点计算得出的。我正在对两行之间的相等性进行测试,这会产生如下所示的意外结果。
我正在寻找解释为什么我在下面包含的测试代码中的前 2 行不相等,但第 2 组 2 行是相等的,尽管两组行的斜率和你拦截?
class Point:
def __init__(self, x1, y1):
self.x = x1
self.y = y1
def to_string(self):
return '{},{}'.format(self.x, self.y)
class Line:
def __init__(self, pt1, pt2):
self.m = (pt1.y - pt2.y)/(pt1.x - pt2.x)
self.b = pt1.y - self.m * pt1.x
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.m == other.m and self.b == other.b
else:
return False
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash((self.m, self.b))
def print_line(self):
print('y = {} x + {}'.format(self.m, self.b))
测试代码:
pt_a = Point(0.1, 1.0)
pt_b = Point(1.1, 1.1)
pt_c = Point(2.1, 1.2)
line1 = Line(pt_a, pt_b)
print('line1:')
line1.print_line()
line2 = Line(pt_b, pt_c)
print('line2:')
line2.print_line()
if line1 == line2:
print('lines equal')
else:
print('lines not equal')
pt_x = Point(0.5, 1)
pt_y = Point(1.5, 2)
pt_z = Point(2.5, 3)
line1 = Line(pt_x, pt_y)
print('line1:')
line1.print_line()
line2 = Line(pt_y, pt_z)
print('line2:')
line2.print_line()
if line1 == line2:
print('lines equal')
else:
print('lines not equal')
这个测试产生输出:
line1:
y = 0.1 x + 0.99
line2:
y = 0.1 x + 0.99
lines not equal
line1:
y = 1.0 x + 0.5
line2:
y = 1.0 x + 0.5
lines equal
【问题讨论】:
-
我不是在比较 Points,Lines 不会存储它们创建时使用的 Points。
-
您依赖于浮点舍入误差是一致的还是没有发生。
-
这个问题不仅仅是通常的“浮点数不精确”问题,因为 OP 正在编写
__hash__方法,这意味着他不能只使用isclose或等效的. -
旁注:如果你想让其他类使用你的
Lines,你应该return NotImplemented,而不是return False,当你不认识other的类型时在__eq__。而您的__ne__应该是ret = self.__eq__(other)、return ret if ret is NotImplemented else not ret(Python 2 上__ne__的规范主体,以根据__eq__正确实现它;Python 3 会自动正确地执行此操作)。
标签: python python-2.7 hash line