【问题标题】:Tensorflow 2.0 & Java APITensorFlow 2.0 和 Java API
【发布时间】:2020-06-03 20:07:16
【问题描述】:

(注意,我已经解决了我的问题并将代码贴在了底部)

我在玩 TensorFlow,后端处理必须在 Java 中进行。我从https://developers.google.com/machine-learning/crash-course 中获取了其中一个模型,并将其保存为 tf.saved_model.save(my_model,"house_price_median_income") (使用 docker 容器)。我复制了模型并将其加载到 Java 中(使用从源代码构建的 2.0 东西,因为我在 Windows 上)。 我可以加载模型并运行它:

   try (SavedModelBundle model = SavedModelBundle.load("./house_price_median_income", "serve")) {
    try (Session session = model.session()) {
        Session.Runner runner = session.runner();
        float[][] in = new float[][]{ {2.1518f} } ;

        Tensor<?> jack = Tensor.create(in);
        runner.feed("serving_default_layer1_input", jack);

        float[][] probabilities = runner.fetch("StatefulPartitionedCall").run().get(0).copyTo(new float[1][1]);

        for (int i = 0; i < probabilities.length; ++i) {
            System.out.println(String.format("-- Input #%d", i));
            for (int j = 0; j < probabilities[i].length; ++j) {
              System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
            }
          }
    }
 }

以上内容被硬编码为输入和输出,但我希望能够读取模型并提供一些信息,以便最终用户可以选择输入和输出等。

我可以使用 python 命令获取输入和输出:saved_model_cli show --dir ./house_price_median_income --all

我想做的是通过 Java 获取输入和输出,因此我的代码不需要执行 python 脚本来获取它们。我可以通过以下方式进行操作:

 Graph graph = model.graph();
    Iterator<Operation> itr = graph.operations();
    while (itr.hasNext()) {
        GraphOperation e = (GraphOperation)itr.next();
        System.out.println(e);

这会将输入和输出都输出为“操作”但是我怎么知道它是输入和/或输出? python 工具使用 SignatureDef 但这似乎根本没有出现在 TensorFlow 2.0 java 的东西中。我是否遗漏了一些明显的东西,还是只是从 TensforFlow 2.0 Java 库中遗漏了?

注意,我已通过下面的答案帮助对我的问题进行了分类。这是我的全部代码,以防将来有人会喜欢它。请注意,这是 TF 2.0 并使用下面提到的 SNAPSHOT。我做了一些假设,但它展示了如何提取输入和输出,然后使用它们来运行模型

import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.Session.Run;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.GraphOperation;
import org.tensorflow.proto.framework.SignatureDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.proto.framework.MetaGraphDef;
import java.util.Map;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.tools.Shape;
import java.nio.FloatBuffer;
import org.tensorflow.tools.buffer.DataBuffers;
import org.tensorflow.tools.ndarray.FloatNdArray;
import org.tensorflow.tools.ndarray.StdArrays;
import org.tensorflow.proto.framework.TensorInfo;

public class v2tensor {
    public static void main(String[] args) {
     try (SavedModelBundle savedModel = SavedModelBundle.load("./house_price_median_income", "serve")) {
        SignatureDef modelInfo = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
        TensorInfo input1 = null;
        TensorInfo output1 = null;
        Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
        for(Map.Entry<String, TensorInfo> input : inputs.entrySet()) {
            if (input1 == null) {
                input1 = input.getValue();
                System.out.println(input1.getName());
            }
            System.out.println(input);
        }
        Map<String, TensorInfo> outputs = modelInfo.getOutputsMap();
        for(Map.Entry<String, TensorInfo> output : outputs.entrySet()) {
            if (output1 == null) {
                output1=output.getValue();
            }
            System.out.println(output);
        }

        try (Session session = savedModel.session()) {
            Session.Runner runner = session.runner();
            FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{ { 2.1518f } } );

            try (Tensor<TFloat32> jack = TFloat32.tensorOf(matrix) ) {
                runner.feed(input1.getName(), jack);
                try ( Tensor<TFloat32> rezz = runner.fetch(output1.getName()).run().get(0).expect(TFloat32.DTYPE) ) { 
                    TFloat32 data = rezz.data();
                    data.scalars().forEachIndexed((i, s) -> {
                        System.out.println(s.getFloat());
                    }   );
                }
            }
        }
    } catch (TensorFlowException ex) {
        ex.printStackTrace();   
    }
    }
}

【问题讨论】:

    标签: java tensorflow


    【解决方案1】:

    您需要做的是将SavedModelBundle 元数据读取为MetaGraphDef,从那里您可以从SignatureDef 中检索输入和输出名称,就像在Python 中一样。

    在 TF Java 1.*(即您在示例中使用的客户端)中,原始定义不能从 tensorflow 工件中直接获得,您需要向 @987654327 添加依赖项@ 并将SavedModelBundle.metaGraphDef() 的结果反序列化为MetaGraphDef 原型。

    在 TF Java 2.* 中(新客户端实际上只能作为来自 here 的快照提供),protos 会立即出现,因此您只需调用此行即可检索正确的 SignatureDef

    savedModel.metaGraphDef().signatureDefMap.getValue("serving_default")
    

    【讨论】:

    • @john-mitchell,这对你有帮助吗?你解封了吗?
    • 抱歉耽搁了;这非常接近:SignatureDef modelInfo = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
    • 抱歉,我才意识到我在 Kotlin 中写了原始答案 :)
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-09-06
    • 1970-01-01
    • 2020-11-02
    • 1970-01-01
    • 2017-07-21
    • 2020-08-12
    相关资源
    最近更新 更多