【问题标题】:tensorflow how to pad batched text like pytorch's 'collate_fn'?tensorflow如何填充像pytorch的'collat​​e_fn'这样的批处理文本?
【发布时间】:2020-05-02 20:34:44
【问题描述】:

我想将一批文本填充成相同的长度,生成段 id、掩码向量,然后将它们提供给 bert 模型。 在pytorch中,我可以使用collate_fn,如下所示。

def collate_fn(self, batch):
    rows = self.df.iloc[batch] # take a batch of data
    ids, seg_ids = self.get_ids_segs(rows) # process data
    attention_mask = (ids > 0)
    return ids, seg_ids,attention_mask

但在 tensorflow 中,数据是通过矩阵元组传递的,因此所有文本都被填充到最大长度 512。

# ids.shape = seg_ids = attention_mask = (data_number, max_seq_len) 
xs = (ids, seg_ids, attention_mask)

model.fit(xs,, ys, batch_size=batch_size)

我发现tf.data.dataset 有一个函数padded_batch。但它只能填充一个输入,我有3个输入数据,idsseq_idsattn_mask

【问题讨论】:

  • 你能举一个tf.data.dataset创建ids,seq_ids,attn_mask的例子吗?

标签: tensorflow deep-learning nlp pytorch


【解决方案1】:

可能使用了apply或map方法

tf.data.Dataset

应用批处理方法后应该可以解决问题。

【讨论】:

    猜你喜欢
    • 2020-01-08
    • 1970-01-01
    • 1970-01-01
    • 2021-05-12
    • 2020-01-17
    • 1970-01-01
    • 1970-01-01
    • 2019-06-02
    • 2020-09-25
    相关资源
    最近更新 更多