一、概述

 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断。思路很简单,就是写一个自定义的数据处理通道,输入为文件名,输出为float数字,里面保存的是像素信息。

 样本包括6万张训练图片和1万张测试图片,图片为灰度图片,分辨率为20*20 。train_tags.tsv文件对每个图片的数值进行了标记,如下:

机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)

  

二、源码

 全部代码: 

namespace MulticlassClassification_Mnist
{
    class Program
    {
        //Assets files download from:https://gitee.com/seabluescn/ML_Assets
        static readonly string AssetsFolder = @"D:\StepByStep\Blogs\ML_Assets\MNIST";
        static readonly string TrainTagsPath = Path.Combine(AssetsFolder, "train_tags.tsv");
        static readonly string TrainDataFolder = Path.Combine(AssetsFolder, "train");
        static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "SDCA-Model.zip");

        static void Main(string[] args)
        {
            MLContext mlContext = new MLContext(seed: 1);
          
            TrainAndSaveModel(mlContext);
            TestSomePredictions(mlContext);

            Console.WriteLine("Hit any key to finish the app");
            Console.ReadKey();
        }

        public static void TrainAndSaveModel(MLContext mlContext)
        {
            // STEP 1: 准备数据
            var fulldata = mlContext.Data.LoadFromTextFile<InputData>(path: TrainTagsPath, separatorChar: '\t', hasHeader: false);
            var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.1);
            var trainData = trainTestData.TrainSet;
            var testData = trainTestData.TestSet;

            // STEP 2: 配置数据处理管道        
            var dataProcessPipeline = mlContext.Transforms.CustomMapping(new LoadImageConversion().GetMapping(), contractName: "LoadImageConversionAction")
               .Append(mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue))
               .Append(mlContext.Transforms.NormalizeMeanVariance( outputColumnName: "FeaturesNormalizedByMeanVar", inputColumnName: "ImagePixels"));


            // STEP 3: 配置训练算法 (using a maximum entropy classification model trained with the L-BFGS method)
            var trainer = mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "Label", featureColumnName: "FeaturesNormalizedByMeanVar");
            var trainingPipeline = dataProcessPipeline.Append(trainer)
                 .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictNumber", "Label"));


            // STEP 4: 训练模型使其与数据集拟合           
            ITransformer trainedModel = trainingPipeline.Fit(trainData);          

            // STEP 5:评估模型的准确性           
            var predictions = trainedModel.Transform(testData);
            var metrics = mlContext.MulticlassClassification.Evaluate(data: predictions, labelColumnName: "Label", scoreColumnName: "Score");
            PrintMultiClassClassificationMetrics(trainer.ToString(), metrics);
          
            // STEP 6:保存模型            
            mlContext.Model.Save(trainedModel, trainData.Schema, ModelPath);           
        }

        private static void TestSomePredictions(MLContext mlContext)
        {
            // Load Model           
            ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema);

            // Create prediction engine 
            var predEngine = mlContext.Model.CreatePredictionEngine<InputData, OutPutData>(trainedModel);
          
            DirectoryInfo TestFolder = new DirectoryInfo(Path.Combine(AssetsFolder, "test"));           
            foreach(var image in TestFolder.GetFiles())
            {
                count++;

                InputData img = new InputData()
                {
                    FileName = image.Name
                };
                var result = predEngine.Predict(img);
               
                Console.WriteLine($"Current Source={img.FileName},PredictResult={result.GetPredictResult()}");                
            }
        }       
    }

    class InputData
    {
        [LoadColumn(0)]
        public string FileName;

        [LoadColumn(1)]
        public string Number;

        [LoadColumn(1)]
        public float Serial;       
    }

    class OutPutData : InputData
    {
        public float[] Score;
        public int GetPredictResult()
        {
            float max = 0;
            int index = 0;
            for (int i = 0; i < Score.Length; i++)
            {
                if (Score[i] > max)
                {
                    max = Score[i];
                    index = i;
                }
            }
            return index;
        }       
    }   
}
View Code

相关文章:

  • 2021-05-15
  • 2022-12-23
  • 2021-06-08
  • 2022-12-23
  • 2021-06-06
  • 2021-07-26
  • 2021-12-04
  • 2021-07-13
猜你喜欢
  • 2022-02-18
  • 2022-12-23
  • 2021-07-12
  • 2022-02-10
  • 2021-05-20
  • 2022-01-01
  • 2021-12-04
相关资源
相似解决方案