【发布时间】:2020-05-02 20:34:44
【问题描述】:
我想将一批文本填充成相同的长度,生成段 id、掩码向量,然后将它们提供给 bert 模型。
在pytorch中,我可以使用collate_fn,如下所示。
def collate_fn(self, batch):
rows = self.df.iloc[batch] # take a batch of data
ids, seg_ids = self.get_ids_segs(rows) # process data
attention_mask = (ids > 0)
return ids, seg_ids,attention_mask
但在 tensorflow 中,数据是通过矩阵元组传递的,因此所有文本都被填充到最大长度 512。
# ids.shape = seg_ids = attention_mask = (data_number, max_seq_len)
xs = (ids, seg_ids, attention_mask)
model.fit(xs,, ys, batch_size=batch_size)
我发现tf.data.dataset 有一个函数padded_batch。但它只能填充一个输入,我有3个输入数据,ids,seq_ids,attn_mask。
【问题讨论】:
-
你能举一个
tf.data.dataset创建ids,seq_ids,attn_mask的例子吗?
标签: tensorflow deep-learning nlp pytorch