【问题标题】:How to use a TFRecord file for batch prediction on GCP AI Platform?如何在 GCP AI Platform 上使用 TFRecord 文件进行批量预测?
【发布时间】:2021-01-05 05:35:12
【问题描述】:

TL;DR Google Cloud AI Platform 在进行批量预测时如何解压TFRecord 文件?

我已将经过训练的 Keras 模型部署到 Google Cloud AI Platform,但我在批量预测的文件格式方面遇到了问题。对于培训,我使用tf.data.TFRecordDataset 来阅读TFRecord 的列表,如下所示,一切正常。

def unpack_tfrecord(record):
    parsed = tf.io.parse_example(record, {
        'chunk': tf.io.FixedLenFeature([128, 2, 3], tf.float32),  # Input
        'class': tf.io.FixedLenFeature([2], tf.int64),            # One-hot classification (binary)
    })

    return (parsed['chunk'], parsed['class'])

files = [str(p) for p in training_chunks_path.glob('*.tfrecord')]
dataset = tf.data.TFRecordDataset(files).batch(32).map(unpack_tfrecord)
model.fit(x=dataset, epochs=train_epochs)
tf.saved_model.save(model, model_save_path)

我将保存的模型上传到 Cloud Storage 并在 AI Platform 中创建一个新模型。 AI Platform 文档指出“使用 gcloud 工具进行批处理 [支持] 带有 JSON 实例字符串的文本文件或 TFRecord 文件(可能已压缩)”(https://cloud.google.com/ai-platform/prediction/docs/overview#prediction_input_data)。但是当我提供一个 TFRecord 文件时,我得到了错误:

("'utf-8' codec can't decode byte 0xa4 in position 1: invalid start byte", 8)

我的 TFRecord 文件包含一堆 Protobuf 编码的tf.train.Example。我没有向 AI Platform 提供 unpack_tfrecord 函数,所以我想它无法正确解包是有道理的,但我知道从这里去哪里。由于数据太大,我对使用 JSON 格式不感兴趣。

【问题讨论】:

  • 您找到解决方案了吗?我也有同样的疑惑。 GCP / AI Platform 文档和示例令人沮丧。
  • 我最终使用 json 作为输入。我将尝试再次使用 tfrecord 进行实验,看看我现在是否可以让它工作。
  • 我认为自定义预测功能是一种(唯一?)方式。
  • 您找到其他解决方案了吗?我使用了子类模型,因此无法使用函数 model_to_estimator。有什么想法吗?

标签: tensorflow tfrecord google-ai-platform


【解决方案1】:

我不知道这是否是解决此问题的最佳方式,但对于 TF 2.x,您可以执行以下操作:

import tensorflow as tf

def make_serving_input_fn():
    # your feature spec
    feature_spec = {
        'chunk': tf.io.FixedLenFeature([128, 2, 3], tf.float32),  
        'class': tf.io.FixedLenFeature([2], tf.int64),
    }

    serialized_tf_examples = tf.keras.Input(
        shape=[], name='input_example_tensor', dtype=tf.string)

    examples = tf.io.parse_example(serialized_tf_examples, feature_spec)

    # any processing 
    processed_chunks = tf.map_fn(
        <PROCESSING_FN>, 
        examples['chunk'], # ?
        dtype=tf.float32)

    return tf.estimator.export.ServingInputReceiver(
        features={<MODEL_FIRST_LAYER_NAME>: processed_chunks},
        receiver_tensors={"input_example_tensor": serialized_tf_examples}
    )


estimator = tf.keras.estimator.model_to_estimator(
    keras_model=model,
    model_dir=<ESTIMATOR_SAVE_DIR>)

estimator.export_saved_model(
    export_dir_base=<WORKING_DIR>,
    serving_input_receiver_fn=make_serving_input_fn)

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2021-12-09
    • 2021-11-01
    • 1970-01-01
    • 1970-01-01
    • 2023-01-03
    • 2021-07-31
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多