【发布时间】:2020-12-09 19:59:08
【问题描述】:
我正在做 covid-19 分类。我从 kaggle 获取数据集。它有一个名为 dataset 的文件夹,其中包含 3 个文件夹 normal pnuemonia 和 covid-19,每个文件夹都包含这些类的图像我被困在 pytorch 自定义数据加载器中编写 getitem ? 数据集有 189 张 covid 图像,但通过这个获取项目,我得到了 920 张 covid 图像,请帮助
class_names = ['normal', 'viral', 'covid']
root_dir = 'COVID-19 Radiography Database'
source_dirs = ['NORMAL', 'Viral Pneumonia', 'COVID-19']
if os.path.isdir(os.path.join(root_dir, source_dirs[1])):
os.mkdir(os.path.join(root_dir, 'test'))
for i, d in enumerate(source_dirs):
os.rename(os.path.join(root_dir, d), os.path.join(root_dir, class_names[i]))
for c in class_names:
os.mkdir(os.path.join(root_dir, 'test', c))
for c in class_names:
images = [x for x in os.listdir(os.path.join(root_dir, c)) if x.lower().endswith('png')]
selected_images = random.sample(images, 30)
for image in selected_images:
source_path = os.path.join(root_dir, c, image)
target_path = os.path.join(root_dir, 'test', c, image)
shutil.move(source_path, target_path)
以上代码用于创建每个类别有 30 张图像的测试数据集
class ChestXRayDataset(torch.utils.data.Dataset):
def __init__(self, image_dirs, transform):
def get_images(class_name):
images = [x for x in os.listdir(image_dirs[class_name]) if
x[-3:].lower().endswith('png')]
print(f'Found {len(images)} {class_name} examples')
return images
self.images = {}
self.class_names = ['normal', 'viral', 'covid']
for class_name in self.class_names:
self.images[class_name] = get_images(class_name)
self.image_dirs = image_dirs
self.transform = transform
def __len__(self):
return sum([len(self.images[class_name]) for class_name in self.class_names])
def __getitem__(self, index):
class_name = random.choice(self.class_names)
index = index % len(self.images[class_name])
image_name = self.images[class_name][index]
image_path = os.path.join(self.image_dirs[class_name], image_name)
image = Image.open(image_path).convert('RGB')
return self.transform(image), self.class_names.index(class_name)
**卡在获取此项目的**
images in folder are arranged as follows Dataset is as follows
**混淆矩阵的代码是**
nb_classes = 3
confusion_matrix = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
for data in tqdm_notebook(dl_train,total=len(dl_train),unit='batch'):
img,lab = data
print(lab)
img,lab = img.to(device),lab.to(device)
_,output = torch.max(model(img),1)
print(output)
for t, p in zip(lab.view(-1), output.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
混淆矩阵的输出只有一个类正在接受训练 confusio matrix image
【问题讨论】:
-
only one class is getting trained- 很可能是因为您的模型太弱,无法做任何比一直预测最常见类别更有用的事情。
标签: machine-learning deep-learning pytorch