【发布时间】:2016-12-23 15:36:52
【问题描述】:
我正在尝试找到一种方法来加快我的代码速度。
简而言之,我有一个训练有素的模型,我使用它来获取预测、对它们进行排序并输出排名。
def predict(feed_dict, truth):
# Feed dict contains about 10K candidates to obtain scores
pred = self.sess.run([self.mdl.predict_op], feed_dict)
pred = np.array(pred)
# With the scores, I sort them by likelihood
sort = np.argsort(pred)[::-1]
# I find the rank of the ground truth
rank = np.where(sort==truth)[0][0] + 1
return rank
但是,这个过程非常缓慢。我有大约 10K 测试样本。我相信 session 不能很好地与 python 中的标准多处理库一起使用,而多 cpu/gpu 支持仅适用于 tensorflow 操作。
有没有什么优雅的方法可以通过多处理加快速度?还是我必须将其作为 TF 计算图的一部分来实现。
非常感谢!
【问题讨论】:
-
哪一部分慢?
-
顺便说一句,
tf.nn.top_k(pred)[1]与您的np.argsort行相同。如果你把所有东西都变成 TF 图,你就不需要多处理——并行的session.run调用可以从同一进程中的不同 Python 线程开始。 -
速度慢的原因是我每次都必须在有效集或测试集上调用 10K+ 次。
-
谢谢!顺便说一句,你知道 TF 中 np.where 的等价物是什么吗?非常感谢
-
在 0.12 版本中是
tf.where(在早期版本中是tf.select)
标签: python numpy tensorflow