【问题标题】:Keras generator and fit_generator, how to build the generator to avoid 'function shape' errorKeras 生成器和 fit_generator,如何构建生成器以避免“函数形状”错误
【发布时间】:2019-08-20 02:20:38
【问题描述】:

我正在为 Keras 构建一个生成器,以便能够加载我的数据集图像,因为它对我的 ram 来说有点大。

我是这样构建生成器的:

# import the necessary packages
import tensorflow
from tensorflow import keras
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import pandas as pd
from tqdm import tqdm

#loading
path_to_txt = "/content/test/leafsnap-dataset/leafsnap-dataset- 
images_improved.txt"
df = pd.read_csv(path_to_txt ,sep='\t')
arr = np.array(df)
#epochs and steps:
NUM_TRAIN_IMAGES = 0
NUM_EPOCHS = 30

def image_generator(arr, bs, mode="train", aug=None):
  while True:
    images = []
    labels = []
    for row in arr:
      if len(images) < bs:
        img = (cv2.resize(cv2.imread("/content/test/leafsnap-dataset/" + 
        row[0]),(224,224)))
        images.append(img)
        labels.append([row[2]])
        NUM_TRAIN_IMAGES += 1
      else:
        break


  if aug is not None:
    (images, labels) = next(aug.flow(np.array(images),labels, 
     batch_size=bs))

  obj = OneHotEncoder()
  values = obj.fit_transform(labels).toarray()

  yield (np.array(images), labels)

然后我从 Sequential 模型中调用 fit_generator(cnn 一直工作,直到出现 OOM 错误)

#create the augmentation function:
 aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
    width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
    horizontal_flip=True, fill_mode="nearest")

#create the generator:
gen = image_generator(arr, bs = 32, mode = "train", aug = aug)

history = model.fit_generator(image_generator,
    steps_per_epoch = NUM_TRAIN_IMAGES,
    epochs = NUM_EPOCHS)

从这里,我得到这个错误:

# Create generator from NumPy or EagerTensor Input.
--> 377   num_samples = int(nest.flatten(data)[0].shape[0])
378   if batch_size is None:
379     raise ValueError('You must specify `batch_size`')
AttributeError: 'function' object has no attribute 'shape'

【问题讨论】:

  • 首先,您的生成器函数内存效率不高。因为您首先加载所有图像。您应该遍历图像文件和内部循环 yield np.array.

标签: python tensorflow keras


【解决方案1】:

我在这里看到两个主要错误。

首先,您的生成器函数内存效率不高。因为您首先加载所有图像(while 循环)。您应该遍历图像文件并在循环内部产生带有标签的图像的 np.array。

其次,当您应该使用它返回的对象 - gen 时,您将生成器函数名称传递给 fit_generator。

【讨论】:

  • 天哪,看了一百遍代码都没看到函数名和对象名的错误。至于优化,是的,我是这个概念的新手,我现在将对其进行优化。谢谢。
  • 我假设 yield 语句在某个循环内?如果不是,这个生成器将与没有生成器的内存效率相同
  • 实际上我认为它不在循环内,因为您不会遍历所有图像,而只是直到达到 bs 大小。但无论如何我的实现都是错误的。
猜你喜欢
  • 2018-11-23
  • 1970-01-01
  • 2021-11-24
  • 2019-04-20
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多