【问题标题】:How do you obtain the output and input values of a model in tensorflow?在 tensorflow 中如何获取模型的输出和输入值?
【发布时间】:2017-05-24 14:17:38
【问题描述】:

我正在研究 GAN,并决定使用 HyperGAN 来实现我的算法。它是使用 TensorFlow 对 DCGAN 的封装。 HyperGAN 使用TF 的检查点方法保存输出。

后来,我尝试使用以下方式运行加载模型:

import tensorflow as tf
sess=tf.Session()    
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
sess.run(tf.global_variables_initializer())

但是,由于它是一个 GAN,它需要一个输入潜在向量并输出一个图像。这是使用

out_image = sess.run(last_node, feed_dict(input_node: value))

但是由于我加载了模型,我不知道最后一个节点的名称是什么,输入节点占位符的名称是什么。我如何获得最初用于创建图表的名称?我尝试使用 TensorBoard 进行可视化,但图表很大,因此卡住了。

【问题讨论】:

    标签: python tensorflow tensorboard


    【解决方案1】:

    您应该尝试在图中打印张量列表:

    with tf.Graph().as_default() as graph:
    ....
    
    count = 0
    for op in graph.get_operations():
        print op.values()
        count+=1
        if count == 50:
            assert False
    

    为了查看图表的前 50 个节点,您将看到如下内容:

    (<tf.Tensor 'import/Placeholder_only:0' shape=<unknown> dtype=float32>,)
    (<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_max:0' shape=() dtype=float32>,)
    (<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_min:0' shape=() dtype=float32>,)
    (<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_const:0' shape=(512,) dtype=quint8>,)
    (<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53:0' shape=(512,) dtype=float32>,)
    

    我将计数放在那里,因为通常终端会打印出太多张量,以至于初始输入节点名称会在终端中消失。

    最后,简单地注释掉计数使用的行:

    #count = 0
    for op in graph.get_operations():
        print op.values()
        #count+=1
        #if count == 50:
        #    assert False
    

    打印出最后几个节点(即您的输出节点)。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2022-12-24
      • 1970-01-01
      • 2022-01-23
      • 2019-07-09
      • 2017-11-21
      • 2021-08-20
      • 2021-08-05
      • 2017-02-23
      相关资源
      最近更新 更多