【问题标题】:Pytorch: why print(model) does not show the activation functions?Pytorch:为什么 print(model) 不显示激活函数?
【发布时间】:2020-06-20 00:26:09
【问题描述】:

我需要从 pytorch 中经过训练的 NN 中提取权重、偏差和至少激活函数的类型。

我知道提取权重和偏差的命令是:

model.parameters()

但我不知道如何提取层上使用的激活函数。这是我的网络

class NetWithODE(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output, sampling_interval, scaler_features):
        super(NetWithODE, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # hidden layer
        self.predict = torch.nn.Linear(n_hidden, n_output)  # output layer
        self.sampling_interval = sampling_interval
        self.device = torch.device("cpu")
        self.dtype = torch.float
        self.scaler_features = scaler_features

    def forward(self, x):
        x0 = x.clone().requires_grad_(True)
        # activation function for hidden layer
        x = F.relu(self.hidden(x))
        # linear output, here r should be the output
        r = self.predict(x)
        # Now the r enters the integrator
        x = self.integrate(r, x0)

        return x

    def integrate(self, r, x0):
        # RK4 steps per interval
        M = 4
        DT = self.sampling_interval / M
        X = x0

        for j in range(M):
            k1 = self.ode(X, r)
            k2 = self.ode(X + DT / 2 * k1, r)
            k3 = self.ode(X + DT / 2 * k2, r)
            k4 = self.ode(X + DT * k3, r)
            X = X + DT / 6 * (k1 + 2 * k2 + 2 * k3 + k4)

        return X

    def ode(self, x0, r):
        qF = r[0, 0]
        qA = r[0, 1]
        qP = r[0, 2]
        mu = r[0, 3]

        FRU = x0[0, 0]
        AMC = x0[0, 1]
        PHB = x0[0, 2]
        TBM = x0[0, 3]

        fFRU = qF * TBM  
        fAMC = qA * TBM  
        fPHB = qP - mu * PHB
        fTBM = mu * TBM

        return torch.stack((fFRU, fAMC, fPHB, fTBM), 0)

如果我运行命令

print(model)

我明白了

NetWithODE(
  (hidden): Linear(in_features=4, out_features=10, bias=True)
  (predict): Linear(in_features=10, out_features=4, bias=True)
)

但是我在哪里可以获得激活函数(在本例中为 Relu)?

我有 pytorch 1.4。

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    向网络图添加操作有两种方式:低级函数方式和更高级的对象方式。您需要后者来使您的结构可观察,在第一种情况下只是调用(不完全是,但是......)一个函数而不存储有关它的信息。所以,而不是

        def forward(self, x):
        ...
            x = F.relu(self.hidden(x))
    

    一定是这样的

    def __init__(...):
        ...
        self.myFirstRelu= torch.nn.ReLU()
    
    def forward(self, x):
        ...
        x1 = self.hidden(x)
        x2 = self.myFirstRelu(x1)
    

    无论如何,两种方式的混合通常是个坏主意,尽管即使torchvision 模型也有这样的不一致之处:models.inception_v3 例如不注册池 >:-( (编辑:它已在 2020 年 6 月修复,谢谢, mitmul!)。


    更新: - 谢谢,这行得通,现在如果我打印,我会看到 ReLU()。但这似乎只以 init 中定义的相同顺序打印函数。有没有办法获得层和激活函数之间的关联?例如,我想知道哪个激活应用于第 1 层,哪个激活应用于第 2 层,依此类推...

    没有统一的方法,但这里有一些技巧: 对象方式:

    -按顺序初始化它们

    -使用 torch.nn.Sequential

    -像这样的节点上的钩子回调-

    def hook( m, i, o):
        print( m._get_name() )
    
    for ( mo ) in model.modules():
        mo.register_forward_hook(hook)
    

    函数式和对象式:

    -使用基于前向传递的内部模型图,就像torchviz 做的(https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py),或者只使用上述torchviz 生成的图。

    【讨论】:

    • 谢谢,这行得通,现在如果我打印,我会看到 ReLU()。但这似乎只以__init__ 中定义的相同顺序打印函数。有没有办法获得层和激活函数之间的关联?例如,我想知道哪个激活应用于第 1 层,哪个激活应用于第 2 层,依此类推...
    猜你喜欢
    • 1970-01-01
    • 2020-10-11
    • 2018-12-10
    • 2019-09-30
    • 1970-01-01
    • 2022-01-16
    • 1970-01-01
    • 1970-01-01
    • 2020-07-17
    相关资源
    最近更新 更多