当前位置: 首页 > java >正文

PyTorch实现CrossEntropyLoss示例

PyTorch实现CrossEntropyLoss示例

  • PyTorch实现CrossEntropyLoss示例
    • 摘要
    • 标签平滑原理
      • 传统交叉熵的问题
      • 标签平滑的数学表达
    • PyTorch代码实现解析
      • 类定义与初始化
      • 前向传播过程
    • 代码示例
    • 关键点总结
    • 扩展思考

PyTorch实现CrossEntropyLoss示例

摘要

在深度学习的分类任务中,交叉熵损失函数被广泛应用。然而,传统的交叉熵损失容易导致模型对预测结果过于自信,从而引发过拟合问题。本文介绍一种改进方法——标签平滑(Label Smoothing),并通过PyTorch实现该技术。代码源自计算机视觉领域的经典论文《Rethinking the Inception Architecture for Computer Vision》,可有效提升模型的泛化能力。


标签平滑原理

传统交叉熵的问题

传统交叉熵损失将真实标签的预测概率设为1,其他类别为0。这种“非黑即白”的方式容易导致模型过度拟合训练数据,对错误标签过于敏感。

标签平滑的数学表达

标签平滑通过引入平滑因子 ϵ \epsilon ϵ,将真实标签的概率调整为:
( 1 − ϵ ) × y + ϵ K (1 - \epsilon) \times y + \frac{\epsilon}{K} (1ϵ)×y+Kϵ
其中 K K K为类别总数。例如,当 ϵ = 0.1 \epsilon=0.1 ϵ=0.1 K = 10 K=10 K=10时,真实标签的概率变为 0.9 + 0.1 / 10 = 0.91 0.9 + 0.1/10=0.91 0.9+0.1/10=0.91,其他类别均为 0.01 0.01 0.01。这种方式使模型输出更“柔和”,防止过拟合。


PyTorch代码实现解析

类定义与初始化

class CrossEntropyLoss(nn.Module):def __init__(self, num_classes, eps=0.1, use_gpu=True, label_smooth=True):super(CrossEntropyLoss, self).__init__()self.num_classes = num_classesself.eps = eps if label_smooth else 0  # 平滑因子self.use_gpu = use_gpuself.logsoftmax = nn.LogSoftmax(dim=1)  # LogSoftmax层def forward(self, inputs, targets):log_probs = self.logsoftmax(inputs)  # 计算Log Softmaxzeros = torch.zeros(log_probs.size())targets = zeros.scatter_(1, targets.unsqueeze(1).data.cpu(), 1)  # 生成One-hot编码if self.use_gpu:targets = targets.cuda()targets = (1 - self.eps) * targets + self.eps / self.num_classes  # 应用标签平滑return (-targets * log_probs).mean(0).sum()  # 计算损失
  • num_classes: 类别总数。
  • eps: 平滑因子,默认0.1。
  • use_gpu: 是否使用GPU加速。
  • label_smooth: 是否启用标签平滑。

前向传播过程

  1. Log Softmax计算
    使用nn.LogSoftmax对模型输出进行处理,获得对数概率。

  2. 生成One-hot标签
    通过scatter_方法将类别索引转换为One-hot向量。例如,假设targets=[2,5],则生成的One-hot矩阵为:

    [[0,0,1,0,0,0,0,0,0,0],[0,0,0,0,0,1,0,0,0,0]]
    
  3. 设备转移
    根据use_gpu参数将标签数据转移到GPU。

  4. 标签平滑公式应用
    调整真实标签的概率分布,公式为:

    t a r g e t s = ( 1 − ϵ ) × one_hot + ϵ / K targets = (1-\epsilon) \times \text{one\_hot} + \epsilon / K targets=(1ϵ)×one_hot+ϵ/K

  5. 损失计算
    计算每个类别的平均损失后求和,等价于对全体样本的损失求平均。


代码示例

# 示例:计算两个样本的损失
num_classes = 10
batch_size = 2
inputs = torch.randn(batch_size, num_classes)  # 随机生成预测结果
targets = torch.LongTensor([2, 5])             # 真实标签criterion = CrossEntropyLoss(num_classes=10, eps=0.1, use_gpu=False)
loss = criterion(inputs, targets)
print(f"Loss: {loss.item():.4f}")

关键点总结

  1. 标签平滑的作用
    通过引入平滑因子,防止模型对训练数据过度自信,提升泛化性能。

  2. 设备一致性
    需确保inputstargets位于同一设备(CPU/GPU)。当前代码依赖use_gpu参数控制,建议改进为自动匹配设备。

  3. 计算等价性
    .mean(0).sum()等价于对全体样本的损失求平均,与传统交叉熵计算方式一致。


扩展思考

  • 参数调优:平滑因子 ϵ \epsilon ϵ通常取0.1,但可根据任务调整。较大的 ϵ \epsilon ϵ适合类别噪声较多的场景。
  • 结合其他技术:可与其他正则化方法(如Dropout、权重衰减)结合使用,进一步提升模型效果。
http://www.xdnf.cn/news/6830.html

相关文章:

  • AIGC在电商行业的应用:革新零售体验
  • 计算机网络(1)——概述
  • Docker入门指南:镜像、容器与仓库的核心概念解析
  • Redis的Hot Key自动发现与处理方案?Redis大Key(Big Key)的优化策略?Redis内存碎片率高的原因及解决方案?
  • STM32 | FreeRTOS 递归信号量
  • C# 深入理解类(静态函数成员)
  • golang中的反射示例
  • 大模型AI原生应用效果测试与评估视频课来啦
  • Python多进程编程执行任务
  • sudo apt update是什么意思呢?
  • (3)python爬虫--Xpath
  • 2022河南CCPC(前四题)
  • pip升级或者安装报错怎么办?
  • 致敬经典 << KR C >> 之打印输入单词水平直方图和以每行一个单词打印输入 (练习1-12和练习1-13)
  • 最小二乘法拟合直线,用线性回归法、梯度下降法实现
  • SLAM定位常用地图对比示例
  • 【深度学习新浪潮】大模型时代,我们还需要学习传统机器学习么?
  • 计算机视觉与深度学习 | Python实现EMD-VMD-LSTM时间序列预测(完整源码和数据)
  • React Flow 节点事件处理实战:鼠标 / 键盘事件全解析(含节点交互代码示例)
  • 跨国应用程序的数据存储方案常见的解决方案
  • R语言空间数据处理入门教程
  • Redis——过期删除策略和内存
  • golang读、写、复制、创建目录、删除、重命名,文件方法总结
  • AI517 AI本地部署 docker微调(失败)
  • Baklib知识中台构建企业智能服务新引擎
  • 板凳-------Mysql cookbook学习 (二)
  • 【新能源轻卡行驶阻力模型参数计算实战:从国标试验到续航优化】
  • Linux | mdadm 创建软 RAID
  • C# WPF .NET Core和.NET5之后引用System.Windows.Forms的解决方案
  • 服务间的“握手”:OpenFeign声明式调用与客户端负载均衡