【问题标题】:Tensorflow session error in universal sentence encoder通用句子编码器中的Tensorflow会话错误
【发布时间】:2020-05-14 20:30:40
【问题描述】:

我有以下通用句子编码器的代码,一旦我将模型加载到烧瓶 api 中并尝试点击它,它会给出以下错误(检查下方):

'''

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
model_2 = hub.load(module_url)
print ("module %s loaded" % module_url)

def embed(input):
    return model_2(input)


def universalModel(messages):
    accuracy = []
    similarity_input_placeholder = tf.placeholder(tf.string, shape=(None))
    similarity_message_encodings = embed(similarity_input_placeholder)
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        message_embeddings_ = session.run(similarity_message_encodings, feed_dict={similarity_input_placeholder: messages})

        corr = np.inner(message_embeddings_, message_embeddings_)
        accuracy.append(corr[0,1])
    # print(corr[0,1])
    return "%.2f" % accuracy[0]

'''

在flask api中使用模型时出现以下错误: tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph is invalid, contains a circle with 1 nodes, 包括:StatefulPartitionedCall 尽管此代码在 colab 笔记本中运行时没有任何错误。

我使用的是 tensorflow 2.2.0 版。

【问题讨论】:

    标签: python tensorflow tensorflow2.0 tensorflow-serving sentence-similarity


    【解决方案1】:
    import tensorflow.compat.v1 as tf
    tf.disable_v2_behavior()
    

    这两行是为了让 tensorflow 2.x 变成 tensorflow 1.x。

    对于 Tensorflow 1.x,这是使用烧瓶、django 等服务时的常见问题。 你必须为推理定义一个图和会话,

    将张量流导入为 tf 导入 tensorflow_hub 作为集线器

    # Create graph and finalize (finalizing optional but recommended).
    g = tf.Graph()
    with g.as_default():
      # We will be feeding 1D tensors of text into the graph.
      text_input = tf.placeholder(dtype=tf.string, shape=[None])
      embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")
      embedded_text = embed(text_input)
      init_op = tf.group([tf.global_variables_initializer(), tf.tables_initializer()])
    g.finalize()
    
    # Create session and initialize.
    session = tf.Session(graph=g)
    session.run(init_op)
    

    可以通过输入请求处理

    result = session.run(embedded_text, feed_dict={text_input: ["Hello world"]})
    

    详情 https://www.tensorflow.org/hub/common_issues

    对于 tensorflow 2.x 会话和图形不是必需的。

    import tensorflow as tf
    module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
    model_2 = hub.load(module_url)
    print ("module %s loaded" % module_url)
    
    def embed(input):
        return model_2(input)
    #pass messages as list
    def universalModel(messages):
        accuracy = []
        message_embeddings_= embed(messages)
        corr = np.inner(message_embeddings_, message_embeddings_)
        accuracy.append(corr[0,1])
        # print(corr[0,1])
        return "%.2f" % accuracy[0]
    

    【讨论】:

    • 同样的代码也在使用 tensorflow 2 的 jupyter notebook 上工作,所以我不明白为什么它不能在 api 中工作? @vivek-ananthan
    • 这是 TF 1.x 与 Flask api 的已知问题。如果要与 TF 2.x 一起使用,则无需像 TF 1.x 那样构建图形和会话。用 import tensorflow 替换前 2 行,以便直接加载 tf 2.x 而不是 tf1.x 功能。
    • 我在查找相关性和加载模型 5 时遇到错误,您能否将我的模型 5 代码深入到烧瓶 api 代码中? @vivek-ananthan
    • 将前 2 个 import 语句替换为“import tensorflow”,让我知道你在那之后遇到了什么错误
    • 我收到占位符错误,因为它在 v2 中不存在,我还需要删除什么? @vivek-ananthan
    猜你喜欢
    • 2021-04-26
    • 2020-04-15
    • 1970-01-01
    • 1970-01-01
    • 2020-03-03
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多