【发布时间】:2026-01-09 03:05:03
【问题描述】:
我需要在 pytorch DataLoader 中使用 BatchSampler,而不是多次调用数据集的 __getitem__(远程数据集,每个查询都很昂贵)。
我不明白如何将批处理采样器与任何给定的数据集一起使用。
例如
class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, idx):
return self.ddf[idx] --------> This is as expensive as a batch call
def get_batch(self, batch_idx):
return self.ddf[batch_idx]
my_loader = DataLoader(MyDataset(remote_ddf),
batch_sampler=BatchSampler(Sampler(), batch_size=3))
我不明白的事情是,我如何使用我的 get_batch 函数而不是 __getitem__ 函数,在网上或 Torch 文档中都没有找到任何示例。
编辑:
按照 Szymon Maszke 的回答,这是我尝试过的,但是 \_\_get_item__ 每次调用都会获得一个索引,而不是大小为 batch_size 的列表
class Dataset(Dataset):
def __init__(self):
...
def __len__(self):
...
def __getitem__(self, batch_idx): ------> here I get only one index
return self.wiki_df.loc[batch_idx]
loader = DataLoader(
dataset=dataset,
batch_sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
num_workers=self.hparams.num_data_workers,
)
【问题讨论】:
标签: pytorch dataloader