【问题标题】:Randomized stratified k-fold cross-validation in scikit-learn?scikit-learn中的随机分层k折交叉验证?
【发布时间】:2013-05-03 04:10:47
【问题描述】:

是否有任何内置方法可以让 scikit-learn 执行 shuffled 分层 k 折交叉验证?这是最常见的 CV 方法之一,我很惊讶找不到内置方法来执行此操作。

我看到cross_validation.KFold() 有一个洗牌标志,但它没有分层。不幸的是cross_validation.StratifiedKFold() 没有这样的选项,cross_validation.StratifiedShuffleSplit() 不会产生不相交的折叠。

我错过了什么吗?这是计划好的吗?

(显然我可以自己实现)

【问题讨论】:

    标签: python machine-learning scikit-learn cross-validation


    【解决方案1】:

    这是我将分层洗牌分成训练和测试集的实现:

    import numpy as np
    
    def get_train_test_inds(y,train_proportion=0.7):
        '''Generates indices, making random stratified split into training set and testing sets
        with proportions train_proportion and (1-train_proportion) of initial sample.
        y is any iterable indicating classes of each observation in the sample.
        Initial proportions of classes inside training and 
        test sets are preserved (stratified sampling).
        '''
    
        y=np.array(y)
        train_inds = np.zeros(len(y),dtype=bool)
        test_inds = np.zeros(len(y),dtype=bool)
        values = np.unique(y)
        for value in values:
            value_inds = np.nonzero(y==value)[0]
            np.random.shuffle(value_inds)
            n = int(train_proportion*len(value_inds))
    
            train_inds[value_inds[:n]]=True
            test_inds[value_inds[n:]]=True
    
        return train_inds,test_inds
    
    
    y = np.array([1,1,2,2,3,3])
    train_inds,test_inds = get_train_test_inds(y,train_proportion=0.5)
    print y[train_inds]
    print y[test_inds]
    

    这段代码输出:

    [1 2 3]
    [1 2 3]
    

    【讨论】:

      【解决方案2】:

      cross_validation.StratifiedKFold 的改组标志已在当前版本 0.15 中引入:

      http://scikit-learn.org/0.15/modules/generated/sklearn.cross_validation.StratifiedKFold.html

      这可以在更新日志中找到:

      http://scikit-learn.org/stable/whats_new.html#new-features

      cross_validation.StratifiedKFold 的随机播放选项。通过杰弗里 布莱克本。

      【讨论】:

        【解决方案3】:

        我想我会发布我的解决方案,以防它对其他人有用。

        from collections import defaultdict
        import random
        def strat_map(y):
            """
            Returns permuted indices that maintain class
            """
            smap = defaultdict(list)
            for i,v in enumerate(y):
                smap[v].append(i)
            for values in smap.values():
                random.shuffle(values)
            y_map = np.zeros_like(y)
            for i,v in enumerate(y):
                y_map[i] = smap[v].pop()
            return y_map
        
        ##########
        #Example Use
        ##########
        skf = StratifiedKFold(y, nfolds)
        sm = strat_map(y)
        for test, train in skf:
            test,train = sm[test], sm[train]
            #then cv as usual
        
        
        #######
        #tests#
        #######
        import numpy.random as rnd
        for _ in range(100):
            y = np.array( [0]*10 + [1]*20 + [3] * 10)
            rnd.shuffle(y)
            sm = strat_map(y)
            shuffled = y[sm]
            assert (sm != range(len(y))).any() , "did not shuffle"
            assert (shuffled == y).all(), "classes not in right position"
            assert (set(sm) == set(range(len(y)))), "missing indices"
        
        
        for _ in range(100):
            nfolds = 10
            skf = StratifiedKFold(y, nfolds)
            sm = strat_map(y)
            for test, train in skf:
                assert (sm[test] != test).any(), "did not shuffle"
                assert (y[sm[test]] == y[test]).all(), "classes not in right position"
        

        【讨论】:

          【解决方案4】:

          据我所知,这实际上是在 scikit-learn 中实现的。

          """ 分层 ShuffleSplit 交叉验证迭代器

          提供训练/测试索引以拆分训练测试集中的数据。

          这个交叉验证对象是 StratifiedKFold 和 ShuffleSplit,返回分层的随机折叠。褶皱 是通过保留每个类的样本百分比来制作的。

          注意:与 ShuffleSplit 策略一样,分层随机拆分 不保证所有折叠都会不同,尽管这是 对于相当大的数据集仍然很有可能。 """

          【讨论】:

          • 正如我在问题中所写,StratifiedShuffleSplit() 不会执行 StratifiedKFold() 的洗牌版本,即在 StratifiedKFold() 之前洗牌。这甚至在您答案的最后一句中提到。 KFold CV 要求折叠之间没有交集,并且它们的并集是整个数据集。
          • 啊,是的,折叠不能保证分离。很抱歉没有读到你的问题的结尾..
          猜你喜欢
          • 2017-01-11
          • 2021-01-25
          • 2015-12-11
          • 2012-01-07
          • 1970-01-01
          • 2016-04-25
          • 2021-06-05
          • 2018-12-07
          • 2017-09-02
          相关资源
          最近更新 更多