【发布时间】:2021-03-02 01:52:46
【问题描述】:
我有一个用于训练 KNN 模型的数据集。稍后我想用新的训练数据更新模型。我看到的是更新后的模型只采用了新的训练数据,而忽略了之前训练过的数据。
Vectorizer vec = new DummyVectorizer<Integer>(1, 2).labeled(0);
DatasetTrainer<KNNClassificationModel, Double> trainer = new KNNClassificationTrainer();
KNNClassificationModel model;
KNNClassificationModel modelUpdated;
Map<Integer, Vector> trainingData = new HashMap<Integer, Vector>();
Map<Integer, Vector> trainingDataNew = new HashMap<Integer, Vector>();
Double[][] data1 = new Double[][] {
{0.136,0.644,0.154},
{0.302,0.634,0.779},
{0.806,0.254,0.211},
{0.241,0.951,0.744},
{0.542,0.893,0.612},
{0.334,0.277,0.486},
{0.616,0.259,0.121},
{0.738,0.585,0.017},
{0.124,0.567,0.358},
{0.934,0.346,0.863}};
Double[][] data2 = new Double[][] {
{0.300,0.236,0.193}};
Double[] observationData = new Double[] { 0.8, 0.7 };
// fill dataset (in cache)
for (int i = 0; i < data1.length; i++)
trainingData.put(i, new DenseVector(data1[i]));
// first training / prediction
model = trainer.fit(trainingData, 1, vec);
System.out.println("First prediction : " + model.predict(new DenseVector(observationData)));
// new training data
for (int i = 0; i < data2.length; i++)
trainingDataNew.put(data1.length + i, new DenseVector(data2[i]));
// second training / prediction
modelUpdated = trainer.update(model, trainingDataNew, 1, vec);
System.out.println("Second prediction: " + modelUpdated.predict(new DenseVector(observationData)));
作为输出我得到这个:
First prediction : 0.124
Second prediction: 0.3
这看起来第二个预测只使用了 data2,它必须导致 0.3 作为预测。
模型更新如何工作?如果我必须将 data2 添加到 data1 然后再次在 data1 上进行训练,那么与对所有组合数据进行全新训练相比会有什么不同?
【问题讨论】:
标签: java machine-learning ignite