【发布时间】:2018-11-13 16:13:35
【问题描述】:
我正在尝试训练和部署简化的 Quick, Draw!来自 Google Cloud 上 here 的分类器。我已经成功地在 GC 中训练模型,现在坚持部署它,更准确地说,在 creating serving input functions。
我正在遵循 here 的指示,并且很难理解输入张量应该是什么类型。
错误:
TypeError:无法将类型对象转换为张量。内容:SparseTensor(indices=Tensor("ParseExample/ParseExample:0", shape=(?, 2), dtype=int64), values=Tensor("ParseExample/ParseExample:1", shape=(?,), dtype= float32), dense_shape=Tensor("ParseExample/ParseExample:2", shape=(2,), dtype=int64))。考虑将元素转换为支持的类型。
服务功能:
def serving_input_receiver_fn():
serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[None], name='input_tensors')
receiver_tensors = {'infer_inputs': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
功能说明:
feature_spec = {
"ink": tf.VarLenFeature(dtype=tf.float32),
"shape": tf.FixedLenFeature([2], dtype=tf.int64)
}
输入层:
def _get_input_tensors(features, labels):
shapes = features["shape"]
lengths = tf.squeeze(
tf.slice(shapes, begin=[0, 0], size=[params.batch_size, 1]))
inks = tf.reshape(features["ink"], [params.batch_size, -1, 3])
if labels is not None:
labels = tf.squeeze(labels)
return inks, lengths, labels
模型代码和训练数据取自here。
【问题讨论】:
标签: python tensorflow tensorflow-serving google-cloud-ml