当前位置: 首页 > news >正文

使用pytorch创建模型时,nn.BatchNorm1d(128)的作用是什么?

在PyTorch中,nn.BatchNorm1d(128) 的作用是对 一维输入数据(如全连接层的输出或时间序列数据)进行批标准化(Batch Normalization),具体功能与实现原理如下:

1. 核心作用

  • 标准话数据分布
    对每个批次的输入数据进行归一化,使其均值接近0、方差接近1,公式如下:
    x^=x−μbatchσbatch2+e\hat{\mathbf{x}}=\frac{\mathbf{x}-\mathbf{\mu}_{batch}}{\sqrt{\sigma^{2}_{batch}+e}}x^=σbatch2+exμbatch
    其中:
    • μbatch\mu_{batch}μbatch:当前批次的均值
    • σbatch\sigma_{batch}σbatch:当前批次的方差
    • eee: 防止除零的小常数(默认1e-5)
  • 可学习的缩放与偏移:
    通过参数γ\gammaγ (缩放)和 β\betaβ(偏移)保留模型的表达能力:
    y=γx^+β y = \gamma \hat{\mathbf{x}}+\beta y=γx^+β

2. 参数解释

在这里插入图片描述

3. 全连接网络应用场景

import torch.nn as nnmodel = nn.Sequential(nn.Linear(64, 128),nn.BatchNorm1d(128),  # 对128维特征归一化nn.ReLU(),nn.Linear(128, 10)
)

数学效果:
若输入特征x∈Rm×128\mathbf{x}\in \mathbb{R}^{m\times128}xRm×128,输出yyy满足:
E[y:j]≈0,Var(y:,j)≈1 \mathbb{E}[y_{:j}]\approx0, Var(y_{:,j})\approx1 E[y:j]0,Var(y:,j)1

4. 与其他归一化层的对比

在这里插入图片描述

5. 训练与推理的差异

  • 训练阶段
    使用当前批次的统计量μbatch\mu_{batch}μbatchσbatch2\sigma_{batch}^2σbatch2,并更新全局统计量:
    μrunnning←μrunning×(1−momentum)+μbatch×momentum\mu_{runnning} \leftarrow \mu_{running}\times(1-momentum) + \mu_{batch}\times momentumμrunnningμrunning×(1momentum)+μbatch×momentum
  • 推理阶段(测试阶段)
    固定使用训练积累的全局统计量μbatch\mu_{batch}μbatchσbatch2\sigma_{batch}^2σbatch2
    KaTeX parse error: Undefined control sequence: \sigmma at position 54: …unning}}{\sqrt{\̲s̲i̲g̲m̲m̲a̲^{2}_{running}+…

6. 代码战争数学性质

import torch# 模拟输入(batch_size=4, 128维特征)
x = torch.randn(4, 128) * 2 + 1  # 均值1,方差4bn = nn.BatchNorm1d(128, affine=False)  # 禁用γ和β
output = bn(x)print("输入均值:", x.mean(dim=0).mean().item())   # ≈1
print("输出均值:", output.mean(dim=0).mean().item())  # ≈0
print("输入方差:", x.var(dim=0).mean().item())    # ≈4
print("输出方差:", output.var(dim=0).mean().item())  # ≈1
http://www.xdnf.cn/news/1156789.html

相关文章:

  • Muduo库中单例模式详解
  • Mysql(事务)
  • 小型支付项目3-5:检测未接收到或未正确处理的支付回调通知
  • UE5多人MOBA+GAS 番外篇:移植Lyra的伤害特效(没用GameplayCue,因为我失败了┭┮﹏┭┮)
  • 音视频学习(四十一):H264帧内压缩技术
  • 【Vue进阶学习笔记】Vue 路由入门指南
  • 单线程 Reactor 模式
  • 动静态库的制作和原理
  • 【unitrix】 6.10 类型转换(from.rs)
  • [BUG]关于UE5.6编译时出现“Microsoft.MakeFile.Targets(44,5): Error MSB3073”问题的解决
  • 【软件测试】从软件测试到Bug评审:生命周期与管理技巧
  • VUE2 学习笔记2 数据绑定、数据代理、MVVM
  • 【数据结构】第一讲 —— 概论
  • 基于Arduino的智能寻迹小车设计
  • 剑指offer——链表:旋转数组的最小数字
  • 【OD机试】池化资源共享
  • 「Java案例」利用方法求反素数
  • Ubuntu挂载和取消挂载
  • LP-MSPM0G3507学习--07定时器之二定时节拍
  • ZYNQ平台深度剖析:EMMC/FLASH/SD卡性能测试与创新实践
  • 从磁记录到数据中心:磁盘原理与服务器架构的完整技术链路
  • 两个数据表的故事:第 1 部分
  • Spring之事务使用指南
  • Java行为型模式---解释器模式
  • Openlayers 面试题及答案180道(121-140)
  • Node.js Express keep-alive 超时时间设置
  • @import导入css样式、scss变量用法、static目录
  • Java中List<int[]>()和List<int[]>[]的区别
  • PAT 1049 Counting Ones
  • 医学图像超分辨率重建深度学习模型开发报告