【发布时间】:2026-02-19 03:00:02
【问题描述】:
Tensorflow 覆盖 Tensor 类、including __lt__, __ge__ 等的多个运算符。
但是,__eq__ seems to be conspicuously absent 的实现:
ops.Tensor._override_operator("__lt__", gen_math_ops.less)
ops.Tensor._override_operator("__le__", gen_math_ops.less_equal)
ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
为什么 tensorflow 的张量的 == 的行为方式与 numpy 数组的行为方式不同?
代码示例:
a = tf.constant([1,2])
b = tf.constant([3,4])
a == b
>>> False
a < b
>>> <tf.Tensor 'Less:0' shape=(2,) dtype=bool>
另一方面,使用 numpy:
a = np.asarray([1,2])
b = np.asarray([3, 4])
a == b
>>> array([False, False], dtype=bool)
【问题讨论】:
-
您是否基于
__eq__并非仅根据这些行定义的断言?因为我看到了other code that handles operator overrides in a generic manner 例如。 -
import tensorflow as tf,然后__eq__ in vars(tf.Tensor)产生True,所以它确实定义了钩子。它被定义为directly on the class。 -
@MartijnPieters 不,我的观察是基于我的代码没有按照我的预期去做。这些链接是在一些挖掘之后产生的。另外,我知道定义了张量的相等性。但是,它不符合
numpy数组。我希望添加的代码可以澄清问题。 -
为什么在测试相等性时张量应该广播?该项目明确决定改为测试身份。
-
是的,有一些特别之处;我发现了一个解释原因的 github 问题。
标签: python numpy tensorflow