【问题标题】:Keras fit_generator: random augmentation inside the generator + shufflingKeras fit_generator:生成器内的随机增强+洗牌
【发布时间】:2018-05-12 03:12:28
【问题描述】:

我创建了一个生成器,将其输入到 keras 的fit_generator 函数中。生成器创建一些随机值。我就是这样做的:

class DataGenerator(object):
    def __init__(self, X_Y_file_path, batch_size, N):
        self.X_Y_file_path = X_Y_file_path
        self.batch_size = size
        self.N = N

    def initialize_zeros(self):
        X = np.zeros((self.batch_size, 1), dtype='int32')
        Y = np.zeros((self.batch_size, 1), dtype='int32')
        Y_neg = np.zeros((self.batch_size, self.N))
        return X, Y, Y_neg

     def generate(self):
        while True:
            i = 0 
            X, Y, Y_neg = initialize_zeros()
            for row in load_data_per_line(self.X_Y_file_path): # load_data_per_line is generator function which goes each line at a time from one file.
                x, y = row
                y_neg = random.sample(id_list, self.N) # a list of id to pick randomly
                X[i] = x
                Y[i] = y
                Y_neg[i] = y_neg
                if i == self.batch_size:
                    yield ([X, Y_neg], Y) # Y_neg goes as input in the model.(not important here. just mentioning)
                    X, Y, Y_neg = initialize_zeros()
                    i = 0

所以这是我的发电机。使用相同的样本数据,它似乎可以正常工作。

我想知道如何在这个生成器中实现一个 shuffle 函数,以便在每个 epoch 之后随机播放?

搜索了一下,我发现Sequence 可以覆盖on_epoch_end 方法,但不清楚如何使用Sequence 继承实现上述生成器。有什么帮助吗? (顺便说一句,在fit_generator 中使用use_multiprocessing 是否是上述函数'安全'?)

编辑

X_Y_file_path 是一个文件(已知长度)。 load_data_per_line 是一个生成器函数,每行产生一个。

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    您在使用序列的正确轨道上。当与多处理一起使用时,它将保证每个数据点都被看到一次。构建序列的一种简单方法是预加载原始数据,每次请求批处理时进行动态处理。

    class MySeq(Sequence):
        def __init__(self, X_Y_file_path, batch_size, N):
            self.X_Y_file_path = X_Y_file_path
            self.batch_size = size
            self.N = N
            self.data = load_data_per_line(self.X_Y_file_path)
    
         def __len__(self):
            return int(np.ceil(len(self.data) / self.batch_size))
    
         def __getitem__(self, idx):
            # Just slice the data based on batch index (idx)
            batch_data = self.data[idx*self.batch_size:(idx+1)*self.batch_size]
            X = np.zeros((len(batch_data), 1), dtype='int32')
            Y = np.zeros((len(batch_data), 1), dtype='int32')
            Y_neg = np.zeros((len(batch_data), self.N))
            for row, i in enumerate(data):
                x, y = row
                y_neg = random.sample(id_list, self.N) # a list of id to pick randomly
                X[i] = x
                Y[i] = y
                Y_neg[i] = y_neg
            return [X, Y_neg], Y # This is a single batch
    

    现在您可以使用on_epoch_end() 对您的self.data 执行任何处理

    【讨论】:

    • 感谢您的回答。您在这里写的是文档中示例的复制粘贴。而且它不适用于问题中提供的我的案例。
    猜你喜欢
    • 2018-08-08
    • 1970-01-01
    • 1970-01-01
    • 2018-02-09
    • 1970-01-01
    • 1970-01-01
    • 2018-02-24
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多