查看原文
其他

终于把机器学习中的混淆矩阵搞懂了!

程序员小寒 程序员学长
2024-09-13
大家好,我是小寒
今天给大家分享一个机器学习中一个重要的概念,混淆矩阵

混淆矩阵是用于评估分类模型性能的表格。它通过将实际(真实)标签与预测标签进行比较,提供分类问题的预测结果摘要。


混淆矩阵本身是正方形(nxn),其中 n 是模型中的类别数。
对于二元分类问题,混淆矩阵由四个主要部分组成:
  • True Positive (TP, 真阳性):实际为正类,预测也为正类的数量。

  • True Negative (TN, 真阴性):实际为负类,预测也为负类的数量。

  • False Positive (FP, 假阳性):实际为负类,预测却为正类的数量,通常称为"Type I 错误"或"误报"。

  • False Negative (FN, 假阴性):实际为正类,预测却为负类的数量,通常称为"Type II 错误"或"漏报"。

为什么要使用混淆矩阵?

混淆矩阵是评估分类模型性能的基本工具。

  1. 错误分析

    它有助于识别模型所犯的错误类型,无论模型更容易出现假阳性还是假阴性,这在应用范围内(例如在医学诊断中)可能至关重要。

  2. 模型改进

    通过分析混淆矩阵,你可以专注于改进模型的特定方面,例如减少误报或提高召回率。
  3. 类别不平衡处理

    在类别不平衡的情况下,一个类别出现的频率高于另一个类别,单凭准确率可能会产生误导。
    混淆矩阵可让你更好地了解模型在每个类别中的表现。
  4. 性能指标计算

分类中的评估指标

1.准确率

准确率是分类任务中最简单的评估指标之一,用来衡量模型预测正确的比例。

准确率的局限性

当处理不平衡的数据集时,一个类别的数量远远超过其他类别,准确率可能会产生误导。

例如,在 95% 的样本属于同一类的数据集中,预测所有实例为多数类的模型的准确率为 95%,但在识别少数类时则无效。

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score

# Example true labels (ytest) and predicted labels (ypred)
ytest = [0, 1, 0, 1, 0, 1, 0, 0, 1, 1]
ypred = [0, 1, 0, 0, 0, 1, 0, 1, 1, 1]

# Calculate confusion matrix
cm = confusion_matrix(ytest, ypred)

# Create a heatmap
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=['1', '0'],
            yticklabels=['1', '0'])

# Add labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title('Confusion Matrix')

# Calculate and display accuracy
accuracy = accuracy_score(ytest, ypred)
plt.text(2.3, 1.5, f'Accuracy: {accuracy:.2f}', fontsize=14, color='black', weight='bold')

plt.show()

2.精度

精度用来衡量模型预测为正类的样本中实际为正类的比例。
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_score

# Example true labels (ytest) and predicted labels (ypred)
ytest = ['spam', 'spam', 'ham', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'ham', 'ham', 'ham']
ypred = ['spam', 'spam', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'ham', 'ham', 'ham', 'ham']

# Calculate the confusion matrix
cm = confusion_matrix(ytest, ypred, labels=['spam', 'ham'])
print("Confusion Matrix:\n", cm)

# Calculate precision
precision = precision_score(ytest, ypred, pos_label='spam')
print("Precision:", precision)

# Create a heatmap for the confusion matrix
plt.figure(figsize=(8, 6))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', cbar=False,
                 xticklabels=['Predicted Spam', 'Predicted Ham'],
                 yticklabels=['Actual Spam', 'Actual Ham'])

# Set labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title(f'Confusion Matrix\nPrecision: {precision:.2f}')

# Show the plot
plt.show()

3.召回率

召回率用来衡量实际为正类的样本中模型预测为正类的比例。
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, recall_score

# Example true labels (ytest) and predicted labels (ypred)
ytest = ['positive', 'positive', 'negative', 'positive', 'negative']
ypred = ['positive', 'negative', 'negative', 'positive', 'positive']

# Calculate the confusion matrix
cm = confusion_matrix(ytest, ypred, labels=['positive', 'negative'])


# Calculate recall
recall = recall_score(ytest, ypred, pos_label='positive')


# Create a heatmap for the confusion matrix
plt.figure(figsize=(6, 4))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', cbar=False,
                 xticklabels=['Predicted Positive', 'Predicted Negative'],
                 yticklabels=['Actual Positive', 'Actual Negative'])

# Set labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title(f'Confusion Matrix\nRecall: {recall:.2f}')

# Show the plot
plt.show()

4.F1-score

F1-score 是精度和召回率的调和平均数,用来综合考虑精度和召回率的平衡。
最后



今天的分享就到这里。如果觉得近期的文章不错,请点赞,转发安排起来。‍‍欢迎大家进高质量 python 学习群

「进群方式:加我微信,备注 “python”」



往期回顾


Fashion-MNIST 服装图片分类-Pytorch实现

python 探索性数据分析(EDA)案例分享

深度学习案例分享 | 房价预测 - PyTorch 实现

万字长文 |  面试高频算法题之动态规划系列

面试高频算法题之回溯算法(全文六千字)  

    



如果对本文有疑问可以加作者微信直接交流。

继续滑动看下一个
程序员学长
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存