【发布时间】:2020-05-05 18:16:18
【问题描述】:
我正在使用 Matterport Mask RCNN 作为我的模型,并且我正在尝试构建我的数据库以进行训练。经过对以下问题的深思熟虑,我想我实际上要问的是如何添加多个类(+ BG)?
我收到以下AssertionError:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-21-c20768952b65> in <module>()
15
16 # display image with masks and bounding boxes
---> 17 display_instances(image, bbox, masks, class_ids/4, train_set.class_names)
/usr/local/lib/python3.6/dist-packages/mask_rcnn-2.1-py3.6.egg/mrcnn/visualize.py in display_instances(image, boxes, masks, class_ids, class_names, scores, title, figsize, ax, show_mask, show_bbox, colors, captions)
103 print("\n*** No instances to display *** \n")
104 else:
--> 105 assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]
106
107 # If no axis is passed, create one and automatically call show()
AssertionError:
问题似乎来自 mask.shape[-1] == class_ids.shape[0] 导致 False 不应该是这种情况。
我现在追溯到masks.shape[-1] 是class_id.shape[0] 值的4 倍,我认为这可能与数据中有4 个类有关。不幸的是,我还没有想出如何解决这个问题。
# load the masks for an image
def load_mask(self, image_id):
# get details of image
info = self.image_info[image_id]
# define box file location
path = info['annotation']
# load XML
boxes, w, h = self.extract_boxes(path)
# create one array for all masks, each on a different channel
masks = zeros([h, w, len(boxes)], dtype='uint8')
# create masks
class_ids = list()
for i in range(len(boxes)):
box = boxes[i]
row_s, row_e = box[1], box[3]
col_s, col_e = box[0], box[2]
masks[row_s:row_e, col_s:col_e, i] = 1
class_ids.append(self.class_names.index('Resistor'))
class_ids.append(self.class_names.index('LED'))
class_ids.append(self.class_names.index('Capacitor'))
class_ids.append(self.class_names.index('Diode'))
return masks, asarray(class_ids, dtype='int32')
# load the masks and the class ids
mask, class_ids = train_set.load_mask(image_id)
print(mask, "and", class_ids)
# display image with masks and bounding boxes
display_instances(image, bbox, mask, class_ids, train_set.class_names)
【问题讨论】:
-
您是否验证了
masks.shape[-1] == class_ids.shape[0]对您的输入有效? -
请将您的问题减少到您作为更新提供的minimal reproducible example。调试这个小例子比调试完整代码更容易。
-
@IonicSolutions 感谢您的回复,对于您的第一条评论,我收到了
False。为冗长的代码道歉,我会减少它(老实说,我不是 100% 确定是什么部分导致它) -
不用道歉!现在你知道为什么断言失败了。您应该检查
display_instances期望mask和class_ids的格式。
标签: python-3.x tensorflow tensorflow-datasets transfer-learning faster-rcnn