【发布时间】:2021-09-25 23:22:25
【问题描述】:
这是我为分类任务构建模型的方式:
def bert_for_classification(transformer_model_name, max_sequence_length, num_labels):
config = ElectraConfig.from_pretrained(
transformer_model_name,
num_labels=num_labels,
output_hidden_states=False,
output_attentions=False
)
model = TFElectraForSequenceClassification.from_pretrained(transformer_model_name, config=config)
# This is the input for the tokens themselves(words from the dataset after encoding):
input_ids = tf.keras.layers.Input(shape=(max_sequence_length,), dtype=tf.int32, name='input_ids')
# attention_mask - is a binary mask which tells BERT which tokens to attend and which not to attend.
# Encoder will add the 0 tokens to the some sequence which smaller than MAX_SEQUENCE_LENGTH,
# and attention_mask, in this case, tells BERT where is the token from the original data and where is 0 pad
# token:
attention_mask = tf.keras.layers.Input((max_sequence_length,), dtype=tf.int32, name='attention_mask')
# Use previous inputs as BERT inputs:
output = model([input_ids, attention_mask])[0]
output = tf.keras.layers.Dense(num_labels, activation='softmax')(output)
model = tf.keras.models.Model(inputs=[input_ids, attention_mask], outputs=output)
model.compile(loss=keras.losses.CategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(3e-05, epsilon=1e-08),
metrics=['accuracy'])
return model
训练完这个模型后,我使用model.save_weights('model.hd5') 保存它
但事实证明有两个文件被保存:model.hd5.index 和 model.hd5.data-00000-of-00001
我应该如何从磁盘加载这个模型?
【问题讨论】:
-
你安装hd5py包了吗?
-
@Kaveh 是的,我已经安装了
h5py -
哦。你写了扩展
hd5,而它是h5。model.save_weights('model.h5'),如果不是h5格式,则保存为SavedModel格式。
标签: python tensorflow tensorflow2.0 huggingface-transformers