【发布时间】:2018-10-30 17:56:34
【问题描述】:
我想将 Keras 模型中的变量与 TensorFlow 检查点中的变量进行比较。我可以像这样得到 TF 变量:
vars_in_checkpoint = tf.train.list_variables(os.path.join("./model.ckpt"))
如何从我的model 中获取要比较的 Keras 变量?
【问题讨论】:
标签: python-3.x tensorflow keras
我想将 Keras 模型中的变量与 TensorFlow 检查点中的变量进行比较。我可以像这样得到 TF 变量:
vars_in_checkpoint = tf.train.list_variables(os.path.join("./model.ckpt"))
如何从我的model 中获取要比较的 Keras 变量?
【问题讨论】:
标签: python-3.x tensorflow keras
您可以通过model.weights(tf.Variable 实例列表)获取 Keras 模型的变量。
【讨论】:
要获取变量的名称,您需要从模型层的权重属性访问它。像这样的:
names = [weight.name for layer in model.layers for weight in layer.weights]
并得到重量的形状:
weights = [weight.shape for weight in model.get_weights()]
【讨论】: