【发布时间】:2021-09-17 20:37:44
【问题描述】:
- 保存的模型
net= Net()
model= torch.nn.DataParallel(net)
############################
# Training
############################
torch.save(model,'./model_shear_pre.pkl')
- 模型加载
net = Net()
model = torch.nn.DataParallel(net, device_ids=[0,1])
model = torch.load('./model_shear_finish.pkl', map_location={'cuda:0':'cuda:0', 'cuda:1':'cuda:0', 'cuda:2':'cuda:1', 'cuda:3':'cuda:1'})
问题是我在训练时使用了 4 个 GPU 的机器,保存模型后,我想在只有 2 个 GPU 的新机器上进行测试。
加载保存的模型后,我预计模型的device_ids 将是[0,1],但它仍然是[0,1,2,3],这是旧设置。 保存或加载有什么问题吗?
【问题讨论】:
标签: pytorch