【问题标题】:Feeding example to tf predictor.from_saved_model() for estimator trained with tf hub module向 tf predictor.from_saved_model() 提供示例,用于使用 tf hub 模块训练的估计器
【发布时间】:2018-08-20 17:08:10
【问题描述】:

我尝试使用tf hub modules 导出文本分类模型,然后使用predictor.from_saved_model() 从中推断出单个字符串示例的预测。我看到some examples 有类似的想法,但在使用 tf hub 模块构建功能时仍然无法使其适用。这是我的工作:

        train_input_fn = tf.estimator.inputs.pandas_input_fn(
        train_df, train_df['label_ids'], num_epochs= None, shuffle=True)

    # Prediction on the whole training set.
    predict_train_input_fn = tf.estimator.inputs.pandas_input_fn(
        train_df, train_df['label_ids'], shuffle=False)

    embedded_text_feature_column = hub.text_embedding_column(
        key='sentence',
        module_spec='https://tfhub.dev/google/nnlm-de-dim128/1')

    #Estimator
    estimator = tf.estimator.DNNClassifier(
        hidden_units=[500, 100],
        feature_columns=[embedded_text_feature_column],
        n_classes=num_of_class,
        optimizer=tf.train.AdagradOptimizer(learning_rate=0.003) )

    # Training
    estimator.train(input_fn=train_input_fn, steps=1000)

    #prediction on training set
    train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)

    print('Training set accuracy: {accuracy}'.format(**train_eval_result))

    feature_spec = tf.feature_column.make_parse_example_spec([embedded_text_feature_column])
    serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

    export_dir_base = self.cfg['model_path']
    servable_model_path = estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn)

    # Example message for inference
    message = "Was ist denn los"
    saved_model_predictor = predictor.from_saved_model(export_dir=servable_model_path)
    content_tf_list = tf.train.BytesList(value=[str.encode(message)])
    example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    'sentence': tf.train.Feature(
                        bytes_list=content_tf_list
                    )
                }
            )
        )

    with tf.python_io.TFRecordWriter('the_message.tfrecords') as writer:
        writer.write(example.SerializeToString())

    reader = tf.TFRecordReader()
    data_path = 'the_message.tfrecords'
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
    _, serialized_example = reader.read(filename_queue)
    output_dict = saved_model_predictor({'inputs': [serialized_example]})

还有输出:

Traceback (most recent call last):
  File "/Users/dimitrs/component-pythia/src/pythia.py", line 321, in _train
    model = algo.generate_model(samples, generation_id)
  File "/Users/dimitrs/component-pythia/src/algorithm_layer/algorithm.py", line 56, in generate_model
    model = self._process_training(samples, generation)
  File "/Users/dimitrs/component-pythia/src/algorithm_layer/tf_hub_classifier.py", line 91, in _process_training
    output_dict = saved_model_predictor({'inputs': [serialized_example]})
  File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/contrib/predictor/predictor.py", line 77, in __call__
    return self._session.run(fetches=self.fetch_tensors, feed_dict=feed_dict)
  File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1135, in _run
    feed_dict_tensor, options, run_metadata)
  File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
    run_metadata)
  File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Unable to get element as bytes.

serialized_example 不是 serving_input_receiver_fn 建议的正确输入吗?

【问题讨论】:

    标签: tensorflow tensorflow-hub


    【解决方案1】:

    所以,我只需要serialized_example = example.SerializeToString() 将示例写入文件需要在读回之前启动会话。简单的序列化就足够了:

        # Example message for inference
        message = "Was ist denn los"
        saved_model_predictor = predictor.from_saved_model(export_dir=servable_model_path)
        content_tf_list = tf.train.BytesList(value=[message.encode('utf-8')])
        sentence = tf.train.Feature(bytes_list=content_tf_list)
        sentence_dict = {'sentence': sentence}
        features = tf.train.Features(feature=sentence_dict)
    
        example = tf.train.Example(features=features)
    
        serialized_example = example.SerializeToString()
        output_dict = saved_model_predictor({'inputs': [serialized_example]})
    

    【讨论】:

      猜你喜欢
      • 2020-02-23
      • 2019-09-05
      • 1970-01-01
      • 2020-08-31
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-02-08
      • 1970-01-01
      相关资源
      最近更新 更多