绘制混淆矩阵热图

使用 Seaborn 绘制。

1
2
3
4
5
6
7
8
9
10
11
12
# 混淆矩阵
print('Confusion Matrix')
print(confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1)))

# 混淆矩阵 heatmap
import seaborn as sn
con = confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))

# heatmap 参数
heatmap = sn.heatmap(con, annot=True, fmt='.5g', cmap='Blues' )
# 图保存
heatmap = plt.savefig('heatmap.png')