【发布时间】:2021-11-04 09:10:53
【问题描述】:
我正在尝试利用 tensorflow_hub 中的预训练 BERT 模型实现自定义分类器。 我遇到了一个问题,不知道如何解决。
代码如下:
class BERTClassifier(tf.keras.models.Model):
def __init__(self):
super(BERTClassifier, self).__init__()
self.preprocessing_layer = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3', name='preprocessing')
self.encoder = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3', trainable=True, name='BERT_encoder')
def call(self, inputs):
x = self.preprocessing_layer(inputs)
x = self.encoder(x)
x = x['outputs']
return x
bert_clf = BERTClassifier('small_bert/bert_en_uncased_L-8_H-768_A-12', 'small_bert/bert_en_uncased_L-8_H-768_A-12')
bert_clf.predict(np.array(tf.reshape(["[CLS] Hello world [SEP]"])
我希望 predict 方法会返回句子的嵌入,但是当我运行代码时出现以下错误:
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (3 total):
* Tensor("inputs:0", shape=(None, 1), dtype=string)
* False
* None
Keyword arguments: {}
Expected these arguments to match one of the following 4 option(s):
Option 1:
Positional arguments (3 total):
* TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
* False
* None
Keyword arguments: {}
Option 2:
Positional arguments (3 total):
* TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
* True
* None
Keyword arguments: {}
Option 3:
Positional arguments (3 total):
* TensorSpec(shape=(None,), dtype=tf.string, name='inputs')
* False
* None
Keyword arguments: {}
Option 4:
Positional arguments (3 total):
* TensorSpec(shape=(None,), dtype=tf.string, name='inputs')
* True
* None
Keyword arguments: {}
有什么问题?我该如何解决? 提前谢谢大家!
【问题讨论】:
标签: python tensorflow bert-language-model tensorflow-hub