sklearn 估计器实现的方法使您可以轻松地保存估计器的相关训练属性。一些估算器自己实现 __getstate__ 方法,但其他估算器,例如 GMM 只使用 base implementation ,它只是保存对象内部字典:
def __getstate__(self):
try:
state = super(BaseEstimator, self).__getstate__()
except AttributeError:
state = self.__dict__.copy()
if type(self).__module__.startswith('sklearn.'):
return dict(state.items(), _sklearn_version=__version__)
else:
return state
将模型保存到光盘的推荐方法是使用pickle 模块:
from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
pickle.dump(model,f)
但是,您应该保存额外的数据,以便将来重新训练您的模型,否则会遭受可怕的后果(例如被锁定到旧版本的 sklearn)。
来自documentation:
为了用未来版本重建一个类似的模型
scikit-learn,额外的元数据应该沿着腌制保存
型号:
训练数据,例如对不可变快照的引用
用于生成模型的python源码
scikit-learn 的版本及其依赖项
在训练数据上得到的交叉验证分数
对于依赖于用 Cython 编写的 tree.pyx 模块(例如 IsolationForest)的 Ensemble 估计器尤其如此,因为它创建了与实现的耦合,这不能保证在 sklearn 版本之间保持稳定。过去它已经看到了向后不兼容的变化。
如果您的模型变得非常大并且加载变得麻烦,您还可以使用更高效的joblib。来自文档:
在scikit的具体情况下,使用起来可能会更有趣
joblib 替换 pickle (joblib.dump & joblib.load),即
在内部携带大型 numpy 数组的对象上更有效
拟合的 scikit-learn 估计器通常是这种情况,但只能
pickle 到磁盘而不是字符串: