【问题标题】:How to disable model/weight serialization fully with AllenNLP settings?如何使用 AllenNLP 设置完全禁用模型/权重序列化?
【发布时间】:2021-01-26 14:55:27
【问题描述】:

我希望通过使用 jsonnet 配置文件在标准 AllenNLP 模型训练中禁用序列化所有模型/状态权重。

原因是我正在使用 Optuna 运行自动超参数优化。 测试数十个模型很快就会填满驱动器。 我已经通过将num_serialized_models_to_keep 设置为0 来禁用检查点:

trainer +: {
    checkpointer +: {
        num_serialized_models_to_keep: 0,
    },

我不希望将serialization_dir 设置为None,因为我仍然想要关于记录中间指标等的默认行为。我只想禁用默认模型状态、训练状态和最佳模型权重写作

除了我上面设置的选项之外,是否有任何默认的训练器或检查点选项来禁用模型权重的所有序列化?我检查了 API 文档和网页,但找不到任何内容。

如果我需要自己定义此类选项的功能,我应该在我的模型子类中覆盖 AllenNLP 中的哪些基本函数?

或者,它们在训练结束时对清理中间模型状态有什么用处吗?

编辑:@petew's answer 显示了自定义检查点的解决方案,但我不清楚如何让 allennlp train 找到我的用例中的此代码。

我希望通过如下配置文件调用 custom_checkpointer:

trainer +: {
    checkpointer +: {
        type: empty,
    },

调用allennlp train --include-package <$my_package> 时加载检查点的最佳做法是什么?

我的 my_package 包含子目录中的子模块,例如 my_package/modelss 和 my_package/training。 我想将自定义检查点代码放在my_package/training/custom_checkpointer.py 我的主模型位于my_package/models/main_model.py。 我是否必须在我的 main_model 类中编辑或导入任何代码/函数才能使用自定义检查点?

【问题讨论】:

    标签: allennlp


    【解决方案1】:

    您可以创建并注册一个基本上什么都不做的自定义Checkpointer

    @Checkpointer.register("empty")
    class EmptyCheckpointer(Registrable):
        def maybe_save_checkpoint(
            self, trainer: "allennlp.training.trainer.Trainer", epoch: int, batches_this_epoch: int
        ) -> None:
            pass
    
        def save_checkpoint(
            self,
            epoch: Union[int, str],
            trainer: "allennlp.training.trainer.Trainer",
            is_best_so_far: bool = False,
            save_model_only=False,
        ) -> None:
            pass
    
        def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]:
            pass
    
        def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
            return {}, {}
    
        def best_model_state(self) -> Dict[str, Any]:
            return {}
    

    【讨论】:

    • 谢谢,这可能会起作用,但我不知道 AllenNLP 找到这个自定义检查点的最佳做法是什么。有关我为什么需要澄清的更多详细信息,请参阅我在 OP 中的编辑。
    • 如果您的自定义检查点位于my_package/training/custom_checkpointer.py,那么您可以使用--include-package my_package.training.custom_checkpointer 调用allennlp train,或者您可以在您的存储库中创建一个名为.allennlp_plugins 的文件并将“my_package.training. custom_checkpointer" 在该文件中它自己的行上。
    • 这是一个示例 .allennlp_plugins 文件:github.com/allenai/allennlp-template-config-files/blob/master/…
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-04-20
    • 2014-10-20
    • 2016-09-28
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多