【问题标题】:PyTorch C++ Extensions: Accessing data for Half TensorsPyTorch C++ 扩展:访问半张量的数据
【发布时间】:2020-02-25 01:08:06
【问题描述】:

我正在尝试使用 C++ Tensor API 为 PyTorch 编写 C++/CUDA 扩展,并且我希望我的代码能够同时使用 float32 和 float16(半精度)。我不确定如何访问来自 Python 的半张量的数据指针。

这是我对浮点张量的处理方式:

// Access data pointer for float Tensor A
torch::Tensor A;
float* ptr = A.data<float>();

这是我尝试过的半张量:

// CUDA float 16 type
// undefined symbol: _ZNK2at6Tensor4dataI6__halfEEPT_v
A.data<__half>();

// PyTorch float16 type
// error: no instance of function template "at::Tensor::data" 
A.data<torch::ScalarType::Half>();

// Casting to __half*
// This compiles but throws and error if the requested pointer type doesn't match the Tensor type:
// RuntimeError: expected scalar type Float but found Half
(__half*)(A.data<float>());

我尝试查看 C++ api 源代码,但找不到任何其他看起来像 float16 类型的内容。

系统信息: Python 3.6.2 PyTorch 1.0.1

【问题讨论】:

    标签: c++ templates pytorch


    【解决方案1】:

    正确的类型原来是at::Half

    【讨论】:

      猜你喜欢
      • 2019-07-08
      • 2021-02-10
      • 2022-10-17
      • 2021-02-01
      • 1970-01-01
      • 2021-08-20
      • 1970-01-01
      • 2022-11-26
      • 1970-01-01
      相关资源
      最近更新 更多