【问题标题】:Converting a tf.dataset to a PyTorch Dataset?将 tf.dataset 转换为 PyTorch 数据集?
【发布时间】:2022-04-24 20:24:39
【问题描述】:

我正在做这个项目,所有数据都经过预处理并准备好作为 tensorflow 数据集,如下所示:

<MapDataset shapes: {input_ids: (128,), input_mask: (128,), label_ids: (), segment_ids: (128,)}, types: {input_ids: tf.int64, input_mask: tf.int64, label_ids: tf.int64, segment_ids: tf.int64}>

我拥有的脚本在 PyTorch 中,并接受如下所示的 Dataset 对象:

Dataset({
    features: ['attention_mask', 'input_ids', 'label', 'sentence', 'token_type_ids'],
    num_rows: 12
})

是否可以将一种转换为另一种?我对这两个 API 都很陌生,所以我不太确定它们是如何工作的?我可以使用 dict 将一个转换为另一个吗?

谢谢

【问题讨论】:

    标签: tensorflow pytorch dataset tensorflow-datasets


    【解决方案1】:

    我使用tfds.as_numpy(dataset) 作为模型训练的数据加载器。为了转换传递给我的模型的数据,我在模型的 forward 函数中使用了torch.as_tensor(data, device=&lt;device&gt;)

    import tensorflow_datasets as tfds
    import torch.nn as nn
    
    def train_dataloader(batch_size):
        return tfds.as_numpy(tfds.load('mnist').batch(batch_size))
    
    class Model(nn.Module):
        def forward(self, x):
            x = torch.as_tensor(x, device='cuda')
            ...
    

    【讨论】:

      猜你喜欢
      • 2022-07-19
      • 2021-05-01
      • 2019-07-20
      • 2021-02-06
      • 1970-01-01
      • 2021-02-09
      • 2018-10-22
      相关资源
      最近更新 更多