【问题标题】:How does Tensorflow's DirectoryIterator work?Tensorflow 的 DirectoryIterator 是如何工作的?
【发布时间】:2020-02-17 17:45:00
【问题描述】:

我习惯于使用model.fix(train_data,train_labels, epochs=10) 之类的东西,我使用 glob 将一个充满图像的文件夹读入 RAM。我想在训练发生时直接从硬盘读取。我发现:

https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/DirectoryIterator

只有我不知道它是如何工作的。我已经在互联网上搜索了更多帮助,然后链接了文档,但我没有找到任何帮助。我在 DirectoryIterator 中有标签和目录。我只是不知道如何将 DirectoryIterator 输入到我的模型中?

代码显示了我到目前为止所做的事情。我还尝试使用 tensorflow 会话并将 DirectoryIterator 作为 feed_dict 提供。代码很乱,一直在尝试这个那个。在代码中,我尝试使用 fit_generator 来适应 DirectoryIterator。

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras
import cv2 as ocv
import glob
import matplotlib.pyplot as plt
from tensorflow import image
import glob

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Plot inline
%matplotlib inline

# Load an color image in 1-colour 0-grayscale -1-bw
img = ocv.imread('C:/Users/ew/Documents/Python Scripts/Noodles/my.png',1)
RGB_im = ocv.cvtColor(img, ocv.COLOR_BGR2RGB)
img.shape
plt.imshow(RGB_im)

cv_img = []
for img in glob.glob("C:\\Users\\EW\\pictures\\Noodles\\Banana\\*.jpg"):
    cv_img.append(img)
    #n= ocv.imread(img)
    #cv_img.append(n)
print(cv_img[1])

image_data_generator = keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
                                       samplewise_center=False,
                                       featurewise_std_normalization=False,
                                       samplewise_std_normalization=False,
                                       zca_whitening=False, zca_epsilon=1e-06,
                                       rotation_range=0,
                                       width_shift_range=0.0,
                                       height_shift_range=0.0,
                                       brightness_range=None,
                                       shear_range=0.0,
                                       zoom_range=0.0,
                                       channel_shift_range=0.0,
                                       fill_mode='nearest',
                                       cval=0.0,
                                       horizontal_flip=False,
                                       vertical_flip=False,
                                       rescale=None,
                                       preprocessing_function=None,
                                       data_format='channels_last',
                                       validation_split=0.3,
#                                       interpolation_order=1,
                                       dtype='float32')

noodle_data = directory = "C:\\Users\\EW\\pictures\\Noodles\\"
image_set = keras.preprocessing.image.DirectoryIterator(directory,
    image_data_generator,
    target_size=(256, 256),
    color_mode='rgb',
    classes=None,
    class_mode='categorical',
    batch_size=32,
    shuffle=True,
    seed=None,
    data_format=None,
    save_to_dir=None,
    save_prefix='',
    save_format='png',
    follow_links=False,
    subset=None,
    interpolation='nearest',
    dtype=None)

model = Sequential()

#add model layers
model.add(Dense(10, activation='relu', input_shape=(256,256)))
model.add(Dense(10, activation='relu'))
model.add(Dense(1))

model.fit_generator(noodle_data , steps_per_epoch=16, validation_data=val_it, validation_steps=8)
---> 12 model.fit_generator(prawn_data , steps_per_epoch=16, validation_data=prawn_data, validation_steps=8)
AttributeError: 'str' object has no attribute 'shape'

【问题讨论】:

  • 你将noodle_data传递给fit_generator,我认为你应该传递image_set

标签: tensorflow keras tensorflow-datasets


【解决方案1】:

我认为你应该这样写:

image_data_generator = keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
                                   ...,
                                   dtype='float32')

directory = "C:\\Users\\EW\\pictures\\Noodles\\"

image_set = image_data_generator.flow_from_directory(directory,
                                                     target_size=(256, 256),
                                                     color_mode='rgb',
                                                     ...,
                                                     dtype=None)

因此,您调用 ImageDataGenerator() 的一个实例,该实例名为 image_data_generator,并使用其方法 flow_from_directory() 从目录中读取。而且您不应该将 image_data_generator 作为 flow_from_directory 中的参数传递。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2020-05-04
    • 1970-01-01
    • 2017-09-10
    • 2017-12-21
    • 2018-04-12
    • 2017-06-22
    • 1970-01-01
    • 2018-01-04
    相关资源
    最近更新 更多