【发布时间】:2019-09-23 23:38:12
【问题描述】:
我有一个大小为 386MB 的 pytorch 模型,但是当我加载模型时
state = torch.load(f, flair.device)
我的 GPU 内存占用高达 900MB,为什么会出现这种情况?有没有办法解决这个问题?
这就是我保存模型的方式
model_state = self._get_state_dict()
# additional fields for model checkpointing
model_state["optimizer_state_dict"] = optimizer_state
model_state["scheduler_state_dict"] = scheduler_state
model_state["epoch"] = epoch
model_state["loss"] = loss
torch.save(model_state, str(model_file), pickle_protocol=4)
【问题讨论】:
标签: pytorch