【发布时间】:2020-09-05 17:27:07
【问题描述】:
我正在研究一个核分割问题,我试图在染色组织的图像中识别核的位置。给定的训练数据集有一张染色组织的图片和一个带有细胞核位置的掩码。由于数据集很小,我想尝试在 PyTorch 中进行数据增强,但在这样做之后,由于某种原因,当我输出我的蒙版图像时,它看起来很好,但对应的组织图像不正确。
我所有的训练图像都在X_train 中,形状为(128, 128, 3),对应的掩码在Y_train 中,形状为(128, 128, 1),类似的交叉验证图像和掩码分别在X_val 和Y_val。
Y_train 和Y_val 有dtype = np.bool、X_train 和X_val 有dtype = np.uint8。
在数据增强之前,我会像这样检查我的图像:
fig, axis = plt.subplots(2, 2)
axis[0][0].imshow(X_train[0].astype(np.uint8))
axis[0][1].imshow(np.squeeze(Y_train[0]).astype(np.uint8))
axis[1][0].imshow(X_val[0].astype(np.uint8))
axis[1][1].imshow(np.squeeze(Y_val[0]).astype(np.uint8))
输出如下: Before Data Augmentation
对于数据扩充,我定义了一个自定义类如下:
在这里,我将torchvision.transforms.functional 导入为TF 和torchvision.transforms as transforms。 images_np 和 masks_np 是 numpy 数组的输入。
class Nuc_Seg(Dataset):
def __init__(self, images_np, masks_np):
self.images_np = images_np
self.masks_np = masks_np
def transform(self, image_np, mask_np):
ToPILImage = transforms.ToPILImage()
image = ToPILImage(image_np)
mask = ToPILImage(mask_np.astype(np.int32))
angle = random.uniform(-10, 10)
width, height = image.size
max_dx = 0.2 * width
max_dy = 0.2 * height
translations = (np.round(random.uniform(-max_dx, max_dx)), np.round(random.uniform(-max_dy, max_dy)))
scale = random.uniform(0.8, 1.2)
shear = random.uniform(-0.5, 0.5)
image = TF.affine(image, angle = angle, translate = translations, scale = scale, shear = shear)
mask = TF.affine(mask, angle = angle, translate = translations, scale = scale, shear = shear)
image = TF.to_tensor(image)
mask = TF.to_tensor(mask)
return image, mask
def __len__(self):
return len(self.images_np)
def __getitem__(self, idx):
image_np = self.images_np[idx]
mask_np = self.masks_np[idx]
image, mask = self.transform(image_np, mask_np)
return image, mask
接下来是:
我用过from torch.utils.data import DataLoader
train_dataset = Nuc_Seg(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = True)
val_dataset = Nuc_Seg(X_val, Y_val)
val_loader = DataLoader(val_dataset, batch_size = 16, shuffle = True)
在这一步之后,我尝试使用以下方法检查我的第一组训练图像和蒙版:
%matplotlib inline
for ex_img, ex_mask in train_loader:
img = ex_img[0]
img = img.reshape(128, 128, 3)
mask = ex_mask[0]
mask = mask.reshape(128, 128)
img = img.numpy()
mask = mask.numpy()
fig, (axis_1, axis_2) = plt.subplots(1, 2)
axis_1.imshow(img.astype(np.uint8))
axis_2.imshow(mask.astype(np.uint8))
break
我得到这个作为我的输出: After Data Augmentation 1
当我将axis_1.imshow(img.astype(np.uint8)) 更改为axis_1.imshow(img) 时,
我得到这张图片: After Data Augmentation 2
面具的图像是正确的,但由于某种原因,细胞核的图像是错误的。使用.astype(np.uint8),组织图像是完全黑色的。
没有.astype(np.uint8),原子核的位置是正确的,但是配色方案全乱了(我希望图像像数据增强之前看到的那样,灰色或粉红色),加上 9 个副本出于某种原因,在网格中显示相同的图像。你能帮我得到组织图像的正确输出吗?
【问题讨论】:
标签: python matplotlib pytorch data-augmentation semantic-segmentation