【问题标题】:Creating `input_fn` from iterator从迭代器创建`input_fn`
【发布时间】:2018-11-12 12:33:54
【问题描述】:

大多数教程都关注整个训练数据集适合内存的情况。但是,我有一个迭代器,它充当(特征、标签)元组的无限流(动态地廉价创建它们)。

在为张量流estimator 实现input_fn 时,我可以从迭代器返回一个实例吗

def input_fn():
   (feature_batch, label_batch) = next(it)
   return tf.constant(feature_batch), tf.constant(label_batch)

或者input_fn 是否必须在每次调用时返回相同的(特征、标签)元组?

此外,这个函数在训练期间被多次调用,我希望它像下面的伪代码一样:

for i in range(max_iter):
   learn_op(input_fn())

【问题讨论】:

  • 我使用Dataset.from_generator 并在生成器中循环遍历迭代器。

标签: python tensorflow


【解决方案1】:

input_fn 的参数在整个训练过程中使用,但函数本身只调用一次。因此,创建一个复杂的 input_fn 并不仅仅像 tutorial 中解释的那样返回一个常量数组。

Tensorflow 为 numpypanda 数组提出了两个此类非平凡的 input_fn 示例,但它们从内存中的数组开始,因此这对您的问题没有帮助。

您还可以通过上面的链接查看他们的代码,了解他们如何实现高效的非平凡input_fn,但您可能会发现它需要更多您想要的代码。

如果您愿意使用 Tensorflow 的低级接口,恕我直言,事情会更简单、更灵活。有一个tutorial 可以满足大多数需求,并且建议的解决方案很容易(-er)实施。

特别是,如果您已经有一个迭代器,如您在问题中描述的那样返回数据,那么使用占位符(上一个链接中的“Feeding”部分)应该很简单。

【讨论】:

  • 我原以为从迭代器/可迭代对象中提供网络是标准用例,而不是例外。
【解决方案2】:

我发现了一个将generator 转换为input_fn 的拉取请求: https://github.com/tensorflow/tensorflow/pull/7045/files

相关部分是

  def _generator_input_fn():
    """generator input function."""
    queue = feeding_functions.enqueue_data(
      x,
      queue_capacity,
      shuffle=shuffle,
      num_threads=num_threads,
      enqueue_size=batch_size,
      num_epochs=num_epochs)

    features = (queue.dequeue_many(batch_size) if num_epochs is None
                else queue.dequeue_up_to(batch_size))
    if not isinstance(features, list):
      features = [features]
    features = dict(zip(input_keys, features))
    if target_key is not None:
      if len(target_key) > 1:
        target = {key: features.pop(key) for key in target_key}
      else:
        target = features.pop(target_key[0])
      return features, target
    return features
  return _generator_input_fn

【讨论】:

    【解决方案3】:
    from tensorflow.contrib.learn.python.learn.learn_io import generator_io
    import numpy as np
    
    # define generator
    def generator():
        for index in range(2):
            yield {'a': np.ones(1) * index,'b': np.ones(1) * index + 32,'label': np.ones(1) * index - 32}
    
    input_fn = generator_io.generator_input_fn(generator, target_key='label', batch_size=2, shuffle=False, num_epochs=1)
    features, target = input_fn()
    

    参考测试用例https://github.com/tensorflow/tensorflow/pull/7045/files

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2016-07-24
      • 1970-01-01
      • 2018-05-20
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2010-10-09
      相关资源
      最近更新 更多