【发布时间】:2022-01-02 02:29:44
【问题描述】:
我已经部署了一个 sagemaker 端点,现在想在端点上运行预测。端点代表 sagemaker 管道和模型。我遵循了教程here。我设置预测器并进行预测的代码如下:
from sagemaker.predictor import Predictor
predictor = Predictor(endpoint_name=endpoint_name)
data_df = data_df.drop("LABEL_NAME", axis=1)
pred_count = 1
payload = data_df.iloc[:pred_count].to_string(header=False, index=False).replace(" ", ",")
p = predictor.predict(payload, initial_args={"ContentType": "text/csv"})
这段代码几乎就是他们在我链接的示例中显示的内容,对我来说很有意义。我的管道的 preprocess.py 代码包括我所包含的以下函数(尽管不确定它们是否相关):
def input_fn(input_data, content_type):
print("BAHHHHHH")
if content_type == "text/csv":
# Read the raw input data as CSV.
df = pd.read_csv(StringIO(input_data), header=None)
return df
else:
raise ValueError("{} not supported by script!".format(content_type))
def output_fn(prediction, accept):
print("BAHHHHHH")
if accept == "application/json":
instances = []
for row in prediction.tolist():
instances.append(row)
json_output = {"instances": instances}
return worker.Response(json.dumps(json_output), mimetype=accept)
elif accept == "text/csv":
return worker.Response(encoders.encode(prediction, accept), mimetype=accept)
else:
raise RuntimeException("{} accept type is not supported by this script.".format(accept))
def predict_fn(input_data, model):
print("BAHHHHHH")
features = model.transform(input_data)
return features
def model_fn(model_dir):
print("BAHHHHHH")
"""Deserialize fitted model"""
preprocessor = joblib.load(os.path.join(model_dir, "model.joblib"))
return preprocessor
运行 predictor.predict() 方法时出现以下错误:
botocore.errorfactory.ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (400) from primary with message "{
"error": "JSON Parse error: Missing a comma or ']' after an array element. at offset: 16"
我在将 payload 变量传递给 predict 方法之前打印了它,它看起来像这样(我截断了它,因为它很长,但这应该足以看看是什么样子的:
0 999.105105 888.607813 6.0 1 los angeles 2431.666667 1.0 NaN 1177.813623 1.076833e+06 los angeles$1$6 0 60376511012 0.0 0.0 0.0 0.0 0.0 0.0 ............
错误消息还提供了一个 URL 以查看更多信息。它是端点的云监视日志。查看这些日志,我没有看到任何额外信息,只是一个 400 错误,除了 400 错误之外没有其他信息。
所以我传入的数据格式显然存在一些问题。 input_fn、output_fn、predict_fn 和 model_fn 方法在方法开始时都有一个打印语句,但这些都没有出现在日志,所以我认为这些都没有达到。
我做错了什么?
【问题讨论】:
-
您可以尝试将
serializer设置为像from sagemaker.serializers import CSVSerializer predictor.serializer = CSVSerializer()一样的CSVSerializer
标签: python amazon-web-services prediction amazon-sagemaker