【问题标题】:How to use existing keras model in tensorflow.js如何在 tensorflow.js 中使用现有的 keras 模型
【发布时间】:2021-02-05 23:49:26
【问题描述】:

我有 Keras 模型,已转换为 tensorflow.js,但无法在 javascript 中加载模型,具体步骤是什么?

model.add(Embedding(vocabulary_size, seq_len, input_length=seq_len))
model.add(LSTM(256,return_sequences=True))
model.add(LSTM(128))
model.add(Dense(256,activation='relu'))
model.add(Dense(vocabulary_size, activation='softmax'))
# compiling the network
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_inputs,train_targets,epochs=256,verbose=1)

【问题讨论】:

  • 模型是否保存在本地存储中?
  • 是的,我已经转换成tensorflow.js并保存在本地了。
  • 应用是纯 JavaScript 还是像 React 这样的框架?
  • 是的,我已经在 React 中创建了它,并且我使用 express 创建了一台服务器来保存我的模型,因此我从该服务器获取我的模型,但我没有取得任何成功,我是否走在正确的道路上?
  • 啊,所以您正在使用 HTTPS 请求来加载模型。我将很快添加一个答案。保存模型的 URL 是什么?我的意思是model.json

标签: javascript python tensorflow keras tensorflow.js


【解决方案1】:
import tensorflowjs as tfjs

model.add(Embedding(vocabulary_size, seq_len, input_length=seq_len))
model.add(LSTM(256,return_sequences=True))
model.add(LSTM(128))
model.add(Dense(256,activation='relu'))
model.add(Dense(vocabulary_size, activation='softmax'))
# compiling the network
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_inputs,train_targets,epochs=256,verbose=1)

# save the model in model.json

tfjs.converters.save_keras_model(model, './keras_converted')

在 javascript 中加载模型

  • 要在浏览器中加载 model.json,它必须在本地服务器中提供,因为浏览器无法直接访问文件系统
import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://hostname:port/path/to/model.json');

【讨论】:

    【解决方案2】:

    您可以像这样从端点加载模型。

    import * as tf from '@tensorflow/tfjs';
    import React, {useState, useEffect} from "react";
    
    const url = {
        model: 'http://localhost:81/tfjs-models/model.json',
    };
    
    async function loadModel(url) {
      try {
        // For layered model
        const model = await tf.loadLayersModel(url.model);
        // For graph model
        // const model = await tf.loadGraphModel(url.model);
        setModel(model);
        console.log("Load model success");
      } catch (err) {
        console.log(err);
      }
    } 
    const [model, setModel] = useState();
    
    useEffect(() => {
      tf.ready().then(() => {
        loadModel(url);
      });
    }, []);
    

    然后可以使用model from state 访问模型。

    【讨论】:

    • 我做了console.log(model.Summary()) 得到这个:TypeError: webglBackend.incRef is not a function,是要解决的问题吗?
    • console.log(model) 输出什么?
    • TypeError: webglBackend.incRef is not a function at reshape (Reshape.ts:50) at Object.min [as kernelFunc] (Min.ts:48) at h (engine.js:562)在 engine.js:619 在 t.e.scopedRun (engine.js:436) 在 t.e.runKernelFunc (engine.js:616) 在 t.e.runKernel (engine.js:494) 在 min_ (min.ts:64) 在 Module.min__op ( operation.ts:51) at Module.min (math_utils.ts:74) at new LSTMCell (recurrent.ts:1663) at new LSTM (recurrent.ts:1856) at fromConfig (recurrent.ts:1886) at deserializeKerasObject (generic_utils .ts:277) 在反序列化时.....
    • 对不起,我弄错了,试试if(model) {console.log(model)},你用的是什么版本的TFJS?
    • 3.0.0是版本
    猜你喜欢
    • 2021-10-10
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2022-12-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-03-16
    相关资源
    最近更新 更多