一、概述
上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断。思路很简单,就是写一个自定义的数据处理通道,输入为文件名,输出为float数字,里面保存的是像素信息。
样本包括6万张训练图片和1万张测试图片,图片为灰度图片,分辨率为20*20 。train_tags.tsv文件对每个图片的数值进行了标记,如下:
二、源码
全部代码:
namespace MulticlassClassification_Mnist { class Program { //Assets files download from:https://gitee.com/seabluescn/ML_Assets static readonly string AssetsFolder = @"D:\StepByStep\Blogs\ML_Assets\MNIST"; static readonly string TrainTagsPath = Path.Combine(AssetsFolder, "train_tags.tsv"); static readonly string TrainDataFolder = Path.Combine(AssetsFolder, "train"); static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "SDCA-Model.zip"); static void Main(string[] args) { MLContext mlContext = new MLContext(seed: 1); TrainAndSaveModel(mlContext); TestSomePredictions(mlContext); Console.WriteLine("Hit any key to finish the app"); Console.ReadKey(); } public static void TrainAndSaveModel(MLContext mlContext) { // STEP 1: 准备数据 var fulldata = mlContext.Data.LoadFromTextFile<InputData>(path: TrainTagsPath, separatorChar: '\t', hasHeader: false); var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.1); var trainData = trainTestData.TrainSet; var testData = trainTestData.TestSet; // STEP 2: 配置数据处理管道 var dataProcessPipeline = mlContext.Transforms.CustomMapping(new LoadImageConversion().GetMapping(), contractName: "LoadImageConversionAction") .Append(mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue)) .Append(mlContext.Transforms.NormalizeMeanVariance( outputColumnName: "FeaturesNormalizedByMeanVar", inputColumnName: "ImagePixels")); // STEP 3: 配置训练算法 (using a maximum entropy classification model trained with the L-BFGS method) var trainer = mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "Label", featureColumnName: "FeaturesNormalizedByMeanVar"); var trainingPipeline = dataProcessPipeline.Append(trainer) .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictNumber", "Label")); // STEP 4: 训练模型使其与数据集拟合 ITransformer trainedModel = trainingPipeline.Fit(trainData); // STEP 5:评估模型的准确性 var predictions = trainedModel.Transform(testData); var metrics = mlContext.MulticlassClassification.Evaluate(data: predictions, labelColumnName: "Label", scoreColumnName: "Score"); PrintMultiClassClassificationMetrics(trainer.ToString(), metrics); // STEP 6:保存模型 mlContext.Model.Save(trainedModel, trainData.Schema, ModelPath); } private static void TestSomePredictions(MLContext mlContext) { // Load Model ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema); // Create prediction engine var predEngine = mlContext.Model.CreatePredictionEngine<InputData, OutPutData>(trainedModel); DirectoryInfo TestFolder = new DirectoryInfo(Path.Combine(AssetsFolder, "test")); foreach(var image in TestFolder.GetFiles()) { count++; InputData img = new InputData() { FileName = image.Name }; var result = predEngine.Predict(img); Console.WriteLine($"Current Source={img.FileName},PredictResult={result.GetPredictResult()}"); } } } class InputData { [LoadColumn(0)] public string FileName; [LoadColumn(1)] public string Number; [LoadColumn(1)] public float Serial; } class OutPutData : InputData { public float[] Score; public int GetPredictResult() { float max = 0; int index = 0; for (int i = 0; i < Score.Length; i++) { if (Score[i] > max) { max = Score[i]; index = i; } } return index; } } }