【问题标题】:Different predictions for the same data相同数据的不同预测
【发布时间】:2021-12-08 09:54:28
【问题描述】:

我使用 Deeplearning4j 对设备名称进行分类。我用 495 个类别标记了大约 50,000 个项目,并使用这些数据来训练神经网络。

也就是说,作为输入,我提供了一组由 0 和 1 组成的向量(50,000),以及每个向量的预期类(0 到 494)。

我使用 IrisClassifier 示例作为代码的基础。

我将训练好的模型保存到一个文件中,现在我可以使用它来预测设备类别。

例如,我尝试使用我用于训练的相同数据(50,000 个项目)进行预测,并将预测与我对这些数据的标记进行比较。

结果非常好,神经网络的误差约为 1%。

之后,我尝试使用这 50,000 条记录中的前 100 个向量进行预测,并删除其余 49900 个。

而对于这 100 个向量,与 50,000 个组成的相同 100 个向量的预测相比,预测是不同的。

也就是说,我们提供给训练好的模型的数据越少,预测误差就越大。

即使是完全相同的向量。

为什么会这样?

我的代码。

培训:

 //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new File(args[0])));

//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 3331;
int numClasses = 495;
int batchSize = 4000;

// DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSetIterator iterator = new RecordReaderDataSetIterator.Builder(recordReader, batchSize).classification(labelIndex, numClasses).build();

List<DataSet> trainingData = new ArrayList<>();
List<DataSet> testData = new ArrayList<>();

while (iterator.hasNext()) {
    DataSet allData = iterator.next();
    allData.shuffle();
    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.8);  //Use 80% of data for training
    trainingData.add(testAndTrain.getTrain());
    testData.add(testAndTrain.getTest());
}

DataSet allTrainingData = DataSet.merge(trainingData);
DataSet allTestData = DataSet.merge(testData);

//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(allTrainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(allTrainingData);     //Apply normalization to the training data
normalizer.transform(allTestData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

long seed = 6;
int firstHiddenLayerSize = labelIndex/6;
int secondHiddenLayerSize = firstHiddenLayerSize/4;

//log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed)
        .activation(Activation.TANH)
        .weightInit(WeightInit.XAVIER)
        .updater(new Sgd(0.1))
        .l2(1e-4)
        .list()
        .layer(new DenseLayer.Builder().nIn(labelIndex).nOut(firstHiddenLayerSize)
                .build())
        .layer(new DenseLayer.Builder().nIn(firstHiddenLayerSize).nOut(secondHiddenLayerSize)
                .build())
        .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
                .nIn(secondHiddenLayerSize).nOut(numClasses).build())
        .build();

//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();

//record score once every 100 iterations
model.setListeners(new ScoreIterationListener(100));

for(int i=0; i<5000; i++ ) {
    model.fit(allTrainingData);
}

//evaluate the model on the test set
Evaluation eval = new Evaluation(numClasses);

INDArray output = model.output(allTestData.getFeatures());

eval.eval(allTestData.getLabels(), output);
log.info(eval.stats());

// Save the Model
File locationToSave = new File(args[1]);
model.save(locationToSave, false);

预测:

// Open the network file
File locationToLoad = new File(args[0]);
MultiLayerNetwork model = MultiLayerNetwork.load(locationToLoad, false);
model.init();

// First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';

// Data to predict
CSVRecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);  //skip no lines at the top - i.e. no header
recordReader.initialize(new FileSplit(new File(args[1])));

//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int batchSize = 4000;

DataSetIterator iterator = new RecordReaderDataSetIterator.Builder(recordReader, batchSize).build();

List<DataSet> dataSetList = new ArrayList<>();

while (iterator.hasNext()) {
    DataSet allData = iterator.next();
    dataSetList.add(allData);
}

DataSet dataSet = DataSet.merge(dataSetList);

DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(dataSet);
normalizer.transform(dataSet);

// Now use it to classify some data
INDArray output = model.output(dataSet.getFeatures());

// Save result
BufferedWriter writer = new BufferedWriter(new FileWriter(args[2], true));
for (int i=0; i<output.rows(); i++) {
    writer
            .append(output.getRow(i).argMax().toString())
            .append(" ")
            .append(String.valueOf(i))
            .append(" ")
            .append(output.getRow(i).toString())
            .append('\n');
}
writer.close();

【问题讨论】:

    标签: machine-learning deeplearning4j dl4j


    【解决方案1】:

    确保在模型旁边按如下方式保存规范化器:

    import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer; 
    NormalizerSerializer SUT = NormalizerSerializer.getDefault(); 
    
    SUT.write(normalizer,new File("outputFile.bin")); 
    
    NormalizeStandardize restored = SUT.restore(new File("outputFile.bin");
    

    【讨论】:

      【解决方案2】:

      您需要对训练和预测使用相同的归一化数据。否则在转换数据时会使用错误的统计信息。

      您当前的操作方式导致数据看起来与训练数据非常不同,这就是您得到如此不同结果的原因。

      【讨论】:

      • 我使用 DataNormalization normalizer = new NormalizerStandardize();在训练和预测中。或者你的意思是我需要在训练中保存规范化器的数据并在预测中加载?
      • 归一化器具有从训练数据中学习的均值和方差。您需要在调用 fit 后保存归一化器并使用这些统计数据加载它,以便正确预处理数据。
      • 我是否理解正确,我需要在保存模型的同时保存归一化器?但是save方法只有两个参数:model.save(File, saveUpdater);您能否提供有关如何正确保存和加载规范化器的示例代码的链接?谢谢!
      猜你喜欢
      • 2019-09-04
      • 2018-11-24
      • 2021-09-06
      • 1970-01-01
      • 1970-01-01
      • 2021-01-27
      • 2019-02-21
      • 1970-01-01
      • 2021-02-28
      相关资源
      最近更新 更多