【问题标题】:how to implement walk forward testing in sklearn?如何在 sklearn 中实现前向测试?
【发布时间】:2015-11-03 23:57:21
【问题描述】:

在 sklearn 中,GridSearchCV 可以将管道作为参数,通过交叉验证找到最佳估计器。但是,通常的交叉验证是这样的:

为了交叉验证时间序列数据,训练和测试数据通常是这样拆分的:

也就是说,测试数据应该总是领先于训练数据。

我的想法是:

  1. 编写我自己的k-fold版本类并将其传递给GridSearchCV,这样我就可以享受管道的便利。问题是让 GridSearchCV 使用指定的训练和测试数据索引似乎很困难。

  2. 写一个新的类GridSearchWalkForwardTest,类似于GridSearchCV,正在研究grid_search.py​​源码,发现有点复杂。

欢迎提出任何建议。

【问题讨论】:

标签: python scikit-learn time-series cross-validation


【解决方案1】:

我的意见是你应该尝试实现你自己的GridSearchWalkForwardTest。我曾经使用 GridSearch 进行培训并自己实现了相同的 GridSearch,但我没有得到相同的结果,尽管我应该这样做。

我最后做的是使用我自己的函数。您可以更好地控制训练和测试集,并且可以更好地控制您训练的参数。

【讨论】:

    【解决方案2】:

    我认为您可以使用Time Series Split 代替您自己的实现,或者作为实现与您描述的完全相同的 CV 方法的基础。

    经过一番挖掘,似乎有人在this PR 的 TimeSeriesSplit 中添加了一个 max_train_size ,这似乎是你想要的。

    【讨论】:

    • 你是对的,walk-forward cross-validation 是 sci-kit learn 的 TimeSeriesSplit 算法。但是如何在 LassoCV 和 ElasticNetCV 等 CV 估计器中选择它作为“cv”对象的选择? KFold、LeaveOneOut、train_test_split 和其他算法属于 sklearn 的 cross_validation 模块,我们可以从中为这些估计器选择一个“cv”对象。但是,TimeSeriesSplit 属于 sklearn 的 model_selection 模块,目前没有选择它。
    【解决方案3】:

    几个月前我做了一些关于这一切的工作。

    您可以在这个问题/答案中查看它:

    Rolling window REVISITED - Adding window rolling quantity as a parameter- Walk Forward Analysis

    【讨论】:

      【解决方案4】:

      我写了一些代码,希望对某人有所帮助。

      'sequence' 是时间序列的周期。我正在训练一个最多 40 个序列的模型,预测 41 个,然后训练最多 41 个以预测 42 个,依此类推......直到最大值。 “数量”是目标变量。然后我所有错误的平均值将成为我的评估指标

      for sequence in range(40, df.sequence.max() + 1):
              train = df[df['sequence'] < sequence]
              test = df[df['sequence'] == sequence]
              X_train, X_test = train.drop(['quantity'], axis=1), test.drop(['quantity'], axis=1)
              y_train, y_test = train['quantity'].values, test['quantity'].values
          
              mdl = LinearRegression()
              mdl.fit(X_train, y_train)
              y_pred = mdl.predict(X_test) 
              error = sklearn.metrics.mean_squared_error(test['quantity'].values, y_pred)
              RMSE.append(error)
      print('Mean RMSE = %.5f' % np.mean(RMSE))
          
      

      【讨论】:

        【解决方案5】:

        利用 sktime TimeSeriesSplit,定义训练和测试大小的固定滚动窗口。注意第一个训练窗口可能包含额外的多余数据(更喜欢保留而不是剪辑):

        def tscv(X, train_size, test_size):
            folds = math.floor(len(X) / test_size)
            tscv = TimeSeriesSplit(n_splits=folds, test_size=test_size)
            splits = []
            for train_index, test_index in tscv.split(X):
                if len(train_index) < train_size:
                    continue
                elif len(train_index) - train_size < test_size and len(train_index) - train_size > 0:
                    pass
                else:
                    train_index = train_index[-train_size:]
                splits.append([train_index, test_index])
            return splits
        

        【讨论】:

          【解决方案6】:

          我使用这个自定义类来创建基于 StratifiedKFold 的不相交分割(可以被 KFold 或其他替代),以便创建以下训练方案:

          |X||V|O|O|O|
          |O|X||V|O|O|
          |O|O|X||V|O|
          |O|O|O|X||V|
          

          X / V 是训练/验证集。 “||”表示在验证集的开头截断了一个间隙(参数 n_gap: int>0),以防止泄漏效应。

          您可以轻松扩展它以获得更长的训练集回顾窗口。

          class StratifiedWalkForward(object):
              
              def __init__(self,n_splits,n_gap):
                  self.n_splits = n_splits
                  self.n_gap = n_gap
                  self._cv = StratifiedKFold(n_splits=self.n_splits+1,shuffle=False)
                  return
              
              def split(self,X,y,groups=None):
                  splits = self._cv.split(X,y)
                  _ixs = []
                  for ix in splits: 
                      _ixs.append(ix[1])
                  for i in range(1,len(_ixs)): 
                      yield tuple((_ixs[i-1],_ixs[i][_ixs[i]>_ixs[i-1][-1]+self.n_gap]))
                      
              def get_n_splits(self,X,y,groups=None):
                  return self.n_splits
          

          请注意,数据集之后可能无法完美分层,原因是使用 n_gap 截断。

          【讨论】:

            猜你喜欢
            • 1970-01-01
            • 2017-09-04
            • 2010-10-22
            • 1970-01-01
            • 2019-08-17
            • 2015-12-20
            • 1970-01-01
            • 1970-01-01
            • 2014-01-13
            相关资源
            最近更新 更多