【问题标题】:PyTorch: load weights from another model without savingPyTorch:从另一个模型加载权重而不保存
【发布时间】:2021-03-11 02:01:38
【问题描述】:

假设我在 PyTorch 中有两个模型,如何在不保存权重的情况下按模型 2 的权重加载模型 1 的权重?

像这样:

model1.weights = model2.weights

在 TensorFlow 中我可以这样做:

variables1 = model1.trainable_variables
variables2 = model2.trainable_variables
for v1, v2 in zip(variables1, variables2):
    v1.assign(v2.numpy())

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    假设您有两个相同模型的实例(必须是nn.Module 的子类),那么您可以使用nn.Module.state_dict()nn.Module.load_state_dict()。可以找到状态词典的简要介绍here

    model1.load_state_dict(model2.state_dict())
    

    【讨论】:

      【解决方案2】:

      这里有两种方法可以做到这一点。

      # Use load state dict
      model_source = Model()
      model_dest = Model()
      model_dest.load_state_dict(model_source.state_dict())
      
      # Use deep copy
      model_source = Model()
      model_dest = copy.deepcopy(model_source )
      

      【讨论】:

        猜你喜欢
        • 2021-12-15
        • 2019-09-26
        • 2021-06-13
        • 2022-01-16
        • 1970-01-01
        • 2020-12-26
        • 1970-01-01
        • 1970-01-01
        • 2020-05-03
        相关资源
        最近更新 更多