参考

model.state_dict()中保存了{参数名:参数值}的字典

import torchvision.models as models

resnet34 = models.resnet34(pretrained=True)
resnet34.state_dict().keys()
for param in resnet34.parameters():
    param.requires_grad = False
resnet.fc = nn.Linear(resnet.fc.in_features, 100)

# resnet.fc = nn.Sequential(nn.Linear(512, 100),
#                          nn.ReLU(),
#                          nn.Linear(100, 10))

保存模型
torch.save(model.state_dict(), PATH) # 保存模型为pth

导入模型

model = ModelClass()   # 需要先建立模型
model.load_state_dict(torch.load(PATH)) # 加载模型

相关文章:

  • 2022-12-23
  • 2021-09-19
  • 2021-04-04
  • 2022-12-23
  • 2023-03-17
  • 2022-03-12
  • 2022-01-11
  • 2021-12-04
猜你喜欢
  • 2021-10-25
  • 2021-12-07
  • 2022-01-07
  • 2022-02-15
  • 2021-09-23
  • 2022-12-23
相关资源
相似解决方案