【问题标题】:How to correctly save keras model to be able to load with hub.Module()?如何正确保存 keras 模型以便能够使用 hub.Module() 加载?
【发布时间】:2019-09-17 15:23:44
【问题描述】:

我正在尝试在新图像集上重新训练 inception v3。

当我尝试保存模型时收到错误消息。

我试过了:

    tf.keras.models.save_model(model, filename)

    model.save(filename)

    tf.contrib.saved_model.save_keras_model(model, filename)       

都给我一个类似的错误,Module has no 'name'

我已附上与问题相关的代码。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys

import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt

FLAGS = None
def create_model(m, img_data):
    # load feature extractor (inception_v3)
    features_extractor_layer = tf.keras.layers.Lambda(m, input_shape=img_data.image_shape)

    # make pre-trained layers un-trainable
    features_extractor_layer.trainable = False

    print(features_extractor_layer.name)

    # add new activation layer to train to our classes
    model = tf.keras.Sequential([
        features_extractor_layer,
        tf.keras.layers.Dense(img_data.num_classes, activation='softmax')
    ])

    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    return model
def get_and_gen_images(module):
    """
    get images from image directory or url

    :param module: module (to get required image size info
    :return: batched image data
    """
    data_name = os.path.splitext(os.path.basename(FLAGS.image_dir_or_url))[0]
    print("data: ", data_name)

    # download images to cache if not already
    if FLAGS.image_dir_or_url.startswith('https://'):
        data_root = tf.keras.utils.get_file(data_name,
                                            FLAGS.image_dir_or_url,
                                            untar=True,
                                            cache_dir=os.getcwd())
    else:   # specify directory with images
        data_root = tf.keras.utils.get_file(data_name,
                                            FLAGS.image_dir_or_url)

    # get image size for specific module
    image_size = hub.get_expected_image_size(module)


    # TODO: this is where to add noise, rotations, shifts, etc.
    image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255, validation_split=0.2)

    # create image stream
    train_image_data = image_generator.flow_from_directory(str(data_root),
                                                           target_size=image_size,
                                                           batch_size=FLAGS.batch_size,
                                                           subset='training')

    validation_image_data = image_generator.flow_from_directory(str(data_root),
                                                                target_size=image_size,
                                                                batch_size=FLAGS.batch_size,
                                                                subset='validation')

    return train_image_data, validation_image_data
# load module (will download from url or directory_
module = hub.Module(FLAGS.tfhub_module)

# generate image stream
train_image_data, validation_image_data = get_and_gen_images(module)

model = create_model(module, train_image_data)
model.summary()

file = FLAGS.saved_model_dir + "/modelname.h5"

model.save(file)

这应该保存一个“.h5”模型文件,但我收到一个命名错误:

Traceback (most recent call last):
  File "/home/raphy/projects/vmi/tf_cpu/retrain.py", line 305, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "/home/raphy/.local/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 125, in run
    _sys.exit(main(argv))
  File "/home/raphy/projects/vmi/tf_cpu/retrain.py", line 205, in main
    model.save(file)
  File "/home/raphy/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/sequential.py", line 319, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "/home/raphy/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/saving.py", line 105, in save_model
    'config': model.get_config()
  File "/home/raphy/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/sequential.py", line 326, in get_config
    'config': layer.get_config()
  File "/home/raphy/.local/lib/python3.6/site-packages/tensorflow/python/keras/layers/core.py", line 756, in get_config
    function = self.function.__name__
AttributeError: 'Module' object has no attribute '__name__'

我想以 tf_hub 模型的格式保存模型。

【问题讨论】:

    标签: python-3.x tf.keras tensorflow-hub


    【解决方案1】:

    Hosting TF Hub TF Org Link中指定,

    如果您有兴趣托管自己的模型存储库, 可通过 tensorflow_hub 库加载,您的 HTTP 分发版 服务应遵循以下协议。

    换句话说,您无法使用 TF Hub 加载任何模型,但您只能加载 TF Hub Modules Site 中存在的模块。

    如果你想加载你保存的模型,你可以使用 tf.saved_model.load

    但是如果你想使用 TF Hub 来做,请参考this link

    另外,如果链接不起作用,请提及以下说明:

    托管您自己的模型

    TensorFlow Hub 在thub.dev 上提供训练模型的开放存储库。 tensorflow_hub 库可以从此存储库和其他基于 HTTP 的机器学习模型存储库加载模型。特别是,该协议允许将标识模型的 URL 用于模型的文档和用于获取模型的端点。

    如果您有兴趣托管自己的模型存储库,这些模型可以使用 tensorflow_hub 库加载,您的 HTTP 分发服务应遵循以下协议。

    协议:

    当使用诸如https://example.com/model 之类的 URL 来标识要加载或实例化的模型时,模型解析器将在附加查询参数 ?tf-hub-format=compressed 后尝试从该 URL 下载压缩的 tarball。

    查询参数将被解释为客户感兴趣的模型格式的逗号分隔列表。目前仅定义“压缩”格式。

    压缩 格式表示客户端需要一个包含模型内容的tar.gz 存档。归档的根目录是模型目录的根目录,应该包含一个 SavedModel,如下例所示:

    # Create a compressed model from a SavedModel directory.
    $ tar -cz -f model.tar.gz --owner=0 --group=0 -C /tmp/export-model/ .
    
    # Inspect files inside a compressed model
    $ tar -tf model.tar.gz
    ./
    ./variables/
    ./variables/variables.data-00000-of-00001
    ./variables/variables.index
    ./assets/
    ./saved_model.pb
    

    用于 TF1 中已弃用的 hub.Module() API 的压缩包也将包含一个 ./tfhub_module.pb 文件。 TF2 SavedModels 的 hub.load() API 会忽略此类文件。

    tensorflow_hub 库要求模型 URL 是版本化的,并且给定版本的模型内容是不可变的,因此可以无限期缓存。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-03-14
      • 1970-01-01
      • 1970-01-01
      • 2020-06-20
      • 2022-10-14
      • 1970-01-01
      • 1970-01-01
      • 2019-04-12
      相关资源
      最近更新 更多