【问题标题】:Pytorch custom Dataset class giving wrong outputPytorch 自定义数据集类给出错误的输出
【发布时间】:2020-02-19 15:04:27
【问题描述】:

我正在尝试使用我为数据集构建的这个类,但它说它应该是 PIL 或 ndarray。我不太确定它有什么问题。这是我正在使用的课程

class RotateDataset(Dataset):
    def __init__(self, image_list, size,transform = None):
        self.image_list = image_list
        self.size = size
        self.transform = transform
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self, idx):

        img = cv2.imread(self.image_list[idx])
        image_height, image_width = img.shape[:2]
        print("ID: ", idx)
        if idx % 2 == 0:
            label = 0 # Set label
            # chose negative or positive rotation
            rotation_degree = random.randrange(35, 50, 1)
            posnegrot = np.random.randint(2)
            if posnegrot == 0:
                #positive rotation
                #rotation_matrix = cv2.getRotationMatrix2D((num_cols/2, num_rows/2), rotation_degree, 1)
                #img = cv2.warpAffine(img, rotation_matrix, (num_cols, num_rows))

                img = rotate_image(img, rotation_degree)
                img = crop_around_center(img, *largest_rotated_rect(image_width,
                                                                image_height,
                                                                math.radians(rotation_degree)))
            else:
                # Negative rotation
                rotation_degree = -rotation_degree
                img = crop_around_center(img, *largest_rotated_rect(image_width,
                                                                image_height,
                                                                math.radians(rotation_degree)))

        else:
           label = 1
        img = cv2.resize(img, self.size, cv2.INTER_AREA)
        return self.transform(img), self.transform(label)

它给我的错误是

TypeError: pic 应该是 PIL Image 或 ndarray。得到类'int'

它应该给我一个 img(张量)和一个标签(张量) 但我认为它做得不对。

TypeError                                 Traceback (most recent call last)
<ipython-input-34-f47943b2600c> in <module>
      2     train_loss = 0.0
      3     net.train()
----> 4     for image, label in enumerate(train_loader):
      5         if train_on_gpu:
      6             image, label = image.cuda(), label.cuda()

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-28-6c77357ff619> in __getitem__(self, idx)
     35             label = 1
     36         img = cv2.resize(img, self.size, cv2.INTER_AREA)
---> 37         return self.transform(img), self.transform(label)

~\Anaconda3\envs\TF2\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, pic)
     99             Tensor: Converted image.
    100         """
--> 101         return F.to_tensor(pic)
    102 
    103     def __repr__(self):

~\Anaconda3\envs\TF2\lib\site-packages\torchvision\transforms\functional.py in to_tensor(pic)
     53     """
     54     if not(_is_pil_image(pic) or _is_numpy(pic)):
---> 55         raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
     56 
     57     if _is_numpy(pic) and not _is_numpy_image(pic):

TypeError: pic should be PIL Image or ndarray. Got <class 'int'>

【问题讨论】:

  • 你能粘贴完整的错误(指向发生错误的行)吗?
  • 问题(很可能)在于你如何定义你的变换,你能分享它的代码吗?
  • TRANSFOM = transforms.ToTensor()
  • 为什么要对label 应用转换? self.transform(label)。原样返回 return label?
  • 它给出了这个问题 所以我认为这不是标签问题

标签: python pytorch


【解决方案1】:

正如 cmets 中所讨论的,问题还在于在 label 上应用变换。 label 应该简单地写成张量:

return self.transform(img), torch.tensor(label)

【讨论】:

    猜你喜欢
    • 2023-03-10
    • 1970-01-01
    • 2018-09-17
    • 2020-12-18
    • 2020-07-06
    • 1970-01-01
    • 2020-10-20
    • 2021-10-07
    • 2020-09-28
    相关资源
    最近更新 更多