【发布时间】:2018-04-22 13:30:21
【问题描述】:
我有一个训练有素的估算器,用于在新输入数据进入时进行实时预测。
在代码的开头我实例化了估算器:
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir="{}/model_dir_{}".format(script_dir, 3))
然后在一个循环中,每次我获得足够的新数据进行预测时:
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array([sample.normalized.input_data])},
num_epochs=1,
shuffle=False)
predictions = estimator.predict(
input_fn=predict_input_fn,
)
每次我这样做时,我都会在控制台中收到这些 tensorflow 消息:
2018-04-21 16:01:08.401319: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1195] 创建 TensorFlow 设备 (/device:GPU:0) -> (设备:0,名称:GeForce GTX 1060 6GB, pci bus id: 0000:04:00.0, 计算能力: 6.1)
INFO:tensorflow:从 /home/fgervais/tf/model_dir_3/model.ckpt-103712 恢复参数
似乎每次预测都会重新完成整个 GPU 检测过程和模型加载。
有没有办法在实时输入之间将模型加载到内存中,以便获得更好的预测率?
【问题讨论】:
-
我不知道为什么人们不赞成这个。我遇到了完全相同的问题,这个解决方案非常有帮助。
标签: tensorflow tensorflow-datasets