当前位置: 首页 > ds >正文

利用模型生成每个样本每个特征的 SHAP 值

SHAP 值是怎么来的?


1. 加载训练好的模型和特征数据

model = joblib.load(model_path)
X = pd.read_csv(feature_csv)
  • model:比如 XGBoostLightGBM 等树模型,是之前训练好的。
  • X:每个样本的特征表(不能含 label,否则解释不纯粹)。

2. 构建 SHAP 解释器:TreeExplainer

explainer = shap.TreeExplainer(model, X, model_output='raw')
  • TreeExplainer:专门为树模型(如 XGBoost)设计的解释器。

  • model_output='raw'

    • 表示在logit 空间计算 SHAP 值(而非概率),这避免了模型后处理(sigmoid)带来的非线性变形;
    • logit 是逻辑回归中的线性部分,值域为 (负无穷,正无穷)也叫“对数几率”。

3. 获取 SHAP 值

shap_values = explainer.shap_values(X, check_additivity=False)
  • shap_values[i][j] 代表第 i 个样本、特征 j 的 SHAP 值;

  • 每个 SHAP 值表示“该特征对模型输出的贡献(在 logit 空间)”;

  • 二分类下,shap_values 是一个长度为 2 的 list:

    • [0]: 对负类的解释;
    • [1]: 对正类的解释 → 你只保留这个(模型预测为1的方向)。

4. 聚合每个特征的平均影响力

mean_abs_shap = np.abs(shap_values).mean(axis=0)
  • 对所有样本,计算每个特征 SHAP 值的平均绝对值
  • 反映“总体上该特征对模型预测的影响力”。

5. 排序 & 截取前10重要特征

top_idx = np.argsort(-mean_abs_shap)[:10]
  • 用于绘图或报告展示最关键变量。

6. 替换特征名称(简写)并导出CSV

shap_df_all = pd.DataFrame(shap_values, columns=all_features_short)
shap_df_all.to_csv("shap_values_all_logit.csv", index=False)
  • 每一行是一个样本;
  • 每一列是一个特征的 SHAP 值(对正类的贡献);
  • 这个 CSV 就是你在 R 中读取可视化的源文件。

📌 图示帮助理解(简化版)

模型(训练好) + 特征数据(X)↓
shap.TreeExplainer()↓
shap_values(每个样本、每个特征的贡献值)↓
按列聚合 → mean(|shap|) → 找出重要特征↓
导出为 CSV → R 语言绘图
http://www.xdnf.cn/news/15628.html

相关文章:

  • 【Git 中的 branch 工作流】关于git 中 branch 的一些基本操作
  • 【每日算法】专题十_字符串
  • 小架构step系列15:白盒集成测试
  • Translational Psychiatry | 通过流形学习和网络分析揭示精神分裂症与双相I型障碍的差异性精神病症状
  • 音视频学习(三十九):IDR帧和I帧
  • 《黑马笔记》 --- C++核心编程
  • PHP安全漏洞深度解析:文件包含与SSRF攻击的攻防实战
  • 在新闻资讯 APP 中添加不同新闻分类页面,通过 ViewPager2 实现滑动切换
  • 网络基础协议综合实验
  • GeoTools 工厂设计模式
  • 【Linux庖丁解牛】— 保存信号!
  • SAP学习笔记 - 开发45 - RAP开发 Managed App New Service Definition,Metadata Extension
  • C++中list各种基本接口的模拟实现
  • 25、企业能源管理(Energy):锚定双碳目标,从分类管控到智能优化的数字化转型之路
  • npu-smi info命令参数解释
  • C++-linux系统编程 8.进程(三)孤儿进程、僵尸进程与进程回收
  • 数据结构之单链表
  • Java :List,LinkedList,ArrayList
  • sqli-labs靶场通关笔记:第17关 POST请求的密码重置
  • 连接new服务器注意事项
  • kiro, 新款 AI 编辑器, 简单了解一下
  • Java基础(八):封装、继承、多态与关键字this、super详解
  • 笔试——Day8
  • Scrapy扩展深度解析:构建可定制化爬虫生态系统的核心技术
  • 直播数据统计:如何让数据为我们所用?
  • CommunityToolkit.Mvvm IOC 示例
  • C++回顾 Day8
  • 一文深入:AI 智能体系统架构设计
  • 简单工厂设计模式
  • QT 中各种坑