【问题标题】:Save python random forest model to file将python随机森林模型保存到文件
【发布时间】:2014-01-06 20:38:45
【问题描述】:

在R中,运行“随机森林”模型后,我可以使用save.image("***.RData")来存储模型。之后,我可以直接加载模型进行预测。

你能在 python 中做类似的事情吗?我将模型和预测分成两个文件。在模型文件中:

rf= RandomForestRegressor(n_estimators=250, max_features=9,compute_importances=True)
fit= rf.fit(Predx, Predy)

我尝试返回rffit,但仍然无法在预测文件中加载模型。

你能使用 sklearn 随机森林包将模型和预测分开吗?

【问题讨论】:

  • 请注意,R 的 save.image 会保存您工作区中的所有内容,包括数据集、工作变量等。如果您只想要拟合模型,请使用 save
  • 哇!感谢这个有用的答案! Bc 每次我 save.image 时,该文件都应该非常大。谢谢!

标签: python machine-learning scikit-learn random-forest


【解决方案1】:
...
import cPickle

rf = RandomForestRegresor()
rf.fit(X, y)

with open('path/to/file', 'wb') as f:
    cPickle.dump(rf, f)


# in your prediction file                                                                                                                                                                                                           

with open('path/to/file', 'rb') as f:
    rf = cPickle.load(f)


preds = rf.predict(new_X)

【讨论】:

  • 进一步问题:'path/to/file',我应该使用什么格式来保存文件?谢谢
  • @user3013706 你的意思是什么文件扩展名?没关系。我认为约定是使用.cpickle
  • scikit learn docs recommand joblib.dump,也包含在sklearn.externals.joblib
  • 这个答案是否仍然与 python3 相关?我看到 cPickle 现在是 _pickle。
  • @lamecicle 据我所知,它只是pickle,默认实现应该在C
【解决方案2】:

我使用 dill,它存储所有数据,我认为可能是模块信息?也许不吧。我记得尝试使用pickle 来存储这些非常复杂的对象,但它对我不起作用。 cPickle 可能与 dill 做同样的工作,但我从未尝试过 cpickle。看起来它以完全相同的方式工作。我使用“obj”扩展名,但这绝不是传统的......这对我来说很有意义,因为我正在存储一个对象。

import dill
wd = "/whatever/you/want/your/working/directory/to/be/"
rf= RandomForestRegressor(n_estimators=250, max_features=9,compute_importances=True)
rf.fit(Predx, Predy)
dill.dump(rf, open(wd + "filename.obj","wb"))

顺便说一句,不确定您是否使用 iPython,但有时以这种方式编写文件并不能,因此您必须这样做:

with open(wd + "filename.obj","wb") as f:
    dill.dump(rf,f)

再次调用对象:

model = dill.load(open(wd + "filename.obj","rb"))

【讨论】:

    【解决方案3】:

    对于模型存储,您也可以使用 .sav 格式。它存储了完整的模型和信息。

    【讨论】:

      【解决方案4】:

      您可以使用joblib 来保存和加载来自 scikit-learn 的随机森林(实际上是来自 scikit-learn 的任何模型)

      例子:

      import joblib
      from sklearn.ensemble import RandomForestClassifier
      # create RF
      rf = RandomForestClassifier()
      # fit on some data
      rf.fit(X, y)
      
      # save
      joblib.dump(rf, "my_random_forest.joblib")
      
      # load
      loaded_rf = joblib.load("my_random_forest.joblib")
      
      

      更重要的是,joblib.dump has compress 参数,因此可以压缩模型。我在 iris 数据集上做了非常简单的testcompress=3 将文件大小减少了大约 5.6 倍。

      【讨论】:

      • joblib.save 是 joblib.dump
      【解决方案5】:

      我要重申,joblib 做得很好,it provides really good compression options(即 lzma)。

      with open("clf.pkl", "wb") as out: pickle.dump(clf, out)
      with open("clf.dill", "wb") as out: dill.dump(clf, out)
      joblib.dump(clf, "clf.jbl")
      joblib.dump(clf, "clf.jbl.lzma")
      joblib.dump(clf, "clf.jbl.gz")
      
      !du clf.*
      24576   clf.dill
      24576   clf.jbl
      5120    clf.jbl.gz
      3072    clf.jbl.lzma
      24576   clf.pkl
      

      【讨论】:

        猜你喜欢
        • 2020-01-16
        • 2018-10-09
        • 2019-07-10
        • 2017-08-11
        • 2016-04-09
        • 2019-08-02
        • 2020-02-25
        • 2018-09-26
        • 2017-07-27
        相关资源
        最近更新 更多