【发布时间】:2022-01-02 11:33:44
【问题描述】:
我遇到了问题!
最近遇到一个 I/O 问题。目标和输入数据存储在 h5py 文件中。每个目标文件为 2.6GB,而每个输入文件为 10.2GB。我总共有 5 个输入数据集和 5 个目标数据集。
我为每个 h5py 文件创建了一个自定义数据集函数,然后使用data.ConcatDataset 类链接所有数据集。自定义数据集函数为:
class MydataSet(Dataset):
def __init__(self, indx=1, root_path='./xxx', tar_size=128, data_aug=True, train=True):
self.train = train
if self.train:
self.in_file = pth.join(root_path, 'train', 'train_noisy_%d.h5' % indx)
self.tar_file = pth.join(root_path, 'train', 'train_clean_%d.h5' % indx)
else:
self.in_file = pth.join(root_path, 'test', 'test_noisy.h5')
self.tar_file = pth.join(root_path, 'test', 'test_clean.h5')
self.h5f_n = h5py.File(self.in_file, 'r', driver='core')
self.h5f_c = h5py.File(self.tar_file, 'r')
self.keys_n = list(self.h5f_n.keys())
self.keys_c = list(self.h5f_c.keys())
# h5f_n.close()
# h5f_c.close()
self.tar_size = tar_size
self.data_aug = data_aug
def __len__(self):
return len(self.keys_n)
def __del__(self):
self.h5f_n.close()
self.h5f_c.close()
def __getitem__(self, index):
keyn = self.keys_n[index]
keyc = self.keys_c[index]
datan = np.array(self.h5f_n[keyn])
datac = np.array(self.h5f_c[keyc])
datan_tensor = torch.from_numpy(datan).unsqueeze(0)
datac_tensor = torch.from_numpy(datac)
if self.data_aug and np.random.randint(2, size=1)[0] == 1: # horizontal flip
datan_tensor = torch.flip(datan_tensor,dims=[2]) # c h w
datac_tensor = torch.flip(datac_tensor,dims=[2])
然后我使用dataset_train = data.ConcatDataset([MydataSet(indx=index, train=True) for index in range(1, 6)]) 进行训练。当只使用 2-3 个 h5py 文件时,I/O 速度正常,一切正常。但是,当使用 5 个文件时,训练速度逐渐降低(5 次迭代/秒到 1 次迭代/秒)。我换了num_worker,问题依旧存在。
谁能给我一个解决方案?我应该将几个 h5py 文件合并成一个更大的文件吗?还是其他方法?提前致谢!
【问题讨论】:
-
如果您有 5 个输入文件,每个文件都是 10.2GB,这是否意味着合并的数据需要 51GB RAM(加上 13GB 用于目标数据)?如果是这样,那是很多内存。首先要确定的是性能瓶颈。它可以 a) 用这么多数据训练模型,b)
data.ConcatDataset()大量数据集的性能,或 c) 具有大量文件的类 MydataSet() 性能。如果要合并 HDF5 文件,这很容易(假设所有文件都有相似的架构和唯一的数据集名称)。 -
@kcw78 您好,感谢您的评论。我不明白你在(a)中的意思。我只使用了一个文件(每个文件都经过测试)进行训练(没有
ConcatDataset),训练正常。我之前用ConcatDataset处理较小的h5py文件,训练也正常。 -
@kcw78 对于 (c),
class MydataSet()仅处理一个 h5py 文件,ConcatDataset连接多个MydataSet类。它可以处理很多文件。
标签: python pytorch h5py pytorch-dataloader