【问题标题】:Tensorflow save one of multiple sessionsTensorFlow 保存多个会话之一
【发布时间】:2018-12-26 11:00:38
【问题描述】:

您有一个 Python 脚本,我在其中实例化了一个神经网络类的两个对象。 每个对象都定义了自己的会话并提供了保存图形的方法。

import tensorflow as tf
import os, shutil

class TestNetwork:

    def __init__(self, id):
        self.id = id

        tf.reset_default_graph()

        self.s = tf.placeholder(tf.float32, [None, 2], name='s')
        w_initializer, b_initializer = tf.random_normal_initializer(0., 1.0), tf.constant_initializer(0.1)
        self.k = tf.layers.dense(self.s, 2, kernel_initializer=w_initializer,
                    bias_initializer=b_initializer, name= 'k')

        '''Defines self.session and initialize the variables'''
        session_conf = tf.ConfigProto(
            allow_soft_placement = True,
            log_device_placement = False)
        self.session = tf.Session(config = session_conf)
        self.session.run(tf.global_variables_initializer())



    def save_model(self, output_dir):
        '''Save the network graph and weights to disk'''
        if os.path.exists(output_dir):
            # if provided output_dir already exists, remove it
            shutil.rmtree(output_dir)

        builder = tf.saved_model.builder.SavedModelBuilder(output_dir)
        builder.add_meta_graph_and_variables(
            self.session,
            [tf.saved_model.tag_constants.SERVING],
            clear_devices=True)
        # create a new directory output_dir and store the saved model in it
        builder.save()


t1 = TestNetwork(1)
t2 = TestNetwork(2)


t1.save_model("t1_model")
t2.save_model("t2_model")

我得到的错误是

TypeError:无法将 feed_dict 键解释为张量:名称 'save/Const:0' 指的是不存在的张量。操作, 'save/Const',图中不存在。

我读到一些东西说这个错误是由于tf.train.Saver

因此我在__init__ 方法的末尾添加了以下行:

self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 5)

但是我仍然得到错误。

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    tf.reset_default_graph 将清除默认图堆栈并重置全局默认图。

    注意:默认图是当前线程的属性。这个 函数只适用于当前线程。调用这个函数 当 tf.Session 或 tf.InteractiveSession 处于活动状态时,将导致 未定义的行为。 使用任何以前创建的 tf.Operation 或 调用此函数后的 tf.Tensor 对象将导致 undefined 行为。

    您应该单独指定Graph,并在相应的图形范围内定义所有这些。

    def __init__(self, id):
        self.id = id
    
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.s = tf.placeholder(tf.float32, [None, 2], name='s')
            w_initializer, b_initializer = tf.random_normal_initializer(0., 1.0), tf.constant_initializer(0.1)
            self.k = tf.layers.dense(self.s, 2, kernel_initializer=w_initializer,
                        bias_initializer=b_initializer, name= 'k')
            init = tf.global_variables_initializer()
    
        '''Defines self.session and initialize the variables'''
        session_conf = tf.ConfigProto(
            allow_soft_placement = True,
            log_device_placement = False)
        self.session = tf.Session(config = session_conf,graph=self.graph)
        self.session.run(init)
    

    tf.train.Saver 是另一种保存模型变量的方法。

    编辑 如果你得到空的“变量”,你应该在图中保存模型:

    def save_model(self, output_dir):
        '''Save the network graph and weights to disk'''
        if os.path.exists(output_dir):
            # if provided output_dir already exists, remove it
            shutil.rmtree(output_dir)
    
        with self.graph.as_default():
            builder = tf.saved_model.builder.SavedModelBuilder(output_dir)
            builder.add_meta_graph_and_variables(
                self.session,
                [tf.saved_model.tag_constants.SERVING],
                clear_devices=True)
            # create a new directory output_dir and store the saved model in it
            builder.save()
    

    【讨论】:

    • 谢谢,您的修复可以正常工作,但保存的模型目录包含一个空的“变量”子目录,因此我收到错误:传递的保存路径不是有效的检查点
    猜你喜欢
    • 2016-04-02
    • 2018-02-07
    • 2016-08-18
    • 2016-04-18
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多