【问题标题】:How to properly serve an object detection model from Tensorflow Object Detection API?如何从 Tensorflow 对象检测 API 正确地提供对象检测模型?
【发布时间】:2017-07-27 23:46:05
【问题描述】:

我正在使用 Tensorflow 对象检测 API (github.com/tensorflow/models/tree/master/object_detection) 执行一项对象检测任务。现在,我在为使用 Tensorflow Serving(tensorflow.github.io/serving/) 训练的检测模型提供服务时遇到问题。

1. 我遇到的第一个问题是将模型导出到可服务文件。 对象检测 api 包含导出脚本,以便我能够将 ckpt 文件转换为带有变量的 pb 文件。但是,输出文件在“变量”文件夹中不会有任何内容。我虽然这是一个错误并在 Github 上报告了它,但似乎他们实习将变量转换为常量,这样就不会有变量了。详情可见HERE

我在导出保存的模型时使用的标志如下:

    CUDA_VISIBLE_DEVICES=0 python export_inference_graph.py \
        --input_type image_tensor \
            --pipeline_config_path configs/rfcn_resnet50_car_Jul_20.config \
                --checkpoint_path resnet_ckpt/model.ckpt-17586 \
                    --inference_graph_path serving_model/1 \
                      --export_as_saved_model True

当我将 --export_as_saved_model 切换为 False 时,它​​在 python 中运行得非常好。

但是,我仍然无法为模型提供服务。

当我试图跑步时:

~/serving$ bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=gan --model_base_path=<my_model_path>

我明白了:

2017-07-27 16:11:53.222439: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:155] Restoring SavedModel bundle.
2017-07-27 16:11:53.222497: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:165] The specified SavedModel has no variables; no checkpoints were restored.
2017-07-27 16:11:53.222502: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running LegacyInitOp on SavedModel bundle.
2017-07-27 16:11:53.229463: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:284] Loading SavedModel: success. Took 281805 microseconds.
2017-07-27 16:11:53.229508: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: gan version: 1}
2017-07-27 16:11:53.244716: I tensorflow_serving/model_servers/main.cc:290] Running ModelServer at 0.0.0.0:9000 ...

我认为模型没有正确加载,因为它显示“指定的 SavedModel 没有变量;没有恢复检查点。”

但是既然我们已经把所有的变量都转换成了常量,这似乎是合理的。我不确定这里。

2。我无法使用客户端调用服务器并对示例图像进行检测。

客户端脚本如下:

from __future__ import print_function
from __future__ import absolute_import

# Communication to TensorFlow server via gRPC
from grpc.beta import implementations
import tensorflow as tf
import numpy as np
from PIL import Image
# TensorFlow serving stuff to send messages
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2


# Command line arguments
tf.app.flags.DEFINE_string('server', 'localhost:9000',
                       'PredictionService host:port')
tf.app.flags.DEFINE_string('image', '', 'path to image in JPEG format')
FLAGS = tf.app.flags.FLAGS


def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
    (im_height, im_width, 3)).astype(np.uint8)

def main(_):
    host, port = FLAGS.server.split(':')
    channel = implementations.insecure_channel(host, int(port))
    stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

    # Send request
    request = predict_pb2.PredictRequest()
    image = Image.open(FLAGS.image)
    image_np = load_image_into_numpy_array(image)
    image_np_expanded = np.expand_dims(image_np, axis=0)
    # Call GAN model to make prediction on the image
    request.model_spec.name = 'gan'
    request.model_spec.signature_name = 'predict_images'
    request.inputs['inputs'].CopyFrom(
    tf.contrib.util.make_tensor_proto(image_np_expanded))

    result = stub.Predict(request, 60.0)  # 60 secs timeout
    print(result)


if __name__ == '__main__':
    tf.app.run()

为了匹配request.model_spec.signature_name = 'predict_images',我从第 289 行开始修改了对象检测 api (github.com/tensorflow/models/blob/master/object_detection/exporter.py) 中的 exporter.py 脚本:

          signature_def_map={
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              detection_signature,
      },

收件人:

          signature_def_map={
          'predict_images': detection_signature,
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              detection_signature,
      },

因为我不知道如何调用默认签名密钥。

当我运行以下命令时:

bazel-bin/tensorflow_serving/example/client --server=localhost:9000 --image=<my_image_file>

我收到以下错误消息:

    Traceback (most recent call last):
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 54, in <module>
    tf.app.run()
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 49, in main
    result = stub.Predict(request, 60.0)  # 60 secs timeout
  File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 324, in __call__
    self._request_serializer, self._response_deserializer)
  File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 210, in _blocking_unary_unary
    raise _abortion_error(rpc_error_call)
grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.NOT_FOUND, details="FeedInputs: unable to find feed output ToFloat:0")

不太清楚这里发生了什么。

最初我虽然可能我的客户端脚本不正确,但在我发现 AbortionError 来自 github.com/tensorflow/tensorflow/blob/f488419cd6d9256b25ba25cbe736097dfeee79f9/tensorflow/core/graph/subgraph.cc 之后。似乎我在构建图表时遇到了这个错误。所以这可能是我遇到的第一个问题造成的。

我对这个东西很陌生,所以我真的很困惑。我想我可能一开始就错了。有什么方法可以正确导出和服务检测模型?任何建议都会有很大帮助!

【问题讨论】:

  • 我收到 code=StatusCode.FAILED_PRECONDITION, details="Serving signature key "predict_images" not found." 。不过,我根据您的代码更新了 exporter.py 文件。有什么想法吗?
  • @PamioSolanky 你可以看到exporter.py 276-277行的原始代码,我做了一些修改。而是使用 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 作为服务签名密钥,我将其更改为“predict_images”。所以我可以使用上面发布的客户端代码来调用它。如果您没有更改它,您可能会使用默认服务签名密钥,或者通过将其添加到您的标志来分配您自己的密钥。

标签: tensorflow object-detection tensorflow-serving


【解决方案1】:

当前的导出器代码未正确填充签名字段。所以使用模型服务器服务是行不通的。对此表示歉意。一个更好地支持导出模型的新版本即将推出。它包括服务所需的一些重要修复和改进,尤其是在 Cloud ML Engine 上服务。如果您想尝试它的早期版本,请参阅github issue

对于“指定的 SavedModel 没有变量;没有恢复检查点。”消息,由于您所说的确切原因,这是预期的,因为所有变量都在图中转换为常量。对于“FeedInputs:无法找到Feed输出ToFloat:0”的错误,请确保在构建模型服务器时使用TF 1.2。

【讨论】:

  • 感谢您的更新。对于“FeedInputs:无法找到Feed输出ToFloat:0”的错误,我使用的是tensorflow_serving的默认tensorflow版本,我今天再次使用tensorflow承诺16d39e94e3724417fcaed87035434e098e892842和最新的tf_models重新构建了tensorflow_serving,但错误仍然存​​在。稍后我将尝试更新用于导出模型的 tf_models(我在不同的文件夹中有 2 个 tf_models,一个用于服务,一个用于 object_detection),看看是否能解决问题。
  • 我尝试了更新导出脚本并且能够让它运行,但似乎只使用CPU而不是GPU进行推理,你知道可能导致这个问题的原因吗?
【解决方案2】:
  1. 你的想法很好。可以有这个警告。

  2. 问题是输入需要按照模型的预期转换为uint8。这是对我有用的代码 sn-p。

request = predict_pb2.PredictRequest()
request.model_spec.name = 'gan'
request.model_spec.signature_name = 
    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY

image = Image.open('any.jpg')
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)

request.inputs['inputs'].CopyFrom(
    tf.contrib.util.make_tensor_proto(image_np_expanded, 
        shape=image_np_expanded.shape, dtype='uint8'))

这部分对您很重要shape=image_np_expanded.shape, dtype='uint8',并确保拉取最新更新以进行服务。

【讨论】:

  • 对我来说,这种方法部分有效,但与其他方法(缺失检测等)相比,并没有给我带来那么好的结果。相反,我使用--input_type encoded_image_string_tensor 导出模型并直接发送jpeg 数据data = image.read(); request.inputs['inputs'].CopyFrom(tf.contrib.util.make_tensor_proto(data, shape=[1]))
  • 这个 sn-p 仅适用于期望图像像素(--input_type image_tensor)作为输入的模型。在您的情况下,它是编码图像,而不是像素。这就是我认为它不起作用的原因。
  • 至少对我来说,这个 sn-p 结合--input_type image_tensor '工作',因为它将返回检测结果。但是,当我使用--input_type encoded_image_string_tensor 导出并发送编码图像时,我得到了更好的结果,例如在图像中检测到 10 次,而使用您的方法检测到 3-4 次。您的方法会返回一些检测结果,因此它似乎可以正常工作,但考虑到模型在训练期间在我的评估集上的表现,它并没有像我预期的那样工作。我的方法表现更接近我的预期。
【解决方案3】:

我一直在努力解决确切的问题。我试图托管来自 Tensorflow 对象检测 API Zoo 的预训练 SSDMobileNet-COCO 检查点

原来我使用的是 tensorflow/models 的旧提交,它恰好是服务的默认子模块

我只是简单地用

提取了最近的提交

cd serving/tf_models git pull origin master git checkout master

之后,再次搭建模型服务器。

bazel build //tensorflow_serving/model_servers:tensorflow_model_server

错误消失了,我能够得到准确的预测

【讨论】:

  • 我试过上面的指令还是不行,你愿意和我分享你的python客户端脚本吗?
  • 我使用了您的客户端脚本的精确副本来测试我的配置并且它有效。请尝试提取 object_detection 的最新提交,因为导出器脚本已更改并且现在更加复杂。
  • 要使用 GPU 进行推理,请确保您已经为 GPU 构建了 Tensorflow Serving。该线程有正确的说明来设置github.com/tensorflow/serving/issues/345
【解决方案4】:

对于错误

grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.NOT_FOUND, details="FeedInputs: unable to find feed output ToFloat:0"

只需将 tf_models 升级到最新版本,然后重新导出模型。

https://github.com/tensorflow/tensorflow/issues/11863

【讨论】:

  • 感谢您的参考,我尝试了但仍然无法解决我的问题。
  • 它工作正常,但似乎只使用CPU而不是GPU进行推理,你知道可能导致这个问题的原因吗?
  • 为 GPU 编译 tf-serving 需要一些文档中没有的特殊技巧。看看这个:github.com/tensorflow/serving/issues/459
  • 非常感谢!
猜你喜欢
  • 1970-01-01
  • 2019-11-09
  • 2019-04-05
  • 1970-01-01
  • 2020-11-02
  • 2019-07-04
  • 1970-01-01
  • 2017-12-02
  • 1970-01-01
相关资源
最近更新 更多