【发布时间】:2019-08-13 06:06:14
【问题描述】:
我正在尝试使用 TensorFlow 中的 kNN 算法计算给定地址的最佳匹配,该算法效果很好,但是当我尝试导出模型并在我们的 Java 环境中使用它时,我陷入了困境从 Java 中提供稀疏的占位符。
这是一个非常精简的 python 部分版本,它返回测试名称和最佳参考名称之间的最小距离。到目前为止,这项工作符合预期。当我导出模型并将其导入我的 Java 程序时,它总是返回相同的值(默认占位符的距离)。我假设,python 函数sparse_from_word_vec(word_vec) 不在模型中,这对我来说完全有意义,但是我应该如何制作这个稀疏张量?我的输入是一个字符串,我需要创建一个合适的稀疏张量(值)来计算距离。我还搜索了一种在 Java 端生成稀疏张量的方法,但没有成功。
import tensorflow as tf
import pandas as pd
d = {'NAME': ['max mustermann',
'erika musterfrau',
'joseph haydn',
'johann sebastian bach',
'wolfgang amadeus mozart']}
df = pd.DataFrame(data=d)
input_name = tf.placeholder_with_default('max musterman',(), name='input_name')
output_dist = tf.placeholder(tf.float32, (), name='output_dist')
test_name = tf.sparse_placeholder(dtype=tf.string)
ref_names = tf.sparse_placeholder(dtype=tf.string)
output_dist = tf.edit_distance(test_name, ref_names, normalize=True)
def sparse_from_word_vec(word_vec):
num_words = len(word_vec)
indices = [[xi, 0, yi] for xi,x in enumerate(word_vec) for yi,y in enumerate(x)]
chars = list(''.join(word_vec))
return(tf.SparseTensorValue(indices, chars, [num_words,1,1]))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
t_data_names=tf.constant(df['NAME'])
reference_names = [el.decode('UTF-8') for el in (t_data_names.eval())]
sparse_ref_names = sparse_from_word_vec(reference_names)
sparse_test_name = sparse_from_word_vec([str(input_name.eval().decode('utf-8'))]*5)
feeddict={test_name: sparse_test_name,
ref_names: sparse_ref_names,
}
output_dist = sess.run(output_dist, feed_dict=feeddict)
output_dist = tf.reduce_min(output_dist, 0)
print(output_dist.eval())
tf.saved_model.simple_save(sess,
"model-simple",
inputs={"input_name": input_name},
outputs={"output_dist": output_dist})
这是我的 Java 方法:
public void run(ApplicationArguments args) throws Exception {
log.info("Loading model...");
SavedModelBundle savedModelBundle = SavedModelBundle.load("/model", "serve");
byte[] test_name = "Max Mustermann".toLowerCase().getBytes("UTF-8");
List<Tensor<?>> output = savedModelBundle.session().runner()
.feed("input_name", Tensor.<String>create(test_names))
.fetch("output_dist")
.run();
System.out.printl("Nearest distance: " + output.get(0).floatValue());
}
【问题讨论】:
标签: java tensorflow