【问题标题】:Determine whether a model is pytorch model or a tensorflow model or scikit model判断一个模型是pytorch模型还是tensorflow模型还是scikit模型
【发布时间】:2021-02-27 03:36:00
【问题描述】:

如果我想确定模型的类型,即它是从哪个框架以编程方式制作的,有没有办法做到这一点?
我有一个以某种序列化方式的模型(例如泡菜文件)。为简单起见,假设我的模型可以是 tensorflow、pytorch 或 scikit learn 的。如何以编程方式确定这 3 个中的哪一个?

【问题讨论】:

    标签: python tensorflow machine-learning scikit-learn pytorch


    【解决方案1】:

    AFAIK,我从未听说过要使用 pickle 或 joblib 保存的 Tensorflow/Keras 和 Pytorch 模型 - 这些框架提供了自己的保存和加载模型的功能:请参阅 SO 线程 Tensorflow: how to save/restore a model?Best way to save a trained model in PyTorch?。此外,Github thread 在尝试使用 pickle 和 joblib 保存 TensorFlow 模型时报告各种问题。

    鉴于此,如果您加载了一个模型,比如说,pickle,那么查看它使用的是什么类型是微不足道的 type(model)model。以下是 scikit-learn 线性回归模型的简短演示:

    import numpy as np
    from sklearn.linear_model import LinearRegression
    
    X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
    y = np.dot(X, np.array([1, 2])) + 3
    reg = LinearRegression()
    reg.fit(X, y)
    
    # save it
    
    import pickle
    
    filename = 'model1.pkl'
    pickle.dump(reg, open(filename, 'wb'))
    

    现在,加载模型:

    loaded_model = pickle.load(open(filename, 'rb'))
    
    type(loaded_model)
    # sklearn.linear_model._base.LinearRegression
    
    loaded_model
    # LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)
    

    这也适用于 XGBoost、LightGBM、CatBoost 等框架。

    【讨论】:

      猜你喜欢
      • 2017-10-11
      • 2019-09-17
      • 2013-04-25
      • 1970-01-01
      • 1970-01-01
      • 2021-03-20
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多