【问题标题】:Training stuck at Epoch 3 PyTorch训练停留在 Epoch 3 PyTorch
【发布时间】:2021-07-28 23:56:06
【问题描述】:

我正在训练一个自定义编码器-解码器网络,但训练卡在 Epoch 3。大约 2 小时内没有任何反应。我将分享 Dataset 类和 DataLoader 对象。 CUDA 和 GPU 的版本见下图。

训练卡在这里:

nvidia-smi 输出如下所示:

数据集类的__getitem__ 方法如下所示:

    def __init__(self,
                 images_dir,
                 annots_dir,
                 train=True,
                 img_size=(512, 1536),
                 stride=4,
                 model='custom',
                 transforms=None):
        """
        :param root: dataset directory
        :param filenames: filenames inside the root directory
        :param labels: Object Detection Labels
        super(CustomDataset).__init__()
        self.images_dir = images_dir
        self.annots_dir = annots_dir
        self.train = train
        self.image_size = img_size
        self.stride = stride
        self.transforms = transforms
        self.model = model

        # Load the image and annotation files from the dataset
        # self.image_files, self.annot_files = self._load_image_and_annot_files()

        self.image_files = [os.path.join(self.images_dir, idx) for idx in os.listdir(self.images_dir)]
        self.annot_files = [os.path.join(self.annots_dir, idx) for idx in os.listdir(self.annots_dir)]

    def __getitem__(self, index):
        """
        :param index: index...0 to N
        :return: tensor_image and tensor_label
        """
        # Image filename from _load_image_files()
        # Load Image with _read_matrix() and label
        curr_image_filename = self.image_files[index]
        curr_annot_filename = self.annot_files[index]
        # curr_image_filename = self.image_files[index]
        # curr_annot_filename = self.annot_files[index]
        np_image = self._read_matrix(raw_img=curr_image_filename)
        np_image_normalized = np.squeeze(self._normalize_raw_img(np_image))

        # label = self.labels[index]
        boxes, classes, depths, tgts = self._load_annotations(curr_annot_filename)

        # Normalize bounding boxes: range [0, 1]
        targets_normalized = self._normalize_bbox(np_image_normalized, tgts)

        # image and the corresponding label should be a tensor
        torch_image = torch.from_numpy(np_image).reshape(1, 512, 1536).float()  # dtype: torch.float64
        torch_boxes = torch.from_numpy(boxes).type(torch.FloatTensor)
        torch_depths = torch.from_numpy(depths)

        if self.model == 'fasterrcnn':
            # For FasterRCNN: As COCO format
            area = (torch_boxes[:, 3] - torch_boxes[:, 1]) * (torch_boxes[:, 2] - torch_boxes[:, 0])
            iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
            image_id = torch.Tensor([index])
            torch_classes = torch.from_numpy(classes)

            target = {'boxes': torch_boxes, 'labels': torch_classes.long(),
                      'area': area, 'iscrowd': iscrowd, 'image_id': image_id}

            return torch_image, target

        elif self.model == 'custom':
            if self.train:
                if self.transforms:
                    try:
                        tr = self.transforms()
                        transform_image, transform_boxes, labels = tr.__call__(np_image, tgts, tgts[:, :4], tgts[:, 4:])
                        transform_targets = np.hstack((np.array(transform_boxes), labels))
                        gt_tensor = gt_creator(img_size=self.image_size,
                                               stride=self.stride,
                                               num_classes=8,
                                               label_lists=transform_targets)
                        return torch.from_numpy(transform_image).float(), gt_tensor
                    except IndexError:
                        pass
                else:
                    gt_tensor = gt_creator(img_size=self.image_size,
                                           stride=self.stride,
                                           num_classes=8,
                                           label_lists=targets_normalized)
                    return torch_image, gt_tensor
            else:
                return torch_image, targets_normalized

而在 train.py 脚本中,DataLoader 对象是:

    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               shuffle=True,
                                               batch_size=1,
                                               num_workers=0,
                                               collate_fn=detection_collate,
                                               pin_memory=True)

为什么训练会卡住? __getitem__ 方法有问题吗?还是DataLoader

谢谢。

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    发生这种情况是因为 Torch 不会重新启动您的数据集,如果您的数据用完,它会停止并等待更多输入,因此必须手动完成循环。

    我使用了一些类似的东西

    from itertools import cycle
    
    class Dataloader(): 
        #init and whatever
        self.__iter__():
            return cycle(get_sample()) # get_sample is your current getitem
    

    【讨论】:

    • 等等,我不明白。如果仍然是 2256 个样本中的第 1794 个样本,数据怎么会用完?另外,您是否建议创建一个不同的 DataLoader 类?get_sample() 然后是 img, tgt = dataset[100] # 只是一个例子
    • 因为它没有足够的时间用于下一个时代。我会尽力挽救你能做的事
    • 有趣。但是,您编写的代码是否正确?因为我用 Dataset 类中的 __getitem__() 替换了 gt_sample() 并且语法无效。
    • @duddal 它直接取自我的代码,所以是的,它对我有用。试一试,语法错误总是比完全不知道发生了什么更容易修复:D
    猜你喜欢
    • 2018-11-27
    • 2022-07-05
    • 2019-07-15
    • 2018-10-13
    • 2020-10-22
    • 1970-01-01
    • 1970-01-01
    • 2020-02-03
    • 1970-01-01
    相关资源
    最近更新 更多