【笔记】BCEWithLogitsLoss
工作原理
BCEWithLogitsLoss
是 PyTorch 中的一个损失函数,用于二分类问题。
它结合了 Sigmoid 激活函数和二元交叉熵(Binary Cross Entropy, BCE)损失在一个类中。
这不仅简化了代码,而且通过数值稳定性优化提高了模型训练的效率和效果。
使用方法
import torch
import torch.nn as nn# 假设我们有一个批次大小为32,单通道,高度和宽度分别为64的图像
inputs = torch.randn(32, 1, 64, 64) # 这是模型的输出(logits)
targets = torch.empty(32, 1, 64, 64).random_(2) # 随机生成的目标(0或1)# 创建损失函数实例
criterion = nn.BCEWithLogitsLoss()# 计算损失
loss = criterion(inputs, targets)print(f"Loss: {loss.item():.4f}")
需要注意的是,inputs和targets应该格式匹配
注意事项
由于BCEWithLogitsLoss 已经内置了Sigmoid函数,所以不需要显示的再应用sigmoid函数
seg_maps = model(images) # 输出是 logits(不需要激活)loss = criterion_segment(seg_maps, masks.unsqueeze(1).float()) # 直接输入 logits