【问题标题】:Is there a way to increase the size of the dataset with labels using data augmentation?有没有办法使用数据增强来增加带有标签的数据集的大小?
【发布时间】:2020-07-14 00:38:34
【问题描述】:

我正在尝试对 Kaggle 的数字识别 dataset 实施逻辑回归。训练集中有 42000 行,我想使用数据增强来增加计数。

我尝试使用 keras 的 ImageDataGenerator 对象

datagen = ImageDataGenerator(  
        rotation_range=30,   
        zoom_range = 0.2,  
        width_shift_range=0.2,         
        height_shift_range=0.2)

datagen.fit(X_train)

但大小保持不变,后来我发现ImageDataGenerator 实际上并没有添加行,而是在训练期间插入了增强数据。 有没有其他工具可以保存或增加相同标签的数据?

【问题讨论】:

  • 大小保持不变是什么意思?你能展示你的完整代码吗?您可能对this 感兴趣。
  • 数据集形状最初是 (42000, 784),运行上述脚本后,它保持不变。我认为它会像 (168000, 784) 一样增长 4 倍,我读到 keras 在训练时会实时创建数据
  • 它保持不变,除非您选择将增强保存在其他地方。但是,我不确定它是否适用于 CSV 数据(请参阅上面的链接)。
  • @NelsonGon,是的,我确实将 csv 行重塑为形状为 (-1, 28, 28, 1) 的 4D 数组,并且该函数运行时没有任何错误,所以我能够将增强数据与标签一起保存?那太好了
  • @NelsonGon,感谢您的输入,我能够将增强数据保存到数组中,一旦我用标签保存它们,我会尽快发布答案

标签: python tensorflow keras


【解决方案1】:

这是我最终保存带有标签的增强数据的方式。我采样了 5 行以获得观看乐趣。当考虑完整数据集时,for 循环可能不是写入数组的最佳方式

#importing data
train = pd.read_csv("train.csv")
X_train = train.drop(labels=["label"], axis=1)
y_train = train.label

#sampling 5 rows and reshaping x to 4D array
x = X_train[0:5].values.reshape(-1,28,28,1)
y = y_train[0:5]

#Augmentation parameters
from keras_preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(  
        rotation_range=30,   
        zoom_range = 0.2,  
        width_shift_range=0.2,  
        height_shift_range=0.2,  
        )  

#using .flow instead of .fit to write to an array
augmented_data = []
num_augmented = 0
batch = 5  # for 5*5 = 25 entries
for X_batch, y_batch in datagen.flow(X_2, y, batch_size=batch, shuffle=False,):
    augmented_data.append(X_batch)
    augmented_labels.append(y_batch)
    num_augmented += 1
    if num_augmented == x.shape[0]:
        break
augmented_data = np.concatenate(augmented_data) #final shape = (25,28,28,1)
augmented_labels = np.concatenate(augmented_labels)


#Lets take a look at augmented images
for index, image in enumerate(augmented_data):
    plt.subplot(5, 5, index + 1)
    plt.imshow(np.reshape(image, (28,28)), cmap=plt.cm.gray)


# reshaping and converting to df
augmented_data_reshaped = augmented_data.reshape(25, 784)
augmented_dataframe = pd.DataFrame(augmented_data_reshaped)
# inserting labels in df
augmented_dataframe.insert(0, "label", augmented_labels)
header = list(train.columns.values)
augmented_dataframe.columns = header
# write
augmented_dataframe.to_csv("augmented.csv")

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2016-05-17
    • 2020-03-15
    • 2019-07-02
    • 1970-01-01
    • 2020-10-28
    • 2022-01-22
    • 2022-08-21
    • 1970-01-01
    相关资源
    最近更新 更多