【发布时间】:2020-08-19 22:33:49
【问题描述】:
在 python 中,我有一个迭代器返回固定范围[0, N] 中称为Sampler 的无限索引字符串。实际上我有一个列表,它们所做的只是返回[0, N_0], [N_0, N_1], ..., [N_{n-1}, N_n].范围内的索引
我现在要做的是首先根据它们的范围长度选择其中一个迭代器,所以我有一个 weights 列表 [N_0, N_1 - N_0, ...] 并选择其中一个:
iterator_idx = random.choices(range(len(weights)), weights=weights/weights.sum())[0]
接下来,我要做的是创建一个迭代器,它随机选择一个迭代器并选择一批M 样本。
class BatchSampler:
def __init__(self, M):
self.M = M
self.weights = [weight_list]
self.samplers = [list_of_iterators]
]
self._batch_samplers = [
self.batch_sampler(sampler) for sampler in self.samplers
]
def batch_sampler(self, sampler):
batch = []
for batch_idx in sampler:
batch.append(batch_idx)
if len(batch) == self.M:
yield batch
if len(batch) > 0:
yield batch
def __iter__(self):
# First select one of the datasets.
iterator_idx = random.choices(
range(len(self.weights)), weights=self.weights / self.weights.sum()
)[0]
return self._batch_samplers[iterator_idx]
这个问题是iter() 似乎只被调用一次,所以只选择了第一次iterator_idx。显然这是错误的......解决这个问题的方法是什么?
当您在 pytorch 中有多个数据集,但您只想从其中一个数据集中抽取批次时,这是一种可能的情况。
【问题讨论】: