【问题标题】:How can I cross-validate by Pytorch and Optuna如何通过 Pytorch 和 Optuna 进行交叉验证
【发布时间】:2020-08-03 06:00:15
【问题描述】:

我想对官方 Optuna 和基于 pytorch 的示例代码 (https://github.com/optuna/optuna/blob/master/examples/pytorch_simple.py) 使用交叉验证。

我想过拆分数据进行交叉验证并尝试对每个折叠进行参数调整,但似乎无法获得每个参数的平均准确度,因为在 study.trials_dataframe() 中可以检查的参数各不相同时间。

【问题讨论】:

  • 简短回答:Optuna 的贝叶斯过程是交叉验证试图近似的。如果可能,请查看此答案并在此处发表评论;我认为此时无需交叉发布:stats.stackexchange.com/a/491268/272731

标签: pytorch optuna


【解决方案1】:

我认为我们需要评估所有折叠并计算目标函数内的平均值。我创建了一个example notebook,所以请看一下。

在笔记本中,我稍微修改了objective 函数以传递带有参数的数据集,并添加了一个包装函数objective_cv 以使用拆分数据集调用objective 函数。然后,我优化了objective_cv 而不是objective 函数。

def objective(trial, train_loader, valid_loader):

    # Remove the following line.
    # train_loader, valid_loader = get_mnist()

    ...

    return accuracy


def objective_cv(trial):

    # Get the MNIST dataset.
    dataset = datasets.MNIST(DIR, train=True, download=True, transform=transforms.ToTensor())

    fold = KFold(n_splits=3, shuffle=True, random_state=0)
    scores = []
    for fold_idx, (train_idx, valid_idx) in enumerate(fold.split(range(len(dataset)))):
        train_data = torch.utils.data.Subset(dataset, train_idx)
        valid_data = torch.utils.data.Subset(dataset, valid_idx)

        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=BATCHSIZE,
            shuffle=True,
        )
        valid_loader = torch.utils.data.DataLoader(
            valid_data,
            batch_size=BATCHSIZE,
            shuffle=True,
        )

        accuracy = objective(trial, train_loader, valid_loader)
        scores.append(accuracy)
    return np.mean(scores)


study = optuna.create_study(direction="maximize")
study.optimize(objective_cv, n_trials=20, timeout=600)

【讨论】:

    猜你喜欢
    • 2019-04-20
    • 2020-03-18
    • 2019-09-20
    • 2020-07-08
    • 2013-12-13
    • 2023-02-05
    • 2020-04-27
    • 1970-01-01
    • 2022-01-25
    相关资源
    最近更新 更多