当前位置: 首页 > backend >正文

Day 14 训练

Day 14 训练

  • SHAP(SHapley Additive exPlanations)
  • 1.创建解释器
  • 2.将特征贡献可视化
      • 第一部分:绘制SHAP特征重要性条形图
      • 第二部分:绘制SHAP特征重要性蜂巢图


SHAP(SHapley Additive exPlanations)

旨在解释复杂机器学习模型(如随机森林、梯度提升树、神经网络等 “黑箱” 模型)对特定输入的预测原因。其核心基于合作博弈论中的 Shapley 值。

  • 将模型的特征比作玩家,预测样本输出值是游戏目标,不同特征子集合作进行预测,特征子集预测得到的值是奖励 / 价值。通过计算每个特征的 Shapley 值来确定其对预测的贡献,具体是考虑所有可能的特征组合,计算特征在每种组合下的边际贡献,再求加权平均。
  • SHAP 具有加性解释特性,模型预测值等于基准值(模型在训练或背景数据集上的平均预测输出)加上所有特征的 SHAP 值之和。
  • SHAP 需要为每个样本的每个特征计算贡献值即 SHAP 值,形成 shap_values 数组。对于回归问题,shap_values 是形状为(n_samples,n_features)的数组;对于分类问题,通常返回一个列表,列表长度等于类别数,每个元素是(n_samples,n_features)数组,表示各特征对预测各类别的贡献。总之,SHAP 通过计算特征边际贡献,将模型预测分解到每个特征上,生成 shap_values 数组来解释预测。

1.创建解释器

import shap
import matplotlib.pyplot as pltexplainer = shap.TreeExplainer(rf_model)
shap_values = explainer.shap_values(X_test)
print(shap_values)
print(shap_values.shape) # 第一维是是样本数,第二维度是特征数量,第三维度是类别数量。
print("shap_values shape:", shap_values.shape)
print("shap_values[0] shape:", shap_values[0].shape)
print("shap_values[:, :, 0] shape:", shap_values[:, :, 0].shape)
print("X_test shape:", X_test.shape)

创建解释器对象

  • explainer = shap.TreeExplainer(rf_model):创建一个 SHAP 解释器对象。shap.TreeExplainer 是 SHAP 库中用于解释基于树的模型(如随机森林、梯度提升树等)的解释器类。rf_model 是一个已经训练好的随机森林模型对象,将其传递给 TreeExplainer,解释器就会根据该模型的结构和参数来计算特征的 SHAP 值。通过这个解释器对象,我们可以进一步获取模型预测的解释信息。

计算 SHAP 值

  • shap_values = explainer.shap_values(X_test):计算测试数据集 X_test 的 SHAP 值。X_test 是一个二维数组,包含了测试样本的特征值。explainer.shap_values 方法会根据之前创建的解释器对象和测试数据集,计算出每个特征对于每个测试样本的 SHAP 值,并将结果存储在 shap_values 变量中。SHAP 值是一个与特征数量相同的数组,其中每个元素表示一个特征对模型预测的贡献度。正的 SHAP 值表示该特征对预测结果有正向影响,负的 SHAP 值表示有负向影响。通过这些 SHAP 值,我们可以分析出哪些特征对模型的预测结果起到了关键作用,以及它们是如何影响预测结果的。

2.将特征贡献可视化

print("1.shap 特征重要性条形图")
shap.summary_plot(shap_values[:,:,0], X_test, plot_type="bar",show=False)
plt.title("SHAP Feature Importance (Class 0)")
plt.show()print("--- 2. SHAP 特征重要性蜂巢图 ---")
shap.summary_plot(shap_values[:, :, 0], X_test,plot_type="violin",show=False,max_display=10) # 这里的show=False表示不直接显示图形,这样可以继续用plt来修改元素,不然就直接输出了
plt.title("SHAP Feature Importance (Violin Plot)")
plt.show()

这段代码是用于绘制SHAP(SHapley Additive exPlanations)特征重要性图的Python代码,主要使用了shap库和matplotlib库。下面对这段代码进行逐行解释:

第一部分:绘制SHAP特征重要性条形图

print("1.shap 特征重要性条形图")
  • 这行代码会在控制台打印一条消息,提示接下来将绘制SHAP特征重要性条形图。
shap.summary_plot(shap_values[:,:,0], X_test, plot_type="bar", show=False)
  • shap.summary_plotshap 库中的一个函数,用于绘制特征重要性图。
  • shap_values[:,:,0]shap_values 是一个三维数组,这里取其第一个维度的所有值,通常对应于模型的某个类别(这里是类别0)。
  • X_test:这是测试数据集,用于提供特征名称和数据范围等信息。
  • plot_type="bar":指定绘制条形图。
  • show=False:表示不直接显示图形,这样可以在绘制完图形后继续使用 matplotlib 修改图形元素。

第二部分:绘制SHAP特征重要性蜂巢图

print("--- 2. SHAP 特征重要性蜂巢图 ---")
  • 这行代码会在控制台打印一条消息,提示接下来将绘制SHAP特征重要性蜂巢图。
shap.summary_plot(shap_values[:, :, 0], X_test, plot_type="violin", show=False, max_display=10)
  • 同样使用 shap.summary_plot 函数绘制特征重要性图。
  • plot_type="violin":指定绘制蜂巢图(小提琴图)。
  • show=False:不直接显示图形,以便后续使用 matplotlib 修改图形元素。
  • max_display=10:限制最多显示10个特征。

@浙大疏锦行

http://www.xdnf.cn/news/4829.html

相关文章:

  • 雷赛伺服电机
  • 山东安全员A证的考试科目有哪些?
  • MySQL中隔离级别那点事
  • 主备Smart Link + Monitor Link组网技术详细配置
  • 【LeetCode】删除排序数组中的重复项 II
  • 2018机械行业ERP软件发展趋势
  • 从 ImageNet 到产业革命:AlexNet 作为破局者的三大核心创新及其时代穿透力
  • SKNet、空间注意力介绍
  • 1.MySQL数据库初体验
  • Matlab 基于Hough变换的人眼虹膜定位方法
  • Prometheus实战教程:k8s平台-node-exporter监控物理机
  • OPCUA,OPCDA与MODBUS学习笔记
  • RabbitMQ学习(第二天)
  • ConcurrentHashMap解析
  • 3中AI领域的主流方向:预测模型、强化学习和世界模型
  • Pytorch的简单介绍(起源、历史、优缺点、应用领域等等)
  • stable-diffusion windows本地部署
  • uniapp上架苹果APP Store踩雷和部分流程注意事项(非完整流程)
  • word文档基本操作: 编辑页眉页脚和插入目录
  • 移动端前端开发中常用的css
  • SQLite3常用语句汇总
  • Kubernetes探针生产环境实战指南
  • 全连接神经网络学习笔记
  • 【Fifty Project - D25】
  • 在模 p 运算中,将负数 x 转换为对应的正数,执行 (x % p + p) % p 操作即可。
  • 单片机-STM32部分:9、定时器
  • 计算机网络笔记(十五)——3.2点对点协议PPP
  • 安装Pod网络插件时pod状态变为ImagePullBackOff
  • Spring Boot Controller 如何处理HTTP请求体
  • 微信小程序上传视频,解决ios上传完video组件无法播放