【问题标题】:Stratified train/validation/test split without scikit-learn没有 scikit-learn 的分层训练/验证/测试拆分
【发布时间】:2021-07-26 14:12:58
【问题描述】:

我正在处理 mnist 数据集,其中包含 1797 张图像,表示 0 到 10 位数字。我想将数据集拆分为训练、验证和测试子数据,以便为每个 sub_data 指定相同数量的每个数字。 python中没有sklearn库如何进行分层?

提前感谢您的回答。

【问题讨论】:

    标签: python scikit-learn split


    【解决方案1】:

    要进行分层数据拆分,您需要知道每个数据点属于哪个类。如果你有一个数据点列表和对应的类列表,你可以提取属于某个类的所有点,然后按照输入的比例进行拆分。

    下面是一些实现这个想法的代码: 请注意,您必须添加一些数组来跟踪数据点在循环中被拆分后所属的类。

    import numpy as np
    train, valid, test = 0.6, 0.2, 0.2
    data_points = np.random.rand(1000, 32, 32)
    classes     = np.random.randint(0, 10, size = (1000,))
    class_set   = np.unique(classes)
    data_train  = []
    data_valid  = []
    data_test   = []
    for class_i in class_set:
        data_inds    = np.where(classes==class_i)
        data_i       = data_points[data_inds, ...]
        N_i          = len(data_inds)
        N_i_train    = int(N_i*train)
        N_i_valid    = int(N_i*valid)
        data_train.append(data_i[:N_i_train])
        data_valid.append(data_i[N_i_train:N_i_train+N_i_valid])
        data_test.append(data_i[N_i_train+N_i_valid:])
        
    data_train = np.concatenate(data_train)
    data_valid = np.concatenate(data_valid)
    data_test = np.concatenate(data_test)
    

    【讨论】:

    • 非常感谢您的回答。它工作正常。我刚刚编辑了最后三行。
    • 很抱歉复制意大利面的错误。感谢您的编辑。
    猜你喜欢
    • 2017-04-11
    • 2015-06-08
    • 2018-12-07
    • 2019-04-10
    • 2020-06-22
    • 2021-06-20
    • 2018-04-22
    • 1970-01-01
    • 2020-10-25
    相关资源
    最近更新 更多