【问题标题】:How to run asynchronous predictions with TensorFlow Estimator API?如何使用 TensorFlow Estimator API 运行异步预测?
【发布时间】:2018-01-06 18:00:38
【问题描述】:

我正在使用 tf.estimator API 来预测标点符号。我使用 TFRecords 和 tf.train.shuffle_batch 对预处理数据进行了训练。现在我想做预测。我可以很好地将静态 NumPy 数据输入tf.constant 并从input_fn 返回。

但是,我正在处理序列数据,我需要一次提供一个示例,下一个输入取决于上一个输出。我还希望能够处理通过 HTTP 请求输入的数据。

每次调用estimator.predict 时,它都会重新加载检查点并重新创建整个图形。这是缓慢且昂贵的。所以我需要能够动态地将数据提供给input_fn

我目前的尝试大致是这样的:

feature_input = tf.placeholder(tf.int32, shape=[1, MAX_SUBSEQUENCE_LEN])
q = tf.FIFOQueue(1, tf.int32, shapes=[[1, MAX_SUBSEQUENCE_LEN]])
enqueue_op = q.enqueue(feature_input)

def input_fn():
    return q.dequeue()

estimator = tf.estimator.Estimator(model_fn, model_dir=model_file)
predictor = estimator.predict(input_fn=input_fn)
sess = tf.Session()
output = None

while True:
    x = get_numpy_data(x, output)
    if x is None:
        break
    sess.run(enqueue_op, {feature_input: x})
    output = predictor.next()
    save_to_file(output)

sess.close()

但是我收到以下错误: ValueError: Input graph and Layer graph are not the same: Tensor("EmbedSequence/embedding_lookup:0", shape=(1, 200, 128), dtype=float32) is not from the passed-in graph.

如何通过input_fn 将数据异步插入现有图表,以便一次获得一个预测?

【问题讨论】:

    标签: python tensorflow tensorflow-estimator


    【解决方案1】:

    事实证明,主要问题是所有张量都需要在 input_fn 内创建,否则它们不会被添加到同一个图中。我需要运行入队操作,但无法访问输入函数返回的任何内容。

    我最终继承了Estimator 类并创建了一个自定义预测函数,它允许我将数据动态添加到预测队列并返回结果:

    # async_estimator.py
    
    import six
    import tensorflow as tf
    from tensorflow.python.estimator.estimator import Estimator
    from tensorflow.python.estimator.estimator import _check_hooks_type
    from tensorflow.python.estimator import model_fn as model_fn_lib
    from tensorflow.python.framework import ops
    from tensorflow.python.framework import random_seed
    from tensorflow.python.training import saver
    from tensorflow.python.training import training
    
    
    class AsyncEstimator(Estimator):
    
        def async_predictor(self,
                    dtype,
                    shape=None,
                    predict_keys=None,
                    hooks=None,
                    checkpoint_path=None):
            """Returns a tuple of functions: first runs predicitons on the model, second cleans up
            Args:
              dtype: the dtype of the input
              shape: the shape of the input placeholder (optional)
              predict_keys: list of `str`, name of the keys to predict. It is used if
                the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used
                then rest of the predictions will be filtered from the dictionary. If
                `None`, returns all.
              hooks: List of `SessionRunHook` subclass instances. Used for callbacks
                inside the prediction call.
              checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
                latest checkpoint in `model_dir` is used.
            Returns:
              (predict, finish): tuple of functions
    
                predict: runs a single prediction and returns the results
                    Args:
                        x: NumPy array of input
                    Returns:
                        Evaluated value of the prediction
    
                finish: closes the session, allowing the program to exit
    
            Raises:
              ValueError: Could not find a trained model in model_dir.
              ValueError: if batch length of predictions are not same.
              ValueError: If there is a conflict between `predict_keys` and
                `predictions`. For example if `predict_keys` is not `None` but
                `EstimatorSpec.predictions` is not a `dict`.
            """
            hooks = _check_hooks_type(hooks)
            # Check that model has been trained.
            if not checkpoint_path:
                checkpoint_path = saver.latest_checkpoint(self._model_dir)
            if not checkpoint_path:
                raise ValueError('Could not find trained model in model_dir: {}.'.format(
                    self._model_dir))
    
            with ops.Graph().as_default() as g:
                random_seed.set_random_seed(self._config.tf_random_seed)
                training.create_global_step(g)
                input_placeholder = tf.placeholder(dtype=dtype, shape=shape)
                queue = tf.FIFOQueue(1, dtype, shapes=shape)
                enqueue_op = queue.enqueue(input_placeholder)
                features = queue.dequeue()
                estimator_spec = self._call_model_fn(features, None,
                                                     model_fn_lib.ModeKeys.PREDICT)
                predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
                mon_sess = training.MonitoredSession(
                        session_creator=training.ChiefSessionCreator(
                            checkpoint_filename_with_path=checkpoint_path,
                            scaffold=estimator_spec.scaffold,
                            config=self._session_config),
                        hooks=hooks)
    
                def predict(x):
                    if mon_sess.should_stop():
                        raise StopIteration
                    mon_sess.run(enqueue_op, {input_placeholder: x})
                    preds_evaluated = mon_sess.run(predictions)
                    if not isinstance(predictions, dict):
                        return preds_evaluated
                    else:
                        preds = []
                        for i in range(self._extract_batch_length(preds_evaluated)):
                            preds.append({
                                key: value[i]
                                for key, value in six.iteritems(preds_evaluated)
                            })
                        return preds
    
                def finish():
                    mon_sess.close()
    
                return predict, finish
    

    这是使用它的粗略代码:

    import tensorflow as tf
    from async_estimator import AsyncEstimator
    
    
    def doPrediction(model_fn, model_dir, max_seq_length):
        estimator = AsyncEstimator(model_fn, model_dir=model_dir)
        predict, finish = estimator.async_predictor(dtype=tf.int32, shape=(1, max_seq_length))
        output = None
    
        while True:
            # my input is dependent on the previous output
            x = get_numpy_data(output)
            if x is None:
                break
            output = predict(x)
            save_to_disk(output)
    
        finish()
    

    注意:这是一个适合我需要的简单解决方案,可能需要针对其他情况进行修改。它正在 TensorFlow 1.2.1 上运行。

    希望 TF 将正式采用类似的方法,以便更轻松地使用 Estimator 提供动态预测。

    【讨论】:

    • 这是一个很好的解决方案!非常感谢科林!
    猜你喜欢
    • 2020-01-24
    • 1970-01-01
    • 2019-04-17
    • 2017-09-14
    • 1970-01-01
    • 2018-08-23
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多