【问题标题】:How to avoid Numpy type conversions?如何避免 Numpy 类型转换?
【发布时间】:2020-01-18 18:56:21
【问题描述】:

对于从整数和32 bit float arrays64 bit float arrays 的自动 Numpy 类型转换,是否可以避免或发出警告?

我的用例是我正在开发一个大型分析包(20k 行 Python 和 Numpy),目前混合了 float 32 和 64 以及一些 int dtypes,很可能导致性能欠佳和浪费内存,基本上我想在任何地方都使用 float32。

我知道在 Tensorflow 中组合两个不同 dtype 的数组会产生错误 - 正是因为隐式转换为 float64 会导致性能不佳,并且在所有计算张量上都具有“传染性”,并且很难找到位置如果隐式完成,则会引入它。

寻找 Numpy 中的选项或猴子修补 Numpy 的方法,使其在这方面的行为类似于 Tensorflow,即在 np.addnp.mul 等操作上的隐式类型转换时发出错误,甚至更好的是,发出带有打印回溯的警告,以便继续执行,但我看到它发生在哪里。可能吗?

【问题讨论】:

  • ufunc like np.add 采用casting 参数。看起来默认值是same_kind' https://docs.scipy.org/doc/numpy/reference/ufuncs.html#casting-rules, https://docs.scipy.org/doc/numpy/reference/generated/numpy.can_cast.html#numpy.can_cast. I think you want casting='no'`。
  • 提供out 参数也可能会有所帮助。
  • 但是在测试中,np.multiply(x,2., casting='no') 给了我一个错误,因为它无法将 np.array(2.) (float64) 转换为 float32(以匹配 x。所以这个转换参数可能需要做的更少使用生成的dtype,以及更多作为输入的内容。

标签: python numpy tensorflow floating-point


【解决方案1】:

免责声明:我没有以任何认真的方式对此进行测试,但这似乎是一条有希望的路线。

操纵 ufunc 行为的相对轻松的方法似乎是 subclassing ndarray 并覆盖 __array_ufunc__。例如,如果您满足于捕获任何产生float64

class no64(np.ndarray):
    def __array_ufunc__(self, ufunc, method, *inputs, **kwds):
        ret = getattr(ufunc, method)(*map(np.asarray,inputs), **kwds)
        # some ufuncs return multiple arrays:
        if isinstance(ret,tuple):
            if any(x.dtype == np.float64 for x in ret):
                raise ValueError
            return (*(x.view(no64) for x in ret),)
        if ret.dtype == np.float64:
             raise ValueError
        return ret.view(no64)

x = np.arange(6,dtype=np.float32).view(no64)

现在让我们看看我们的类能做什么:

x*x
# no64([ 0.,  1.,  4.,  9., 16., 25.], dtype=float32)
np.sin(x)
# no64([ 0.        ,  0.84147096,  0.9092974 ,  0.14112   , -0.7568025 ,
#       -0.9589243 ], dtype=float32)
np.frexp(x)
# (no64([0.   , 0.5  , 0.5  , 0.75 , 0.5  , 0.625], dtype=float32), no64([0, 1, 2, 2, 3, 3], dtype=int32))

现在让我们将它与 64 位参数配对:

x + np.arange(6)
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
#   File "<stdin>", line 9, in __array_ufunc__
# ValueError

np.multiply.outer(x, np.arange(2.0))
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
#   File "<stdin>", line 9, in __array_ufunc__
# ValueError

什么不起作用(我相信还有更多)

np.outer(x, np.arange(2.0))  # not a ufunc, so slips through
# array([[0., 0.],
#        [0., 1.],
#        [0., 2.],
#        [0., 3.],
#        [0., 4.],
#        [0., 5.]])

__array_function__ 似乎吸引了那些人。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2019-10-04
    • 2014-11-26
    • 2020-12-29
    • 1970-01-01
    • 2020-12-05
    • 2022-01-22
    • 2022-07-31
    • 1970-01-01
    相关资源
    最近更新 更多