【发布时间】:2021-06-18 05:32:59
【问题描述】:
我有一个用 tensorflowjs 构建的非常基本的模型。它应该可以工作,因为它是直接从 google tensorflow 教程复制而来的。我只是将该代码转换为一个可重用的类。每次我传递一些数据进行预测时,它都会返回 undefined。
class TFModel {
constructor() {
this.xs = tf.randomUniform([10000, 2]);
this.ys = tf.randomUniform([10000, 1]);
this.valXs = tf.randomUniform([1000, 2]);
this.valYs = tf.randomUniform([1000, 1]);
this.model = tf.sequential();
}
init() {
this.model.add(tf.layers.dense({
units: 1,
inputShape: [2]
}));
this.model.compile({
loss: 'meanSquaredError',
optimizer: 'sgd',
metrics: ['MAE']
});
}
async train() {
await this.model.fit(this.xs, this.ys, {
epochs: 4,
validationData: [this.valXs, this.valYs]
});
}
predict(data, callback) {
let transData = tf.tensor(data)
console.log(transData)
this.model.predict(transData, result => {
console.log("Result Predict", result)
// callback(result)
})
}
dispose() {
}
}
这是我如何调用上面的类
model = new TFModel();
model.init()
model.train().then(data => {
console.log("Resul Predict", data)
})
model.predict([
[3, 3]
], result => {
// console.log("Result Predict", result)
})
【问题讨论】:
标签: javascript node.js tensorflow tensorflow.js