当前位置: 首页 > java >正文

面试常问系列(一)-神经网络参数初始化-之-softmax

背景

本文内容还是对之前关于面试题transformer的一个延伸,详细讲解一下softmax

面试常问系列(二)-神经网络参数初始化之自注意力机制-CSDN博客

Softmax函数的梯度特性与输入值的幅度密切相关,这是Transformer中自注意力机制需要缩放点积结果的关键原因。以下从数学角度展开分析:

1. Softmax 函数回顾

给定输入向量 z = [z₁, z₂, ..., zₖ],Softmax 输出概率为:

\sigma(z)_i = \frac{e^{z_i}}{\sum_{j=1}^{k}e^{z_j}} =\frac{e^{z_i}}{S},S=\sum_{j=1}^{k}e^{z_j}                        

其中 S 是归一化因子。

2. 梯度计算目标

计算 Softmax 对输入 z 的梯度,即 \frac{\delta \sigma_i}{\delta z_j}对所有 i,j∈{1,…,k}。

3. 梯度推导

根据链式法则,对 σi​ 关于 zj​ 求导:

\frac{\delta \sigma_i}{\delta z_j} = \left\{\begin{matrix} &\sigma_i(1-\sigma_j)) &if &i=j, \\ & -\sigma_i\sigma_j &if &i\neq j, \end{matrix}\right.

具体推到过程就不展示了,感兴趣的有需要的可以评论下。因为本次重点不是通用的softmax分析,而是偏实战分析。

4. 与交叉熵损失结合的梯度

在实际应用中,Softmax 通常与交叉熵损失L = \sum_{i=1}^{n}y_i*log\sigma_i 结合使用。此时梯度计算更简单:

\frac{\partial L }{\partial z_j} = \sigma(z_j)-y_j

其中 y_j是真实标签的 one-hot 编码。

5. 推导

  1. 交叉熵损失对 ​\sigma_i 的梯度:

\frac{\partial L }{\partial\sigma_i} = -\frac{y_i}{\sigma_i}

    2. 通过链式法则:

\frac{\partial L }{\partial z_j} =\sum_{i} \frac{\partial L }{\partial \sigma_i}\frac{\partial \sigma_i }{\partial z_j}=\sum_{i} -\frac{y_i }{\sigma_i}\frac{\partial \sigma_i }{\partial z_j}

    3. 代入在上面求解出的\frac{\delta \sigma_i}{\delta z_j}

  • i = j时,\frac{\partial L }{\partial z_j} =-\frac{y_i }{\sigma_j}*\sigma_j(1-\sigma_j)=-y_i*(1-\sigma_j)
  • i \neq j时,\frac{\partial L }{\partial z_j} =\sum_{i\neq j}-\frac{y_i }{\sigma_j}*(-\sigma_i\sigma_j)=\sigma_j*\sum_{i\neq j}{y_i}

    4.合并上述结果

\frac{\partial L }{\partial z_j} =-y_j*(1-\sigma_j) + \sigma_j*(1-y_j)=\sigma_j-y_j

6. 梯度消失问题

  • 极端输入值:若z_k远大于其他z_i,则\sigma (z_k) \approx 1,其他\sigma (z_i) \approx 0。此时:
    • z_k的梯度:-y_k*(1-\sigma_{z_k}) \approx 0(若yk​=1,梯度接近0)。
    • 对其他zi​的梯度:\sigma (z_i) \approx 0, \sigma_j*\sum_{i\neq j}{y_i} \approx 0,梯度趋近于0。
  • 后果:梯度消失导致参数更新困难,模型难以训练。

7. 缩放的作用

在Transformer中,点积结果除以dk​​后:

  • 输入值范围受限:缩放后z_i的方差为1,避免极端值。
  • 梯度稳定性提升\sigma (z_i)分布更均匀,-y_k*(1-\sigma_{z_k})\sigma (z_i)不会趋近于0,梯度保持有效。

5. 直观示例

  • 未缩放:若dk​=512,点积标准差结果可能达±22,Softmax输出接近0或1,梯度消失。
  • 缩放后:点积结果范围约±5,σ(zi​)分布平缓,梯度稳定。
  • 这个示例在最开始的跳转链接有详细解释,可以参考。

总结

Softmax的梯度对输入值敏感,过大输入会导致梯度消失。Transformer通过除以dk​​控制点积方差,确保Softmax输入值合理,从而保持梯度稳定,提升训练效率。这一设计是深度学习中处理高维数据时的重要技巧。

http://www.xdnf.cn/news/4268.html

相关文章:

  • java springboot解析出一个图片的多个二维码
  • 软考-软件设计师中级备考 13、刷题 数据结构
  • k8s node soft lockup (内核软死锁) 优化方案
  • python3使用:macOS上通过Homebrew安装pip库
  • linux 如何防止内存碎片化?
  • C#中不能通过new关键字创建实例的情况
  • conda虚拟环境相关操作
  • LeetCode 热题 100 39. 组合总和
  • Jetpack Compose 自定义 Slider 完全指南
  • QT键盘触发按钮
  • laravel 12 监听syslog消息,并将消息格式化后存入mongodb
  • 深度解析:2D 写实交互数字人 —— 开启智能交互新时代
  • API 开发实战:基于京东开放平台的实时商品数据采集接口实现
  • 【25软考网工】第五章(6)TCP和UDP协议、流量控制和拥塞控制、重点协议与端口
  • 项目中为什么选择RabbitMQ
  • Vision-Language Models (VLMs) 视觉语言模型的技术背景、应用场景和商业前景(Grok3 DeepSearch模式回答)
  • 隔离端口配置
  • 消除AttributeError: module ‘ttsfrd‘ has no attribute ‘TtsFrontendEngine‘报错输出的记录
  • 2015-2018年 重要城市交通拥堵指数-社科数据
  • Ragflow服务器上部署教程
  • 前端、XSS(跨站脚本攻击,Cross-Site Scripting)
  • 深入理解 Oracle 数据块:行迁移与行链接的性能影响
  • 互联网大厂Java求职面试:云原生与AI融合下的系统设计挑战-2
  • 网络编程核心技术解析:从Socket基础到实战开发
  • 在Spring Boot 中如何配置MongoDB的副本集 (Replica Set) 或分片集群 (Sharded Cluster)?
  • C++ STL 基础与多线程安全性说明文档
  • 如何开发一个笑话管理小工具
  • 盛最多水的容器
  • conda 安装cudnn
  • SpringBoot中使用MCP和通义千问来处理和分析数据