【发布时间】:2019-08-02 06:33:11
【问题描述】:
我有一个保存的模型,我设法加载、运行并获得对 1 行 9 个特征的预测。 (输入) 现在我正在尝试像这样预测 100 行, 但是当试图从 Tensor.copyTo() 读取结果到结果数组时,我得到了不兼容的形状
java.lang.IllegalArgumentException: cannot copy Tensor with shape [1, 1] into object with shape [100, 1]
显然我设法在循环中运行这一预测 - 但这比一次运行 100 次的等效 python 执行慢 20 倍。
这是 /saved_model_cli.py 报告的已保存模型信息
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['input'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 9)
name: dense_1_input:0
The given SavedModel SignatureDef contains the following output(s):
outputs['output'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: dense_4/BiasAdd:0
Method name is: tensorflow/serving/predict
问题是 - 我是否需要像问题 here 那样为我想要预测的每一行运行 run()
【问题讨论】:
-
您不能在 CNN 上输入形状为
[100,1]的数组,因为它的第一层是固定的,您无法更改它。我的第一个猜测是你必须为每一行run(),但是 Tensorflow 的 Java 实现看起来很糟糕,我找不到合适的例子来证明这一点 -
OK 在这里回答我自己。根据 JAVA 示例的官方张量流 - run() 是每个预测。 github.com/tensorflow/models/blob/master/samples/languages/java/…
-
我没有提到它,但问题的原因是 java 的运行速度比 python 慢 20 倍。这对我来说很奇怪。
标签: java tensorflow deep-learning prediction tensorflow-serving