当前位置: 首页 > news >正文

遥感机器学习入门实战教程|Sklearn案例⑧:评估指标(metrics)全解析

很多同学问:“模型好不好,怎么量化?”
本篇系统梳理 sklearn.metrics 中常用且“够用”的多分类指标,并给出一段可直接运行的示例代码,覆盖:准确率、宏/微/加权 F1、Kappa、MCC、混淆矩阵(计数/归一化)、Top-K 准确率、ROC-AUC(OvR/OvO)、PR-AUC、对数损失、(多类)Brier 分数、以及 ROC/PR 曲线绘制。

🧭 指标速览与使用场景

  • 整体验证

    • accuracy_score(OA,总体准确率)
    • balanced_accuracy_score(类别不均衡时更合理)
  • 逐类与加权

    • precision_recall_fscore_support / classification_report
    • 平均方式:average="macro" | "micro" | "weighted"
  • 一致性/稳健性

    • cohen_kappa_score(Kappa)
    • matthews_corrcoef(MCC,抗不均衡)
  • 混淆矩阵

    • confusion_matrix(计数 & 归一化)
  • 概率质量/排序质量

    • roc_auc_score(多类:multi_class="ovr"|"ovo"average="macro"|"weighted"
    • average_precision_score(PR-AUC)
    • top_k_accuracy_score(Top-K)
    • log_loss(对数损失,校准敏感)
    • 多类 Brier(自定义:one-hot 与 predict_proba 的 MSE 均值)
  • 曲线

    • ROC 曲线(micro/macro)
    • Precision-Recall 曲线(micro)

经验:类不均衡→看 balanced_accuracy / macro-F1 / Kappa / MCC
要概率好坏→看 log_loss / ROC-AUC / PR-AUC
Top-K 检索/多候选→看 top_k_accuracy_score

💻 一键可跑代码(修改 DATA_DIR 后直接运行)

# -*- coding: utf-8 -*-
"""
Sklearn案例⑧:metrics 全解析(多分类 / 概率与曲线)
数据:KSC(将 DATA_DIR 改为你的数据路径)
"""import os, numpy as np, scipy.io as sio, matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (accuracy_score, balanced_accuracy_score,precision_recall_fscore_support, classification_report, confusion_matrix,cohen_kappa_score, matthews_corrcoef, top_k_accuracy_score, roc_auc_score,average_precision_score, log_loss)
from sklearn.preprocessing import label_binarize# ============ 参数 ============
DATA_DIR = "your_path"     # ←← 修改为包含 KSC.mat / KSC_gt.mat 的目录
PCA_DIM, TRAIN_RATIO, SEED = 30, 0.3, 42# ============ 1. 载入与预处理 ============
X = sio.loadmat(os.path.join(DATA_DIR, "KSC.mat"))["KSC"].astype(np.float32)  # (H,W,B)
Y = sio.loadmat(os.path.join(DATA_DIR, "KSC_gt.mat"))["KSC_gt"].astype(int)   # (H,W)
coords = np.argwhere(Y != 0)
Xpix   = X[coords[:,0], coords[:,1]]       # (N,B)
y      = Y[coords[:,0], coords[:,1]] - 1   # 0..C-1
num_classes = int(y.max() + 1)Xtr, Xte, ytr, yte = train_test_split(Xpix, y, train_size=TRAIN_RATIO,stratify=y, random_state=SEED)
scaler = StandardScaler().fit(Xtr)
pca    = PCA(n_components=PCA_DIM, random_state=SEED).fit(scaler.transform(Xtr))
Xtr    = pca.transform(scaler.transform(Xtr))
Xte    = pca.transform(scaler.transform(Xte))# ============ 2. 训练一个可输出概率的模型 ============
# 用 RF 示范(也可以换 SVC(probability=True)、LogReg 等)
clf = RandomForestClassifier(n_estimators=300, random_state=SEED, n_jobs=-1)
clf.fit(Xtr, ytr)
y_pred = clf.predict(Xte)
y_proba = clf.predict_proba(Xte)           # (N_test, C)# ============ 3. 基础/稳健指标 ============
oa  = accuracy_score(yte, y_pred)
boa = balanced_accuracy_score(yte, y_pred)
kappa = cohen_kappa_score(yte, y_pred)
mcc   = matthews_corrcoef(yte, y_pred)prec_m, rec_m, f1_m, _   = precision_recall_fscore_support(yte, y_pred, average="macro", zero_division=0)
prec_w, rec_w, f1_w, _   = precision_recall_fscore_support(yte, y_pred, average="weighted", zero_division=0)print("=== 基础评估 ===")
print(f"OA                : {oa*100:.2f}%")
print(f"Balanced Acc      : {boa*100:.2f}%")
print(f"Macro-F1          : {f1_m*100:.2f}% (P={prec_m*100:.1f} R={rec_m*100:.1f})")
print(f"Weighted-F1       : {f1_w*100:.2f}% (P={prec_w*100:.1f} R={rec_w*100:.1f})")
print(f"Cohen's Kappa     : {kappa:.4f}")
print(f"Matthews Corrcoef : {mcc:.4f}")
print("\n=== 分类报告(逐类) ===")
print(classification_report(yte, y_pred, digits=4, zero_division=0))# ============ 4. 混淆矩阵(计数/归一化) ============
cm = confusion_matrix(yte, y_pred, labels=np.arange(num_classes))
cm_norm = cm / np.maximum(cm.sum(axis=1, keepdims=True), 1)plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.imshow(cm, interpolation='nearest')
plt.title("Confusion Matrix (Counts)")
plt.xlabel("Pred"); plt.ylabel("True")
plt.colorbar(fraction=0.046, pad=0.04)plt.subplot(1,2,2)
plt.imshow(cm_norm, vmin=0, vmax=1, interpolation='nearest')
plt.title("Confusion Matrix (Normalized)")
plt.xlabel("Pred"); plt.ylabel("True")
plt.colorbar(fraction=0.046, pad=0.04)
plt.tight_layout(); plt.show()# ============ 5. 概率/排序质量 ============
# 5.1 多类 ROC-AUC:OvR & OvO(macro/weighted)
y_bin = label_binarize(yte, classes=np.arange(num_classes))  # (N,C)
auc_ovr_macro = roc_auc_score(yte, y_proba, multi_class="ovr", average="macro")
auc_ovr_weight= roc_auc_score(yte, y_proba, multi_class="ovr", average="weighted")
auc_ovo_macro = roc_auc_score(yte, y_proba, multi_class="ovo", average="macro")
print("\n=== 概率/排序质量 ===")
print(f"ROC-AUC OvR (macro)   : {auc_ovr_macro:.4f}")
print(f"ROC-AUC OvR (weighted): {auc_ovr_weight:.4f}")
print(f"ROC-AUC OvO (macro)   : {auc_ovo_macro:.4f}")# 5.2 PR-AUC(macro)
ap_macro = average_precision_score(y_bin, y_proba, average="macro")
print(f"PR-AUC (macro)        : {ap_macro:.4f}")# 5.3 对数损失(log-loss)
ll = log_loss(yte, y_proba, labels=np.arange(num_classes))
print(f"Log Loss              : {ll:.4f}")# 5.4 多类 Brier(自定义:one-hot 与 predict_proba 的 MSE 均值)
brier_multi = np.mean((y_bin - y_proba)**2)
print(f"Brier Score (multi)   : {brier_multi:.4f}")# 5.5 Top-K 准确率(以 K=3 为例)
top3 = top_k_accuracy_score(yte, y_proba, k=3, labels=np.arange(num_classes))
print(f"Top-3 Accuracy        : {top3*100:.2f}%")# ============ 6. 曲线:micro-ROC 与 micro-PR ============
# micro:将多类视为一个“整体二分类”汇总,便于一张图比较
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay
# ROC (micro)
fpr = dict(); tpr = dict()
from sklearn.metrics import roc_curve, precision_recall_curve, auc
y_bin_pred = y_proba
fpr_micro, tpr_micro, _ = roc_curve(y_bin.ravel(), y_bin_pred.ravel())
roc_auc_micro = auc(fpr_micro, tpr_micro)# PR (micro)
prec_micro, rec_micro, _ = precision_recall_curve(y_bin.ravel(), y_bin_pred.ravel())
pr_auc_micro = auc(rec_micro, prec_micro)plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(fpr_micro, tpr_micro, lw=2, label=f"micro-ROC AUC={roc_auc_micro:.3f}")
plt.plot([0,1],[0,1],'--', lw=1)
plt.xlabel("FPR"); plt.ylabel("TPR")
plt.title("ROC (micro-average)")
plt.legend(frameon=False)plt.subplot(1,2,2)
plt.plot(rec_micro, prec_micro, lw=2, label=f"micro-PR AUC={pr_auc_micro:.3f}")
plt.xlabel("Recall"); plt.ylabel("Precision")
plt.title("Precision-Recall (micro-average)")
plt.legend(frameon=False)
plt.tight_layout(); plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

✅ 实战要点(如何选指标)

  • 报告一页通读OA + macro-F1 + Kappa + MCC + 混淆矩阵(归一化)
    这几项能同时反映整体、逐类与稳健性,对不均衡也更有意义。
  • 需要概率质量:加上 log_loss + ROC-AUC(ovr, macro) + PR-AUC(macro)
    若要“多候选命中”,再加 Top-K
  • 展示与沟通:曲线(ROC/PR)更直观,归一化混淆矩阵能指出“易混类”。
  • 避免踩坑:类别极不均衡时,单看 accuracy 容易误判;阈值可调的任务(告警/检索),更应看 PR-AUCPrecision-Recall 曲线

欢迎大家关注下方我的公众获取更多内容!

http://www.xdnf.cn/news/1352323.html

相关文章:

  • tcpdump命令打印抓包信息
  • 【golang】ORM框架操作数据库
  • 2-5.Python 编码基础 - 键盘输入
  • STM32CubeIDE V1.9.0下载资源链接
  • 醋酸镨:催化剂领域的璀璨新星
  • LangChain4J-基础(整合Spring、RAG、MCP、向量数据库、提示词、流式输出)
  • 信贷模型域——信贷获客模型(获客模型)
  • 温度对直线导轨的性能有哪些影响?
  • 小白向:Obsidian(Markdown语法学习)快速入门完全指南:从零开始构建你的第二大脑(免费好用的笔记软件的知识管理系统)、黑曜石笔记
  • 数字经济、全球化与5G催生域名新价值的逻辑与实践路径
  • 快速掌握Java非线性数据结构:树(二叉树、平衡二叉树、多路平衡树)、堆、图【算法必备】
  • vue3 - 组件间的传值
  • 【小沐学GIS】基于Godot绘制三维数字地球Earth(Godot)
  • 计算机网络 TLS握手中三个随机数详解
  • 【Golang】有关垃圾收集器的笔记
  • 语义通信高斯信道仿真代码
  • GaussDB 数据库架构师修炼(十八) SQL引擎-计划管理-SQL PATCH
  • Base64编码、AES加密、RSA加密、MD5加密
  • RAG Embeddings 向量数据库
  • 使用Ollama部署自己的本地模型
  • 疯狂星期四文案网第48天运营日记
  • 12 SQL进阶-锁(8.20)
  • Python语法速成课程(二)
  • 科技赋能,宁夏农业绘就塞上新“丰”景
  • 进程的概念:进程调度算法
  • 【GPT入门】第57课 详解 LLamaFactory 与 XTuner 实现大模型多卡分布式训练的方案与实践
  • rust语言 (1.88) egui (0.32.1) 学习笔记(逐行注释)(七) 鼠标在控件上悬浮时的提示
  • linux中文本文件操作之grep命令
  • 【软件设计模式】策略模式
  • MySQL:事务管理