【问题标题】:Confusion Matrix in Keras+TensorflowKeras+Tensorflow 中的混淆矩阵
【发布时间】:2018-04-25 16:36:30
【问题描述】:

第一季度

我已经训练了一个 CNN 模型并将其保存为 model.h5。我正在尝试检测 3 个对象。比如说“猫”、“狗”和“其他”。我的测试集有 300 张图片,每个类别 100 张。第一个 100 是“猫”,第二个 100 是“狗”,第三个 100 是“其他”。我正在使用 Keras 类 ImageDataGeneratorflow_from_directory。这是示例代码:

test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='sparse',
        shuffle=False)

现在使用

from sklearn.metrics import confusion_matrix

cnf_matrix = confusion_matrix(y_test, y_pred)

我需要y_testy_pred。我可以使用以下代码获得y_pred

probabilities = model.predict_generator(test_generator)
y_pred = np.argmax(probabilities, axis=1)
print (y_pred)

[0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 1 0 0 0 0 0 0 1 0 0 0
 0 0 0 0 1 0 0 0 0 1 2 0 2 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 1 1
 0 2 0 0 0 0 1 0 0 0 0 0 0 1 0 2 0 1 0 0 1 0 0 1 0 0 1 1 1 1 1 1 1 1 1 1 2
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1
 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 2 2 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 2
 1 1 1 1 1 2 1 1 1 1 1 2 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 1 2 2 2 1 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2]

这基本上将对象预测为 0,1 和 2。现在我知道前 100 个对象(猫)是 0,第二个 100 个对象(狗)是 1,第三个 100 个对象(其他)是 2。我创建手动使用numpy 的列表,其中第一个 100 点为 0,第二个 100 点为 1,第三个 100 点为 2 以获得y_test?是否有任何 Keras 类可以做到(创建 y_test)?

第二季度

如何查看错误检测到的对象。如果你查看print(y_pred),第 3 点是 1,这是错误预测的。如果不手动进入我的“test_dir”文件夹,如何查看该图像?

【问题讨论】:

    标签: python-3.x keras confusion-matrix


    【解决方案1】:

    由于您没有使用任何增强和shuffle=False,您可以简单地从生成器中获取图像:

    imgBatch = next(test_generator)
        #it may be interesting to create the generator again if 
        #you're not sure it has output exactly all images before
    

    使用 Pillow (PIL) 或 MatplotLib 等绘图库在 imgBatch 中绘制每个图像。

    要仅绘制所需的图像,请将y_testy_pred 进行比较:

    compare = y_test == y_pred
    
    position = 0
    while position < len(y_test):
        imgBatch = next(test_generator)
        batch = imgBatch.shape[0]
    
        for i in range(position,position+batch):
            if compare[i] == False:
                plot(imgBatch[i-position])
    
        position += batch
    

    【讨论】:

    • #Q2:我认为这会在test_generator 中绘制每张图片。我想知道只绘制“错误检测”图像的方法是什么?
    • 如果生成器是 keras Sequence,则可以按索引获取图像,否则,您必须获取图像,直到达到所需的索引。
    猜你喜欢
    • 2020-03-08
    • 2018-09-26
    • 2018-06-06
    • 2017-05-27
    • 2018-11-22
    • 2018-11-06
    • 2019-05-26
    • 2021-02-01
    • 2018-08-06
    相关资源
    最近更新 更多