当前位置: 首页 > web >正文

信息最大化(Information Maximization)

信息最大化在目标域无标签的域自适应任务中,它迫使模型在没有真实标签的情况下,对未标记数据产生高置信度且类别均衡的预测。此外,这些预测也可以作为伪标签用于自训练。

例如,在目标域没有标签时,信息最大化损失可以应用于目标域数据,使模型适应目标域并产生有意义的预测,缓解源域和目标域之间的分布偏移。在自训练或生成模型中,信息最大化通过要求整体预测分布均衡,有效防止模型将所有样本都预测到少数几个类别上。

信息最大化损失函数

信息最大化的损失函数可以表达为[1]:
L I M = L e n t + L d i v = H ( Y ∣ X ) − H ( Y ) \begin{align} \mathcal{L}_{IM} &= \mathcal{L}_{ent}+\mathcal{L}_{div} \\ &= H(Y|X)-H(Y) \end{align} LIM=Lent+Ldiv=H(YX)H(Y)

式中, H ( Y ∣ X ) H(Y|X) H(YX)模型预测输出标签的信息熵,最小化条件熵让模型的预测P(Y|X)更加自信。 H ( Y ) H(Y) H(Y)是预测类别标签的边缘熵,由于损失前面有个负号,则需要最大化边缘熵,迫使模型预测的各个类别均匀分布,而不是偏向其中某个类别。

信息最大化本质是最大化预测标签 Y Y Y 与输入 X X X 的互信息 I ( X ; Y ) = H ( Y ) − H ( Y ∣ X ) I(X;Y) = H(Y) - H(Y|X) I(X;Y)=H(Y)H(YX),因此 L I M = − I ( X ; Y ) \mathcal{L}_{IM} = -I(X;Y) LIM=I(X;Y)。最小化该损失等价于提升输入与预测标签之间的互信息。

① 熵最小化损失 (Entropy Minimization)
L e n t = − E x ∑ c = 1 C δ c ( f ( x ) ) log ⁡ δ c ( f ( x ) ) \begin{align} \mathcal{L}_{ent} = -\mathbb{E}_x \sum_{c=1}^C \delta_c(f(x)) \log \delta_c(f(x)) \end{align} Lent=Exc=1Cδc(f(x))logδc(f(x))

式中, f ( x ) f(x) f(x)表示模型的预测输出, δ c \delta_c δc是softmax函数,代表样本 x x x是类别 c c c的概率值。

② 多样性最大化损失 (Diversity Regularization)
L d i v = − ( − ∑ c = 1 C p ^ c log ⁡ p ^ c ) = ∑ c = 1 C p ^ c log ⁡ p ^ c \begin{align} \mathcal{L}_{div} &=-(- \sum_{c=1}^C \hat{p}_c \log \hat{p}_c)\\ &= \sum_{c=1}^C \hat{p}_c \log \hat{p}_c \end{align} Ldiv=(c=1Cp^clogp^c)=c=1Cp^clogp^c

其中, p ^ c = 1 N ∑ i = 1 N p i c \hat{p}_c=\frac{1}{N}\sum_{i=1}^{N}p_{ic} p^c=N1i=1Npic表示类别 c c c在整个批次上的平均预测概率,也就是类别 c c c的边缘分布。注意:在数学上, P ( Y = c ) P(Y=c) P(Y=c)的边缘概率应该为 P ( Y = c ) = ∑ i = 1 N P ( X = x i , Y = c ) P(Y = c) = \sum_{i=1}^N P(X = x_i, Y = c) P(Y=c)=i=1NP(X=xi,Y=c),这里采用的均值而非总和,即采用批次均值 p ^ c \hat{p}_c p^c 作为无偏估计。

总结

  • 最小化条件熵 H ( Y ∣ X ) H(Y|X) H(YX):迫使模型对每个目标域样本做出确定性预测。
  • 最大化边缘熵 H ( Y ) H(Y) H(Y):确保模型在整个目标域上的预测类别分布均匀,防止坍缩到少数类。

代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass IMLoss(nn.Module):def __init__(self, lambda_div=0.1, eps=1e-8):"""信息最大化损失函数:无监督学习范式Args:lambda_div (float): 多样性损失的权重系数,默认0.1eps (float): 数值稳定项,防止log(0),默认1e-8"""super().__init__()self.lambda_div = lambda_divself.eps = epsdef forward(self, logits):"""计算信息最大化损失Args:logits (torch.Tensor): 模型输出的logits张量,形状为(batch_size, num_classes)Returns:torch.Tensor: 计算得到的总损失值"""# 计算softmax概率probs = F.softmax(logits, dim=1)# 1. L_ent: 熵最小化损失,使预测更确定entropy_per_sample = -torch.sum(probs * torch.log(probs + self.eps), dim=1)entropy_loss = torch.mean(entropy_per_sample)# 2. L_div: 多样性最大化损失, 使类别分布均匀mean_probs = torch.mean(probs, dim=0) # 边缘分布,由于样本是独立同分布的,这里考虑概率的平均值而非总和diversity_loss = -torch.sum(mean_probs * torch.log(mean_probs + self.eps))# L_IM总损失total_loss = entropy_loss - self.lambda_div * diversity_lossreturn total_lossnum_classes=3
bs=2logits = torch.randn(bs, num_classes) # 模型输出的逻辑值:model(x)
loss_fn = IMLoss()
loss = loss_fn(logits)

知识点

边缘分布/边际分布 (Marginal distribution)定义[2]:Given a known joint distribution of two discrete random variables, say, X X X and Y Y Y, the marginal distribution of either variable – X X X for example – is the probability distribution of X X X when the values of Y Y Y are not taken into consideration.
p X ( x i ) = ∑ j p ( x i , y j ) , p Y ( y j ) = ∑ i p ( x i , y j ) p_X(x_i)=\sum_jp(x_i,y_j),\\ p_Y(y_j)=\sum_i p(x_i,y_j) pX(xi)=jp(xi,yj),pY(yj)=ip(xi,yj)

案例:

如下表所示,一个批次有3个样本,类别为4,对应的随机变量Y的边缘分布为最后一行。

y 1 y_1 y1 y 2 y_2 y2 y 3 y_3 y3 y 4 y_4 y4 p X ( x ) p_X(x) pX(x)
x 1 x_1 x1 4 32 \frac{4}{32} 324 2 32 \frac{2}{32} 322 1 32 \frac{1}{32} 321 1 32 \frac{1}{32} 321 8 32 \frac{8}{32} 328
x 2 x_2 x2 3 32 \frac{3}{32} 323 6 32 \frac{6}{32} 326 3 32 \frac{3}{32} 323 3 32 \frac{3}{32} 323 15 32 \frac{15}{32} 3215
x 3 x_3 x3 9 32 \frac{9}{32} 329 0 0 0 0 0 0 0 0 0 9 32 \frac{9}{32} 329
p Y ( y ) p_Y(y) pY(y) 16 32 \frac{16}{32} 3216 8 32 \frac{8}{32} 328 4 32 \frac{4}{32} 324 4 32 \frac{4}{32} 324 32 32 \frac{32}{32} 3232

参考

[1] [2002.08546] Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation

[2] Marginal distribution - Wikipedia

http://www.xdnf.cn/news/12655.html

相关文章:

  • Go语言进阶④:Go的数据结构和Java的有啥不一样
  • 光学字符识别(OCR)理论概述与实践教程
  • 动目标显示处理解析一(脉冲对消器)
  • Ubuntu 配置使用 zsh + 插件配置 + oh-my-zsh 美化过程
  • 前沿论文汇总(机器学习/深度学习/大模型/搜广推/自然语言处理)
  • 数据类型 -- 字符
  • SQL字符串截取函数全解析:LEFT、RIGHT、SUBSTRING 实战指南
  • 如何使用Jmeter进行压力测试?
  • MySQL-运维篇
  • 隐私计算时代B端页面安全设计:数据脱敏与权限体系升级路径
  • 数据结构算法(C语言)
  • 新能源汽车热管理核心技术解析:冬季续航提升40%的行业方案
  • 开源之夏·西安电子科技大学站精彩回顾:OpenTiny开源技术下沉校园,点燃高校开发者技术热情
  • 华为云Astro中服务编排、自定义模型,页面表格之间有什么关系?如何连接起来?如何操作?
  • 【第七篇】 SpringBoot项目的热部署
  • Mac 安装git心路历程(心累版)
  • Mysql批处理写入数据库
  • 虚幻基础:角色旋转
  • IEC 61347-1:2015 灯控制装置安全通用要求详解
  • Docker基础(一)
  • 轻量级Docker管理工具Docker Switchboard
  • python如何统计图片的颜色分布
  • jenkins gerrit-trigger插件配置
  • JVM 垃圾回收器 详解
  • C++算法训练营 Day11 栈与队列(2)
  • mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包
  • 阿里云ACP云计算备考笔记 (4)——企业应用服务
  • 【MySQL】视图、用户管理、MySQL使用C\C++连接
  • 【数据结构初阶】单链表
  • Harmony核心:动态方法修补与.NET游戏Mod开发