【问题标题】:Using tf extract_image_patches for input to a CNN?使用 tf extract_image_patches 作为 CNN 的输入?
【发布时间】:2020-10-28 11:01:41
【问题描述】:

我想从我的原始图像中提取补丁以将它们用作 CNN 的输入。 经过一番研究,我找到了一种提取补丁的方法 tensorflow.compat.v1.extract_image_patches.

由于这些需要重新整形为“图像格式”,我实现了一个 reshape_image_patches 方法来重新整形并将重新整形的补丁存储在一个数组中。

image_patches2 = []

def reshape_image_patches(image_patches, sess, ksize_rows, ksize_cols):
    a = sess.run(tf.shape(image_patches))
    nr, nc = a[1], a[2]
    for i in range(nr):
      for j in range(nc):
        patch = tf.reshape(image_patches[0,i,j,], [ksize_rows, ksize_cols, 3])
        image_patches2.append(patch)
    return image_patches2

如何将它与 Keras 生成器结合使用,以使这些补丁成为我的 CNN 的输入?

编辑 1:

我已经尝试过Load tensorflow images and create patches中的方法

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    <directory>,
    label_mode=None,
    seed=1,
    subset='training',
    validation_split=0.1,
    image_size=(900, 900))

get_patches = lambda x: (tf.reshape(
    tf.image.extract_patches(
        x,
        sizes=[1, 16, 16, 1],
        strides=[1, 8, 8, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'),  (111*111, 16, 16, 3)))

dataset = dataset.map(get_patches)

fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images = next(iter(dataset))
for index, image in enumerate(images):
    ax = plt.subplot(2, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image)
plt.show()

在线:images = next(iter(dataset)) 我收到错误:InvalidArgumentError: Input to reshape is a tensor with 302800896 values, but the requested shape has 9462528 [[{{节点重塑}}]]

有人知道如何解决这个问题吗?

【问题讨论】:

  • 这能回答你的问题吗? Load tensorflow images and create patches
  • 谢谢!它看起来比我的方法更方便。如果我使用生成器,我仍然如何使用 get_patches 的一部分。根据我对生成器的理解,图像并没有在开始时全部加载,因此我不能在它们上使用 get_patches 对吗?

标签: tensorflow machine-learning keras neural-network conv-neural-network


【解决方案1】:

tf.reshape 不会改变张量中元素的顺序或总数。错误表明,您正试图将元素总数从 302800896 减少到 9462528 。您在lambda 函数中使用了tf.reshape

在下面的示例中,我重新创建了您的场景,其中我将shape 参数作为2 给定tf.reshape,它不能容纳原始张量的所有元素,因此会引发错误 -

代码 -

%tensorflow_version 2.x
import tensorflow as tf
t1 = tf.Variable([1,2,2,4,5,6])

t2 = tf.reshape(t1, 2)

输出 -

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-3-0ff1d701ff22> in <module>()
      3 t1 = tf.Variable([1,2,2,4,5,6])
      4 
----> 5 t2 = tf.reshape(t1, 2)

3 frames
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: Input to reshape is a tensor with 6 values, but the requested shape has 2 [Op:Reshape]

tf.reshape 应该是这样一种方式,即元素的排列可以改变,但元素的总数必须保持不变。所以解决方法是将形状更改为[2,3] -

代码 -

%tensorflow_version 2.x
import tensorflow as tf
t1 = tf.Variable([1,2,2,4,5,6])

t2 = tf.reshape(t1, [2,3])
print(t2)

输出 -

tf.Tensor(
[[1 2 2]
 [4 5 6]], shape=(2, 3), dtype=int32)

要解决您的问题,请提取您尝试使用tf.reshape 大小的补丁(tf.image.extract_patches)或将tf.reshape 更改为提取补丁的大小。

还建议您研究其他 tf.image 功能,例如 tf.image.central_croptf.image.crop_and_resize

【讨论】:

    猜你喜欢
    • 2020-06-23
    • 2020-10-10
    • 2020-06-30
    • 1970-01-01
    • 2016-03-16
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-01-05
    相关资源
    最近更新 更多