【发布时间】:2021-09-27 05:08:27
【问题描述】:
我正在将我的 tensorflow 代码转换为 pytorch-lightning 代码。我无法找到如何在 pytorch-lightning 中使用交叉验证。他们无论如何都要在闪电数据模块中做到这一点。 我将 tensorflow 代码保存在使用 sklearn 实现交叉验证的下方。
folds = RepeatedStratifiedKFold(n_splits = 5, n_repeats = 1)
for train_index, test_index in folds.split(left_input, targets):
left_input_cv, left_input_test, targets_cv, targets_test = left_input[train_index], left_input[test_index], targets[train_index], targets[test_index]
right_input_cv, right_input_test = right_input[train_index], right_input[test_index]
【问题讨论】:
-
您的问题没有指向任何关于 tf 或 pytorch 的 ml 模块。分层是在模型训练的数据准备步骤中。
-
是的,但是在 Pytorch Lightning 中,准备数据集进入
LightningModule。这就是我问的原因
标签: python scikit-learn