【问题标题】:scikit-learn model persistence: pickle vs pmml vs ...?scikit-learn 模型持久性:pickle vs pmml vs ...?
【发布时间】:2017-12-17 03:15:31
【问题描述】:

我构建了一个 scikit-learn 模型,我想在日常的 python cron 作业中重复使用(注意:不涉及其他平台 - 没有 R,没有 Java 等)。

pickled 它(实际上,我腌制了我自己的对象,其一个字段是 GradientBoostingClassifier),然后我在 cron 作业中取消腌制它。到目前为止一切顺利(已经在Save classifier to disk in scikit-learnModel persistence in Scikit-Learn? 中讨论过)。

但是,我升级了 sklearn,现在收到以下警告:

.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator DecisionTreeRegressor from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)
.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator PriorProbabilityEstimator from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)
.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator GradientBoostingClassifier from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)

我现在该怎么办?

  • 我可以降级到 0.18.1 并坚持使用它,直到我准备好重建模型。由于各种原因,我认为这是不可接受的。

  • 我可以取消腌制文件并再次重新腌制。这适用于 0.18.2,但 与 0.19 中断。 NFG。 joblib 看起来也好不到哪里去。

  • 我希望我可以将数据保存为与版本无关的 ASCII 格式(例如 JSON 或 XML)。显然,这是最佳解决方案,但似乎没有方法可以做到这一点(另请参阅Sklearn - model persistence without pkl file)。

  • 我可以将模型保存到PMML,但它的支持充其量是不冷不热的: 我可以使用sklearn2pmml保存模型(虽然不容易),并使用augustus/lightpmmlpredictor应用(虽然不是加载)模型。但是,pip 无法直接使用这些,这使得部署成为一场噩梦。此外,augustuslightpmmlpredictor 项目似乎已经死了。 Importing PMML models into Python (Scikit-learn) - 不。

  • 上述变体:使用sklearn2pmml 保存PMML,并使用openscoring 进行评分。需要与外部进程交互。呵呵。

建议?

【问题讨论】:

    标签: python python-2.7 scikit-learn pmml


    【解决方案1】:

    跨不同版本 scikit-learn 的模型持久性通常是不可能的。原因很明显:你用一个定义腌制Class1,并想用另一个定义将它解压缩成Class2

    你可以:

    • 仍然尝试坚持使用一个版本的 sklearn。
    • 忽略警告并希望对Class1 有效的方法也适用于Class2
    • 编写你自己的类,可以序列化你的GradientBoostingClassifier,并从这个序列化的形式中恢复它,希望它比pickle更好。

    我做了一个示例,说明如何将单个 DecisionTreeRegressor 转换为纯列表和字典格式,完全兼容 JSON,然后将其恢复。

    import numpy as np
    from sklearn.tree import DecisionTreeRegressor
    from sklearn.datasets import make_classification
    
    ### Code to serialize and deserialize trees
    
    LEAF_ATTRIBUTES = ['children_left', 'children_right', 'threshold', 'value', 'feature', 'impurity', 'weighted_n_node_samples']
    TREE_ATTRIBUTES = ['n_classes_', 'n_features_', 'n_outputs_']
    
    def serialize_tree(tree):
        """ Convert a sklearn.tree.DecisionTreeRegressor into a json-compatible format """
        encoded = {
            'nodes': {},
            'tree': {},
            'n_leaves': len(tree.tree_.threshold),
            'params': tree.get_params()
        }
        for attr in LEAF_ATTRIBUTES:
            encoded['nodes'][attr] = getattr(tree.tree_, attr).tolist()
        for attr in TREE_ATTRIBUTES:
            encoded['tree'][attr] = getattr(tree, attr)
        return encoded
    
    def deserialize_tree(encoded):
        """ Restore a sklearn.tree.DecisionTreeRegressor from a json-compatible format """
        x = np.arange(encoded['n_leaves'])
        tree = DecisionTreeRegressor().fit(x.reshape((-1,1)), x)
        tree.set_params(**encoded['params'])
        for attr in LEAF_ATTRIBUTES:
            for i in range(encoded['n_leaves']):
                getattr(tree.tree_, attr)[i] = encoded['nodes'][attr][i]
        for attr in TREE_ATTRIBUTES:
            setattr(tree, attr, encoded['tree'][attr])
        return tree
    
    ## test the code
    
    X, y = make_classification(n_classes=3, n_informative=10)
    tree = DecisionTreeRegressor().fit(X, y)
    encoded = serialize_tree(tree)
    decoded = deserialize_tree(encoded)
    assert (decoded.predict(X)==tree.predict(X)).all()
    

    有了这个,你可以继续序列化和反序列化整个GradientBoostingClassifier

    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.ensemble.gradient_boosting import PriorProbabilityEstimator
    
    def serialize_gbc(clf):
        encoded = {
            'classes_': clf.classes_.tolist(),
            'max_features_': clf.max_features_, 
            'n_classes_': clf.n_classes_,
            'n_features_': clf.n_features_,
            'train_score_': clf.train_score_.tolist(),
            'params': clf.get_params(),
            'estimators_shape': list(clf.estimators_.shape),
            'estimators': [],
            'priors':clf.init_.priors.tolist()
        }
        for tree in clf.estimators_.reshape((-1,)):
            encoded['estimators'].append(serialize_tree(tree))
        return encoded
    
    def deserialize_gbc(encoded):
        x = np.array(encoded['classes_'])
        clf = GradientBoostingClassifier(**encoded['params']).fit(x.reshape(-1, 1), x)
        trees = [deserialize_tree(tree) for tree in encoded['estimators']]
        clf.estimators_ = np.array(trees).reshape(encoded['estimators_shape'])
        clf.init_ = PriorProbabilityEstimator()
        clf.init_.priors = np.array(encoded['priors'])
        clf.classes_ = np.array(encoded['classes_'])
        clf.train_score_ = np.array(encoded['train_score_'])
        clf.max_features_ = encoded['max_features_']
        clf.n_classes_ = encoded['n_classes_']
        clf.n_features_ = encoded['n_features_']
        return clf
    
    # test on the same problem
    clf = GradientBoostingClassifier()
    clf.fit(X, y);
    encoded = serialize_gbc(clf)
    decoded = deserialize_gbc(encoded)
    assert (decoded.predict(X) == clf.predict(X)).all()
    

    这适用于 scikit-learn v0.19,但不要问我下一个版本会出现什么来破坏此代码。我既不是预言家也不是 sklearn 的开发者。

    如果你想完全独立于新版本的 sklearn,最安全的做法是编写一个遍历序列化树并进行预测的函数,而不是重新创建 sklearn 树。

    【讨论】:

    • 这怎么比泡菜靠谱? pickle 的问题在于,如果 sklearn 更改了类定义(例如,删除或重命名插槽),我将不得不重写 serialize_*deserialize_* 函数,更重要的是,编写将序列化转换为old 版本到 new 版本。我同意这可能比泡菜噩梦要好,但几乎没有。
    • 这不能保证您将与 sklearn 的 20 或 200 版本兼容。但它至少可以让你更好地控制局势。例如。如果 sklearn 完全重写了它的ClassificationLossFunction,你不会受到影响。
    猜你喜欢
    • 2017-03-16
    • 2017-02-24
    • 2014-04-20
    • 2018-12-07
    • 1970-01-01
    • 2020-11-11
    • 1970-01-01
    • 2020-10-04
    • 2011-03-30
    相关资源
    最近更新 更多