【问题标题】:how pytorch nn.module save submodulepytorch nn.module 如何保存子模块
【发布时间】:2018-04-29 16:16:53
【问题描述】:

我对 pytorch nn.module 的工作原理有一些疑问

import torch
import torch.nn as nn



class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.sub_module = nn.Linear(10, 5)
        self.value = 3

net = Net()
print(net.__dict__)

输出

{'_modules': OrderedDict([('sub_module', Linear (10 -> 5))]),  'value': 3, ...}

我知道一个类的每一个属性都应该存储在__dict__中,为什么value(a int value)在里面,而sub_module(a nn.Module)没有,相反,sub_module是存储在 _modules

我阅读了 nn.Module 实现的代码,但我没有弄明白。有人有什么想法吗?

谢谢!!

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    我会尽量保持简单。

    每次您在类Net 中创建一个新项目时,例如:self.sub_module = nn.Linear(10, 5),它都会调用其父类的方法__setattr__,在本例中为nn.Module。然后,在__setattr__ 方法中,参数被存储到它们所属的字典中。在这种情况下,由于nn.Linear 是一个模块,它被存储到_modules 字典中。

    这是在 Modulehttps://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L389 中执行此操作的一段代码

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2021-06-23
      • 1970-01-01
      • 2021-03-11
      • 2020-08-18
      • 2020-10-04
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多