【问题标题】:Number of training samples for each class using ImageDataGenerator with validation_split使用 ImageDataGenerator 和 validation_split 的每个类的训练样本数
【发布时间】:2019-05-19 10:23:18
【问题描述】:

使用 Keras,我在 X 中有图像,在 Y 中有标签。然后我做:

 train_datagen = ImageDataGenerator(validation_split = 0.25)

 train_generator = train_datagen.flow(X, Y, subset = 'training')

我的问题是:当train_generator 在模型的fit_generator 中使用时,每个类中有多少样本实际上作为训练样本呈现?

例如,如果我有 3 个类的 1000 个 (x, y) 对:A 类 500 个,B 类 300 个,C 类 200 个,那么有多少来自 A、B 和 C 类的样本 fit_generator 真的视为训练样本?或者我们能做的只有:500*(1.0 - 0.25) 等等?

【问题讨论】:

    标签: python machine-learning keras


    【解决方案1】:

    如果我们检查the relevant part of the source code,我们会发现X(和y)中的最后一个validation_split * num_samples样本将用于验证,而其他样本将用于训练:

    split_idx = int(len(x) * image_data_generator._validation_split)
    
    # ...
    if subset == 'validation':
        x = x[:split_idx]
        x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]
        if y is not None:
            y = y[:split_idx]
    else:
        x = x[split_idx:]
        x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]
        if y is not None:
            y = y[split_idx:]
    

    因此,如果您想确保训练和验证子集中的类比例相同(即 Keras 在使用此功能时不保证这一点),则由您负责。 Keras verifies 唯一的一点是,每个类中至少有一个样本包含在训练和验证子集中:

    if not np.array_equal(
            np.unique(y[:split_idx]),
            np.unique(y[split_idx:])):
        raise ValueError('Training and validation subsets '
                         'have different number of classes after '
                         'the split. If your numpy arrays are '
                         'sorted by the label, you might want '
                         'to shuffle them.')
    

    因此,分层拆分的解决方案(即在训练和验证拆分中保留每个类的样本比例)是使用 sklearn.model_selection.train_test_splitstratify 参数集:

    from sklearn.model_selection import train_test_split
    
    val_split = 0.25
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=val_split, stratify=y)
    
    X = np.concatenate((X_train, X_val))
    y = np.concatenate((y_train, y_val))
    

    现在您可以将validation_split=val_split 传递给ImageDataGenerator,并保证训练和验证子集中的类比例相同。

    【讨论】:

    • 水晶般清澈 - 非常感谢!除了花时间和熟悉软件包之外,还有什么技巧可以放大相应的源代码吗?曾经是一个表面上的、只看文档的用户:-)
    • @willz 好吧,Keras 源代码并不难检查,因为它没有那么多文件。您只需要根据您使用的模块转到相关文件,该文件很容易找到,并阅读源代码。所以没有技巧,至少对我来说:)
    猜你喜欢
    • 2020-03-07
    • 2021-07-03
    • 2021-01-20
    • 1970-01-01
    • 2012-04-18
    • 1970-01-01
    • 2018-06-26
    • 2020-08-24
    • 2021-09-27
    相关资源
    最近更新 更多