【问题标题】:Load pytorch model from S3 bucket从 S3 存储桶加载 pytorch 模型
【发布时间】:2021-08-14 18:57:21
【问题描述】:

我想从 S3 存储桶加载 pytorch 模型 (model.pt)。我写了以下代码:

from smart_open import open as smart_open
import io

load_path = "s3://serial-no-images/yolo-models/model4/model.pt"
with smart_open(load_path) as f:
    buffer = io.BytesIO(f.read())
    model.load_state_dict(torch.load(buffer))

这会导致以下错误:

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte

一种解决方案是在本地下载模型,但我想避免这种情况并直接从 S3 加载模型。不幸的是,我在网上找不到一个好的解决方案。有人可以帮我吗?

【问题讨论】:

    标签: amazon-s3 pytorch torch


    【解决方案1】:

    AFAIK torch.load 期望 filename 作为参数 - 而不是文件的内容。您的 buffer 是否可能已经等同于 torch.loading 文件的本地副本的结果?
    如果你尝试model.load_state_dict(buffer) 会发生什么?

    【讨论】:

    • 感谢您的回答。我对 S3 的了解非常有限,但据我所知 .pt 文件是一个 io.BytesIO 对象。 torch.load 默认情况下无法处理此问题。我刚刚找到解决方案,会尽快发布。
    • @spadel 你能分享解决方案吗?我也在尝试从 s3 加载 pytorch 模型。
    【解决方案2】:

    根据documentation,以下工作:

    from smart_open import open as smart_open
    import io
    
    load_path = "s3://serial-no-images/yolo-models/model4/model.pt"
    with smart_open(load_path, 'rb') as f:
        buffer = io.BytesIO(f.read())
        model.load_state_dict(torch.load(buffer))
    

    我之前尝试过这个,但没有看到我必须将 'rb' 设置为参数。

    【讨论】:

      猜你喜欢
      • 2021-11-14
      • 2021-02-07
      • 1970-01-01
      • 1970-01-01
      • 2014-09-25
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多