【问题标题】:pytorch load model without torchvision没有torchvision的pytorch加载模型
【发布时间】:2021-08-17 08:00:26
【问题描述】:

是否可以在不依赖torchvision的情况下加载pytorch模型(来自.pth文件,包含架构+state_dict)?

import os
import torch
assert os.path.exists(r'.\vgg.pth')
model = torch.load(r'.\vgg.pth')

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-4-e26863d95688> in <module>
      2 import torch
      3 assert os.path.exists(r'.\vgg.pth')
----> 4 model = torch.load(r'.\vgg.pth')

~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    590                     opened_file.seek(orig_position)
    591                     return torch.jit.load(opened_file)
--> 592                 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    593         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    594 

~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
    849     unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
    850     unpickler.persistent_load = persistent_load
--> 851     result = unpickler.load()
    852 
    853     torch._utils._validate_loaded_sparse_tensors()

ModuleNotFoundError: No module named 'torchvision'

我已经研究了 torch/serialization.py,但我认为没有理由需要 torchvision。该文件中的导入如下:

import difflib
import os
import io
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._sources import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
import copyreg
import pickle
import pathlib

【问题讨论】:

  • torchvision 是您尝试加载的vgg 模型的依赖项,它与serialization 无关。如果没有torchvision,则无法加载vgg 模型。
  • @Kishore 这完全是错误的,torchvision.vgg 模型不需要加载名为vgg.pth 的文件。使用torch.load 加载的文件包含模型dictnot 模块定义...为什么您甚至会认为torchvision 的实现包含在vgg.pth 中?这没有任何意义。
  • 您是提供整个代码还是运行其他代码?
  • @Ivan 您可以使用torch.save(the_model, PATH) 保存整个模型,然后使用the_model = torch.load(PATH) 加载它。这类似于将整个模型保存在 tensorflow 或只是 pickling 模型中。请参阅here。并回答你的问题“你为什么会认为它是包含在 vgg.pth 中的 torchvision 的实现”你认为还有什么导致错误的,这很简单。
  • @Ivan 从错误堆栈中我们可以看到ModuleNotFoundError 是从result = unpickler.load() 行生成的。当pytorch 尝试从.pth 文件加载保存的模型时,保存的model 必须使用torchvision 作为依赖项,并且由于我们没有安装torchvision,因此引发了“ModuleNotFoundError” .而且我也接受将模型参数保存在字典中是序列化模型的更好方法。

标签: python pytorch pickle


【解决方案1】:

是什么导致了我的问题

我的问题中的vgg.pth 文件生成如下:

import torchvision
vgg = models.vgg16(pretrained=True, init_weights=False)
torch.save(vgg, r'.\vgg.pth')

这样,文件vgg.pth 不仅包含模型参数,还包含模型架构(参见pytorch: save/load entire model)。然而,正如@Kishore 在 cmets 中指出的那样,这种架构似乎也需要 torchvision 作为依赖项。

我是如何解决的

  • 在使用 torchvision 的环境中,我将预训练的 VGG 模型加载到内存中并保存了 state_dict
from torchvision.models.vgg import vgg16
import torch

model = vgg16(pretrained=True)
torch.save(model.state_dict, r'.\state_dict.pth')
  • 在没有torchvision 的环境中,我通过检查torchvision.models.vgg 代码重建了模型。
    然后我将此 state_dict 文件加载到我的模型的 state_dict 中。
    最后,我将此模型(包括架构)保存到 .pth 文件中。
import torch

# a file where I pasted the torchvision.models.vgg code
# and commented out the torchvision dependencies I don't need
# in this case: 'from .._internally_replaced_utils import load_state_dict_from_url'
from torch_save import *

model = vgg16()
model.load_state_dict(torch.load(r'.\state_dict.pth'))
torch.save(model, r'.\entire_model.pth')

当我在无 torchvision 的环境中再次加载它时,我没有收到任何错误。

【讨论】:

    猜你喜欢
    • 2019-07-23
    • 2021-07-30
    • 2021-11-30
    • 1970-01-01
    • 2020-05-24
    • 2021-04-05
    • 2021-12-15
    • 2019-09-26
    • 2019-01-19
    相关资源
    最近更新 更多