当前位置: 首页 > news >正文

Focal Loss 原理详解及 PyTorch 代码实现

Focal Loss 原理详解及 PyTorch 代码实现

  • 介绍
    • 一、Focal Loss 背景
    • 二、代码逐行解析
      • 1. 类定义与初始化
    • 三、核心参数作用
    • 四、使用示例
    • 五、应用场景
    • 六、总结

介绍

一、Focal Loss 背景

Focal Loss 是为解决类别不平衡问题设计的损失函数,通过引入 gamma 参数降低易分类样本的权重,使用 alpha 参数调节正负样本比例。在目标检测等类别不平衡场景中表现优异。

二、代码逐行解析

1. 类定义与初始化

class FocalLoss(nn.Module):"""应用 Focal Loss 通过 gamma 和 alpha 参数改进 BCEWithLogitsLoss 以处理类别不平衡"""def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):super().__init__()self.loss_fcn = loss_fcn  # 必须使用 nn.BCEWithLogitsLoss()self.gamma = gamma        # 调节难易样本权重的指数参数self.alpha = alpha        # 平衡正负样本比例的权重系数# 修改原损失函数的 reduction 为 'none' 进行逐元素计算self.reduction = loss_fcn.reductionself.loss_fcn.reduction = "none"def forward(self, pred, true):# 计算基础交叉熵损失loss = self.loss_fcn(pred, true)# 通过 sigmoid 获取概率预测值(范围0-1)pred_prob = torch.sigmoid(pred)# 计算 p_t(真实类别对应的预测概率)(正确分类的概率)p_t = true * pred_prob + (1 - true) * (1 - pred_prob)# 计算 alpha 因子:正样本乘 alpha,负样本乘 (1-alpha) (类别权重)alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)# 计算调制因子:难分类样本权重更大 (困难样本权重)modulating_factor = (1.0 - p_t) ** self.gamma# 组合得到最终的 Focal Lossloss *= alpha_factor * modulating_factor# 根据 reduction 设置返回结果if self.reduction == "mean":return loss.mean()elif self.reduction == "sum":return loss.sum()else:  # 'none'return loss

三、核心参数作用

  1. Gamma (γ)

    • γ > 0 时,降低易分类样本(p_t 接近 1)的损失权重
    • 典型取值范围:0.5-5.0
    • 示例:当 p_t=0.9,γ=2 → 调制因子 = 0.01
  2. Alpha (α)

    • 调节正负样本权重比例
    • α 接近 1 时强调正样本
    • α 接近 0 时强调负样本

四、使用示例

# 初始化
criterion = FocalLoss(loss_fcn=nn.BCEWithLogitsLoss(),gamma=2.0,alpha=0.75
)# 计算损失
pred = model(inputs)
loss = criterion(pred, targets)

五、应用场景

  • 目标检测(如 RetinaNet)
  • 医学图像分析
  • 任何存在严重类别不平衡的分类任务

六、总结

Focal Loss 通过两个关键参数实现了:

  1. 降低大量易分类样本的损失贡献
  2. 平衡正负样本的权重比例
  3. 改善模型对困难样本的学习能力
http://www.xdnf.cn/news/432685.html

相关文章:

  • 运行Spark程序-在shell中运行
  • 思路解析:第一性原理解 SQL
  • 2025.5.13山东大学软件学院计算机图形学期末考试回忆版本
  • msyql8.0.xx忘记密码解决方法
  • 2025.05.11阿里云机考真题算法岗-第二题
  • 重置集群(有异常时)
  • Spring 集成 SM4(国密对称加密)
  • Springboot | 如何上传文件
  • ros2-node
  • SpringBoot--springboot简述及快速入门
  • 2025年全国青少年信息素养大赛初赛模拟测试网站崩了的原因及应对比赛流程
  • SparkSQL操作Mysql
  • 1995-2022年各省能源消费总量数据(万吨标煤)
  • UDS诊断----------$11诊断服务
  • 【YOLO模型】参数全面解读
  • JavaWeb 前端开发
  • 优化的代价(AI编码带来的反思)-来自Grok
  • 基于TouchSocket实现WebSocket自定义OpCode扩展协议
  • day19-线性表(顺序表)(链表I)
  • 操作系统:内存管理
  • JavaScript编译原理
  • 数据结构(七)——图
  • ThingsBoard3.9.1 MQTT Topic(4)
  • UDP协议详细讲解及C++代码实例
  • 数据压缩的概念和优缺点
  • 【电子科技大学主办 | 往届快至会后2个月EI检索】第六届电子通讯与人工智能国际学术会议(ICECAI 2025)
  • Gatsby知识框架
  • angular的rxjs中的操作符
  • Vitrualbox完美显示系统界面(只需三步)
  • vue2将文字转为拼音