【问题标题】:Saving the weights of a Pytorch .pth model into a .txt or .json将 Pytorch .pth 模型的权重保存到 .txt 或 .json
【发布时间】:2022-08-05 22:11:54
【问题描述】:

我正在尝试将 pytorch 模型的权重保存到 .txt 或 .json 中。将其写入 .txt 时,

#import torch
model = torch.load(\"model_path\")
string = str(model)
with open(\'some_file.txt\', \'w\') as fp:
     fp.write(string)

我得到一个未保存所有权重的文件,即整个文本文件中都有省略号。我无法将其写入 JSON,因为模型具有不可序列化 JSON 的张量 [除非有一种我不知道的方法?] 如何将 .pth 文件中的权重保存为某种格式,以便没有信息丢失了,很容易被看到吗?

谢谢

  • 我假设你正在做的实际上是str(model.state_dict()),对吧?否则无论如何都不会打印权重

标签: python machine-learning pytorch


【解决方案1】:

当你在做str(model.state_dict()) 时,它递归地使用它包含的元素的str 方法。所以问题是如何构建单个元素字符串表示。您应该增加以单个字符串表示形式打印的行数限制:

torch.set_printoptions(profile="full")

查看与此的区别:

import torch
import torchvision.models as models
mobilenet_v2 = models.mobilenet_v2()
torch.set_printoptions(profile="default")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])
torch.set_printoptions(profile="full")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])

张量目前不是 JSON 可序列化的。

【讨论】:

  • 有点晚了,但希望这项工作
  • 我犯了一个错误,我要添加一个答案。
【解决方案2】:

有点晚了,但希望这会有所帮助。这是您存储它的方式:

import torch
from torch.utils.data import Dataset

from json import JSONEncoder
import json

class EncodeNumpyArray(JSONEncoder,Dataset):
    def default(self, obj):
        if isinstance(obj, torch.Tensor):
            return obj.cpu().detach().numpy().tolist()
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

with open('torch_weights.json', 'w') as json_file:
    json.dump(model.state_dict(), json_file,cls=EncodeNumpyArray)

考虑到存储的值是list 类型,所以当你要使用权重时你必须使用torch.Tensor(list)

【讨论】:

    猜你喜欢
    • 2018-10-05
    • 2021-06-13
    • 2022-01-16
    • 2021-10-20
    • 1970-01-01
    • 1970-01-01
    • 2019-11-24
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多