【发布时间】:2021-05-12 20:39:08
【问题描述】:
当我尝试从 tensorflow-hub resporitory 获取模型时。 我可以将其视为已保存的模型格式,但我无法访问模型架构以及每一层的权重存储。
import tensorflow_hub as hub
model = hub.load("https://tfhub.dev/tensorflow/centernet/hourglass_512x512/1")
)
有什么正式的方法可以使用它吗?
对于原始模型中的特定层,我可以通过model.__dict__ 获得的所有属性都不清楚。
{'_self_setattr_tracking': True,
'_self_unconditional_checkpoint_dependencies': [TrackableReference(name='_model', ref=<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject object at 0x7fe4e4914710>),
TrackableReference(name='signatures', ref=_SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(input_tensor) at 0x7FE4E601F210>})),
TrackableReference(name='_self_saveable_object_factories', ref=DictWrapper({}))],
'_self_unconditional_dependency_names': {'_model': <tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject at 0x7fe4e4914710>,
'signatures': _SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(input_tensor) at 0x7FE4E601F210>}),
'_self_saveable_object_factories': {}},
'_self_unconditional_deferred_dependencies': {},
'_self_update_uid': 176794,
'_self_name_based_restores': set(),
'_self_saveable_object_factories': {},
'_model': <tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject at 0x7fe4e4914710>,
'signatures': _SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(input_tensor) at 0x7FE4E601F210>}),
'__call__': <tensorflow.python.saved_model.function_deserialization.RestoredFunction at 0x7fe315a28950>,
'graph_debug_info': ,
'tensorflow_version': '2.4.0',
'tensorflow_git_version': 'unknown'}
我也试过model.signatures['serving_default'].__dict__,每层的张量代表不可见
[<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>],
【问题讨论】:
-
您能否分享一下您使用的具体型号,以便我们重现设置?
-
我只是从模型中心加载这个模型:model = tf.saved_model.load('/tmp/tfhub_modules/3085eb2fbe2ad0b69801d50844c97b7a7a5ecade')
-
为此,您必须事先将模型下载到
/tmp/tfhub_modules/(经过训练的模型不会出现在系统的临时文件夹中)。您是如何想出这条特定路径的? -
是的。我通过 tensorflow_hub 的传统 hub.load() 得到它
-
我已经更新了问题
标签: tensorflow keras tf.keras tensorflow-serving tensorflow-hub