【问题标题】:How can I call scikit-learn classifiers from Java?如何从 Java 调用 scikit-learn 分类器?
【发布时间】:2012-09-26 04:07:06
【问题描述】:

我有一个使用 Python 的 scikit-learn 训练的分类器。如何使用 Java 程序中的分类器?我可以使用 Jython 吗?有没有办法在 Python 中保存分类器并在 Java 中加载它?有没有其他的使用方式?

【问题讨论】:

    标签: java python jython scikit-learn


    【解决方案1】:

    或者,您也可以从经过训练的模型生成 Python 代码。这是一个可以帮助您解决问题的工具https://github.com/BayesWitnesses/m2cgen

    【讨论】:

      【解决方案2】:

      我发现自己处于类似的境地。 我会推荐创建一个分类器微服务。您可以有一个在 python 中运行的分类器微服务,然后通过一些 RESTFul API 公开对该服务的调用,从而产生 JSON/XML 数据交换格式。我认为这是一种更清洁的方法。

      【讨论】:

        【解决方案3】:

        您可以使用搬运工,我已经测试了 sklearn-porter (https://github.com/nok/sklearn-porter),它适用于 Java。

        我的代码如下:

        import pandas as pd
        from sklearn import tree
        from sklearn_porter import Porter
        
        train_dataset = pd.read_csv('./result2.csv').as_matrix()
        
        X_train = train_dataset[:90, :8]
        Y_train = train_dataset[:90, 8:]
        
        X_test = train_dataset[90:, :8]
        Y_test = train_dataset[90:, 8:]
        
        print X_train.shape
        print Y_train.shape
        
        
        clf = tree.DecisionTreeClassifier()
        clf = clf.fit(X_train, Y_train)
        
        porter = Porter(clf, language='java')
        output = porter.export(embed_data=True)
        print(output)
        

        就我而言,我使用的是 DecisionTreeClassifier,以及

        的输出

        打印(输出)

        以下代码是控制台中的文本:

        class DecisionTreeClassifier {
        
          private static int findMax(int[] nums) {
            int index = 0;
            for (int i = 0; i < nums.length; i++) {
                index = nums[i] > nums[index] ? i : index;
            }
            return index;
          }
        
        
          public static int predict(double[] features) {
            int[] classes = new int[2];
        
            if (features[5] <= 51.5) {
                if (features[6] <= 21.0) {
        
                    // HUGE amount of ifs..........
        
                }
            }
        
            return findMax(classes);
          }
        
          public static void main(String[] args) {
            if (args.length == 8) {
        
                // Features:
                double[] features = new double[args.length];
                for (int i = 0, l = args.length; i < l; i++) {
                    features[i] = Double.parseDouble(args[i]);
                }
        
                // Prediction:
                int prediction = DecisionTreeClassifier.predict(features);
                System.out.println(prediction);
        
            }
          }
        }
        

        【讨论】:

        • 感谢您的信息。您能否分享您对如何使用 sklearn porter 执行腌制的 sklearn 模型的想法,并将其用于 Java 中的预测 - @gustavoresque
        【解决方案4】:

        以下是 JPMML 解决方案的一些代码:

        --Python部分--

        # helper function to determine the string columns which have to be one-hot-encoded in order to apply an estimator.
        def determine_categorical_columns(df):
            categorical_columns = []
            x = 0
            for col in df.dtypes:
                if col == 'object':
                    val = df[df.columns[x]].iloc[0]
                    if not isinstance(val,Decimal):
                        categorical_columns.append(df.columns[x])
                x += 1
            return categorical_columns
        
        categorical_columns = determine_categorical_columns(df)
        other_columns = list(set(df.columns).difference(categorical_columns))
        
        
        #construction of transformators for our example
        labelBinarizers = [(d, LabelBinarizer()) for d in categorical_columns]
        nones = [(d, None) for d in other_columns]
        transformators = labelBinarizers+nones
        
        mapper = DataFrameMapper(transformators,df_out=True)
        gbc = GradientBoostingClassifier()
        
        #construction of the pipeline
        lm = PMMLPipeline([
            ("mapper", mapper),
            ("estimator", gbc)
        ])
        

        --JAVA部分--

        //Initialisation.
        String pmmlFile = "ScikitLearnNew.pmml";
        PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile));
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
        
        //Determine which features are required as input
        HashMap<String, Field>() inputFieldMap = new HashMap<String, Field>();
        for (int i = 0; i < evaluator.getInputFields().size();i++) {
          InputField curInputField = evaluator.getInputFields().get(i);
          String fieldName = curInputField.getName().getValue();
          inputFieldMap.put(fieldName.toLowerCase(),curInputField.getField());
        }
        
        
        //prediction
        
        HashMap<String,String> argsMap = new HashMap<String,String>();
        //... fill argsMap with input
        
        Map<FieldName, ?> res;
        // here we keep only features that are required by the model
        Map<FieldName,String> args = new HashMap<FieldName, String>();
        Iterator<String> iter = argsMap.keySet().iterator();
        while (iter.hasNext()) {
          String key = iter.next();
          Field f = inputFieldMap.get(key);
          if (f != null) {
            FieldName name =f.getName();
            String value = argsMap.get(key);
            args.put(name, value);
          }
        }
        //the model is applied to input, a probability distribution is obtained
        res = evaluator.evaluate(args);
        SegmentResult segmentResult = (SegmentResult) res;
        Object targetValue = segmentResult.getTargetValue();
        ProbabilityDistribution probabilityDistribution = (ProbabilityDistribution) targetValue;
        

        【讨论】:

          【解决方案5】:

          为此目的有JPMML项目。

          首先,您可以直接从 python 使用sklearn2pmml 库将 scikit-learn 模型序列化为 PMML(内部是 XML),或者先将其转储到 python 中,然后在 java 中使用 jpmml-sklearn 或从这里提供的命令行进行转换图书馆。接下来,您可以在 Java 代码中使用 jpmml-evaluator 加载 pmml 文件、反序列化并执行加载的模型。

          这种方式并非适用于所有 scikit-learn 模型,但适用于其中的 many

          【讨论】:

          • 如何确保特征转换部分在 Python 中用于训练的部分与在 Java 中(使用 pmml)用于服务的部分之间保持一致?
          • 我试过这个,它绝对适用于将 sklearn 转换器和 xgboost 模型转换为 Java。但是,由于 AGPL 许可证,我们没有在生产环境中选择这个。 (也有商业许可,但协商许可不符合我们的项目时间表。)
          • 我试过这个,通过Java程序保留了所有的特征提取、清理、转换逻辑。它在 Java 端(jpmml-evaluator)运行良好。容器化 Spring boot 应用的一个不错的选择,大大降低了 devops 的复杂性,因为 python 训练的频率和时间线无法与 Java 程序的持续集成同步
          【解决方案6】:

          您不能使用 jython,因为 scikit-learn 严重依赖 numpy 和 scipy,它们有许多编译的 C 和 Fortran 扩展,因此无法在 jython 中工作。

          在 java 环境中使用 scikit-learn 的最简单方法是:

          • 将分类器公开为 HTTP / Json 服务,例如使用 flaskbottlecornice 等微框架,并使用 HTTP 客户端库从 java 调用它

            李>
          • 在 python 中编写一个命令行包装应用程序,该应用程序读取标准输入上的数据并使用某种格式(如 CSV 或 JSON(或某些较低级别的二进制表示))在标准输出上输出预测,并从 java 调用 python 程序,例如使用 @ 987654324@.

          • 让 python 程序输出拟合时学习到的原始数值参数(通常作为浮点值数组)并在 java 中重新实现预测函数(这对于预测通常很容易的线性预测模型只是一个阈值点积)。

          如果您还需要在 Java 中重新实现特征提取,则最后一种方法的工作量会更大。

          最后,您可以使用实现所需算法的 Java 库(例如 Weka 或 Mahout),而不是尝试从 Java 中使用 scikit-learn。

          【讨论】:

          • 我的一位同事刚刚建议 Jepp...这对这个有用吗?
          • 可能我不知道 jepp。它确实看起来适合这项任务。
          • 对于一个网络应用来说,我个人更喜欢http暴露方式。 @user939259 然后可以为各种应用程序使用分类器池并更轻松地对其进行扩展(根据需求调整池大小)。我只会考虑将 Jepp 用于桌面应用程序。和我一样是 python 爱好者,除非 scikit-lear 的性能明显优于 Weka 或 Mahout,否则我会选择单语言解决方案。拥有不止一种语言/框架应该被视为技术债。
          • 我同意多语言技术债务:在一个团队中工作很困难,因为所有开发人员都知道 java 和 python,并且必须从一种技术文化切换到另一种技术文化,这在管理中增加了无用的复杂性项目。
          • 也许这是技术债务 - 但延伸一下比喻,在机器学习中,你总是不断宣布破产,因为你正在尝试一些东西,发现它不起作用,然后调整它/抛出它离开。所以在这样的情况下,也许债务并不是什么大不了的事。
          猜你喜欢
          • 1970-01-01
          • 2015-09-25
          • 2015-05-09
          • 1970-01-01
          • 2018-07-01
          • 2016-05-22
          • 1970-01-01
          • 2013-07-09
          • 2015-03-27
          相关资源
          最近更新 更多