【发布时间】:2021-01-08 07:29:07
【问题描述】:
我只想接收文本输入并尝试仅返回预测结果中的标签值。
例如。
curl -d '{"inputs":{"test": ["我今天很伤心"]}}'
-X POST http://{location}:predict
我想得到返回值“sad”
所以我看到this 并尝试了。
保存模型的时候是用decorate tf.function保存的
self.tf_model_wrapper = TFModel(model)
tf.saved_model.save(self.tf_model_wrapper.model, f'classifier/saved_models/{int(time.time())}',
signatures={'serving_default': self.tf_model_wrapper.prediction})
该函数只是简单地接收文本并对其进行标记,然后尝试将预测结果值返回给标签名称。
@tf.function(input_signature=[tf.TensorSpec(shape=(1, ), dtype=tf.string)])
def prediction(self, text: str):
input_ids, input_attention, input_token_type = self.tokenizer(text)
input_encoding = (input_ids, input_attention, input_token_type)
result = self.convert_label(self.model(input_encoding))
return result
但我收到了这个错误
TypeError: tf__prediction() missing 2 required positional arguments: 'input2' and 'input3'
我以为是因为我的模型接收了 3 个输入,所以我这样修改它,它似乎可以工作。
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.int32,
name="input_ids"),tf.TensorSpec(shape=None, dtype=tf.int32,
name="attention_mask"),tf.TensorSpec(shape=None, dtype=tf.int32, name="token_type_ids")])
def prediction(self, input1, input2, input3):
input = (input1, input2, input3)
return self.model(input)
但是,这与最初的目的不同,似乎不可能只接收文本并返回预测结果。
有什么办法可以吗?
【问题讨论】:
标签: tensorflow tensorflow-serving