【发布时间】:2017-12-15 05:54:51
【问题描述】:
我正在使用 TensorFlow 的全新 Object Detection API,并决定在其他一些公开可用的数据集上对其进行训练。
我偶然发现了thisgrocery 数据集,该数据集由超市货架上各种品牌的香烟盒的图像以及一个列出每个图像中每个香烟盒的边界框的文本文件组成。数据集中有 10 个主要品牌被标记,所有其他品牌都属于第 11 个“杂项”类别。
我关注了他们的tutorial,并设法在这个数据集上训练了模型。由于处理能力的限制,我只使用了三分之一的数据集,并对训练和测试数据进行了 70:30 的拆分。我使用了 faster_rcnn_resnet101 模型。我的配置文件中的所有参数都和TF提供的默认参数一样。
在 16491 个全局步骤之后,我在一些图像上测试了模型,但我对结果不太满意 -
我遇到的另一个问题是模型从未检测到除标签 1 之外的任何其他标签
未从训练数据中检测到产品的裁剪实例
即使在负面图像中,它也能以 99% 的置信度检测烟盒!
有人可以帮我解决问题吗?我可以做些什么来提高准确性?为什么它检测到所有产品都属于类别 1,尽管我提到了总共有 11 个类别?
编辑添加了我的标签图:
item {
id: 1
name: '1'
}
item {
id: 2
name: '2'
}
item {
id: 3
name: '3'
}
item {
id: 4
name: '4'
}
item {
id: 5
name: '5'
}
item {
id: 6
name: '6'
}
item {
id: 7
name: '7'
}
item {
id: 8
name: '8'
}
item {
id: 9
name: '9'
}
item {
id: 10
name: '10'
}
item {
id: 11
name: '11'
}
【问题讨论】:
-
你能提供你工作的标签图吗?
-
@JonathanHuang 我在编辑中添加了我的标签映射
-
谢谢,看起来不错。正如其他人所提到的那样,您可能需要更多数据,但我很困惑为什么您总是预测同一个类......也许您需要再次仔细检查 TFRecord 文件?
-
我注意到标签以某种方式限制为 20..
-
@BanachTarski 干得好。您可以分享您从杂货数据集创建 tfrecord 的代码吗?
标签: python machine-learning tensorflow classification object-detection