【问题标题】:Why does a Java Tensorflow session seem to reset state when a Python Tensorflow session does not?为什么当 Python Tensorflow 会话没有重置状态时,Java Tensorflow 会话似乎会重置状态?
【发布时间】:2018-01-07 09:48:39
【问题描述】:

我正在尝试在 Linux 上通过 1.4 Java API 构建和评估 TensorFlow Graphs。我注意到,每次调用 Session.run() 时,Java API 似乎都会重置操作输出张量的值。这种行为似乎与 Python 中发生的情况不匹配。我最终的问题(详见下文)是如何避免这种明显的行为?

Python 示例

这里的例子是增加标量张量中的值的 Python 代码(也使用 1.4 API)。

>>> import tensorflow as tf
>>> x = tf.get_variable("x", [], dtype=tf.float32, initializer=tf.zeros_initializer)
>>> step = tf.constant(1.0)
>>> xUpdateOp = x.assign_add(step)
>>> s = tf.Session()
>>> s.run(x.initializer)
>>> x.eval(s)
0.0
>>> s.run(xUpdateOp)
1.0
>>> x.eval(s)
1.0
>>> s.run(xUpdateOp)
2.0
>>> x.eval(s)
2.0
>>>

请注意,正如预期的那样,评估 x 会给出其当前值,并且使用会话运行 xUpdateOp 会导致 x 增大 1。

Java 示例

这是我尝试使用 Java 构建递增标量张量的 Tensorflow 图。 Java API 中的初始化不同,因为它缺少一些 Python 的便捷方法。

public static void doCounting(){
    try(Graph g = new Graph()){
        try(Tensor<Float> zeroT = Tensors.create(0.0f);
            Tensor<Float> stepT = Tensors.create(1.0f)){
            Output<Float> zero = g.opBuilder("Const", "start")
                    .setAttr("dtype", zeroT.dataType())
                    .setAttr("value", zeroT)
                    .build().output(0);
            Output<Float> step = g.opBuilder("Const", "step")
                    .setAttr("dtype", stepT.dataType())
                    .setAttr("value", stepT)
                    .build().output(0);
            Output<Float> xVar = g.opBuilder("Variable", "x")
                    .setAttr("dtype", zero.dataType())
                    .setAttr("shape", zero.shape())
                    .build().output(0);
            Output<Float> x = g.opBuilder("Assign", "init_x")
                    .addInput(xVar)
                    .addInput(zero)
                    .build().output(0);

            Operation xUpdateOp = g.opBuilder("AssignAdd", "x_get_x_plus_step")
                    .addInput(x)
                    .addInput(step)
                    .build();

            try(Session s = new Session(g)) {
                s.runner().addTarget(xUpdateOp).run();
                s.runner().addTarget(xUpdateOp).run();
                s.runner().addTarget(xUpdateOp).run();

                try(Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)){
                    System.out.println(result.floatValue());
                }
            }
        }
    }
}

以上代码片段的输出

1.0

但我希望它是 4.0,因为我在 xUpdateOp 上调用了 run() 4 次。就算我一败涂地,1.0 也不是我所期望的。

问题

我需要对这个 Java 示例做什么才能获得与 Python 示例相同的行为?如何让 xUpdateOp 使用在之前调用 run() 时计算的 x 值?

我已经尝试过的

我已经尝试使用 feed() 函数来输入 x 值

try(Session s = new Session(g)) {
    try(Tensor<Float> x1 = s.runner().fetch(xUpdateOp.name()).run().get(0).expect(Float.class)) {
        s.runner().feed(xUpdateOp.name(), 0, x1);
        try (Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)) {
            System.out.println(result.floatValue());
        }
    }
}

结果

1.0

我还尝试在没有 addTarget 或 fetch() 的情况下调用 run(),认为 addTarget 或 fetch() 是导致状态重置的原因。或许一旦会话知道要运行什么,它就可以运行多次。

try(Session s = new Session(g)) {
    s.runner().addTarget(xUpdateOp).run();
    s.runner().run();
    s.runner().run();

    try(Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)){
        System.out.println(result.floatValue());
    }
}

结果

Exception in thread "main" java.lang.IllegalArgumentException: Must specify at least one target to fetch or execute.
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
at org.tensorflow.examples.Example.doCounting(MandelbrotExample.java:80)
at org.tensorflow.examples.Example.main(MandelbrotExample.java:50)
ERROR: Non-zero return code '1' from command: Process exited with status 1.

一些相关的问题

How to create/initialize a Variable with Tensorflow 1.0 Java API

java tensorflow reset_default_graph

Java - train loaded tensorflow model

提前感谢您的宝贵时间!

【问题讨论】:

  • 也许真正的区别在于初始化。 Tensorflow 是否在每次调用 s.runner().addTarget(xUpdateOp).run() 时重新运行“init_x”操作?

标签: java python tensorflow


【解决方案1】:

在您的示例中,xUpdateOpx 作为其输入,x 是将zero 分配给变量的操作的输出。因此,每次运行 xUpdateOp 时,它首先将零分配给变量。

对您的代码稍作调整将生成 4.0:

# Changed addInput(x) to addInput(xVar)
Operation xUpdateOp =
    g.opBuilder("AssignAdd", "x_get_x_plus_step").addInput(xVar).addInput(step).build();

try (Session s = new Session(g)) {
  # Initialize the variable once
  s.runner().addTarget(x.op()).run();
  s.runner().addTarget(xUpdateOp).run();
  s.runner().addTarget(xUpdateOp).run();
  s.runner().addTarget(xUpdateOp).run();

  try (Tensor<Float> result =
       s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)) {
    System.out.println(result.floatValue());
  }                     
}

与 Python 代码平行:上面的 Java 代码 sn-p 更像是问题中的 Python 代码。虽然问题中的 Java 代码更像 Python 中的以下代码:

import tensorflow as tf

zero = tf.constant(0.0)
step = tf.constant(1.0)
xVar = tf.Variable(initial_value=zero, name="x")
x = tf.assign(xVar, zero)
xUpdateOp = tf.assign_add(x, step)

所以tf.assign_add(x, step)tf.assign_add(xVar, step) 会有所不同。在前者中,AssignAdd 操作适用于Assign 操作的输出。

希望对您有所帮助。

【讨论】:

  • 我可以确认这确实解决了问题,但我不明白为什么。 AssignAdd 的文档及其相关操作“分配”说“输出:= 与“参考”相同。返回是为了方便在变量更新后想要使用新值的操作。”我读到这意味着从充当初始化器的“分配”返回的值是初始化的输入变量,并且可以在其他地方用作变量。
  • 我试图通过详细说明答案来解释。这有帮助吗?
  • 在您的答案中添加 Java 和 Python 之间的对比确实有帮助。问题在于稍后在初始化变量的 Tensorflow 操作的结果图中的使用,而不是变量本身。似乎“分配”运算符的返回值是新值,而不是用新值填充的变量。谢谢!
猜你喜欢
  • 1970-01-01
  • 2021-11-07
  • 1970-01-01
  • 2016-04-14
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2010-09-28
  • 2010-12-13
相关资源
最近更新 更多