【问题标题】:Size mismatch error during VGG finetuningVGG微调期间的大小不匹配错误
【发布时间】:2018-07-25 18:04:28
【问题描述】:

我一直在关注 PyTorch 官方文档 (http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) 中的蚂蚁和蜜蜂迁移学习教程。我正在尝试通过更改最后一层来预测两个类之一来微调 VGG19 模型。我可以使用以下代码修改最后一个 fc 层。

但在执行 train_model 函数时出现错误。错误是“/opt/conda/conda-bld/pytorch_1513368888240/work/torch/lib/THC/generic/THCTensorMathBlas.cu:243 的大小不匹配”。知道问题是什么吗?

model_conv = torchvision.models.vgg19(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

model_conv = nn.Sequential(*list(model_conv.classifier.children())[:-1] +
                     [nn.Linear(in_features=4096, out_features=2)])
if use_gpu:
    model_conv = model_conv.cuda()

criterion = nn.CrossEntropyLoss()

optimizer_conv = optim.SGD(model_conv._modules['6'].parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25)

【问题讨论】:

标签: python deep-learning pytorch vision vgg-net


【解决方案1】:

在定义模型时,您只是在考虑 classifier,它仅包含在网络的完全连接部分上。然后,当将 224*224*3 图像输入模型时,它会尝试“通过”一个以 25K 特征作为输入的线性层。要解决它,您只需要在之前添加卷积部分,然后重新定义模型,如下所示:

class newModel(nn.Module):
    def __init__(self, old_model):
        super(newModel, self).__init__()

        self.features = old_model.features
        self.classifier = nn.Sequential(*list(old_model.classifier.children())[:-1] +
                                         [nn.Linear(in_features=4096, out_features=2)])

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

model_conv = newModel(model_conv)

现在您也只需告诉要优化的参数,如果您只想训练最后一层(新添加的层):

optimizer_conv = optim.SGD(model_conv.classifier._modules['6'].parameters(), lr=0.001, momentum=0.9)

其余代码保持不变。

希望对您有所帮助!

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2019-06-16
    • 1970-01-01
    • 1970-01-01
    • 2019-06-01
    • 2020-05-08
    • 2017-02-23
    • 1970-01-01
    相关资源
    最近更新 更多