【问题标题】:Tensorflow Dataset in predict() method throws errorpredict() 方法中的 Tensorflow 数据集抛出错误
【发布时间】:2021-07-18 20:01:20
【问题描述】:

我目前正在尝试使用 Tensorflow 数据集作为 TensorFlow 模型的 predict() 函数的输入。由于某种原因,我总是收到错误,详细信息如下。

我有以下生成器:

def create_dataset(Q, unary_potentials, features_1, features_2, features_3, kernel_size, depth_image):
    height, width, _ = np.shape(unary_potentials)
    Q = np.pad(Q, [[kernel_size, kernel_size], [kernel_size, kernel_size], [0, 0]])
    unary_potentials = np.pad(unary_potentials, [[kernel_size, kernel_size], [kernel_size, kernel_size], [0, 0]])
    features_1 = np.pad(features_1, [[kernel_size, kernel_size], [kernel_size, kernel_size], [0, 0]])
    features_2 = np.pad(features_2, [[kernel_size, kernel_size], [kernel_size, kernel_size], [0, 0]])
    features_3 = np.pad(features_3, [[kernel_size, kernel_size], [kernel_size, kernel_size], [0, 0]])
    for y in range(kernel_size, height+kernel_size):
        print(f"y: {y}")
        Q_rows = Q[y-kernel_size:y+kernel_size+1]
        unary_rows = unary_potentials[y-kernel_size:y+kernel_size+1]
        f_1_rows = features_1[y-kernel_size:y+kernel_size+1]
        f_2_rows = features_2[y-kernel_size:y+kernel_size+1]
        f_3_rows = features_3[y-kernel_size:y+kernel_size+1]
        for x in range(kernel_size, width+kernel_size):
            if depth_image[y-kernel_size][x-kernel_size] > 0.0001:
                yield  f_1_rows[kernel_size][x], f_2_rows[kernel_size][x], f_3_rows[kernel_size][x],\
                       unary_rows[:, x-kernel_size:x+kernel_size+1], Q_rows[:, x-kernel_size:x+kernel_size+1],\
                       f_1_rows[:, x-kernel_size:x+kernel_size+1], f_2_rows[:, x-kernel_size:x+kernel_size+1],\
                       f_3_rows[:, x-kernel_size:x+kernel_size+1]

我正在从生成器返回一些 numpy 数组。 然后我像这样构建数据集:

dataset = tf.data.Dataset.from_generator(create_dataset,
                                             output_shapes=((3,), (6,), (5,), (kernel_size*2+1, kernel_size*2+1, number_of_surfaces),
                                                            (kernel_size*2+1, kernel_size*2+1, number_of_surfaces),
                                                            (kernel_size*2+1, kernel_size*2+1, 3),
                                                            (kernel_size*2+1, kernel_size*2+1, 6),
                                                            (kernel_size*2+1, kernel_size*2+1, 5)),
                                             output_types=(tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32),
                                             args=(initial_Q, unary_potentials, *features, kernel_size, depth_image)).batch(batch_size)

如果我现在打电话

Q = MFI_NN.predict(dataset)

其中 MFI_NN 是我的模型,我只是收到以下错误:

Traceback (most recent call last):
  File "C:/Users/marc/Desktop/MA/Code/find_planes_MS.py", line 496, in <module>
    test_model_on_image_2(test_indices[0])
  File "C:/Users/marc/Desktop/MA/Code/find_planes_MS.py", line 199, in test_model_on_image_2
    Q = MFI_NN.predict(x=dataset)
  File "C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1629, in predict
    tmp_batch_outputs = self.predict_function(iterator)
  File "C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 871, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\framework\func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\framework\func_graph.py", line 977, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1478 predict_function  *
        return step_function(self, iterator)
    C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1468 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1259 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3417 _call_for_each_replica
        return fn(*args, **kwargs)
    C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1461 run_step  **
        outputs = model.predict_step(data)
    C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1433 predict_step
        x, _, _ = data_adapter.unpack_x_y_sample_weight(data)
    C:\Users\marc\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py:1454 unpack_x_y_sample_weight
        raise ValueError(error_msg)

    ValueError: Data is expected to be in format `x`, `(x,)`, `(x, y)`, or `(x, y, sample_weight)`, found: (<tf.Tensor 'IteratorGetNext:0' shape=(None, 3) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None, 6) dtype=float32>, <tf.Tensor 'IteratorGetNext:2' shape=(None, 5) dtype=float32>, <tf.Tensor 'IteratorGetNext:3' shape=(None, 21, 21, 49) dtype=float32>, <tf.Tensor 'IteratorGetNext:4' shape=(None, 21, 21, 49) dtype=float32>, <tf.Tensor 'IteratorGetNext:5' shape=(None, 21, 21, 3) dtype=float32>, <tf.Tensor 'IteratorGetNext:6' shape=(None, 21, 21, 6) dtype=float32>, <tf.Tensor 'IteratorGetNext:7' shape=(None, 21, 21, 5) dtype=float32>)


Process finished with exit code 1

生成器本身的输出肯定是正确的,因为以下工作正常:

for x in dataset:
    Q = MFI_NN.predict(x)

我觉得我在这里遗漏了一些明显的东西,如果有人能告诉我它是什么,那就太好了。 非常感谢!

【问题讨论】:

    标签: python tensorflow tensorflow2.0


    【解决方案1】:

    来自tf.keras.Model.predict()方法的文档:

    x 输入样本。可能是:

    一个 Numpy 数组(或类似数组),或一个数组列表(如果模型有多个输入)。 TensorFlow 张量或张量列表(如果模型有多个输入)。 一个 tf.data 数据集。 生成器或 keras.utils.Sequence 实例。关于迭代器类型(数据集、生成器、序列)的解包行为的更详细描述在 Model.fit 的类似迭代器的输入部分的解包行为中给出。

    所以,我建议您将dataset 直接传递给MFI_NN.predict(),例如:

        Q = MFI_NN.predict(dataset)  # dataset is a tf.data dataset.
    

    或者如果您出于其他原因需要 for 循环:

    for x in dataset:
        Q = MFI_NN.predict([x])  # [x] is a list of tensors.
    
    

    【讨论】:

    • 感谢您的回答,不幸的是我的问题有一个错误,因为我实际上确实使用了 Q = MFI_NN.predict(dataset) (我只是将错误的内容复制到问题中)。这种情况正是发生错误的情况。
    • 然后编辑您的问题会很有帮助,以便更容易理解您的问题。将 for 循环中的数据集输出作为 list to predict 方法传递是否解决了您的问题?
    • 如果我像我在问题中展示的那样作为 for 循环执行它,它可以工作。但显然我宁愿在数据集上使用普通的 predict() 方法。我不只是在完整数据集上调用 predict() 作为列表的原因是它太大了,所以使用生成器是必要的。现在我只是使用 for 循环,尽管它不太优雅并且可能效率较低。
    猜你喜欢
    • 2021-07-23
    • 1970-01-01
    • 2020-05-21
    • 1970-01-01
    • 1970-01-01
    • 2020-07-20
    • 1970-01-01
    • 2015-06-05
    • 1970-01-01
    相关资源
    最近更新 更多