混淆矩阵的生成(python实现,含机器学习方法)

2024-05-01 1412阅读

混淆矩阵(Confusion Matrix)是用于评估分类模型性能的一种表格形式。它显示了在分类问题中模型的预测结果与实际标签之间的各种组合情况。

混淆矩阵通常用于二分类问题,但也可以扩展到多分类问题。对于二分类问题,它由四个重要的指标组成:

真正例(True Positive, TP):模型预测为正例,并且实际上是正例的数量。

真反例(True Negative, TN):模型预测为反例,并且实际上是反例的数量。

假正例(False Positive, FP):模型预测为正例,但实际上是反例的数量。也称为"误报"。

假反例(False Negative, FN):模型预测为反例,但实际上是正例的数量。也称为"漏报"。

混淆矩阵的一般形式如下:

混淆矩阵的生成(python实现,含机器学习方法)

使用混淆矩阵可以计算多个衡量分类器性能的指标,如准确率(Accuracy)、精确率(Precision)、召回率(Recall,也称为敏感度或真正例率)和 F1 值等。这些指标可以通过混淆矩阵中的各个元素计算得出:

准确率(Accuracy):分类器预测正确的样本占总样本数的比例,计算公式为 (TP + TN) / (TP + TN + FP + FN) 。

精确率(Precision):正例预测正确的比例,计算公式为 TP / (TP + FP) 。

召回率(Recall):正例被正确预测为正例的比例,计算公式为 TP / (TP + FN) 。

F1 值:综合考虑了精确率和召回率的指标,计算公式为 2 (Precision Recall) / (Precision + Recall) 。

混淆矩阵提供了更详细和全面地评估分类模型性能的能力,帮助我们了解预测中的误报和漏报情况。通过分析混淆矩阵,我们可以获得对分类器在每个类别上的表现有关的宝贵见解,并对分类结果进行优化。



废话不多数,上代码:

def draw_confusion_matrix(label_true, label_pred, label_name, normlize, , pdf_save_path=None, dpi=100):
    """
    @param label_true: 真实标签,比如[0,1,2,7,4,5,...]
    @param label_pred: 预测标签,比如[0,5,4,2,1,4,...]
    @param label_name: 标签名字,比如['cat','dog','flower',...]
    @param normlize: 是否设元素为百分比形式
    @param title: 图标题
    @param pdf_save_path: 是否保存,是则为保存路径pdf_save_path=xxx.png | xxx.pdf | ...等其他plt.savefig支持的保存格式
    @param dpi: 保存到文件的分辨率,论文一般要求至少300dpi
    @return:
    example:
            draw_confusion_matrix(label_true=y_gt,
                          label_pred=y_pred,
                          label_name=["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"],
                          normlize=True,
                          ,
                          pdf_save_path="Confusion_Matrix_on_Fer2013.png",
                          dpi=300)
    """
    cm1=confusion_matrix(label_true, label_pred)
    cm = confusion_matrix(label_true, label_pred)
    if normlize:
        row_sums = np.sum(cm, axis=1)
        cm = cm / row_sums[:, np.newaxis]
    cm=cm.T
    cm1=cm1.T
    plt.imshow(cm, cmap='Blues')
    plt.title(title)
    plt.xlabel("Predict label")
    plt.ylabel("Truth label")
    plt.yticks(range(label_name.__len__()), label_name)
    plt.xticks(range(label_name.__len__()), label_name, rotation=45)
    plt.tight_layout()
    plt.colorbar()
    for i in range(label_name.__len__()):
        for j in range(label_name.__len__()):
            color = (1, 1, 1) if i == j else (0, 0, 0)	# 对角线字体白色,其他黑色
            value = float(format('%.1f' % (cm[i, j]*100)))
            value1=str(value)+'%\n'+str(cm1[i, j])
            plt.text(i, j, value1, verticalalignment='center', horizontalalignment='center', color=color)
    # plt.show()
    if not pdf_save_path is None:
        plt.savefig(pdf_save_path, bbox_inches='tight',dpi=dpi)
labels_name = ['bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
y_gt=[]
y_pred=[]
model_weight_path = "./best_CBAM_model.pth"
models = Xception(num_classes = 4)
models.load_state_dict(torch.load(model_weight_path))

models.eval()
for index, (imgs, labels) in enumerate(test_dl):
    labels_pd = models(imgs)
    predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1).tolist()
    labels_np = labels.numpy().tolist()
    y_pred.extend(predict_np)
    y_gt.extend(labels_np)
print("预测标签为:", y_pred)
print("真实标签为", y_gt)
draw_confusion_matrix(label_true=y_gt,
                      label_pred=y_pred,
                      label_name=labels_name,
                      normlize=True,
                      ,
                      pdf_save_path="Confusion_Matrix.jpg",
                      dpi=300)

结果如下:

混淆矩阵的生成(python实现,含机器学习方法)

更新

这里大佬给我提供了一种更加简单的混淆矩阵生成方法,是基于机器学习sklearn里面的库confusion_matrix和seaborn 库。原文参考深度学习100例-卷积神经网络(CNN)识别眼睛状态 | 第17天

Seaborn是一个基于Matplotlib的Python数据可视化库,它提供了一种高层次的接口,用于绘制具有吸引力和信息丰富的统计图形。Seaborn的设计目标是使得数据可视化更加简单,同时也更具吸引力,以便更好地理解和传达数据的含义。它具有许多内置的图表类型和样式,可用于探索数据分布、比较多个变量之间的关系、绘制分类数据以及在统计模型中的可视化等。它还提供了许多自定义选项,使您能够根据自己的需求进行图形的修改和美化。

而sklearn.metrics是Scikit-learn库中的一个模块,用于评估机器学习模型的性能和预测结果。这个模块提供了各种用于计算模型性能指标(如准确度、精确度、召回率、F1值等)的函数,以及用于绘制混淆矩阵、ROC曲线、学习曲线等的工具函数。

混淆矩阵是衡量分类模型性能的一种方法,它以矩阵形式表示了模型预测结果与真实标签之间的差异。混淆矩阵的行表示真实标签,列表示预测标签,每个单元格中的值表示对应标签的样本数量。通过分析混淆矩阵,我们可以得出模型的准确性、错误类型和偏差等信息。

通过Seaborn库的heatmap函数,我们可以将混淆矩阵可视化为一个热力图,更直观地展示模型预测结果的分布情况。热力图的每个单元格的颜色深浅表示对应标签的样本数量或其他统计指标。

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd
# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):
    
    # 生成混淆矩阵
    conf_numpy = confusion_matrix(labels, predictions)
    # 将矩阵转化为 DataFrame
    conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
    
    plt.figure(figsize=(8,7))
    
    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
    
    plt.title('混淆矩阵',fontsize=15)
    plt.ylabel('真实值',fontsize=14)
    plt.xlabel('预测值',fontsize=14)
val_pre   = []
val_label = []
for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵
    for image, label in zip(images, labels):
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(image, 0) 
        # 使用模型预测图片中的人物
        prediction = model.predict(img_array)
        val_pre.append(class_names[np.argmax(prediction)])
        val_label.append(class_names[label])
plot_cm(val_label, val_pre)

混淆矩阵的生成(python实现,含机器学习方法)

VPS购买请点击我

免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们,邮箱:ciyunidc@ciyunshuju.com。本站只作为美观性配图使用,无任何非法侵犯第三方意图,一切解释权归图片著作权方,本站不承担任何责任。如有恶意碰瓷者,必当奉陪到底严惩不贷!

目录[+]