【发布时间】:2020-05-14 07:58:28
【问题描述】:
我的模型没有学习.. 它应该在最后进行 softmax 计算。因此,我想要一个分类(退出或不退出)。该模型应预测客户是否会退出。我将退出列作为标签并有 196 个输入特征。
我的面罩说根本没有学习。但是我不确定,如果我的模型学习,遮阳板将如何获取信息。我对 javascript 非常陌生,如果有任何帮助,我将不胜感激。
ngOnInit() {
this.train();
}
async train(): Promise<any> {
const csvUrl = '/assets/little.csv';
const csvDataset = tf.data.csv(
csvUrl,
{
columnConfigs: {
quit: {
isLabel: true
}
},
delimiter:','
});
const numOfFeatures = (await csvDataset.columnNames()).length -1;
console.log(numOfFeatures);
const flattenedDataset =
csvDataset
.map(({xs, ys}: any) =>
{
// Convert xs(features) and ys(labels) from object form (keyed by
// column name) to array form.
return {xs:Object.values(xs), ys:Object.values(ys)};
}).batch(10);
console.log(flattenedDataset.toArray());
const model = tf.sequential({
layers: [
tf.layers.dense({inputShape: [196], units: 100, activation: 'relu'}),
tf.layers.dense({units: 100, activation: 'relu'}),
tf.layers.dense({units: 100, activation: 'relu'}),
tf.layers.dense({units: 1, activation: 'softmax'}),
]
});
await trainModel(model, flattenedDataset);
const surface = { name: 'Model Summary', tab: 'Model Inspection'};
tfvis.show.modelSummary(surface, model);
console.log('Done Training');
}
async function trainModel(model, flattenedDataset) {
// Prepare the model for training.
model.compile({
optimizer: tf.train.adam(),
loss: tf.losses.sigmoidCrossEntropy,
metrics: ['accuracy']
});
const batchSize = 32;
const epochs = 50;
return await model.fitDataset(flattenedDataset, {
batchSize,
epochs,
shuffle: true,
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training Performance' },
['loss'],
{ height: 200, callbacks: ['onEpochEnd'] }
)
});
}
【问题讨论】:
-
您是否尝试过不同的损失函数,例如
categoricalCrossentropy?并将配置更改为:loss: 'categoricalCrossentropy' -
随着你的变化,loss-Value 保持在一个低得多的水平,但仍然是一条平线。之前,该线在 0.75 左右。随着您的更改,它保持在 0.00005 左右。但是算法似乎仍然没有学习..
-
softmax 激活用于分类问题。您的模型似乎没有进行分类。您的最后一层有一个单位,表明您正在进行回归。因此,您的模型很可能无法学习
-
这将是一个很好的解释!我更新了问题。您能否提供一个建议,我必须如何更改代码,以便模型正确分类?我希望输出是退出/不退出的分类。
-
最后一层的单元数为类别数。在
quit和no-quit中有两个类别。此外,您的标签应该是一次性编码的。可以在here 找到有关模型不学习的更一般性的答案
标签: machine-learning neural-network tensorflow.js