【发布时间】:2020-12-19 01:41:33
【问题描述】:
我正在使用 Deeplearning4j 和 datavec,并且我有一个代表我所有数据的 DataSetIterator 对象,它是一个时间序列。如何将其拆分为训练和测试迭代器?我检查并弃用了 DataSetIterator 类的方法。谢谢。
【问题讨论】:
标签: java deep-learning training-data deeplearning4j
我正在使用 Deeplearning4j 和 datavec,并且我有一个代表我所有数据的 DataSetIterator 对象,它是一个时间序列。如何将其拆分为训练和测试迭代器?我检查并弃用了 DataSetIterator 类的方法。谢谢。
【问题讨论】:
标签: java deep-learning training-data deeplearning4j
遍历您的 DataSetIterator 并为每个 DataSet 条目创建两个新的 DataSets,分别用于训练和测试。
关键是使用splitTestAndTrain 方法,它接受一个double fractionTrain,它将指定要训练的数据量(其余要测试)。该方法有不同的重载,因此您可以选择最适合您需求的一种。如果您希望将所有训练和测试数据集添加到一个公共迭代器中,您可以将它们存储在两个不同的列表中,稍后再获取它们对应的迭代器。比如:
List<DataSet> trainList = new ArrayList<>();
List<DataSet> testList= new ArrayList<>();
while (yourDataSetIterator.hasNext())
{
DataSet ds = yourDataSetIterator.next();
SplitTestAndTrain splData = ds.splitTestAndTrain(0.5); //half for each
DataSet trainDs = splData.getTrain();
trainList.add(trainDs);
DataSet testDs = splData.getTest();
testList.add(testDs);
(...)
}
Iterator<DataSet> trainIterator = trainList.iterator();
Iterator<DataSet> testIterator = testList.iterator();
由于我不太了解这个库的具体细节,所以这个例子只是创建了“基本”iterators。这可以自定义,因此您可以创建 DataSetIterators。
请注意,您可能还需要在拆分数据集之前对其进行洗牌 (ds.shuffle())。你可以找到一些例子here
如果您希望以特定方式拆分它,您可以标记不同的条目并找到测试数据集的最大索引;然后,调用splitTestAndTrain(int max) 方法,该方法专门针对最大参数拆分数据集。 sortByLabel 方法在这里也有帮助。
Adam Gibson 对 cmets 关于其他机制提出了很好的建议,以便拆分 DataSetIterator,这似乎也是一种“更自然”的方式,DataSetIteratorSplitter。
它提供了getTrainIterator() 和getTestIterator() 方法,它们返回库的特定迭代器DataSetIterator。
【讨论】: