【问题标题】:tf.data.Dataset feedable iterator for training and inference用于训练和推理的 tf.data.Dataset 可馈送迭代器
【发布时间】:2019-03-04 22:25:54
【问题描述】:

我有一个 TensorFlow 模型,它使用 tf.data.Dataset feedable 迭代器在训练和验证之间切换。两个数据集共享相同的结构,即它们具有特征矩阵和相应的标签向量。为了使用相同的模型和迭代器进行推理(没有标签向量只有特征矩阵),我需要理想地提供一个零标签向量。是否有更高效、更优雅的方式将数据集 API 用于训练(验证)和推理?

在代码中:

training_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
validation_dataset = tf.data.Dataset.from_tensor_slices((X_validation, y_validation))

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
features, labels = iterator.get_next()

特征和标签在模型内部用作输入占位符。 为了在数据集之间切换,我需要为每个数据集创建一个迭代器:

training_iterator = training_dataset.make_initializable_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

然后创建句柄

training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

并使用handle 选择要使用的数据集,例如:

sess.run(next_element, feed_dict={handle: training_handle})

现在,如果我有没有标签的推理数据会怎样?

inference_dataset = tf.data.Dataset.from_tensor_slices(X_inference) # NO y values
inferece_iterator = inference_dataset.make_initializable_iterator()

如果我添加这个迭代器,它会抛出异常,因为“组件数不匹配:预期 2 种类型但得到 1 种。” 有什么建议吗?

这个帖子How to use tf.Dataset design in both training and inferring?和这个问题有关,但是tf.data.Dataset没有解压方法。

解决此问题的最佳做法是什么?

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    如果您的图形代码我假设您正在尝试从数据集中提取标签 y 的值,对吗?在推理时,这可能已经融入到 tensorflow 依赖图中。

    您在这里有几个选择。可能最简单的解决方案是从代码重新创建图形(运行 build_graph() 函数,然后使用 saver.restore(sess, "/tmp/model.ckpt") 之类的东西加载权重)。如果您这样做,您可以重新创建没有标签y 的图形。我假设y 上没有其他依赖项(有时 tensorboard 摘要会添加您需要检查的依赖项)。您的问题现在应该已经解决了。

    但是,既然我已经写了上面的评论(我将保持原样,因为它仍然是有用的信息),我意识到您甚至可能不需要它。在推理时,您不应该在任何地方使用标签(再次检查张量板摘要)。如果您不需要y,则 tensorflow 不应运行任何使用y 的操作。这应该包括不尝试从数据集中提取它们。仔细检查您是否没有要求 tensorflow 在推理时的任何地方使用您的标签。

    【讨论】:

    • 感谢您的回答,但问题不在于我在图中使用y。如果我使用 feed_dict,则使用相同的图表它可以正常工作(仅将 X 作为输入传递)。当我尝试使用迭代器时会引发错误,该迭代器期望底层数据集有两个组件Xy。所以我想知道,是否可以在 production 中使用迭代器和数据集进行训练和推理,推理数据集 only X
    • 在图形运行时,如果该组件未使用,则不应尝试从迭代器中提取第二个组件。至少我希望它是这样的。也许您正在重建图形并尝试从迭代器中提取 2 个值,在这种情况下,一个简单的基于 python 的 if 语句应该可以解决它。如果您明确命名数据集元素,则可能是另一种解决方案。这应该是相当简单的事情,只是某个地方的一个细节。问题是在图形运行时(sess.run)还是在构建图形时(例如在 python 领域)出现?
    • 它发生在图形运行时。您能否提供第二个解决方案的示例,即明确命名数据集元素?
    • 这是一个与命名数据集中元素有关的问题:stackoverflow.com/questions/48471688/… 如果是出现错误的图形,您也可以尝试实现tf.cond,尽管我觉得有一个更清洁的解决方案:tensorflow.org/api_docs/python/tf/cond
    【解决方案2】:

    我认为David Parks提出的第一个解决方案看起来是这样的,我认为比在代码中乱用tf.cond要好。

    import tensorflow as tf
    import numpy as np
    
    def build_model(features, labels=None, train=False):
        linear_model = tf.layers.Dense(units=1)
        y_pred = linear_model(features)
        if train:
            loss = tf.losses.mean_squared_error(labels=labels, predictions=y_pred)
            optimizer = tf.train.GradientDescentOptimizer(1e-4)
            train = optimizer.minimize(loss)
            return train, loss
        else:
            return y_pred
    
    X_train = np.random.random(100).reshape(-1, 1)
    y_train = np.random.random(100).reshape(-1, 1)
    
    training_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    training_dataset = training_dataset.batch(10)
    training_dataset = training_dataset.shuffle(20)
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
    
    features, labels = iterator.get_next()
    training_iterator = training_dataset.make_one_shot_iterator()
    
    train, loss = build_model(features, labels, train=True)
    
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    
    sess = tf.Session()
    training_handle = sess.run(training_iterator.string_handle())
    
    sess.run(init)
    for i in range(10):
        _, loss_value = sess.run((train, loss), feed_dict={handle: training_handle})
        print(loss_value)
    
    saver.save(sess, "tmp/model.ckpt")
    sess.close()
    
    tf.reset_default_graph()
    
    X_test = np.random.random(10).reshape(-1, 1)
    inference_dataset = tf.data.Dataset.from_tensor_slices(X_test)
    inference_dataset = inference_dataset.batch(5)
    
    handle = tf.placeholder(tf.string, shape=[])
    iterator_inference = tf.data.Iterator.from_string_handle(handle, inference_dataset.output_types, inference_dataset.output_shapes)
    
    inference_iterator = inference_dataset.make_one_shot_iterator()
    
    features_inference = iterator_inference.get_next()
    
    y_pred = build_model(features_inference)
    
    saver = tf.train.Saver()
    sess = tf.Session()
    inference_handle = sess.run(inference_iterator.string_handle())
    saver.restore(sess, "tmp/model.ckpt") # Restore variables from disk.
    print(sess.run(y_pred, feed_dict={handle: inference_handle}))
    sess.close()
    

    【讨论】:

      猜你喜欢
      • 2019-07-11
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-06-05
      • 1970-01-01
      • 2019-10-15
      • 2020-06-24
      • 2019-03-26
      相关资源
      最近更新 更多