【发布时间】:2021-08-17 08:00:26
【问题描述】:
是否可以在不依赖torchvision的情况下加载pytorch模型(来自.pth文件,包含架构+state_dict)?
import os
import torch
assert os.path.exists(r'.\vgg.pth')
model = torch.load(r'.\vgg.pth')
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
<ipython-input-4-e26863d95688> in <module>
2 import torch
3 assert os.path.exists(r'.\vgg.pth')
----> 4 model = torch.load(r'.\vgg.pth')
~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
590 opened_file.seek(orig_position)
591 return torch.jit.load(opened_file)
--> 592 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
593 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
594
~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
849 unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
850 unpickler.persistent_load = persistent_load
--> 851 result = unpickler.load()
852
853 torch._utils._validate_loaded_sparse_tensors()
ModuleNotFoundError: No module named 'torchvision'
我已经研究了 torch/serialization.py,但我认为没有理由需要 torchvision。该文件中的导入如下:
import difflib
import os
import io
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._sources import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
import copyreg
import pickle
import pathlib
【问题讨论】:
-
torchvision是您尝试加载的vgg模型的依赖项,它与serialization无关。如果没有torchvision,则无法加载vgg模型。 -
@Kishore 这完全是错误的,
torchvision.vgg模型不需要加载名为vgg.pth的文件。使用torch.load加载的文件包含模型dict 的not 模块定义...为什么您甚至会认为torchvision的实现包含在vgg.pth中?这没有任何意义。 -
您是提供整个代码还是运行其他代码?
-
@Ivan 您可以使用
torch.save(the_model, PATH)保存整个模型,然后使用the_model = torch.load(PATH)加载它。这类似于将整个模型保存在tensorflow或只是pickling模型中。请参阅here。并回答你的问题“你为什么会认为它是包含在 vgg.pth 中的 torchvision 的实现”你认为还有什么导致错误的,这很简单。 -
@Ivan 从错误堆栈中我们可以看到
ModuleNotFoundError是从result = unpickler.load()行生成的。当pytorch尝试从.pth文件加载保存的模型时,保存的model必须使用torchvision作为依赖项,并且由于我们没有安装torchvision,因此引发了“ModuleNotFoundError” .而且我也接受将模型参数保存在字典中是序列化模型的更好方法。