【发布时间】:2019-09-26 17:06:51
【问题描述】:
我有一个 RGB 图像,它包含不同颜色的掩码,每种颜色代表一个特定的类。
我想将其转换为格式 - n_masks, image_height, image_width 其中 n_masks 是图像中存在的掩码数。并且沿第 0 轴的矩阵的每个切片代表一个二进制掩码。
到目前为止,我已经能够将它转换为image_height, image_width 的格式,其中每个数组值代表它属于哪个类,但我有点被它打动了。
下面是我将其转换为image_height,image_width 格式的代码-
def mask_to_class(mask):
target = torch.from_numpy(mask)
h,w = target.shape[0],target.shape[1]
masks = torch.empty(h, w, dtype=torch.long)
colors = torch.unique(target.view(-1,target.size(2)),dim=0).numpy()
target = target.permute(2, 0, 1).contiguous()
mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
for k in mapping:
idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
validx = (idx.sum(0) == 3)
masks[validx] = torch.tensor(mapping[k], dtype=torch.long)
return masks
它将比方说格式 (512,512,3) 的图像转换为 (512,512),其中每个像素值代表它所属的类,但我不知道如何进一步进行。
P.S- 我在 pytorch 中对其进行编码,但也欢迎任何涉及 numpy 的方法。
【问题讨论】:
标签: python arrays numpy pytorch