Focal Loss
Focal loss是一种分类任务损失函数,是经过优化的交叉熵损失函数,用于解决难/易分类样本数量不均衡的问题(一般多用于二分类)。
背景
one stage目标检测算法存在以下两个问题:
- 正负样本不平衡:预测的bounding box数量非常多, 且负样本数量 >> 正样本数量 (背景数量 >> 前景数量)
- 正样本: 预测框中包含目标体的为正样本 (预测框与真实框的IOUIOUIOU大于 threshold)
- 负样本: 预测框中不包含目标体的为负样本 (预测框与真实框的IOUIOUIOU小于threshold)
- 难易样本不平衡:易分的样本数量多,损失主导了训练方向,而难分的样本数量少,对 loss 的贡献几乎被忽略。
样本属于某个类,且预测结果中该类的概率越大,其越容易分类:
- 假设正/负样本 预测为属于正样本/负样本的概率为 0.9 ——>易分类样本
- 假设正/负样本 预测为属于正样本/负样本的概率为 0.3 ——>难分类样本
控制正负样本的平衡
思路:
数量多的负样本,给予较小的损失权重系数;数量少的正样本,给予较大的损失权重系数。
以二分类交叉熵损失函数CE为例,假设 ppp 表示样本被预测为正样本的概率,那么损失函数为 :
CE(p,y)={−log(p)if y=1−log(1−p)otherwise(1)CE(p, y)= \begin{cases} -log(p)& \text{if \; y=1} \tag{1} \\ -log(1-p)& \text{otherwise} \end{cases}CE(p,y)={−log(p)−log(1−p)if y=1otherwise(1)
简化公式,令
pt={pif y=11−potherwisep_t= \begin{cases} p& \text{if \; y=1}\\ 1-p& \text{otherwise} \end{cases}pt={p1−pif y=1otherwise
那么
CE(p,y)=CE(pt)=−log(pt)CE(p, y) = CE(p_t)=-log(p_t)CE(p,y)=CE(pt)=−log(pt)
为平衡正负样本的影响,在常规的损失函数前增加一个系数 αt\alpha_tαt
CE(p,y)=−αtlog(pt)(2)CE(p, y) = -\alpha_t log(p_t) \tag{2}CE(p,y)=−αtlog(pt)(2)
其中,αt\alpha_tαt 为超参数,与 ptp_tpt 类似:
αt={αif y=11−αotherwise\alpha_t= \begin{cases} \alpha& \text{if \; y=1}\\ 1-\alpha& \text{otherwise} \end{cases}αt={α1−αif y=1otherwise
α\alphaα的范围是(0,1)(0,1)(0,1),通过设置α\alphaα实现控制正负样本对loss的贡献。
如果损失函数展开就是:
CE(p,y,α)={−α⋅log(p)if y=1−(1−α)⋅log(1−p)otherwiseCE(p, y, \alpha)= \begin{cases} -\alpha \cdot log(p)& \text{if \; y=1} \\ -(1-\alpha) \cdot log(1-p)& \text{otherwise} \end{cases}CE(p,y,α)={−α⋅log(p)−(1−α)⋅log(1−p)if y=1otherwise
延申:torch.nn.BCEWithLogitsLoss() 中,有对应设置参数α\alphaα 的变量pos_weight。详细可见https://blog.csdn.net/qq_39088868/article/details/151113961?spm=1011.2124.3001.6209
控制难易样本的平衡
思路:
数量多的易分类样本,给予较小的损失权重系数;数量少的难分类样本,给予较大的权重系数。
在二分类问题中,正样本的标签为1,负样本的标签为0,ppp代表样本为正样本的概率:
对于正样本而言,1−p1-p1−p 的值越大,样本越难分类。
对于负样本而言,ppp 的值越大,样本越难分类
ptp_tpt的定义如下:
pt={pif y=11−potherwisep_t= \begin{cases} p& \text{if \; y=1}\\ 1-p& \text{otherwise} \end{cases}pt={p1−pif y=1otherwise
那么1−pt1-p_t1−pt就可以表示出每个样本属于易分类或难分类样本。
FL(pt)=−(1−pt)γ⋅log(pt)FL(p_t)= -(1- p_t)^{\gamma} \cdot log(p_t)FL(pt)=−(1−pt)γ⋅log(pt)
其中,(1−pt)γ(1- p_t)^{\gamma}(1−pt)γ 为调制系数。
- 对于 ptp_tpt :
- 当 pt→0p_t \rightarrow 0pt→0 时,说明正样本预测结果为正的概率很小 (负样本预测结果为负的概率很小),这是一个难分类样本。这时 (1−pt)→1(1-p_t) \rightarrow 1(1−pt)→1,损失对于总的loss贡献大。
- 当 pt→1p_t \rightarrow 1pt→1 时,说明正样本预测结果为正的概率很大 (负样本预测结果为负的概率很大),这是一个易分类样本,需要降权,这时 (1−pt)(1-p_t)(1−pt) 是一个较小值,达到降权效果。
- 对于 γ\gammaγ:
当 γ=0\gamma=0γ=0 时,就是传统的交叉熵损失,通过调整 γ\gammaγ 可以实现权重的改变。
两种权重控制方法合并
将两种权重控制一起考虑(既考虑正负样本数量,又考虑难分/易分样本),即为类别加权 Focal Loss
FL(pt)=−αt⋅(1−pt)γ⋅log(pt)(4)FL(p_t) = - \alpha_t \cdot (1- p_t)^{\gamma} \cdot log(p_t) \tag{4}FL(pt)=−αt⋅(1−pt)γ⋅log(pt)(4)
其中,αt\alpha_tαt类别权重, (1−pt)γ(1- p_t)^{\gamma}(1−pt)γ 难度权重
在focal loss论文中,对于二分类问题,一般 α=0.25\alpha = 0.25α=0.25, γ=2\gamma = 2γ=2 时,效果较好。
代码示例
手动编写的focal loss 代码如下:
二分类 Focal Loss:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Focal_Loss(nn.Module):def __init__(self, alpha=0.25, gamma=2, logits=True):super(Focal_Loss, self).__init__()self.alpha = alphaself.gamma = gammaself.logits = logitsdef forward(self, preds, labels):""""""pred_logit: size=(sample_num), 值表示预测为正样本的概率labels:# size=(sample_num), 值为1表示为正样本,值为0表示为负样本"""if not self.logits:preds = F.sigmoid(preds)eps = 1e-8pos_floss = -1 * self.alpha * torch.pow((1 - preds), self.gamma) * torch.log(preds + eps) * labelsneg_floss = -1 * (1 - self.alpha) * torch.pow(preds, self.gamma) * torch.log(1 - preds + eps) * (1 - labels)floss = pos_floss + neg_flossreturn torch.mean(floss)if __name__ == '__main__':preds = torch.randn(10) # 预测为正样本的概率labels = torch.tensor([0, 0, 1, 1, 1, 0, 1, 1, 0, 1]) # 1表示为正样本,0表示为负样本focal_loss = Focal_Loss(logits=False)loss = focal_loss(preds, labels)print(loss)
多分类 Focal Loss:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Focal_Loss(nn.Module):def __init__(self, alpha, gamma=2, logits=True):super(Focal_Loss, self).__init__()self.alpha = alpha # 损失权重,长度等于类别数的向量self.gamma = gammaself.logits = logitsself.class_num = len(alpha)def forward(self, preds, labels):"""preds: size=(n, m), n个样本, 以及对应的预测出的m个类别的概率labels: size=(n), n个样本的真实类别"""if not self.logits:preds = F.softmax(preds, dim=-1)#类别序号转换one hot形式labels_onehot = F.one_hot(labels, num_classes=self.class_num)eps = 1e-8v_loss = -1 * self.alpha * torch.pow((1 - preds), self.gamma) * torch.log(preds + eps) * labels_onehots_loss = torch.sum(v_loss, dim=1)return torch.mean(s_loss)if __name__ == '__main__':preds = torch.randn(5, 10) # 5个样本对应的预测出10个类别的概率labels = torch.randint(0, 10, (5,)) # 5个样本的真实类别,共10个类别,序号从 0~9weights = torch.randn(10) # 10个类别的损失权重focal_loss = Focal_Loss(alpha=weights, logits=False)loss = focal_loss(preds, labels)print(loss)