【发布时间】:2018-01-07 09:48:39
【问题描述】:
我正在尝试在 Linux 上通过 1.4 Java API 构建和评估 TensorFlow Graphs。我注意到,每次调用 Session.run() 时,Java API 似乎都会重置操作输出张量的值。这种行为似乎与 Python 中发生的情况不匹配。我最终的问题(详见下文)是如何避免这种明显的行为?
Python 示例
这里的例子是增加标量张量中的值的 Python 代码(也使用 1.4 API)。
>>> import tensorflow as tf
>>> x = tf.get_variable("x", [], dtype=tf.float32, initializer=tf.zeros_initializer)
>>> step = tf.constant(1.0)
>>> xUpdateOp = x.assign_add(step)
>>> s = tf.Session()
>>> s.run(x.initializer)
>>> x.eval(s)
0.0
>>> s.run(xUpdateOp)
1.0
>>> x.eval(s)
1.0
>>> s.run(xUpdateOp)
2.0
>>> x.eval(s)
2.0
>>>
请注意,正如预期的那样,评估 x 会给出其当前值,并且使用会话运行 xUpdateOp 会导致 x 增大 1。
Java 示例
这是我尝试使用 Java 构建递增标量张量的 Tensorflow 图。 Java API 中的初始化不同,因为它缺少一些 Python 的便捷方法。
public static void doCounting(){
try(Graph g = new Graph()){
try(Tensor<Float> zeroT = Tensors.create(0.0f);
Tensor<Float> stepT = Tensors.create(1.0f)){
Output<Float> zero = g.opBuilder("Const", "start")
.setAttr("dtype", zeroT.dataType())
.setAttr("value", zeroT)
.build().output(0);
Output<Float> step = g.opBuilder("Const", "step")
.setAttr("dtype", stepT.dataType())
.setAttr("value", stepT)
.build().output(0);
Output<Float> xVar = g.opBuilder("Variable", "x")
.setAttr("dtype", zero.dataType())
.setAttr("shape", zero.shape())
.build().output(0);
Output<Float> x = g.opBuilder("Assign", "init_x")
.addInput(xVar)
.addInput(zero)
.build().output(0);
Operation xUpdateOp = g.opBuilder("AssignAdd", "x_get_x_plus_step")
.addInput(x)
.addInput(step)
.build();
try(Session s = new Session(g)) {
s.runner().addTarget(xUpdateOp).run();
s.runner().addTarget(xUpdateOp).run();
s.runner().addTarget(xUpdateOp).run();
try(Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)){
System.out.println(result.floatValue());
}
}
}
}
}
以上代码片段的输出
1.0
但我希望它是 4.0,因为我在 xUpdateOp 上调用了 run() 4 次。就算我一败涂地,1.0 也不是我所期望的。
问题
我需要对这个 Java 示例做什么才能获得与 Python 示例相同的行为?如何让 xUpdateOp 使用在之前调用 run() 时计算的 x 值?
我已经尝试过的
我已经尝试使用 feed() 函数来输入 x 值
try(Session s = new Session(g)) {
try(Tensor<Float> x1 = s.runner().fetch(xUpdateOp.name()).run().get(0).expect(Float.class)) {
s.runner().feed(xUpdateOp.name(), 0, x1);
try (Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)) {
System.out.println(result.floatValue());
}
}
}
结果
1.0
我还尝试在没有 addTarget 或 fetch() 的情况下调用 run(),认为 addTarget 或 fetch() 是导致状态重置的原因。或许一旦会话知道要运行什么,它就可以运行多次。
try(Session s = new Session(g)) {
s.runner().addTarget(xUpdateOp).run();
s.runner().run();
s.runner().run();
try(Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)){
System.out.println(result.floatValue());
}
}
结果
Exception in thread "main" java.lang.IllegalArgumentException: Must specify at least one target to fetch or execute.
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
at org.tensorflow.examples.Example.doCounting(MandelbrotExample.java:80)
at org.tensorflow.examples.Example.main(MandelbrotExample.java:50)
ERROR: Non-zero return code '1' from command: Process exited with status 1.
一些相关的问题
How to create/initialize a Variable with Tensorflow 1.0 Java API
java tensorflow reset_default_graph
Java - train loaded tensorflow model
提前感谢您的宝贵时间!
【问题讨论】:
-
也许真正的区别在于初始化。 Tensorflow 是否在每次调用 s.runner().addTarget(xUpdateOp).run() 时重新运行“init_x”操作?
标签: java python tensorflow