【问题标题】:Batchsize in input shape of chainer CNNchainer CNN 输入形状中的批量大小
【发布时间】:2019-01-07 07:19:16
【问题描述】:

我有一个包含 9957 张图像的训练集。训练集的形状为 (9957, 3, 60, 80)。 将训练集用于模型时是否需要批量大小? 如果需要,是否可以认为原始形状正确适合 conv2D 层,或者我是否需要将 batchsize 添加到 input_shape?

X_train.shape

(9957, 60,80,3) 从chainer.datasets导入split_dataset_random 从chainer.dataset导入DatasetMixin

import numpy as np


class MyDataset(DatasetMixin):
   def __init__(self, X, labels):
       super(MyDataset, self).__init__()
       self.X_ = X
       self.labels_ = labels
       self.size_ = X.shape[0]

   def __len__(self):
       return self.size_

   def get_example(self, i):
       return np.transpose(self.X_[i, ...], (2, 0, 1)), self.labels_[i] 


batch_size = 3

label_train = y_trainHot1
dataset = MyDataset(X_train1, label_train)
dataset_train, valid = split_dataset_random(dataset, 8000, seed=0)
train_iter = iterators.SerialIterator(dataset_train, batch_size)
valid_iter = iterators.SerialIterator(valid, batch_size, repeat=False, 
shuffle=False)

【问题讨论】:

    标签: python-3.x conv-neural-network chainer


    【解决方案1】:

    下面的代码告诉您,您不必自己关心批量大小。您只需按照chainer教程中的说明使用DatsetMixinSerialIterator即可。

    from chainer.dataset import DatasetMixin
    from chainer.iterators import SerialIterator
    import numpy as np
    
    NUM_IMAGES = 9957
    NUM_CHANNELS = 3  # RGB
    IMAGE_WIDTH = 60
    IMAGE_HEIGHT = 80
    
    NUM_CLASSES = 10
    
    BATCH_SIZE = 32
    
    TRAIN_SIZE = min(8000, int(NUM_IMAGES * 0.9))
    
    images = np.random.rand(NUM_IMAGES, NUM_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT)
    labels = np.random.randint(0, NUM_CLASSES, (NUM_IMAGES,))
    
    
    class MyDataset(DatasetMixin):
        def __init__(self, images_, labels_):
            # note: input arg.'s tailing underscore is just to avoid shadowing
            super(MyDataset, self).__init__()
            self.images_ = images_
            self.labels_ = labels_
            self.size_ = len(labels_)
    
        def __len__(self):
            return self.size_
    
        def get_example(self, i):
            return self.images_[i, ...], self.labels_[i]
    
    
    dataset_train = MyDataset(images[:TRAIN_SIZE, ...], labels[:TRAIN_SIZE])
    dataset_valid = MyDataset(images[TRAIN_SIZE:, ...], labels[TRAIN_SIZE:])
    train_iter = SerialIterator(dataset_train, BATCH_SIZE)
    valid_iter = SerialIterator(dataset_valid, BATCH_SIZE, repeat=False, shuffle=False)
    
    ###############################################################################
    """This block is just for the confirmation.
    
    .. note: NOT recommended to call :func:`concat_examples` in your code.
        Use :class:`chainer.updaters.StandardUpdater` instead. 
    """
    from chainer.dataset import concat_examples
    
    batch_image, batch_label = concat_examples(next(train_iter))
    print("batch_image.shape\n{}".format(batch_image.shape))
    print("batch_label.shape\n{}".format(batch_label.shape))
    

    输出

    batch_image.shape
    (32, 3, 60, 80)
    batch_label.shape
    (32,)
    

    需要注意的是chainer.dataset.concat_example是一个有点棘手的部分。通常用户不会关注这个功能,如果你使用StandardUpdater隐藏了原生功能chainer.dataset.concat_example

    由于chainer是在Trainer(Standard)Updater、一些Optimizer(Serial)IteratorDataset(Mixin)的方案上设计的,如果不按照这个方案,就得跳入@987654333的大海@源代码。

    【讨论】:

      猜你喜欢
      • 2019-12-06
      • 2018-11-21
      • 1970-01-01
      • 2020-12-10
      • 2018-08-23
      • 2020-10-20
      • 1970-01-01
      • 1970-01-01
      • 2022-11-28
      相关资源
      最近更新 更多