当前位置: 首页 > java >正文

实现层归一化

五、Layer Normalization

层归一化介绍

  • 层归一化(LayerNorm)通过对同一层的神经元输出进行归一化处理,有效提升模型训练稳定性。与BatchNorm不同,LayerNorm对单个样本的所有特征维度进行归一化,使其对序列长度变化具有更好的适应性。

    当处理变长序列数据时,LayerNorm保持每个时间步的独立计算特性,避免不同序列长度带来的统计量偏差。该操作通过可学习的缩放参数gamma和平移参数beta保留模型的表达能力。

数学公式

  • 层归一化对输入张量最后一个维度进行标准化处理:

    μ = 1 d m o d e l ∑ i = 1 d m o d e l x i σ 2 = 1 d m o d e l ∑ i = 1 d m o d e l ( x i − μ ) 2 out = γ ⋅ x − μ σ 2 + ϵ + β \mu = \frac{1}{d_{model}}\sum_{i=1}^{d_{model}}x_i \\ \sigma^2 = \frac{1}{d_{model}}\sum_{i=1}^{d_{model}}(x_i - \mu)^2 \\ \text{out} = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta μ=dmodel1i=1dmodelxiσ2=dmodel1i=1dmodel(xiμ)2out=γσ2+ϵ xμ+β
    其中:

    • γ ∈ R d _ m o d e l \gamma \in \mathbb{R}^{d\_{model}} γRd_model:可学习缩放参数(初始化为1)
    • β ∈ R d _ m o d e l \beta \in \mathbb{R}^{d\_{model}} βRd_model:可学习平移参数(初始化为0)
    • ϵ \epsilon ϵ:数值稳定系数(默认1e-12)

    d _ m o d e l d\_model d_model 为模型维度。

代码实现

  • 层归一化代码实现

    import torch
    from torch import nnclass LayerNorm(nn.Module):def __init__(self, d_model, eps=1e-12):super(LayerNorm, self).__init__()self.d_model = d_modelself.eps = eps# 可学习参数初始化self.gamma = nn.Parameter(torch.ones(d_model))  # 缩放参数self.beta = nn.Parameter(torch.zeros(d_model))   # 平移参数def forward(self, x):"""shape of x:              [batch_size,seq_len,d_model]shape of mean and var :  [batch_size,seq_len,1]shape of gamma and beta: [d_model]"""# 步骤1:计算最后一个维度的均值和方差mean = x.mean(dim=-1, keepdim=True)var = x.var(dim=-1, unbiased=False,  # 使用有偏方差估计(与PyTorch官方实现保持一致)keepdim=True # 保持维度对其)# 步骤2:标准化计算normalized = (x - mean) / torch.sqrt(var + self.eps)# 步骤3:仿射变换out = self.gamma * normalized + self.betareturn out
    
  • 注意

    1. 在统计学中,“unbiased”(无偏)通常指的是一个估计量,其期望值等于所估计的参数值。对于方差的计算,有两种常用的方法:无偏估计和最大似然估计。

      无偏估计(Bessel’s Correction):在计算样本方差时,通常使用无偏估计,其公式为:
      Var = 1 N − 1 ∑ i = 1 N ( x i − x ˉ ) 2 \text{Var} = \frac{1}{N-1} \sum_{i=1}^{N} (x_i - \bar{x})^2 Var=N11i=1N(xixˉ)2
      这里我们除以的是 (N-1),而不是样本数量 (N)。这是因为用 (N-1) 除法能够补偿样本方差相对于总体方差的系统性低估,也就是所谓贝塞尔校正(Bessel’s correction)。

      有偏估计(或最大似然估计,MLE):这是普遍用于深度学习中的方法,就是直接使用:
      Var = 1 N ∑ i = 1 N ( x i − x ˉ ) 2 \text{Var} = \frac{1}{N} \sum_{i=1}^{N} (x_i - \bar{x})^2 Var=N1i=1N(xixˉ)2

      其中 N 为样本总量,延伸到我们transformer中就是 d _ m o d e l d\_model d_model

    2. 使用有偏估计可以减小计算的偏差,提升计算的效率和速度,这个同事也在深度学习过程中允许更快的梯度更新,虽然有轻微的偏差,但通常在大量数据下这并不显著影响模型的训练效果。

    3. unbiased=False 是指定要使用有偏估计来计算方差。这默认为 False,偏向于加速计算并使得在深度学习环境下的表现更加稳定。

  • 维度计算流程

    操作步骤张量形状变化示例
    输入数据[batch_size, seq_len, d_model]
    步骤1 - 计算均值(dim=-1)[batch_size, seq_len, 1]
    步骤1 - 计算方差(dim=-1)[batch_size, seq_len, 1]
    步骤2 - 标准化计算[batch_size, seq_len, d_model]
    步骤3 - 线性变换(gamma/beta)[batch_size, seq_len, d_model]

使用示例

  • 测试代码

    if __name__ == "__main__":# 参数配置batch_size = 4seq_len = 100d_model = 512# 生成测试数据x = torch.randn(batch_size, seq_len, d_model)# 实例化层归一化模块layer_norm = LayerNorm(d_model=d_model)# 前向传播out = layer_norm(x)# 验证输出print("输入形状:", x.shape)        # torch.Size([4, 100, 512])print("输出形状:", out.shape)      # torch.Size([4, 100, 512])print("参数gamma形状:", layer_norm.gamma.shape)  # torch.Size([512])print("参数beta形状:", layer_norm.beta.shape)    # torch.Size([512])
    

http://www.xdnf.cn/news/1961.html

相关文章:

  • 数据结构------C语言经典题目(7)
  • 【T-MRMSM】文本引导多层次交互多尺度空间记忆融合多模态情感分析
  • 基于cesium实现鼠标移动动态绘制矩形和圆
  • Rust 学习笔记:函数和控制流
  • React 中什么时候用事件总线
  • 微信小程序直传阿里云 OSS 实践指南(V4 签名 · 秒传支持 · 高性能封装)
  • ROS1、ROS2如何把预编译好的二进制文件封装成功能包?
  • 【Django】新增字段后兼容旧接口 This field is required
  • 代码随想录:数组
  • 如何实现Android屏幕和音频采集并启动RTSP服务?
  • 如何使用@KafkaListener实现从nacos中动态获取监听的topic
  • 【Hive入门】Hive数据导出完全指南:从HDFS到本地文件系统的专业实践
  • 利用JMeter代理服务器方式实现高效压测
  • Leetcode 2845 题解
  • C++_数据结构_详解红黑树
  • 微软官网Win10镜像下载快速获取ISO文件
  • 第18章:MCP在创作领域中的应用
  • Java集成Redisson实现分布式锁(实战)
  • 学生管理系统微服务方式实现
  • WebUI可视化:第3章:Gradio入门实战
  • FerretDB:基于PostgreSQL的MongoDB替代产品
  • 2、Ubuntu 环境下安装RabbitMQ
  • PDFMathTranslate:基于LLM的PDF文档翻译及双语对照的工具【使用教程】
  • Golang | 迭代器模式
  • 学习整理在centos7上安装mysql8.0版本教程
  • 同步定时器的用户数要和线程组保持一致,否则jmeter会出现接口不执行’stop‘和‘×’的情况
  • 基于线性LDA算法对鸢尾花数据集进行分类
  • 【uniapp】vue2 搜索文字高亮显示
  • 【Java】BitSet简介
  • 10.ArkUI Grid的介绍和使用