【问题标题】:How to change activation layer in Pytorch pretrained module?如何更改 Pytorch 预训练模块中的激活层?
【发布时间】:2020-02-06 09:08:33
【问题描述】:

如何更改 Pytorch 预训练网络的激活层? 这是我的代码:

print("All modules")
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
        child=nn.SELU()
        print(child)
print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

这是我的输出:

All modules
ReLU(inplace=True)
Before changing activation
ReLU(inplace=True)
SELU()
after changing activation
ReLU(inplace=True)

【问题讨论】:

标签: python neural-network deep-learning pytorch activation-function


【解决方案1】:

._modules 为我解决了问题。

for name,child in net.named_children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        net._modules['relu'] = nn.SELU()

【讨论】:

    【解决方案2】:

    我假设您使用模块接口nn.ReLU 来创建激活层,而不是使用功能接口F.relu。如果是这样,setattr 对我有用。

    import torch
    import torch.nn as nn
    
    # This function will recursively replace all relu module to selu module. 
    def replace_relu_to_selu(model):
        for child_name, child in model.named_children():
            if isinstance(child, nn.ReLU):
                setattr(model, child_name, nn.SELU())
            else:
                replace_relu_to_selu(child)
    
    ########## A toy example ##########
    net = nn.Sequential(
                nn.Conv2d(3, 32, kernel_size=3, stride=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(3, 32, kernel_size=3, stride=1),
                nn.ReLU(inplace=True)
              )
    
    ########## Test ##########
    print('Before changing activation')
    for child in net.children():
        if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
            print(child)
    # Before changing activation
    # ReLU(inplace=True)
    # ReLU(inplace=True)
    
    
    print('after changing activation')
    for child in net.children():
        if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
            print(child)
    # after changing activation
    # SELU()
    # SELU(
    

    【讨论】:

    • 打印网络架构,在我的情况下,它不起作用。
    【解决方案3】:

    这是替换任何层的通用函数

    def replace_layers(model, old, new):
        for n, module in model.named_children():
            if len(list(module.children())) > 0:
                ## compound module, go inside it
                replace_layers(module, old, new)
                
            if isinstance(module, old):
                ## simple module
                setattr(model, n, new)
    
    replace_layer(model, nn.ReLU, nn.ReLU6())
    

    我为此苦苦挣扎了几天。所以,我做了一些挖掘并写了一个kaggle notebook 解释如何在 pytorch 中访问不同类型的层/模块。

    【讨论】:

      【解决方案4】:

      我将提供一个适用于任何层的更通用的解决方案(并避免其他问题,例如在循环遍历字典时修改字典或彼此内部存在递归 nn.modules 时)。

      def replace_bn(module, name):
          '''
          Recursively put desired batch norm in nn.module module.
      
          set module = net to start code.
          '''
          # go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
          for attr_str in dir(module):
              target_attr = getattr(m, attr_str)
              if type(target_attr) == torch.nn.BatchNorm2d:
                  print('replaced: ', name, attr_str)
                  new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
                                                track_running_stats=False)
                  setattr(module, attr_str, new_bn)
      
          # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
          for name, immediate_child_module in module.named_children():
              replace_bn(immediate_child_module, name)
      
      replace_bn(model, 'model')
      

      关键是您需要递归地不断更改层(主要是因为有时您会遇到本身具有模块的属性)。我认为比上面更好的代码是添加另一个 if 语句(在批处理规范之后)检测是否必须递归,如果是,则递归。上面的方法首先改变了外层的批处理规范(即第一个循环),然后用另一个循环确保没有其他应该递归的对象被遗漏(然后递归)。

      原帖:https://discuss.pytorch.org/t/how-to-modify-a-pretrained-model/60509/10

      学分https://discuss.pytorch.org/t/replacing-convs-modules-with-custom-convs-then-notimplementederror/17736/3?u=brando_miranda

      【讨论】:

        猜你喜欢
        • 2021-05-20
        • 2022-10-07
        • 1970-01-01
        • 2020-06-13
        • 2019-09-11
        • 1970-01-01
        • 1970-01-01
        • 2022-08-10
        • 2018-02-20
        相关资源
        最近更新 更多