【问题标题】:Mlflow log_model, not able to predict with spark_udf but with python worksMlflow log_model,无法使用 spark_udf 进行预测,但可以使用 python 进行预测
【发布时间】:2022-01-14 15:56:33
【问题描述】:

我想在 mlflow 上记录一个模型,一旦我这样做了,我就可以用 python 加载的模型预测概率,但不能用 spark_udf。问题是,我仍然需要在模型内部有一个预处理功能。这是一个玩具可复制示例,供您查看它何时失败:

import mlflow
from mlflow.models.signature import infer_signature
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import numpy as np

X, y = make_classification(n_samples=1000, n_features=10, n_informative=2, n_classes=2, shuffle=True, random_state=1995)
X, y = pd.DataFrame(X), pd.DataFrame(y,columns=["target"])
# geerate column names
X.columns = [f"col_{idx}" for idx in range(len(X.columns))]
X["categorical_column"] = np.random.choice(["a","b","c"], size=len(X) )


def encode_catcolumn(X):
  X = X.copy()
  # replace cat values [a,b,c] for [-10,0,35] respectively
  X['categorical_column'] = np.select([X["categorical_column"] == "a", X["categorical_column"] == "b", X["categorical_column"] == "c"],  [-10, 0,35] ) 
  return X

# with catcolumn encoded; i need to use custom encoding , we'll do this within mlflow later
X_encoded = encode_catcolumn(X)

现在让我们为模型创建一个包装器,以便对模型中的函数进行编码。请注意,类内的encode_catcolumn函数和前面介绍的类外的函数是一样的。

class SklearnModelWrapper(mlflow.pyfunc.PythonModel):
  def __init__(self, model):
    self.model = model
  def encode_catcolumn(self,X):
    X = X.copy()
    # replace cat values [a,b,c] for [-10,0,35] respectively
    X['categorical_column'] = np.select([X["categorical_column"] == "a", X["categorical_column"] == "b", X["categorical_column"] == "c"],  [-10, 0,35] ) 
    return X 
  def predict(self, context, model_input):
    # encode catvariable
    model_input = self.encode_catcolumn(model_input)
    # predict probabilities
    predictions = self.model.predict_proba(model_input)[:,1]
    return predictions

现在让我们记录模型

with mlflow.start_run(run_name="reproductible_example") as run:
  clf = RandomForestClassifier()
  clf.fit(X_encoded,y)
  # wrappmodel with pyfunc, does the encoding inside the class 
  wrappedModel = SklearnModelWrapper(clf)
  # When the model is deployed, this signature will be used to validate inputs.
  mlflow.pyfunc.log_model("reproductible_example_model", python_model=wrappedModel)

model_uuid = run.info.run_uuid
model_path = f'runs:/{model_uuid}/reproductible_example_model'

在没有火花的情况下进行推理并完美运行:

model_uuid = run.info.run_uuid
model_path = f'runs:/{model_uuid}/reproductible_example_model'
# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(model_path)
# predictions without spark , encodes the variables INSIDE; this WORKS
loaded_model.predict(X)

现在使用 spark_udf 进行推理并得到错误:

# create spark dataframe to test it on spark
X_spark = spark.createDataFrame(X)
# Load model as a Spark UDF.
loaded_model_spark = mlflow.pyfunc.spark_udf(spark, model_uri=model_path)

# Predict on a Spark DataFrame.
columns = list(X_spark.columns)
# this does not work
X_spark.withColumn('predictions', loaded_model_spark(*columns)).collect()

错误是:

PythonException: An exception was thrown from a UDF: 'KeyError: 'categorical_column'', from <command-908038>, line 7. Full traceback below:

我需要了解如何在类中对变量进行编码和预处理。是否有任何解决方案或任何解决方法使此代码能够与火花一起工作? 到目前为止我尝试过的:

  1. 将 encode_catcolumn 合并到 sklearn 管道中(使用自定义编码器 sklearn)-> 失败;
  2. 在 sklearn 包装类(本示例)中创建一个函数 -> 失败 3 ) 使用 log_model 然后创建一个 pandas_udf 以便用 spark 来做 --> 可以,但这不是我想要的。我希望能够通过调用 .predict() 方法或类似的方法在 spark 上运行模型。
  3. 当移除预处理函数并在类外执行时 --> 这确实有效,但不是这样

【问题讨论】:

    标签: apache-spark pyspark scikit-learn mlflow mlops


    【解决方案1】:

    当我加载 spark_udf 模型并执行推理时,我只需更改问题的最后一部分即可解决此问题。这是问题的可能答案。只需将 F.struct() 传递给 spark_udf 而不是列列表。就像在下面的块中一样:

    import pyspark.sql.functions as F
    # create spark dataframe to test it on spark
    X_spark = spark.createDataFrame(X)
    # Load model as a Spark UDF.
    loaded_model_spark = mlflow.pyfunc.spark_udf(spark, model_uri=model_path)
    
    # Predict on a Spark DataFrame.
    # columns = list(X_spark.columns) --> delete this
    columns = F.struct(X_spark.columns) # use struct
    # this does not work
    X_spark.withColumn('predictions', loaded_model_spark(columns)).collect()
    
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-05-09
      • 2018-12-31
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2013-01-23
      • 2018-03-04
      • 2018-09-24
      相关资源
      最近更新 更多