【问题标题】:TensorFlow Custom Estimator predict throwing value errorTensorFlow Custom Estimator 预测抛出值错误
【发布时间】:2019-04-17 21:02:53
【问题描述】:

注意:这个问题有一个随附的记录在案的 Colab 笔记本。

TensorFlow 的文档有时可能有很多不足之处。一些用于较低级别 api 的旧文档似乎已被删除,大多数较新的文档都指向使用更高级别的 api,例如 TensorFlow 的 kerasestimators 的子集。如果更高级别的 api 不经常密切依赖它们的较低级别,这将不会有那么大的问题。举个例子,estimators(尤其是使用 TensorFlow Records 时的input_fn)。

以下 Stack Overflow 帖子:

在 TensorFlow / StackOverflow 社区的慷慨帮助下,我们更接近于做 TensorFlow "Creating Custom Estimators" guide 没有做的事情,展示了如何制作一个可以在实践中实际使用的估算器(而不是玩具示例),例如一个:

  • 有一个验证集,用于在性能恶化时提前停止,
  • 从 TF Records 读取,因为许多数据集大于 TensorFlow 推荐的 1Gb 内存,并且
  • 在训练时保存其最佳版本

虽然我对此仍有许多疑问(从将数据编码到 TF 记录的最佳方式,到 serving_input_fn 的确切期望),但有一个问题比其他问题更突出:

如何使用我们刚刚制作的自定义估算器进行预测?

predict 的文档中,它指出:

input_fn:构造特征的函数。预测会一直持续到input_fn 引发输入结束异常(tf.errors.OutOfRangeErrorStopIteration)。有关详细信息,请参阅预制估算​​器。该函数应构造并返回以下之一:

  • tf.data.Dataset 对象:Dataset 对象的输出必须具有与以下相同的约束。
  • features: 一个 tf.Tensor 或一个字符串特征名到 Tensor 的字典。特征由 model_fn 使用。它们应该满足输入对 model_fn 的期望。
  • 一个元组,在这种情况下,第一项被提取为特征。

(也许)最有可能的是,如果使用estimator.predict,他们正在使用内存中的数据,例如密集张量(因为保留的测试集可能会通过evaluate)。

因此,我在随附的 Colab 中创建了一个密集示例,将其包装在 tf.data.Dataset 中,然后调用 predict 以获取 ValueError

如果有人能向我解释我该怎么做,我将不胜感激:

  1. 加载我保存的估算器
  2. 给定一个密集的内存示例,使用估计器预测输出

【问题讨论】:

    标签: python tensorflow machine-learning


    【解决方案1】:
    to_predict = random_onehot((1, SEQUENCE_LENGTH, SEQUENCE_CHANNELS))\
            .astype(tf_type_string(I_DTYPE))
    pred_features = {'input_tensors': to_predict}
    
    pred_ds = tf.data.Dataset.from_tensor_slices(pred_features)
    predicted = est.predict(lambda: pred_ds, yield_single_examples=True)
    
    next(predicted)
    

    ValueError: Tensor("IteratorV2:0", shape=(), dtype=resource) 必须与 Tensor("TensorSliceDataset:0", shape=(), dtype=variant) 来自同一个图。

    当您使用tf.data.Dataset 模块时,它实际上定义了一个独立于模型图的输入图。这里发生的情况是,您首先通过调用 tf.data.Dataset.from_tensor_slices() 创建了一个小图,然后估算器 API 通过自动调用 dataset.make_one_shot_iterator() 创建了第二个图。这 2 个图无法通信,因此会引发错误。

    为了避免这种情况,您永远不应该在 estimator.train/evaluate/predict 之外创建数据集。这就是为什么所有相​​关数据都包含在输入函数中的原因。

    def predict_input_fn(data, batch_size=1):
      dataset = tf.data.Dataset.from_tensor_slices(data)
      return dataset.batch(batch_size).prefetch(None)
    
    predicted = est.predict(lambda: predict_input_fn(pred_features), yield_single_examples=True)
    next(predicted)
    

    现在,图表不是在预测调用之外创建的。

    我还添加了dataset.batch(),因为您的其余代码需要批处理数据并且它会引发形状错误。预取只是加快速度。

    【讨论】:

    • 那么如果我想用批处理数据进行训练,有没有办法在没有批处理数据或不同大小的批处理的情况下进行预测?
    • 几乎所有的 tf 操作都需要批处理数据。例如,卷积需要 4 级输入。因此,除了一些非常罕见的情况外,您永远不想使用未批处理的数据。你可以有 batch_size = 1 虽然。不同大小的批次不会造成任何问题。
    猜你喜欢
    • 1970-01-01
    • 2020-05-21
    • 1970-01-01
    • 1970-01-01
    • 2020-03-16
    • 1970-01-01
    • 2018-11-08
    • 2018-04-24
    • 1970-01-01
    相关资源
    最近更新 更多