【问题标题】:How to read Torch Tensor from C [closed]如何从 C 中读取 Torch 张量 [关闭]
【发布时间】:2015-07-15 08:58:37
【问题描述】:

我必须使用 Torch 框架训练一个卷积神经网络,然后用 C 语言编写相同的网络。 为此,我必须以某种方式从我的 C 程序中读取网络的学习参数,但我找不到将 Torch 张量转换或写入文件以使其在 C 中可读的方法。 理想情况下,我想在 C 中将张量转换为双精度数组。

有人知道怎么做吗?在此先感谢:)

【问题讨论】:

    标签: c lua neural-network luajit torch


    【解决方案1】:

    我找不到将 Torch 张量转换或写入文件以使其在 C 中可读的方法。理想情况下,我想在 C 中将张量转换为双精度数组。

    最基本的(也是直接的)方式是直接在C中fread将你之前写入的数据写入二进制文件。在这种情况下,您通常会连接每一层的权重和偏差(如果有)。

    在 Lua/Torch 方面,您可以使用 File 实用程序来逐字地 fwrite 每个张量数据。例如,这是一个基本功能:

    local fwrite = function(tensor, file)
      if not tensor then return false end
      local n = tensor:nElement()
      local s = tensor:storage()
      return assert(file:writeDouble(s) == n)
    end
    

    例如,如果m 引用了包含权重的torch/nn 模块,您可以按如下方式使用它:

    local file = torch.DiskFile("net.bin", "w"):binary()
    fwrite(m.weight, file)
    fwrite(m.bias, file)
    

    当然,您需要编写自己的逻辑来确保您fwrite 并连接所有层的所有权重。在C端,除了net.bin,你还需要知道你的网络的结构(nb.layers,内核大小等参数)来知道double-s到fread有多少块。

    作为示例(在 Lua 中),您可以查看 overfeat-torch(非官方项目),它说明了如何读取这样一个普通的二进制文件:请参阅 ParamBank 工具。

    请记住,稳健的解决方案包括使用适当的二进制序列化格式,如 msgpackProtocol Buffers,这将使导出/导入过程干净且可移植。

    --

    这是一个玩具示例:

    -- EXPORT
    require 'nn'
    
    local fwrite = function(tensor, file)
      if not tensor then return false end
      local n = tensor:nElement()
      local s = tensor:storage()
      return assert(file:writeDouble(s) == n)
    end
    
    local m = nn.Linear(2, 2)
    
    print(m.weight)
    print(m.bias)
    
    local file = torch.DiskFile("net.bin", "w"):binary()
    fwrite(m.weight, file)
    fwrite(m.bias, file)
    

    然后在 C 中:

    /* IMPORT */
    #include <stdio.h>
    #include <stdlib.h>
    #include <assert.h>
    
    int
    main(void)
    {
      const int N = 2; /* nb. neurons */
    
      double *w = malloc(N*N*sizeof(*w)); /* weights */
      double *b = malloc(N*sizeof(*w));   /* biases */
    
      FILE *f = fopen("net.bin", "rb");
      assert(fread(w, sizeof(*w), N*N, f) == N*N);
      assert(fread(b, sizeof(*w), N, f) == N);
      fclose(f);
    
      int i, j;
      for (i = 0; i < N; i++)
        for (j = 0; j < N; j++)
          printf("w[%d,%d] = %f\n", i, j, w[N*i+j]);
    
      for (i = 0; i < N; i++)
          printf("b[%d] = %f\n", i, b[i]);
    
      free(w);
      free(b);
    
      return 0;
    }
    

    【讨论】:

    • 感谢您的回答!我已经尝试从 Lua 写入文件,然后按照您的建议使用 fread 从 C 读取,但我只得到了垃圾!我今天再次尝试使用浮点数而不是双精度数,它可以工作:D 再次感谢您!
    • 刚刚添加了一个玩具示例来说明该过程。
    猜你喜欢
    • 1970-01-01
    • 2021-10-07
    • 2016-07-02
    • 2016-03-15
    • 2017-09-11
    • 2019-06-09
    • 1970-01-01
    • 2023-04-02
    • 2016-07-05
    相关资源
    最近更新 更多