当前位置: 首页 > ai >正文

深入理解 Pre-LayerNorm :让 Transformer 训练更稳

摘要

在超深 Transformer 与大语言模型(LLM)时代,归一化策略直接决定了模型能否稳定收敛、推理性能能否最大化。把归一化层从 “残差之后” 挪到 “子层之前”(Pre-LayerNorm,Pre-LN),再将传统 LayerNorm 简化为 RMSNorm——只做均方根缩放、不再减均值——是 GPT-3、LLaMA-4、DeepSeek-V3 等主流 LLM 的标准做法。Pre-LN 让每一层在进入注意力或前馈前就保持单位尺度,显著缓解梯度爆炸/消失;RMSNorm 进一步减少 7-64 % 归一化 FLOPs,同时保持收敛性能。本文先对比 Post-LN 与 Pre-LN 的梯度流,再解释 RMSNorm 的数学原理,最后给出 PyTorch 伪代码

Pre-LayerNorm(Pre-LN) 结构里,输入向量 x 会先经过 LayerNorm(或 RMSNorm)再送入 Masked Multi-Head Attention;注意力子层完成后再与原始 x 做残差相加。这与原始 Transformer(Post-LN)“先算子层→残差→再 LayerNorm”的顺序正好相反。


1 Pre-LN 子层的计算流程

# 以解码器的 Masked Multi-Head Attention 为例
norm_x  = LN(x)                         # ① 先归一化
att_out = MHA(norm_x, norm_x, norm_x)   # ② 计算 Q/K/V 并做 Masked Attention
y       = x + att_out                   # ③ 残差相加
  • LayerNorm 放前:保证传入注意力的张量均值≈0、方差≈1,数值尺度固定。

  • 残差连接保留原信息:子层只需学习“增量”,梯度更易传播。


2 为什么要这样做?(逐步推理)

  1. 梯度稳定

    • Post-LN 时,梯度要先穿过注意力的大矩阵,再被 LayerNorm 缩放,深层模型易爆炸/消失。

    • Pre-LN 把归一化提前,子层输入始终单位方差,梯度连乘更稳,可支撑 100+ 层深度。

  2. 调参简单

    • 许多实践表明 Pre-LN 可直接使用较大学习率并把 warm-up 步数缩短到 0-500。

  3. 推理省显存

    • 由于不必保留 LayerNorm 前的大量激活以做反向梯度,训练峰值显存可再降 5-10 %。

3 梯度推理:为什么 Pre-LN 更稳?


3 与 Masked Multi-Head Attention 的关系

Decoder 的第 1 个注意力子层 中,需要“未来位屏蔽(mask)”。

  • Pre-LN 只改变 先归一化再 Attention 的顺序,并 不影响“上三角 mask”逻辑;掩码仍在 Softmax 前把未来得分置 -∞。

  • 这样既保持自回归条件,又享受梯度稳定优势。


4 代码模板(PyTorch ≥ 2.7)

class PreLNDecoderBlock(nn.Module):def __init__(self, d_model, n_heads):super().__init__()self.norm1 = nn.LayerNorm(d_model, eps=1e-6)      # 可改 nn.RMSNormself.mha    = nn.MultiheadAttention(d_model, n_heads,batch_first=True)self.norm2 = nn.LayerNorm(d_model, eps=1e-6)self.ffn    = nn.Sequential(nn.Linear(d_model, 4*d_model),nn.GELU(), nn.Linear(4*d_model, d_model))def forward(self, x, mask):# Masked Multi-Head Attentionx = x + self.mha(self.norm1(x), self.norm1(x),self.norm1(x), attn_mask=mask)[0]# Feed-Forwardx = x + self.ffn(self.norm2(x))return x
若需 RMSNorm 只要把 nn.LayerNorm 换成 nn.RMSNorm,其他接口不变。


5 参考文献

  1. S.H. Tsang, Pre-LN Transformer Review 2022 

  2. ApX ML Blog, Pre-LN vs Post-LN 2023 

  3. Vaswani et al., Attention Is All You Need 2017 (原始 Post-LN 结构)

  4. Sebastian Raschka Blog, Why Pre-LN Works Better 2022 

  5. GitHub issue #278 (nanoGPT) 讨论 Pre-LN 实现 2023 

  6. Medium, Masked Multi-Head Attention in Transformer 2024 

  7. StackOverflow #58127059 解读注意力 mask 2019 

  8. arXiv 2002.04745, On Layer Normalization in the Transformer Architecture 2020 

  9. arXiv 2502.02732, Peri-LN: Revisiting LayerNorm 2025 


结束

  • Pre-LayerNorm 把梯度问题“扼杀在源头”;

  • RMSNorm 在此基础上再省 7-64 % FLOPs;

  • 二者组合已是 LLM 标配。今晚就把 LayerNorm 换成 RMSNorm,让 GPU 算力用在刀刃上!

觉得有用 👍 点赞 / ⭐️ 收藏 / 💬 评论 / 🚀 转发三连支持一下,让更多工程师告别梯度爆炸的烦恼!
http://www.xdnf.cn/news/7764.html

相关文章:

  • Day123 | 灵神 | 二叉树 | 找树左下角的值
  • Vue3中插槽, pinia的安装和使用(超详细教程)
  • 物联网之使用Vertx实现UDP最佳实践【响应式】
  • DataOutputStream DataInputStream转换流
  • I.MX6U Mini开发板测试GPIO
  • Linux中进程控制(上)
  • 【Rust智能指针】Rust智能指针原理剖析与应用指导
  • C++初阶-vector的模拟实现3
  • vue原生table表格实现动态添加列,一行添加完换行继续添加。el-select输入框背景颜色根据所选内容不同而改变
  • BeamDojo: Learning Agile Humanoid Locomotion on Sparse Footholds
  • 如果教材这样讲--单片机IO口Additional Functions和 Alternate Functions的区别
  • 基于Android的XX校园交流APP
  • 工业路由器WiFi6+5G的作用与使用指南,和普通路由器对比
  • Veo 3 可以生成视频,并附带配乐
  • springboot项目读取dll
  • RT_Thread——快速入门
  • 电子电路:怎么理解放大电路中集电极电流Ic漂移?
  • 如何使用Java生成pdf报告
  • 面向恶劣条件的道路交通目标检测----大创自用(当然你也可以在里面学到很多东西)
  • 如何使用AI搭建WordPress网站
  • SAP-ABAP:ABAP异常处理与安全工程的融合 —— 构建防注入、防泄漏、合规审计的防御性编程体系
  • C# Prism框架详解:构建模块化WPF应用程序
  • 轩辕杯Wp
  • 【Java】泛型在 Java 中是怎样实现的?
  • java day14
  • debian系统redis-dump安装
  • Vite + Vue 工程中,为什么需要关注 `postcss.config.ts`?
  • 如何用JAVA手写一个Tomcat
  • c#基础03(运算符)
  • LeetCode 3355.零数组变换 I:差分数组