【问题标题】:tensorflow how to change datasettensorflow如何更改数据集
【发布时间】:2018-01-09 20:41:18
【问题描述】:

我有一个 Dataset API doohickey,它是我的 tensorflow 图的一部分。当我想使用不同的数据时如何换掉它?

dataset = tf.data.Dataset.range(3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

variable = tf.Variable(3, dtype=tf.int64)
model = variable*next_element

#pretend like this is me training my model, or something
with tf.Session() as sess:
    sess.run(variable.initializer)
    try:
        while True:
            print(sess.run(model)) # (0,3,6)
    except:
        pass

dataset = tf.data.Dataset.range(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()  

### HOW TO DO THIS THING?
with tf.Session() as sess:
    sess.run(variable.initializer) #This would be a saver restore operation, normally...
    try:
        while True:
            print(sess.run(model)) # (0,3)... hopefully
    except:
        pass

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    我不相信这是可能的。您要求更改计算图本身,这在 tensorflow 中是不允许的。我没有自己解释,而是发现这篇文章中接受的答案在解释这一点时特别清楚Is it possible to modify an existing TensorFlow computation graph?

    现在,就是说,我认为有一种相当简单/干净的方法来完成您所寻求的。本质上,您想要重置图形并重建 Dataset 部分。当然,您想重用代码的model 部分。因此,只需将该模型放在一个类或函数中以允许重用。基于您的代码构建的一个简单示例:

    # the part of the graph you want to reuse
    def get_model(next_element):
        variable = tf.Variable(3,dtype=tf.int64)
        return variable*next_element
    
    # the first graph you want to build
    tf.reset_default_graph()
    
    # the part of the graph you don't want to reuse
    dataset = tf.data.Dataset.range(3)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    
    # reusable part
    model = get_model(next_element)
    
    #pretend like this is me training my model, or something
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        try:
            while True:
                print(sess.run(model)) # (0,3,6)
        except:
            pass
    
    # now the second graph
    tf.reset_default_graph()
    
    # the part of the graph you don't want to reuse
    dataset = tf.data.Dataset.range(2)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()  
    
    # reusable part
    model = get_model(next_element)
    
    ### HOW TO DO THIS THING?
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        try:
            while True:
                print(sess.run(model)) # (0,3)... hopefully
        except:
            pass
    

    最后提示:您还会在这里和那里看到一些对tf.contrib.graph_editor docs here 的引用。他们特别说您无法使用 graph_editor 完成您想要的(请参阅该链接:“这是您不能做的示例”;但您可以非常接近)。尽管如此,这不是一个好习惯。他们有充分的理由只添加图表,我认为我建议的上述方法是完成您所寻求的更清洁的方法。

    【讨论】:

    • 这就是我一直在做的事情,并且正在发挥作用。当我问这个问题时,我在恢复变量时遇到了问题,因为我没有重置默认图表(呃!)。这绝对是要走的路。
    【解决方案2】:

    我建议的一种方法是使用place_holders,然后使用tf.data.dataset,这会使事情变慢。因此,您将拥有以下内容:

    train_data = tf.placeholder(dtype=tf.float32, shape=[None, None, 1]) # just an example
    # Then add the tf.data.dataset here
    train_data = tf.data.Dataset.from_tensor_slices(train_data).shuffle(10000).batch(batch_size)
    

    现在在会话中运行图表时,您必须使用占位符输入数据。因此,您可以随意喂食...

    希望这会有所帮助!

    【讨论】:

      猜你喜欢
      • 2018-04-02
      • 2020-04-09
      • 1970-01-01
      • 2020-05-31
      • 1970-01-01
      • 2019-10-26
      • 2018-12-10
      • 2022-10-15
      • 2018-08-16
      相关资源
      最近更新 更多