【问题标题】:Predicted mask image has wrong dimension unet- TypeError: Invalid shape (2023, 2023, 256) for image data预测的蒙版图像尺寸错误 unet-TypeError: 图像数据的形状无效 (2023, 2023, 256)
【发布时间】:2022-09-23 02:19:44
【问题描述】:

我已经设法训练了一个 unet 网络,目前正试图将预测可视化。这个问题与我问过here 的这个问题有关。面罩应该有相同的尺寸并且应该是单通道的吧?

请在下面找到代码:

保存模型如下:

#load weights to network
weights_path = unet_dir + \"unet1.pt\"
device = \"cpu\"

unet = UNet(in_channels=3, out_channels=3, init_features=8)
unet.to(device)
unet.load_state_dict(torch.load(weights_path, map_location=device))


初始功能:

#define augmentations 
inference_transform = A.Compose([
    A.Resize(256, 256, always_apply=True),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 
    ToTensorV2()
])

#define function for predictions
def predict(model, img, device):
    model.eval()
    with torch.no_grad():
        images = img.to(device)
        output = model(images)
        predicted_masks = (output.squeeze() >= 0.5).float().cpu().numpy()
        
    return(predicted_masks)

#define function to load image and output mask
def get_mask(img_path):
    image = cv2.imread(img_path)
    #assert image is not None
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    original_height, original_width = tuple(image.shape[:2])
    
    image_trans = inference_transform(image = image)
    image_trans = image_trans[\"image\"]
    image_trans = image_trans.unsqueeze(0)
    
    image_mask = predict(unet, image_trans, device)
    #image_mask = image_mask.astype(np.int16)
    image_mask = cv2.resize(image_mask,(original_width, original_height),
                          interpolation=cv2.INTER_NEAREST)
    #image_mask = cv2.resize(image_mask, (original_height, original_width))
    #Y_train[n] = mask > 0.5    
    return(image_mask)



#image example
example_path = \"../input/test-image/10078.tiff\"
image = cv2.imread(example_path)
#assert image is not None
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

mask = get_mask(example_path)

#masked_img = image*np.expand_dims(mask, 2).astype(\"uint8\")

#plot the image, mask and multiplied together
fig, (ax1, ax2) = plt.subplots(2)

ax1.imshow(image)
ax2.imshow(mask)
#ax3.imshow(masked_img)

输出:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_4859/3003834023.py in <module>
     13 
     14 ax1.imshow(image)
---> 15 ax2.imshow(mask)
     16 #ax3.imshow(masked_img)

/opt/conda/lib/python3.7/site-packages/matplotlib/_api/deprecation.py in wrapper(*args, **kwargs)
    457                 \"parameter will become keyword-only %(removal)s.\",
    458                 name=name, obj_type=f\"parameter of {func.__name__}()\")
--> 459         return func(*args, **kwargs)
    460 
    461     # Don\'t modify *func*\'s signature, as boilerplate.py needs it.

/opt/conda/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1412     def inner(ax, *args, data=None, **kwargs):
   1413         if data is None:
-> 1414             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1415 
   1416         bound = new_sig.bind(ax, *args, **kwargs)

/opt/conda/lib/python3.7/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)
   5485                               **kwargs)
   5486 
-> 5487         im.set_data(X)
   5488         im.set_alpha(alpha)
   5489         if im.get_clip_path() is None:

/opt/conda/lib/python3.7/site-packages/matplotlib/image.py in set_data(self, A)
    714                 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
    715             raise TypeError(\"Invalid shape {} for image data\"
--> 716                             .format(self._A.shape))
    717 
    718         if self._A.ndim == 3:

TypeError: Invalid shape (2023, 2023, 256) for image data

输出图像:

请问有人可以在这件事上帮助我吗?

谢谢和最好的问候

施罗特迈克尔

    标签: python machine-learning image-segmentation kaggle semantic-segmentation


    【解决方案1】:

    首先,如果您使用 vscode,我建议您使用它来调试:

    https://marketplace.visualstudio.com/items?itemName=elazarcoh.simply-view-image-for-python-debugging

    在我的脑海中,我会说你应该对一个轴上的值求和,(当我将它们想象成一个热编码时):

    np.sum(A,axis = 1)

    取自:

    sum numpy ndarray with 3d array along a given axis 1

    【讨论】:

      最近更新 更多