【发布时间】:2020-10-28 11:01:41
【问题描述】:
我想从我的原始图像中提取补丁以将它们用作 CNN 的输入。
经过一番研究,我找到了一种提取补丁的方法
tensorflow.compat.v1.extract_image_patches.
由于这些需要重新整形为“图像格式”,我实现了一个 reshape_image_patches 方法来重新整形并将重新整形的补丁存储在一个数组中。
image_patches2 = []
def reshape_image_patches(image_patches, sess, ksize_rows, ksize_cols):
a = sess.run(tf.shape(image_patches))
nr, nc = a[1], a[2]
for i in range(nr):
for j in range(nc):
patch = tf.reshape(image_patches[0,i,j,], [ksize_rows, ksize_cols, 3])
image_patches2.append(patch)
return image_patches2
如何将它与 Keras 生成器结合使用,以使这些补丁成为我的 CNN 的输入?
编辑 1:
我已经尝试过Load tensorflow images and create patches中的方法
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
dataset = tf.keras.preprocessing.image_dataset_from_directory(
<directory>,
label_mode=None,
seed=1,
subset='training',
validation_split=0.1,
image_size=(900, 900))
get_patches = lambda x: (tf.reshape(
tf.image.extract_patches(
x,
sizes=[1, 16, 16, 1],
strides=[1, 8, 8, 1],
rates=[1, 1, 1, 1],
padding='VALID'), (111*111, 16, 16, 3)))
dataset = dataset.map(get_patches)
fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images = next(iter(dataset))
for index, image in enumerate(images):
ax = plt.subplot(2, 2, index + 1)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(image)
plt.show()
在线:images = next(iter(dataset)) 我收到错误:InvalidArgumentError: Input to reshape is a tensor with 302800896 values, but the requested shape has 9462528 [[{{节点重塑}}]]
有人知道如何解决这个问题吗?
【问题讨论】:
-
这能回答你的问题吗? Load tensorflow images and create patches
-
谢谢!它看起来比我的方法更方便。如果我使用生成器,我仍然如何使用 get_patches 的一部分。根据我对生成器的理解,图像并没有在开始时全部加载,因此我不能在它们上使用 get_patches 对吗?
标签: tensorflow machine-learning keras neural-network conv-neural-network