【问题标题】:How to easily convert a PyTorch dataloader to tf.Dataset?如何轻松地将 PyTorch 数据加载器转换为 tf.Dataset?
【发布时间】:2022-07-19 20:22:55
【问题描述】:

我们如何将pytorch 数据加载器转换为tf.Dataset

我发现了这个 sn-p:-

def convert_pytorch_dataloader_to_tf_dataset(dataloader, batch_size, shuffle=True):
    dataset = tf.data.Dataset.from_generator(
        lambda: dataloader,
        output_types=(tf.float32, tf.float32),
        output_shapes=(tf.TensorShape([256, 512]), tf.TensorShape([2,]))
    )
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(dataloader.dataset))
    dataset = dataset.batch(batch_size)
    return dataset

但它根本不起作用。

是否有一个内置选项可以轻松地将dataloaders 导出到tf.Datasets?我有一个非常复杂的数据加载器,所以一个简单的解决方案应该确保没有错误:)

【问题讨论】:

  • 你的 pytorch 数据加载器是在哪里定义的?
  • 我认为您可以尝试使用 MNIST 数据加载器作为示例 - 但问题特别是一种更简单的内置方式,而不是弄乱函数

标签: python tensorflow pytorch data-conversion


【解决方案1】:

对于 h5py 格式的数据,您可以使用下面的脚本。 name_x 是 h5py 中的功能名称,name_y 是标签的文件名。此方法内存效率高,您可以批量输入数据。

class Generator(object):

def __init__(self,open_directory,batch_size,name_x,name_y):

    self.open_directory = open_directory

    data_f = h5py.File(open_directory, "r")

    self.x = data_f[name_x]
    self.y = data_f[name_y]

    if len(self.x.shape) == 4:
        self.shape_x = (None, self.x.shape[1], self.x.shape[2], self.x.shape[3])

    if len(self.x.shape) == 3:
        self.shape_x = (None, self.x.shape[1], self.x.shape[2])

    if len(self.y.shape) == 4:
        self.shape_y = (None, self.y.shape[1], self.y.shape[2], self.y.shape[3])

    if len(self.y.shape) == 3:
        self.shape_y = (None, self.y.shape[1], self.y.shape[2])

    self.num_samples = self.x.shape[0]
    self.batch_size = batch_size
    self.epoch_size = self.num_samples//self.batch_size+1*(self.num_samples % self.batch_size != 0)

    self.pointer = 0
    self.sample_nums = np.arange(0, self.num_samples)
    np.random.shuffle(self.sample_nums)


def data_generator(self):

    for batch_num in range(self.epoch_size):

        x = []
        y = []

        for elem_num in range(self.batch_size):

            sample_num = self.sample_nums[self.pointer]

            x += [self.x[sample_num]]
            y += [self.y[sample_num]]

            self.pointer += 1

            if self.pointer == self.num_samples:
                self.pointer = 0
                np.random.shuffle(self.sample_nums)
                break

        x = np.array(x,
                     dtype=np.float32)
        y = np.array(y,
                     dtype=np.float32)

        yield x, y

def get_dataset(self):
    dataset = tf.data.Dataset.from_generator(self.data_generator,
                                             output_types=(tf.float32,
                                                           tf.float32),
                                             output_shapes=(tf.TensorShape(self.shape_x),
                                                            tf.TensorShape(self.shape_y)))
    dataset = dataset.prefetch(1)

    return dataset

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2013-02-09
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2014-05-06
    • 2010-10-29
    • 2015-05-23
    • 2019-06-26
    相关资源
    最近更新 更多