【问题标题】:How can I feed a sparse placeholder in a TensorFlow model from Java如何在 Java 的 TensorFlow 模型中提供稀疏占位符
【发布时间】:2019-08-13 06:06:14
【问题描述】:

我正在尝试使用 TensorFlow 中的 kNN 算法计算给定地址的最佳匹配,该算法效果很好,但是当我尝试导出模型并在我们的 Java 环境中使用它时,我陷入了困境从 Java 中提供稀疏的占位符。

这是一个非常精简的 python 部分版本,它返回测试名称和最佳参考名称之间的最小距离。到目前为止,这项工作符合预期。当我导出模型并将其导入我的 Java 程序时,它总是返回相同的值(默认占位符的距离)。我假设,python 函数sparse_from_word_vec(word_vec) 不在模型中,这对我来说完全有意义,但是我应该如何制作这个稀疏张量?我的输入是一个字符串,我需要创建一个合适的稀疏张量(值)来计算距离。我还搜索了一种在 Java 端生成稀疏张量的方法,但没有成功。

import tensorflow as tf
import pandas as pd

d = {'NAME': ['max mustermann', 
              'erika musterfrau', 
              'joseph haydn', 
              'johann sebastian bach', 
              'wolfgang amadeus mozart']}

df = pd.DataFrame(data=d)  

input_name = tf.placeholder_with_default('max musterman',(), name='input_name')
output_dist = tf.placeholder(tf.float32, (), name='output_dist')

test_name = tf.sparse_placeholder(dtype=tf.string)
ref_names = tf.sparse_placeholder(dtype=tf.string)

output_dist = tf.edit_distance(test_name, ref_names, normalize=True)

def sparse_from_word_vec(word_vec):
    num_words = len(word_vec)
    indices = [[xi, 0, yi] for xi,x in enumerate(word_vec) for yi,y in enumerate(x)]
    chars = list(''.join(word_vec))
    return(tf.SparseTensorValue(indices, chars, [num_words,1,1]))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    t_data_names=tf.constant(df['NAME'])
    reference_names = [el.decode('UTF-8') for el in (t_data_names.eval())]

    sparse_ref_names = sparse_from_word_vec(reference_names)
    sparse_test_name = sparse_from_word_vec([str(input_name.eval().decode('utf-8'))]*5)

    feeddict={test_name: sparse_test_name,
              ref_names: sparse_ref_names, 
              }    

    output_dist = sess.run(output_dist, feed_dict=feeddict)
    output_dist = tf.reduce_min(output_dist, 0)
    print(output_dist.eval())

    tf.saved_model.simple_save(sess,
                               "model-simple",
                               inputs={"input_name": input_name},
                               outputs={"output_dist": output_dist})

这是我的 Java 方法:

public void run(ApplicationArguments args) throws Exception {
  log.info("Loading model...");

  SavedModelBundle savedModelBundle = SavedModelBundle.load("/model", "serve");

  byte[] test_name = "Max Mustermann".toLowerCase().getBytes("UTF-8");


  List<Tensor<?>> output = savedModelBundle.session().runner()
      .feed("input_name", Tensor.<String>create(test_names))
      .fetch("output_dist")
      .run();

  System.out.printl("Nearest distance: " + output.get(0).floatValue());

}

【问题讨论】:

    标签: java tensorflow


    【解决方案1】:

    我能够让您的示例正常工作。在深入研究之前,我在你的 python 代码上有几个 cmets。

    您在整个代码中使用变量 output_dist 来表示 3 种不同的值类型。我不是 python 专家,但我认为这是不好的做法。除了将其作为输入导出外,您也从未真正使用过input_name 占位符。最后一个是tf.saved_model.simple_save 已弃用,您应该改用tf.saved_model.Builder

    现在解决问题。

    使用命令jar tvf libtensorflow-x.x.x.jar(感谢this 帖子)查看libtensorflow jar 文件,您可以看到没有用于创建稀疏张量的有用绑定(也许提出功能请求?)。所以我们必须将输入更改为密集张量,然后向图中添加操作以将其转换为稀疏。在您的原始代码中,稀疏转换位于 python 端,这意味着 java 中加载的图形不会有任何操作。

    这是新的python代码:

    import tensorflow as tf
    import pandas as pd
    
    def model():
        #use dense tensors then convert to sparse for edit_distance
        test_name = tf.placeholder(shape=(None, None), dtype=tf.string, name="test_name")
        ref_names = tf.placeholder(shape=(None, None), dtype=tf.string, name="ref_names")
    
        #Java Does not play well with the empty character so use "/" instead
        test_name_sparse = tf.contrib.layers.dense_to_sparse(test_name, "/")
        ref_names_sparse = tf.contrib.layers.dense_to_sparse(ref_names, "/")
    
        output_dist = tf.edit_distance(test_name_sparse, ref_names_sparse, normalize=True)
    
        #output the index to the closest ref name
        min_idx = tf.argmin(output_dist)
    
        return test_name, ref_names, min_idx
    
    #Python code to be replicated in Java
    def pad_string(s, max_len):
        return s + ["/"] * (max_len - len(s))
    
    d = {'NAME': ['joseph haydn', 
                  'max mustermann', 
                  'erika musterfrau', 
                  'johann sebastian bach', 
                  'wolfgang amadeus mozart']}
    
    df = pd.DataFrame(data=d)  
    input_name = 'max musterman'
    
    #pad dense tensor input
    max_len = max([len(n) for n in df['NAME']])
    
    test_input = [list(input_name)]*len(df['NAME'])
    #no need to pad, all same length
    ref_input = list(map(lambda x: pad_string(x, max_len), [list(n) for n in df['NAME']]))
    
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
    
        test_name, ref_names, min_idx = model()
    
        #run a test to make sure the model works
        feeddict = {test_name: test_input,
                    ref_names: ref_input,
                }
        out = sess.run(min_idx, feed_dict=feeddict)
        print("test output:", out)
    
        #save the model with the new Builder API
        signature_def_map= {
        "predict": tf.saved_model.signature_def_utils.predict_signature_def(
            inputs= {"test_name": test_name, "ref_names": ref_names},
            outputs= {"min_idx": min_idx})
        }
    
        builder = tf.saved_model.Builder("model")
        builder.add_meta_graph_and_variables(sess, ["serve"], signature_def_map=signature_def_map)
        builder.save()
    

    这里是加载和运行它的java。这里可能还有很大的改进空间(java 不是我的主要语言),但它给了你想法。

    import org.tensorflow.Graph;
    import org.tensorflow.Session;
    import org.tensorflow.Tensor;
    import org.tensorflow.TensorFlow;
    import org.tensorflow.SavedModelBundle;
    
    import java.util.ArrayList;
    import java.util.List;
    import java.util.Arrays;
    
    public class Test {
        public static byte[][] makeTensor(String s, int padding) throws Exception
        {
            int len = s.length();
            int extra = padding - len;
    
            byte[][] ret = new byte[len + extra][];
            for (int i = 0; i < len; i++) {
                String cur = "" + s.charAt(i);
                byte[] cur_b = cur.getBytes("UTF-8");
                ret[i] = cur_b;
            }
    
            for (int i = 0; i < extra; i++) {
                byte[] cur = "/".getBytes("UTF-8");
                ret[len + i] = cur;
            }
    
            return ret;
        }
        public static byte[][][] makeTensor(List<String> l, int padding) throws Exception
        {
            byte[][][] ret = new byte[l.size()][][];
            for (int i = 0; i < l.size(); i++) {
                ret[i] = makeTensor(l.get(i), padding);
            }
    
            return ret;
        }
        public static void main(String[] args) throws Exception {
            System.out.println("Loading model...");
    
            SavedModelBundle savedModelBundle = SavedModelBundle.load("model", "serve");
    
    
            List<String> str_test_name = Arrays.asList("Max Mustermann",
                "Max Mustermann",
                "Max Mustermann",
                "Max Mustermann",
                "Max Mustermann");
            List<String> names = Arrays.asList("joseph haydn",
                "max mustermann",
                "erika musterfrau",
                "johann sebastian bach",
                "wolfgang amadeus mozart");
    
            //get the max length for each array
            int pad1 = str_test_name.get(0).length();
            int pad2 = 0;
            for (String var : names) {
                if(var.length() > pad2)
                    pad2 = var.length();
            }
    
    
            byte[][][] test_name = makeTensor(str_test_name, pad1);
            byte[][][] ref_names = makeTensor(names, pad2);
    
            //use a with block so the close method is called
            try(Tensor t_test_name = Tensor.<String>create(test_name))
            {
                try (Tensor t_ref_names = Tensor.<String>create(ref_names))
                {
                    List<Tensor<?>> output = savedModelBundle.session().runner()
                        .feed("test_name", t_test_name)
                        .feed("ref_names", t_ref_names)
                        .fetch("ArgMin")
                        .run();
    
                    System.out.println("Nearest distance: " + output.get(0).longValue());
                }
            }
        }
    }
    

    【讨论】:

    • 我猜想,我需要将 Java 中的密集张量转换为模型中的稀疏张量,但我不知道该怎么做。谢谢,您的回答正是我所需要的。
    猜你喜欢
    • 2016-02-28
    • 2018-09-01
    • 2023-03-28
    • 1970-01-01
    • 2018-07-02
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多