【问题标题】:How to list all used operations in Tensorflow SavedModel?如何在 Tensorflow SavedModel 中列出所有使用的操作?
【发布时间】:2020-02-10 16:37:10
【问题描述】:

如果我使用 tensorflow.saved_model.save 函数以 SavedModel 格式保存我的模型,之后我如何检索该模型中使用了哪些 Tensorflow Ops。由于模型可以恢复,这些操作都存储在图中,我猜是在saved_model.pb文件中。如果我加载这个 protobuf(所以不是整个模型),protobuf 的库部分会列出这些,但目前没有记录并标记为实验性功能。在 TensorFlow 1.x 中创建的模型没有这部分。

那么,从 SavedModel 格式的模型中检索已使用操作列表(如 MatchingFilesWriteFile)的快速可靠方法是什么?

现在我可以像tensorflowjs-converter 一样冻结整个事情。因为他们还检查支持的操作。当模型中有 LSTM 时,这目前不起作用,请参阅 here。有没有更好的方法来做到这一点,因为 Ops 肯定在那里?

示例模型:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

预计输出所有 Ops,在这种情况下至少包含:

  • ReadFilehere 所述
  • ...

【问题讨论】:

  • 很难准确说出你想要什么,saved_model.pb 是什么,是tf.GraphDef,还是SavedModel protobuf 消息?如果您有一个名为gdtf.GraphDef,您可以使用sorted(set(n.op for n in gd.node)) 获取已使用操作的列表。如果你有一个加载模型,你可以做sorted(set(op.type for op in tf.get_default_graph().get_operations()))。如果是SavedModel,则可以从中获取tf.GraphDef(例如saved_model.meta_graphs[0].graph_def)。
  • 我想从存储的 SavedModel 中检索操作。确实,您描述的最后一个选项。您上一个示例中的 saved_model 变量是什么? tf.saved_model.load('/path/to/model') 的结果或者加载了 saved_model.pb 文件的 protobuf。

标签: python tensorflow


【解决方案1】:

如果saved_model.pbSavedModel protobuf 消息,那么您可以直接从那里获取操作。假设我们创建一个模型如下:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

我们现在可以像这样找到该模型使用的操作:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin

【讨论】:

  • 我尝试过这样的事情,但不幸的是,这不是我所期望的:假设我有一个这样做的模型:input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content') 然后列出here 的ReadFile Op 在那里,但没有打印出来。
  • @sampers 我已经用你建议的例子编辑了答案。我确实在输出中得到了ReadFile 操作。在您的实际情况下,该操作是否可能不在保存模型的输入和输出之间?在那种情况下,我认为它可能会被修剪。
  • 确实适用于给定的模型。不幸的是,对于 tf2 中的模块,它没有。如果我创建一个带有 1 个函数的 tf.Module 和 file_name 参数 @tf.function 注释,其中包含我在之前的评论中列出的调用,它会给出以下列表:Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
  • 在我的问题中添加了一个模型
  • @sampers 我已经更新了我的答案。我之前使用的是 TF 1.x,我不熟悉 TF 2.x 中图形定义对象的更改,我认为现在答案涵盖了已保存模型中的所有内容。我认为与您编写的 Python 函数对应的操作在 saved_model.meta_graphs[0].graph_def.library.function[0](该函数对象内的 node_def 集合)中。
猜你喜欢
  • 2021-01-07
  • 2021-02-22
  • 1970-01-01
  • 2020-01-13
  • 2018-03-12
  • 1970-01-01
  • 1970-01-01
  • 2021-10-27
  • 2020-11-09
相关资源
最近更新 更多