【问题标题】:tf.estimator.BoostedTreesRegressor SavedModel Restore Issuetf.estimator.BoostedTreesRegressor SavedModel 恢复问题
【发布时间】:2019-06-20 18:55:41
【问题描述】:

我在使用 tf.SavedModel 恢复 tf.estimator.BoostedTreesRegressor 模型时遇到问题。使用 tf.contrib.predictor.from_saved_model() 从保存的模型目录重新加载模型时,我收到以下错误:

KeyError:“名称 'boosted_trees/QuantileAccumulator/' 指的是 图中没有的操作。”

此错误仅在使用数字特征(例如 tf.feature_column.numeric_column)时发生。仅使用分类列时重新加载模型效果很好

当我不保存/恢复时,BoostedTreesRegressor 会成功评估和预测所有功能。

以下估算器保存/恢复方案已成功运行:
- 具有数字和分类特征的 DNNRegressor
- 具有数字和分类特征的线性回归器
- 只有分类特征的 BoostedTreeRegressor

fc = tf.feature_column
feature_columns = [
fc.numeric_column('f1', dtype=tf.int64),
fc.numeric_column('f2', dtype=tf.int64),
fc.indicator_column(
               fc.categorical_column_with_vocabulary_list('f3',f3)),
fc.indicator_column(
               fc.categorical_column_with_vocabulary_list('f4',f4))
]

feature_spec = fc.make_parse_example_spec(feature_columns)

params = {
    'feature_columns' : feature_columns,
    'n_batches_per_layer' : n_batches,
    'n_trees': 200,
    'max_depth': 6,
    'learning_rate': 0.01
}

regressor = tf.estimator.BoostedTreesRegressor(**params)
regressor.train(train_input_fn, max_steps=400)

serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

regressor.export_saved_model('saved_model', serving_input_receiver_fn)

.
.
.
# latest is path to saved model
predict_fn = predictor.from_saved_model(latest[:-4])
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-101-ee20beae4424> in <module>
----> 1 predict_fn = predictor.from_saved_model(latest[:-4])
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/contrib/predictor/predictor_factories.py in from_saved_model(export_dir, signature_def_key, signature_def, input_names, output_names, tags, graph, config)
    151       tags=tags,
    152       graph=graph,
--> 153       config=config)
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/contrib/predictor/saved_model_predictor.py in __init__(self, export_dir, signature_def_key, signature_def, input_names, output_names, tags, graph, config)
    151     with self._graph.as_default():
    152       self._session = session.Session(config=config)
--> 153       loader.load(self._session, tags.split(','), export_dir)
    154 
    155     if input_names is None:
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load(sess, tags, export_dir, import_scope, **saver_kwargs)
    267   """
    268   loader = SavedModelLoader(export_dir)
--> 269   return loader.load(sess, tags, import_scope, **saver_kwargs)
    270 
    271 
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load(self, sess, tags, import_scope, **saver_kwargs)
    418     with sess.graph.as_default():
    419       saver, _ = self.load_graph(sess.graph, tags, import_scope,
--> 420                                  **saver_kwargs)
    421       self.restore_variables(sess, saver, import_scope)
    422       self.run_init_ops(sess, tags, import_scope)
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load_graph(self, graph, tags, import_scope, **saver_kwargs)
    348     with graph.as_default():
    349       return tf_saver._import_meta_graph_with_return_elements(  # pylint: disable=protected-access
--> 350           meta_graph_def, import_scope=import_scope, **saver_kwargs)
    351 
    352   def restore_variables(self, sess, saver, import_scope=None):
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/training/saver.py in _import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope, return_elements, **kwargs)
   1455           import_scope=import_scope,
   1456           return_elements=return_elements,
-> 1457           **kwargs))
   1458 
   1459   saver = _create_saver_from_imported_meta_graph(
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/meta_graph.py in import_scoped_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, graph, import_scope, input_map, unbound_inputs_col_name, restore_collections_predicate, return_elements)
    850           for value in field.value:
    851             col_op = graph.as_graph_element(
--> 852                 ops.prepend_name_scope(value, scope_to_prepend_to_names))
    853             graph.add_to_collection(key, col_op)
    854         elif kind == "int64_list":
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
   3476 
   3477     with self._lock:
-> 3478       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3479 
   3480   def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
   3536         if name not in self._nodes_by_name:
   3537           raise KeyError("The name %s refers to an Operation not in the "
-> 3538                          "graph." % repr(name))
   3539         return self._nodes_by_name[name]
   3540 
KeyError: "The name 'boosted_trees/QuantileAccumulator/' refers to an Operation not in the graph."

【问题讨论】:

    标签: tensorflow tensorflow-estimator


    【解决方案1】:

    如果你使用的是Tensorflow版本,1.x (1.14, 1.15),你可以使用

    tf.compat.v1.saved_model.loadtf.compat.v1.saved_model.loader.loadtf.saved_model.loader.load 加载保存的模型。

    如果您使用 Tensorflow Version 2,下面是 SavingRestoring 成功使用 的代码tf.estimator.BoostedTreesClassifier

    n_batches = 1
    est = tf.estimator.BoostedTreesClassifier(feature_columns,
                                              n_batches_per_layer=n_batches)
    
    # The model will stop training once the specified number of trees is built, not
    # based on the number of steps.
    est.train(train_input_fn, max_steps=100)
    
    # Eval.
    result = est.evaluate(eval_input_fn)
    clear_output()
    print(pd.Series(result))
    
    feature_spec = fc.make_parse_example_spec(feature_columns)
    
    serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
    
    Exported_Path = est.export_saved_model('saved_model', serving_input_receiver_fn)
    
    imported = tf.saved_model.load(Exported_Path)
    

    如需使用Tensorflow Version 2的完整工作代码,请查看Github Gist

    【讨论】:

      猜你喜欢
      • 2020-06-15
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2016-06-10
      • 2019-05-14
      相关资源
      最近更新 更多