【问题标题】:How to load TF hub model from local system如何从本地系统加载 TF hub 模型
【发布时间】:2021-07-18 19:44:15
【问题描述】:

一种方法是每次从tensorflow_hub 下载模型,如下所示

import tensorflow as tf
import tensorflow_hub as hub

hub_url = "https://tfhub.dev/google/tf2-preview/nnlm-en-dim128/1"
embed = hub.KerasLayer(hub_url)
embeddings = embed(["A long sentence.", "single-word", "http://example.com"])
print(embeddings.shape, embeddings.dtype)

我想下载文件一次又一次地使用,而不是每次都下载

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:
    1. 从 url + "?tf-hub-format=compressed"
      下载您的模型 例如“https://tfhub.dev/google/tf2-preview/nnlm-en-dim128/1?tf-hub-format=compressed”
    2. 解压
    3. 在代码中加载解压文件夹
    import tensorflow as tf
    import tensorflow_hub as hub
    
    embed = hub.KerasLayer('path/to/untarred/folder')
    embeddings = embed(["A long sentence.", "single-word", "http://example.com"])
    print(embeddings.shape, embeddings.dtype)
    

    【讨论】:

      【解决方案2】:

      您可以使用hub.load() 方法加载 TF Hub 模块。另外,docs 说,

      目前只有 TensorFlow 2.x 和 通过调用tensorflow.saved_model.save() 创建的模块。这 该方法适用于 Eager 模式和图形模式。

      hub.load 方法有一个参数handle。模块句柄的类型是,

      1. 智能 URL 解析器,例如 tfhub.dev,例如:https://tfhub.dev/google/nnlm-en-dim128/1

      2. Tensorflow 支持的文件系统上包含模块文件的目录。这可能包括本地目录(例如/usr/local/mymodule)或谷歌云存储桶(gs://mymodule)。

      3. 指向模块的 TGZ 存档的 URL,例如https://example.com/mymodule.tar.gz

      您可以使用第 2 点和第 3 点。

      【讨论】:

      • hub_module = hub.load('/tmp/arbitrary-image-stylization-v1-256/') 不起作用。我做错了吗?
      【解决方案3】:

      如果有人想知道模型在 Windows 上默认保存在哪里,比如我,就在这里。

      C:\Users\AvrakDavra\AppData\Local\Temp\tfhub_modules\

      显然,您可以在任何地方下载并提及路径和 tfhub 将从那里获取,但以防万一。 在 Windows 上立即打开临时文件夹。

      1. 按 WindowsButton+R
      2. 写入 %TEMP%

      它将为您的用户名打开 temp 文件夹,并且默认情况下 tfhub_modules 文件夹位于该文件夹中。它将包含以下文件夹

      文本文件内容如下。

      Module: https://tfhub.dev/google/universal-sentence-encoder/4 Download Time: 2021-07-17 18:17:09.714147 Downloader Hostname: LAPTOP(PID:12720)

      【讨论】:

      • 如果您不想每次都手动提及路径,但您已经手动下载了模型,则可以将其设置为具有特定名称的默认位置。文件夹的名称实际上是 tfhub 在第一次尝试下载时决定的,它会创建一个 .lock 文件,下载完成后变成 .txt 文件,下载完成后的 .tmp 文件夹变成普通文件夹。该文件夹是tar所有内容所在的位置,txt文件是告诉下载是否完成。
      【解决方案4】:

      也许其他人可能会从一个具体的、可重复的答案中受益。这个帖子对应这个specific tfhub model

      tensorflow_hub 版本:0.12.0
      张量流版本:2.2.0

      我在我的 Linux 服务器上设置了以下路径:

      # Note, I manually created this entire path before ever downloading tfhub models
      /opt/tfhub/tf2/bert_en_uncased_L-12_H-768_A-12_4/
      

      (由于各种原因,我们仍然对 Tensorflow 1.x 有一些需求,所以我认为根据它们是否设计用于 tensorflow 1.x 和 tensorflow 2.x 来分离模型可能是个好主意,因此tf2 在我的路径中)

      然后我下载了模型文件,将其推送到我的 Linux 服务器,将其放在上述位置,然后执行:

      # bash
      tar xzf bert_en_uncased_L-12_H-768_A-12_4.tar.gz
      

      这给了我以下文件:

      # python
      import os
      os.listdir("/opt/tfhub/tf2/bert_en_uncased_L-12_H-768_A-12_4/")
      >>> ['keras_metadata.pb', 'saved_model.pb', 'assets', 'variables']
      

      那么我可以像这样加载模型:

      # python
      import tensorflow_hub as tfhub
      import tensorflow as tf
      bert_layer = tfhub.KerasLayer(tfhub.load("/opt/tfhub/tf2/bert_en_uncased_L-12_H-768_A-12_4"))
      

      【讨论】:

        猜你喜欢
        • 2020-10-21
        • 2019-05-07
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2022-08-04
        • 2021-12-26
        • 2022-10-07
        • 1970-01-01
        相关资源
        最近更新 更多