【问题标题】:Nontransitive subclassing with numpy and jax使用 numpy 和 jax 进行非传递子类化
【发布时间】:2025-12-04 19:20:05
【问题描述】:

我的问题很简单:

>>> isinstance(x, jax.numpy.ndarray)
True
>>> issubclass(jax.numpy.ndarray, numpy.ndarray)
True
>>> isinstance(x, numpy.ndarray)
False

?

现在我会闲逛,以便 SE 接受我合理的问题。

【问题讨论】:

    标签: python numpy jax


    【解决方案1】:

    出现这种情况的原因是jax.numpy.ndarray 使用元类覆盖了实例检查:

    class _ArrayMeta(type(np.ndarray)):  # type: ignore
      """Metaclass for overriding ndarray isinstance checks."""
    
      def __instancecheck__(self, instance):
        try:
          return isinstance(instance.aval, _arraylike_types)
        except AttributeError:
          return isinstance(instance, _arraylike_types)
    
    class ndarray(np.ndarray, metaclass=_ArrayMeta):
      dtype: np.dtype
      shape: Tuple[int, ...]
      size: int
    
      def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
                   order=None):
        raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
                        " Use jax.numpy.array, or jax.numpy.zeros instead.")
    

    (view source)

    您的代码返回它所做的事情的原因是因为您有一个x 值,它不是numpy.ndarray 的一个实例,但是这个__instancecheck__ 方法返回true。

    为什么在 JAX 中有这种诡计?好吧,出于 JIT 编译、自动微分和其他转换的目的,JAX 使用称为 tracers 的替代对象,这些对象看起来和行为都像一个数组,尽管实际上并不是一个数组。这种对实例检查的覆盖是 JAX 用来进行此类跟踪的技巧之一。

    【讨论】:

    • 很好的答案。谢谢!