【发布时间】:2020-01-30 20:24:44
【问题描述】:
以下代码(工作),训练模型识别猫并对所选图片进行预测。 (代码TensorFlowJS但问题一般是TensorFlow)
到目前为止,它只预测一个类别(“猫”),因此汽车或狗将是例如 80% 的猫。
问:
我如何添加其他类(如“狗”)?
应该是这样(抽象):model.fit([img1, img2, img3], [label1, label2, label3] ...) ?
没看懂:
标签和训练集之间有什么关系。
这是代码(请暂时忽略“预测”部分):
<head>
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.2.7"> </script>
<script src="https://unpkg.com/@tensorflow-models/mobilenet"></script>
</head>
<body>
<div class="container mt-5">
<div class="row">
<input id ="image-selector" class="form-control border-0" type="file"/>
</div>
<div class="row">
<div class="col">
<h2>Prediction</h2>
<ol id="prediction-list"></ol>
</div>
</div>
<div class="row">
<div class="col-12">
<h2 class="ml-3">Image</h2>
<canvas id="canvas" width="400" height="300" style="border:1px solid #000000;"></canvas>
</div>
</div>
</div>
<div id="training-images">
<img width="400" height="300" class="train-image cat" src="training-images/cat.jpg" />
<img width="400" height="300" class="train-image cat" src="training-images/cat2.jpeg" />
<img width="400" height="300" class="train-image cat" src="training-images/cat3.jpeg" />
<img width="400" height="300" class="train-image cat" src="training-images/cat4.jpeg" />
<img width="400" height="300" class="train-image dog" src="training-images/dog.jpeg" />
<img width="400" height="300" class="train-image dog" src="training-images/dog2.jpeg" />
<img width="400" height="300" class="train-image dog" src="training-images/dog3.jpeg" />
<img width="400" height="300" class="train-image dog" src="training-images/dog4.jpeg" />
</div>
</body>
<script>
const modelType = "mobilenet";
const model = tf.sequential();
const label = ['cat'];
var ys, setLabel, input, canvas, context;
input = document.getElementById("image-selector");
canvas = document.getElementById("canvas");
context = canvas.getContext('2d');
//-------------------------- Training: --------------------------------
window.addEventListener('load', (event) => {
// Labels
setLabel = Array.from(new Set(label));
ys = tf.oneHot(tf.tensor1d(label.map((a) => setLabel.findIndex(e => e === a)), 'int32'), 10);
console.log('ys:::'+ys);
// Prepare model :
model.add(tf.layers.conv2d({
inputShape: [224, 224 , 3],
kernelSize: 5,
filters: 8,
strides: 2,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
}));
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
model.add(tf.layers.flatten({}));
model.add(tf.layers.dense({units: 64, activation: 'relu'}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
model.compile({
loss: 'meanSquaredError',
optimizer : 'sgd'
});
// Prepare training images
var images = [];
for(var i = 0; i < 40; i++) {
let img = preprocessImage(document.getElementsByClassName("cat")[i], modelType);
images.push(tf.reshape(img, [1, 224, 224, 3],'resize'));
}
console.log("processed images : ");
console.log(images);
trainModel(images);
});
async function trainModel(images) {
for(var i = 0; i < images.length; i++) {
await model.fit(images[i], ys, {epochs: 100, batchSize: 32}).then((loss) => {
const t = model.predict(images[i]);
console.log('Prediction:::'+t);
pred = t.argMax(1).dataSync(); // get the class of highest probability
const labelsPred = Array.from(pred).map(e => setLabel[e]);
console.log('labelsPred:::'+labelsPred);
}).catch((e) => {
console.log(e.message);
})
}
console.log("Training done!");
}
//-------------------------- Predict: --------------------------------
input.addEventListener("change", function() {
var reader = new FileReader();
reader.addEventListener("loadend", function(arg) {
var src_image = new Image();
src_image.onload = function() {
canvas.height = src_image.height;
canvas.width = src_image.width;
context.drawImage(src_image, 0, 0);
var imageData = canvas.toDataURL("image/png");
runPrediction(src_image)
}
src_image.src = this.result;
});
var res = reader.readAsDataURL(this.files[0]);
});
async function runPrediction(imageData){
let tensor = preprocessImage(imageData, "mobilenet");
const resize_image = tf.reshape(tensor, [1, 224, 224, 3],'resize');
let prediction = await model.predict(tensor).data();
console.log('prediction:::'+ prediction);
let top5 = Array.from(prediction)
.map(function(p,i){
return {
probability: p,
className: prediction[i]
};
}).sort(function(a,b){
return b.probability-a.probability;
}).slice(0,1);
$("#prediction-list").empty();
top5.forEach(function(p){
$("#prediction-list").append(`<li>${p.className}:${p.probability.toFixed(6)}</li>`);
});
}
//-------------------------- Helpers: --------------------------------
function preprocessImage(image, modelName)
{
let tensor = tf.browser.fromPixels(image)
.resizeNearestNeighbor([224,224])
.toFloat();
let offset=tf.scalar(127.5);
return tensor.sub(offset)
.div(offset)
.expandDims();
}
</script>
代码基于TFJS文档和github上的评论:https://github.com/tensorflow/tfjs/issues/1288
更新:
所以我需要 X 和 Y 对于 X:images 和 Y:labels 的长度相同,其中 Y1 是 X1 的标签,依此类推......
我试过了:
ys:::Tensor (with only 2 classes represented in the training data set) :
[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]]
一张图片 + 所有标签 -> 带有“model.fit(images[i], ys, {epochs: 100})...”,我得到:
错误:“输入张量应具有与目标张量相同数量的样本。找到 1 个输入样本和 10 个目标样本。”
一张图片 + 一张标签 -> 带有“model.fit(images[i], ys[i], {epochs: 100})...”,我得到:
错误:“无法读取 null 的属性‘形状’”,我猜 ys 是张量,但 y[i] 不是。
所有图像+所有标签->带有“model.fit(images,ys,{epochs:100})...”,我得到:
错误:“检查模型输入时:您传递给模型的张量数组不是模型预期的大小。
预计会看到 1 个张量,但得到了以下张量列表:张量 ..."
猜想:我需要将所有图像放在一个与 ys 结构相同的张量中。
已解决:
感谢 Rishabh Sahrawat 解决了标签问题后,我不得不在 tf.concat(...) 的帮助下将所有张量(图像)合并为一个。
[tensorImg1, tensorImg2, tensorImg3, tensorImg4, ...] x tensor[label1, label2, label3, label4, ...]
->
tensor[dataImg1, dataImg2, dataImg3, dataImg4, ...] x tensor[label1, label2, label3, label4, ...]
更新代码:
<head>
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.2.7"> </script>
<script src="https://unpkg.com/@tensorflow-models/mobilenet"></script>
</head>
<body>
<div class="container mt-5">
<div class="row">
<input id ="image-selector" class="form-control border-0" type="file"/>
</div>
<div class="row">
<div class="col">
<h2>Prediction</h2>
<ol id="prediction-list"></ol>
</div>
</div>
<div class="row">
<div class="col-12">
<h2 class="ml-3">Image</h2>
<canvas id="canvas" width="400" height="300" style="border:1px solid #000000;"></canvas>
</div>
</div>
</div>
<div id="training-images">
<img width="400" height="300" class="train-image cat" src="training-images/cat.jpg" />
<img width="400" height="300" class="train-image cat" src="training-images/cat2.jpeg" />
<img width="400" height="300" class="train-image cat" src="training-images/cat3.jpeg" />
<img width="400" height="300" class="train-image dog" src="training-images/dog.jpeg" />
<img width="400" height="300" class="train-image dog" src="training-images/dog2.jpeg" />
<img width="400" height="300" class="train-image dog" src="training-images/dog3.jpeg" />
<img width="400" height="300" class="train-image dog" src="training-images/dog4.jpeg" />
</div>
</body>
<script>
const modelType = "mobilenet";
const model = tf.sequential();
var labels = ['cat', 'dog'];
var ys, setLabel, input, canvas, context;
input = document.getElementById("image-selector");
canvas = document.getElementById("canvas");
context = canvas.getContext('2d');
//-------------------------- Training: --------------------------------
window.addEventListener('load', (event) => {
// Prepare model :
prepareModel();
// Prepare training images
var images = [];
var trainLabels = []
for(var i = 0; i < document.getElementsByClassName('train-image').length; i++) {
let img = preprocessImage(document.getElementsByClassName('train-image')[i], modelType);
//images.push(tf.reshape(img, [1, 224, 224, 3],'resize'));
images.push(img);
if (document.getElementsByClassName('train-image')[i].classList.contains("cat")){
trainLabels.push(0)
} else {
trainLabels.push(1)
}
}
console.log(labels)
setLabel = Array.from(labels);
ys = tf.oneHot(trainLabels, 2);
console.log('ys:::'+ys);
console.log(images);
trainModel(images);
});
async function trainModel(images) {
for(var i = 0; i < images.length; i++) {
await model.fit(tf.concat(images, 0), ys, {epochs: 100}).then((loss) => {
const t = model.predict(images[i]);
console.log('Prediction:::'+t);
pred = t.argMax().dataSync(); // get the class of highest probability
//const labelsPred = Array.from(pred).map(e => setLabel[e]);
//console.log('labelsPred:::'+labelsPred);
}).catch((e) => {
console.log(e.message);
})
}
console.log("Training done!");
}
//-------------------------- Predict: --------------------------------
input.addEventListener("change", function() {
var reader = new FileReader();
reader.addEventListener("loadend", function(arg) {
var src_image = new Image();
src_image.onload = function() {
canvas.height = src_image.height;
canvas.width = src_image.width;
context.drawImage(src_image, 0, 0);
var imageData = canvas.toDataURL("image/png");
runPrediction(src_image)
}
src_image.src = this.result;
});
var res = reader.readAsDataURL(this.files[0]);
});
async function runPrediction(imageData){
let tensor = preprocessImage(imageData, "mobilenet");
const resize_image = tf.reshape(tensor, [1, 224, 224, 3],'resize');
let prediction = await model.predict(tensor).data();
console.log('prediction:::'+ prediction);
let top5 = Array.from(prediction)
.map(function(p,i){
return {
probability: p,
className: prediction[i]
};
}).sort(function(a,b){
return b.probability-a.probability;
}).slice(0,1);
$("#prediction-list").empty();
top5.forEach(function(p){
$("#prediction-list").append(`<li>${p.className}:${p.probability.toFixed(6)}</li>`);
});
}
//-------------------------- Helpers: --------------------------------
function prepareModel(){
model.add(tf.layers.conv2d({
inputShape: [224, 224 , 3],
kernelSize: 5,
filters: 8,
strides: 2,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
}));
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
model.add(tf.layers.flatten({}));
model.add(tf.layers.dense({units: 64, activation: 'relu'}));
model.add(tf.layers.dense({units: 2, activation: 'softmax'}));
model.compile({
loss: 'meanSquaredError',
optimizer : 'sgd'
});
model.summary()
}
function preprocessImage(image, modelName)
{
let tensor = tf.browser.fromPixels(image)
.resizeNearestNeighbor([224,224])
.toFloat();
let offset=tf.scalar(127.5);
return tensor.sub(offset)
.div(offset)
.expandDims();
}
</script>
【问题讨论】:
标签: tensorflow