【问题标题】:Custom dataset Loader pytorch自定义数据集加载器 pytorch
【发布时间】:2020-12-09 19:59:08
【问题描述】:

我正在做 covid-19 分类。我从 kaggle 获取数据集。它有一个名为 dataset 的文件夹,其中包含 3 个文件夹 normal pnuemonia 和 covid-19,每个文件夹都包含这些类的图像我被困在 pytorch 自定义数据加载器中编写 getitem ? 数据集有 189 张 covid 图像,但通过这个获取项目,我得到了 920 张 covid 图像,请帮助

class_names = ['normal', 'viral', 'covid']
root_dir = 'COVID-19 Radiography Database'
source_dirs = ['NORMAL', 'Viral Pneumonia', 'COVID-19']

 if os.path.isdir(os.path.join(root_dir, source_dirs[1])):
   os.mkdir(os.path.join(root_dir, 'test'))

for i, d in enumerate(source_dirs):
    os.rename(os.path.join(root_dir, d), os.path.join(root_dir, class_names[i]))

for c in class_names:
    os.mkdir(os.path.join(root_dir, 'test', c))

for c in class_names:
    images = [x for x in os.listdir(os.path.join(root_dir, c)) if x.lower().endswith('png')]
    selected_images = random.sample(images, 30)
    for image in selected_images:
        source_path = os.path.join(root_dir, c, image)
        target_path = os.path.join(root_dir, 'test', c, image)
        shutil.move(source_path, target_path)

以上代码用于创建每个类别有 30 张图像的测试数据集

 class ChestXRayDataset(torch.utils.data.Dataset):
   def __init__(self, image_dirs, transform):
      def get_images(class_name):
        images = [x for x in os.listdir(image_dirs[class_name]) if 
        x[-3:].lower().endswith('png')]
        print(f'Found {len(images)} {class_name} examples')
        return images
    
    self.images = {}
    self.class_names = ['normal', 'viral', 'covid']
    
    for class_name in self.class_names:
        self.images[class_name] = get_images(class_name)
        
    self.image_dirs = image_dirs
    self.transform = transform
    

def __len__(self):
    return sum([len(self.images[class_name]) for class_name in self.class_names])


def __getitem__(self, index):
    class_name = random.choice(self.class_names)
    index = index % len(self.images[class_name])
    image_name = self.images[class_name][index]
    image_path = os.path.join(self.image_dirs[class_name], image_name)
    image = Image.open(image_path).convert('RGB')
    return self.transform(image), self.class_names.index(class_name)

**卡在获取此项目的**

images in folder are arranged as follows Dataset is as follows

**混淆矩阵的代码是**

nb_classes = 3

 confusion_matrix = torch.zeros(nb_classes, nb_classes)
 with torch.no_grad():
 for data in tqdm_notebook(dl_train,total=len(dl_train),unit='batch'):
    img,lab = data
    print(lab)
    img,lab = img.to(device),lab.to(device)
    _,output = torch.max(model(img),1)
    print(output)
    
    for t, p in zip(lab.view(-1), output.view(-1)):
            confusion_matrix[t.long(), p.long()] += 1

混淆矩阵的输出只有一个类正在接受训练 confusio matrix image

【问题讨论】:

  • only one class is getting trained - 很可能是因为您的模型太弱,无法做任何比一直预测最常见类别更有用的事情。

标签: machine-learning deep-learning pytorch


【解决方案1】:

将图像放入字典会使操作复杂化,而不是使用列表。此外,您的 Dataset 不应该有任何随机性,数据的改组应该来自 DataLoader 而不是来自 Dataset。

使用类似下面的东西:

 class ChestXRayDataset(torch.utils.data.Dataset):
   def __init__(self, image_dirs, transform):
      def get_images(class_name):
        images = [x for x in os.listdir(image_dirs[class_name]) if 
        x[-3:].lower().endswith('png')]
        print(f'Found {len(images)} {class_name} examples')
        return images
    
    self.images = []
    self.labels = []
    self.class_names = ['normal', 'viral', 'covid']
    
    for class_name in self.class_names:
        images = get_images(class_name)
        # This is a list containing all the images
        self.images.extend(images)
        # This is a list containing all the corresponding image labels
        self.labels.extend([class_name]*len(images))
        
    self.image_dirs = image_dirs
    self.transform = transform
    

def __len__(self):
    return len(self.images)

# Will return the image and its label at the position `index`
def __getitem__(self, index):
    # image at index position of all the images
    image_name = self.images[index]
    # Its label 
    class_name = self.labels[index]
    image_path = os.path.join(self.image_dirs[class_name], image_name)
    image = Image.open(image_path).convert('RGB')
    return self.transform(image), self.class_names.index(class_name)

如果你列举它说使用

ds = ChestXRayDataset(image_dirs, transform)
for x, y in ds:
   print (x.shape, y)

您应该按顺序看到所有图像和标签。

但是在实际情况下,您更愿意使用 Torch DataLoader 并将其传递给 ds 对象,并将随机参数设置为 True。因此,DataLoader 将通过调用 __getitem__ 并使用混洗后的 index 值来处理数据集的混洗。

【讨论】:

  • 你能解释一下我们在做什么吗?很抱歉打扰您
  • @MuditGupta 用 cmets 更新了答案
  • 我用过,但它不能正常工作我应该添加指向 colab 笔记本的链接,以便您可以建议我更改
  • 我在那里工作得很好,但是在训练后打印的混淆矩阵是 liketensor([[1311., 0., 0.], [1315., 0., 0.], [ 189., 0. , 0.]]) 为什么会这样
  • 因为您有三个类别,所以您的混淆矩阵的大小将符合预期的3 X 3
猜你喜欢
  • 2021-11-09
  • 1970-01-01
  • 2019-11-08
  • 2019-12-30
  • 2017-09-12
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多