【问题标题】:Two confusion matrix plots have different sizes两个混淆矩阵图的大小不同
【发布时间】:2021-10-26 16:36:25
【问题描述】:

我正在尝试在同一张图像中绘制混淆矩阵,但它们的大小不同。

代码如下:

fig, ax = plt.subplots(nrows=1, ncols=2, figsize = (18,8))
fig.suptitle('Matriz de Confusão')


skplt.metrics.plot_confusion_matrix(y_test, y_pred_log, normalize=True, ax=ax[0], title=('Regressão Logística'))
skplt.metrics.plot_confusion_matrix(y_test, y_pred_tree, normalize=True, ax=ax[1], title=('Árvore de decisão'))
ax[0].xaxis.set_ticklabels(['Normal', 'Fraude']); ax[0].yaxis.set_ticklabels(['Normal', 'Fraude']);
ax[1].xaxis.set_ticklabels(['Normal', 'Fraude']); ax[1].yaxis.set_ticklabels(['Normal', 'Fraude']);


plt.show()

这就是我得到的:

如何更改第二个图的大小?
另外,如果我可以删除多余的颜色条会很好。

【问题讨论】:

  • fig, ax = ... 之后尝试ax = ax.ravel() 看看会发生什么。虽然我无法在matplotlib 3.4.3 中重现此问题
  • plot_confusion_matrix 有参数colorbar=False,但是,删除颜色条会改变矩阵的大小,这也是为什么你的两个绘图轴大小不同的原因。一个没有颜色条,一个有两个,所以 API 调整了它们的大小。由于这不包含完整的minimal reproducible example,因此无法重现。

标签: python matplotlib plot scikit-learn confusion-matrix


【解决方案1】:

您应该定义需要放置颜色条的轴。您可以查看this answer 作为参考。
将这些概念应用于您的案例将导致类似于以下内容:

import matplotlib.pyplot as plt
import numpy as np


M1 = np.random.rand(2, 2)
M2 = np.random.rand(2, 2)


fig, ax = plt.subplots(1, 2, figsize = (18, 8))
plt.subplots_adjust(right = 0.77)
cbar_ax_1 = fig.add_axes([0.8, 0.1, 0.04, 0.8])
cbar_ax_2 = fig.add_axes([0.9, 0.1, 0.04, 0.8])

im_1 = ax[0].imshow(M1, cmap = 'magma')
im_2 = ax[1].imshow(M2, cmap = 'magma')

plt.colorbar(im_1, cax = cbar_ax_1)
plt.colorbar(im_2, cax = cbar_ax_2)

plt.show()


如果您只想要一个颜色条,明智的做法是根据两个矩阵的值对唯一的颜色条进行归一化:

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import Normalize
from matplotlib import cm


M1 = np.random.rand(2, 2)
M2 = np.random.rand(2, 2)


fig, ax = plt.subplots(1, 2, figsize = (18, 8))
plt.subplots_adjust(right = 0.87)
cbar_ax = fig.add_axes([0.9, 0.1, 0.04, 0.8])
norm = Normalize(vmin = min(np.min(M1), np.min(M2)), vmax = max(np.max(M1), np.max(M2)))
cmap = cm.magma

im_1 = ax[0].imshow(M1, cmap = cmap)
im_2 = ax[1].imshow(M2, cmap = cmap)

plt.colorbar(cm.ScalarMappable(norm = norm, cmap = cmap), cax = cbar_ax)

plt.show()

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2018-11-27
    • 2022-01-02
    • 2020-02-23
    • 2019-11-23
    • 2020-07-09
    • 2019-09-23
    • 1970-01-01
    相关资源
    最近更新 更多