当前位置: 首页 > web >正文

PyTorch 损失函数详解:从理论到实践

目录

一、损失函数的基本概念

二、常用损失函数及实现

1. 均方误差损失(MSELoss)

2. 平均绝对误差损失(L1Loss/MAELoss)

3. 交叉熵损失(CrossEntropyLoss)

4. 二元交叉熵损失(BCELoss)

三、损失函数选择指南

四、损失函数在训练中的应用

五、总结


损失函数是深度学习模型训练的核心组件,它量化了模型预测值与真实值之间的差异,指导模型参数的更新方向。本文将结合 PyTorch 代码实例,详细讲解常用损失函数的原理、适用场景及实现方法。

一、损失函数的基本概念

损失函数(Loss Function)又称代价函数(Cost Function),是衡量模型预测结果与真实标签之间差异的指标。在模型训练过程中,通过优化算法(如梯度下降)最小化损失函数,使模型逐渐逼近最优解。

损失函数的选择取决于具体任务类型:

  • 回归任务:预测连续值(如房价、温度)
  • 分类任务:预测离散类别(如图片分类、垃圾邮件识别)
  • 其他任务:如生成任务、序列标注等

二、常用损失函数及实现

1. 均方误差损失(MSELoss)

均方误差损失是回归任务中最常用的损失函数,计算预测值与真实值之间平方差的平均值。

数学公式MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 其中,y_i为真实值,\hat{y}_i为预测值,n为样本数量。

代码实现

import torch
import torch.nn as nn# 初始化MSE损失函数
mse_loss = nn.MSELoss()# 示例数据
y_true = torch.tensor([3.0, 5.0, 2.5])  # 真实值
y_pred = torch.tensor([2.5, 5.0, 3.0])  # 预测值# 计算损失
loss = mse_loss(y_pred, y_true)
print(f'MSE Loss: {loss.item()}')  # 输出:MSE Loss: 0.0833333358168602

特点

  • 对异常值敏感,因为会对误差进行平方
  • 是凸函数,存在唯一全局最小值
  • 适用于大多数回归任务

2. 平均绝对误差损失(L1Loss/MAELoss)

平均绝对误差计算预测值与真实值之间绝对差的平均值,对异常值的敏感性低于 MSE。

数学公式MAE = \frac{1}{n} \sum_{i=1}^{n} |y_i - \hat{y}_i|

代码实现

# 初始化L1损失函数
l1_loss = nn.L1Loss()# 计算损失
loss = l1_loss(y_pred, y_true)
print(f'L1 Loss: {loss.item()}')  # 输出:L1 Loss: 0.25

特点

  • 对异常值更稳健
  • 梯度在零点处不连续,可能影响收敛速度
  • 适用于存在异常值的回归场景

3. 交叉熵损失(CrossEntropyLoss)

交叉熵损失是多分类任务的标准损失函数,在 PyTorch 中内置了 Softmax 操作,直接作用于模型输出的 logits。

数学公式CrossEntropyLoss = -\sum_{i=1}^{C} y_i \log(\hat{y}_i) 其中,C为类别数,y_i为真实标签的 one-hot 编码,\hat{y}_i为经过 Softmax 处理的预测概率。

代码实现

def test_cross_entropy():# 模型输出的logits(未经过softmax)logits = torch.tensor([[1.5, 2.0, 0.5], [0.5, 1.0, 1.5]])# 真实标签(类别索引)labels = torch.tensor([1, 2])  # 第一个样本属于类别1,第二个样本属于类别2# 初始化交叉熵损失函数criterion = nn.CrossEntropyLoss()loss = criterion(logits, labels)print(f'Cross Entropy Loss: {loss.item()}')  # 输出:Cross Entropy Loss: 0.6422222256660461test_cross_entropy()

计算过程解析

  1. 对 logits 应用 Softmax 得到概率分布
  2. 计算真实类别对应的负对数概率
  3. 取平均值作为最终损失

特点

  • 自动包含 Softmax 操作,无需手动添加
  • 适用于多分类任务(类别互斥)
  • 标签格式为类别索引(非 one-hot 编码)

4. 二元交叉熵损失(BCELoss)

二元交叉熵损失用于二分类任务,需要配合 Sigmoid 激活函数使用,确保输入值在 (0,1) 范围内。

数学公式BCELoss = -\frac{1}{n} \sum_{i=1}^{n} [y_i \log(\hat{y}_i) + (1-y_i) \log(1-\hat{y}_i)]

代码实现

def test_bce_loss():# 模型输出(已通过sigmoid处理)y_pred = torch.tensor([[0.7], [0.2], [0.9], [0.7]])# 真实标签(0或1)y_true = torch.tensor([[1], [0], [1], [0]], dtype=torch.float)# 方法1:使用BCELossbce_loss = nn.BCELoss()loss1 = bce_loss(y_pred, y_true)# 方法2:使用functional接口loss2 = nn.functional.binary_cross_entropy(y_pred, y_true)print(f'BCELoss: {loss1.item()}')  # 输出:BCELoss: 0.47234177589416504print(f'Functional BCELoss: {loss2.item()}')  # 输出:Functional BCELoss: 0.47234177589416504test_bce_loss()

变种:BCEWithLogitsLoss
对于未经过 Sigmoid 处理的 logits,推荐使用BCEWithLogitsLoss,它内部会自动应用 Sigmoid,数值稳定性更好:

# 对于logits输入(未经过sigmoid)
logits = torch.tensor([[0.8], [-0.5], [1.2], [0.6]])
bce_with_logits_loss = nn.BCEWithLogitsLoss()
loss = bce_with_logits_loss(logits, y_true)

三、损失函数选择指南

任务类型推荐损失函数特点
回归任务MSELoss对异常值敏感,适用于大多数回归场景
回归任务(含异常值)L1Loss对异常值稳健,梯度不连续
多分类任务CrossEntropyLoss内置 Softmax,处理互斥类别
二分类任务BCELoss/BCEWithLogitsLoss配合 Sigmoid 使用,输出概率值
多标签分类BCEWithLogitsLoss每个类别独立判断,可同时属于多个类别

四、损失函数在训练中的应用

以图像分类任务为例,展示损失函数在完整训练流程中的使用:

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 定义简单的全连接网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28)x = torch.relu(self.fc1(x))x = self.fc2(x)  # 输出logits,不使用softmaxreturn x# 初始化模型、损失函数和优化器
model = SimpleNet()
criterion = nn.CrossEntropyLoss()  # 多分类任务
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
def train(epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()# 打印每轮的平均损失avg_loss = running_loss / len(train_loader)print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')train()

五、总结

损失函数的选择直接影响模型的训练效果和收敛速度,关键要点:

  1. 回归任务优先选择 MSELoss,存在异常值时考虑 L1Loss
  2. 多分类任务使用 CrossEntropyLoss,无需手动添加 Softmax
  3. 二分类任务推荐使用 BCEWithLogitsLoss,数值稳定性更好
  4. 训练过程中需监控损失变化,判断模型是否收敛或过拟合

合理选择损失函数并配合适当的优化器,才能充分发挥模型的学习能力。在实际应用中,可根据具体任务特点和数据分布尝试不同的损失函数,选择表现最佳的方案。

http://www.xdnf.cn/news/15628.html

相关文章:

  • Qt小组件 - 7 SQL Thread Qt访问数据库ORM
  • Uniapp中双弹窗为什么无法显示?
  • 玩转Docker | 使用Docker部署bender个人导航页工具
  • 利用Java自定义格式,循环导出数据、图片到excel
  • 【论文阅读 | CVPR 2023 |CDDFuse:基于相关性驱动的双分支特征分解的多模态图像融合】
  • lua(xlua)基础知识点记录
  • 【前端】在Vue3中绘制多系列柱状图与曲线图
  • 量子比特耦合与系统集成:量子计算硬件的核心突破
  • 入门华为数通,HCIA/HCIP/HCIE该怎么选?
  • 2025年自动化工程、物联网与计算机应用国际会议(AEITCA 2025)
  • Java基础:分支/循环/数组
  • PLC-BMS电力载波通信技术深度解析:智能电网与储能系统的融合创新
  • 【WRFDA数据第一期】WRFDA Free Input 数据网页
  • Spring Boot 整合 Nacos 实战教程:服务注册发现与配置中心详解
  • 【后端】.NET Core API框架搭建(6) --配置使用MongoDB
  • 微软AutoGen:多智能体协作的工业级解决方案
  • PyCharm高效入门
  • NodeJS Express 静态文件、中间件、路由案例
  • 手撕Spring底层系列之:IOC、AOP
  • java操作Excel两种方式EasyExcel 和POI
  • 跟着Carl学算法--回溯【2】
  • React Hooks 数据请求库——SWR使用详解
  • Spring AI 系列之十四 - RAG-ETL之一
  • Vue3+Ts实现父子组件间传值的两种方式
  • Unity Android Logcat插件 输出日志中文乱码解决
  • 小白成长之路-Elasticsearch 7.0 配置
  • BNN 技术详解:当神经网络只剩下 +1 和 -1
  • 基于redis的分布式锁 lua脚本解决原子性
  • 免杀学习篇(1)—— 工具使用
  • 网页源码保护助手 海洋网页在线加密:HTML 源码防复制篡改,密文安全如铜墙铁壁