【问题标题】:How to use cross validation in pytorch lightning?如何在 pytorch 闪电中使用交叉验证?
【发布时间】:2021-09-27 05:08:27
【问题描述】:

我正在将我的 tensorflow 代码转换为 pytorch-lightning 代码。我无法找到如何在 pytorch-lightning 中使用交叉验证。他们无论如何都要在闪电数据模块中做到这一点。 我将 tensorflow 代码保存在使用 sklearn 实现交叉验证的下方。

folds = RepeatedStratifiedKFold(n_splits = 5, n_repeats = 1)

for train_index, test_index in folds.split(left_input, targets):
    left_input_cv, left_input_test, targets_cv, targets_test = left_input[train_index], left_input[test_index], targets[train_index], targets[test_index]
    right_input_cv, right_input_test = right_input[train_index], right_input[test_index]

【问题讨论】:

  • 您的问题没有指向任何关于 tf 或 pytorch 的 ml 模块。分层是在模型训练的数据准备步骤中。
  • 是的,但是在 Pytorch Lightning 中,准备数据集进入 LightningModule。这就是我问的原因

标签: python scikit-learn


【解决方案1】:

如果使用数据框,你可以做这样的事情

for fold,(train_idx,val_idx) in enumerate(kfold.split(df)):
    print('------------fold no---------{}----------------------'.format(fold))
    train_split=df.loc[train_idx].reset_index(drop=True)
    val_split=df.loc[val_idx].reset_index(drop=True)

    model=OurModel(train_split,val_split,fold)

class OurModel(LightningModule):
    def __init__(self,train_split,val_split,fold):
        super(OurModel,self).__init__()
        
        self.train_split=train_split
        self.val_split=val_split
        self.fold=fold

如果您从图像文件夹中读取,则可以这样做

combined=torchvision.datasets.ImageFolder('../multiclass/train/')

for fold,(train_idx,val_idx) in enumerate(kfold.split(combined)):
    print('------------fold no---------{}----------------------'.format(fold))
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)
    model=OurModel(combined,train_subsampler,val_subsampler)

class OurModel(LightningModule):
    def __init__(self,combined,train_subsampler,test_subsampler,test_data=None):
        super(OurModel,self).__init__()

        self.train_subsampler=train_subsampler
        self.test_subsampler=test_subsampler
        self.combined=combined

    def train_dataloader(self):
        return DataLoader(DataReader(self.combined,aug),,sampler=self.train_subsampler,shuffle=False)

【讨论】:

    【解决方案2】:

    您可以使用下面的 lambda 转换或类似逻辑。我不确定,但这可能会对你有所帮助。这是一个 10 倍交叉验证。这里 cut 是您要裁剪的图像大小(用于排除背景噪音)。

    transform_test = transforms.Compose([
        transforms.TenCrop(cut_size),
        transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))
    

    并将此转换传递给您的test_loader

    【讨论】:

      猜你喜欢
      • 2019-04-20
      • 2020-07-08
      • 2022-12-06
      • 1970-01-01
      • 2015-09-06
      • 2019-09-20
      • 2011-01-24
      • 2016-06-27
      • 1970-01-01
      相关资源
      最近更新 更多