【问题标题】:How to do keras image augmentation using custom data generator?如何使用自定义数据生成器进行 keras 图像增强?
【发布时间】:2020-11-20 12:47:57
【问题描述】:

我正在使用 Keras 自定义生成器,我想对自定义数据生成器返回的数据应用图像增强技术。

我想要这些图像增强技术

ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

这是 keras 自定义生成器

def __data_generation(self, list_IDs_temp):
  'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
  # Initialization
  X = np.empty((self.batch_size, *self.dim, self.n_channels))
  y = np.empty((self.batch_size), dtype=int)

      # Generate data
      for i, ID in enumerate(list_IDs_temp):
          # Store sample
          X[i,] = tfk.preprocessing.image.load_img(self.list_IDs[ID])
    
          # Store class
          y[i] = self.labels[ID]
    
      return X, tkf.utils.to_categorical(y, num_classes=self.n_classes)

【问题讨论】:

    标签: python tensorflow keras tensorflow2.0


    【解决方案1】:

    还没有尝试过,但我想您可以使用ImageDataGenerator 实例中的flow 方法。例如,您的自定义类可能如下所示:

    class CustomDataGenerator(tf.keras.utils.Sequence):
        
        def __init__(self, batch_size=32):
            self.batch_size = batch_size
            self.augmentor = ImageDataGenerator(
                rotation_range=40,
                width_shift_range=0.2,
                height_shift_range=0.2,
                shear_range=0.2,
                zoom_range=0.2,
                horizontal_flip=True,
                fill_mode='nearest'
            )
    
        ...
    
        def __data_generation(self, list_IDs_temp):
          'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
          # Initialization
          X = np.empty((self.batch_size, *self.dim, self.n_channels))
          y = np.empty((self.batch_size), dtype=int)
    
          # Generate data
          for i, ID in enumerate(list_IDs_temp):
              # Store sample
              X[i,] = tfk.preprocessing.image.load_img(self.list_IDs[ID])
        
              # Store class
              y[i] = self.labels[ID]
    
          X_gen = self.augmentor.flow(X, batch_size=self.batch_size, shuffle=False)
          """do not perform shuffle here, the shuffling is performed beforehand
           by your custom class anyway, you just want the transformations to be 
          applied, and above all you want to keep your images synced with the 
          labels.""" 
          
          return next(X_gen), tkf.utils.to_categorical(y, num_classes=self.n_classes)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2016-11-04
      • 2018-09-17
      • 1970-01-01
      • 2020-05-31
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多