【问题标题】:Problem with missing and unexpected keys while loading my model in Pytorch在 Pytorch 中加载我的模型时丢失和意外键的问题
【发布时间】:2019-05-23 05:48:00
【问题描述】:

我正在尝试使用本教程加载模型:https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference。不幸的是,我非常初学者,我面临一些问题。

我已经创建了检查点:

checkpoint = {'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),'loss': loss}
torch.save(checkpoint, 'checkpoint.pth')

然后我为我的网络编写了类,我想加载文件:

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 1000)
        self.fc3 = nn.Linear(1000, 102)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = log(F.softmax(x, dim=1))
        return x

这样:

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = Network()
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

model = load_checkpoint('checkpoint.pth')

我收到此错误(已编辑以显示整个通信):

RuntimeError: Error(s) in loading state_dict for Network:
    Missing key(s) in state_dict: "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "fc3.weight", "fc3.bias". 
    Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.3.weight", "features.3.bias", "features.6.weight", "features.6.bias", "features.8.weight", "features.8.bias", "features.10.weight", "features.10.bias", "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias", "classifier.fc3.weight", "classifier.fc3.bias". 

这是我的model.state_dict().keys()

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 
'features.3.bias', 'features.6.weight', 'features.6.bias', 
'features.8.weight', 'features.8.bias', 'features.10.weight', 
'features.10.bias', 'classifier.fc1.weight', 'classifier.fc1.bias', 
'classifier.fc2.weight', 'classifier.fc2.bias', 'classifier.fc3.weight', 
'classifier.fc3.bias'])

这是我的模型:

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)

((classifier): Sequential(
(fc1): Linear(in_features=9216, out_features=4096, bias=True)
(relu1): ReLU()
(fc2): Linear(in_features=4096, out_features=1000, bias=True)
(relu2): ReLU()
(fc3): Linear(in_features=1000, out_features=102, bias=True)
(output): LogSoftmax()
)
)

这是我的第一个网络,但我一直在犯错。感谢您引导我走向正确的方向!

【问题讨论】:

  • 如果只是重命名model.state_dict().keys()中的对应键,让features.3.weight变成fc3.weight,以此类推?
  • 我会尽快让你知道
  • 这很奇怪,但是当我这样做时,加载模型后是None
  • 啊,好吧,因为您没有在函数上使用return 值,所以当您调用load_checkpoint 时,它什么也不返回;因此NoneType。如果要从函数中返回模型,则需要在函数底部添加 return model。如果您不需要返回它,请从 model = load_checkpoint('checkpoint.pth') 中删除 model = ,它只会调用该函数。
  • 如果要返回多个变量,则需要单独返回它们。例如。 return checkpoint, model, epoc, loss 等等。在调用函数的地方,您需要将每个返回值捕获到另一个变量中。例如。 checkpoint, model, epoc, loss = load_checkpoint('checkpoint.pth')

标签: python machine-learning neural-network conv-neural-network pytorch


【解决方案1】:

所以您的Network 本质上是AlexNetclassifier 部分,您希望将预训练的AlexNet 权重加载到其中。问题是state_dict 中的键是“完全合格的”,这意味着如果您将网络视为嵌套模块的树,则键只是每个分支中的模块列表,并与@987654326 之类的点连接@。你想

  1. 只保留名称以“分类器”开头的张量。
  2. 删除“分类器”。部分键

所以试试

model = Network()
loaded_dict = checkpoint['model_state_dict']
prefix = 'classifier.'
n_clip = len(prefix)
adapted_dict = {k[n_clip:]: v for k, v in loaded_dict.items()
                if k.startswith(prefix)}
model.load_state_dict(adapted_dict)

【讨论】:

  • 它不会返回任何错误,但是当我print(model) 时它显示None。我的意思是在此之后model = load_checkpoint('checkpoint.pth')
  • 我试图以对我来说更容易理解的方式实现 1) 和 2),但结果仍然相同。加载后的模型为空。
  • 对了,上面Adam说过,要获取模型作为函数的返回值,需要返回模型。
  • 什么意思? AlexNet 由称为featuresclassifier 的两部分组成。您的 Network 仅实现 classifier 所以是的,您正在丢失 features 部分。我假设这就是你的意思
  • 我想保存整个模型以便将来训练它
猜你喜欢
  • 2019-03-27
  • 2019-11-12
  • 2022-01-17
  • 1970-01-01
  • 2020-12-16
  • 1970-01-01
  • 2021-12-15
  • 2019-09-26
  • 2021-12-18
相关资源
最近更新 更多