pytorch查看模型model参数parameters

示例1:pytorch自带的faster r-cnn模型

import torchimport torchvisionmodel = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)for name, p in model.named_parameters():    print(name)    print(p.requires_grad)    print(...)#或者for p in model.parameters():    print(p)    print(...)

示例2:自定义网络模型

class Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]        self.features = self._vgg_layers(cfg)    def _vgg_layers(self, cfg):        layers = []        in_channels = 3        for x in cfg:            if x == 'M':                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]            else:                layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1),                        nn.BatchNorm2d(x),                        nn.ReLU(inplace=True)                        ]                in_channels = x                    return nn.Sequential(*layers)    def forward(self, data):        out_map = self.features(data)        return out_map    Model = Net()for name, p in model.named_parameters():    print(name)    print(p.requires_grad)    print(...)#或者for p in model.parameters():    print(p)    print(...)

在自定义网络中,model.parameters()方法继承自nn.Module

pytorch查看模型参数总结

1:DNN_printer

其中(3, 32, 32)是输入的大小,其他方法中的参数同理

from DNN_printer import DNN_printerbatch_size = 512def train(epoch):    print('\nEpoch: %d' % epoch)    net.train()    train_loss = 0    correct = 0    total = 0    // put the code here and you can get the result    DNN_printer(net, (3, 32, 32),batch_size)

结果

在pytorch中如何查看模型model参数parameters

2:parameters

def cnn_paras_count(net):    """cnn参数量统计, 使用方式cnn_paras_count(net)"""    # Find total parameters and trainable parameters    total_params = sum(p.numel() for p in net.parameters())    print(f'{total_params:,} total parameters.')    total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)    print(f'{total_trainable_params:,} training parameters.')    return total_params, total_trainable_paramscnn_paras_count(net)

直接输出参数量,然后自己计算

需要注意的是,一般模型中参数是以float32保存的,也就是一个参数由4个bytes表示,那么就可以将参数量转化为存储大小。

例如:

  • 44426个参数*4 / 1024 ≈ 174KB

3:get_model_complexity_info()

from ptflops import get_model_complexity_infofrom torchvision import modelsnet = models.mobilenet_v2()ops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, 										print_per_layer_stat=True, verbose=True)

在pytorch中如何查看模型model参数parameters

4:torchstat

from torchstat import statimport torchvision.models as modelsmodel = models.resnet152()stat(model, (3, 224, 224))

输出

在pytorch中如何查看模型model参数parameters

以上为个人经验,希望能给大家一个参考,也希望大家多多支持。

原文地址:https://blog.csdn.net/qq_38600065/article/details/105552816

相关文章:

  • 2021-06-19
  • 2021-12-16
  • 2022-12-23
  • 2021-10-19
  • 2022-12-23
  • 2021-10-04
  • 2021-11-14
猜你喜欢
  • 2022-12-23
  • 2022-12-23
  • 2022-01-01
  • 2021-06-25
  • 2021-08-02
  • 2021-04-14
  • 2022-12-23
相关资源
相似解决方案