当前位置: 首页 > news >正文

nn.LayerNorm():对输入张量的最后一个维度(特征维度)进行归一化

nn.LayerNorm 是 PyTorch 中的一个归一化层,用于对输入数据进行归一化处理。通常用于处理序列数据或任何需要在特征维度上进行归一化的场景。

1. 张量的维度

在深度学习中,张量(Tensor)是一个多维数组,用于表示数据。张量的维度通常有不同的含义,具体取决于应用场景。例如:

  • 二维张量:形状为 (batch_size, feature_size),通常用于表示一批样本的特征。
  • 三维张量:形状为 (batch_size, sequence_length, feature_size),通常用于表示一批序列数据(如自然语言处理中的句子)。
  • 四维张量:形状为 (batch_size, channels, height, width),通常用于表示一批图像数据。

2. 特征维度

在上述张量中,“特征维度”通常是指最后一个维度,它表示每个样本或序列元素的特征向量。例如:

  • 对于二维张量 (batch_size, feature_size),特征维度是 feature_size
  • 对于三维张量 (batch_size, sequence_length, feature_size),特征维度是 feature_size
  • 对于四维张量 (batch_size, channels, height, width),特征维度是 channels

3. 归一化

归一化是一种数据预处理方法,目的是将数据转换为具有零均值和单位方差的分布。归一化的公式如下:

Norm ( x ) = x − μ σ 2 + ϵ \text{Norm}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} Norm(x)=σ2+ϵ xμ

其中:

  • x x x 是输入数据。
  • μ \mu μ 是输入数据的均值。
  • σ 2 \sigma^2 σ2 是输入数据的方差。
  • ϵ \epsilon ϵ 是一个小的常数,用于防止除以零。

4. 对最后一个维度(特征维度)进行归一化

当我们说“对输入张量的最后一个维度(特征维度)进行归一化”时,意味着我们对每个样本或序列元素的特征向量进行归一化。具体来说:

  • 二维张量:对每个样本的特征向量进行归一化。
  • 三维张量:对每个序列元素的特征向量进行归一化。
  • 四维张量:对每个像素的通道值进行归一化。

eg:

假设有一个三维张量,形状为 (2, 3, 4),表示 2 个样本,每个样本有 3 个序列元素,每个序列元素有 4 个特征。我们使用 nn.LayerNorm 对最后一个维度(特征维度)进行归一化:

import torch
import torch.nn as nn# 输入张量
input_tensor = torch.randn(2, 3, 4)  # (batch_size, sequence_length, feature_size)# 创建 LayerNorm 层
layer_norm = nn.LayerNorm(normalized_shape=4)  # 对最后一个维度进行归一化# 应用 LayerNorm
normalized_tensor = layer_norm(input_tensor)
print(normalized_tensor)

在这个例子中:

  • 输入张量的形状是 (2, 3, 4)
  • normalized_shape=4 表示我们对最后一个维度(特征维度)进行归一化。
  • LayerNorm 会分别对每个序列元素的特征向量进行归一化,计算每个特征向量的均值和方差,并进行归一化处理。

具体步骤

  1. 计算均值和方差:对每个序列元素的特征向量计算均值和方差。
    • 对于第一个样本的第一个序列元素,计算其 4 个特征的均值和方差。
    • 对于第一个样本的第二个序列元素,计算其 4 个特征的均值和方差。
    • 依此类推。
  2. 归一化:使用上述公式对每个特征向量进行归一化。
    • 归一化后的特征向量具有零均值和单位方差。
  3. 应用缩放和偏移(可选):如果 elementwise_affine=True,还会对归一化后的特征向量进行缩放和偏移。

为什么这样做?

对特征维度进行归一化有以下好处:

  • 稳定训练:归一化可以防止特征值的范围差异过大,从而稳定训练过程。
  • 加速收敛:归一化后的数据更容易优化,可以加速模型的收敛。
  • 适应不同特征范围:不同特征可能有不同的范围,归一化可以将它们统一到相同的尺度。
http://www.xdnf.cn/news/62983.html

相关文章:

  • 【目标检测】目标检测综述 目标检测技巧
  • 全球首个人形机器人半程马拉松技术分析:翻车名场面背后的突破与挑战
  • DeepSeek赋能Nuclei:打造网络安全检测的“超级助手”
  • 量化研究---小果全球大类低相关性动量趋势增强轮动策略实盘设置
  • ThinkPHP5 的 SQL 注入漏洞
  • 【时时三省】(C语言基础)循环的嵌套和几种循环的比较
  • STM32——新建工程并使用寄存器以及库函数进行点灯
  • DeepSeek 大模型 + LlamaIndex + MySQL 数据库 + 知识文档 实现简单 RAG 系统
  • electron从安装到启动再到打包全教程
  • Python 网络编程:TCP 与 UDP 协议详解及实战代码
  • uni-app 开发企业级小程序课程
  • LangChain、LlamaIndex 和 ChatGPT 的详细对比分析及总结表格
  • 【Flink SQL实战】 UTC 时区格式的 ISO 时间转东八区时间
  • 2025.04.20【Lollipop】| Lollipop图绘制命令简介
  • python——函数
  • Unocss 类名基操, tailwindcss 类名
  • 分数线降低,25西电马克思主义学院(考研录取情况)
  • RESTful学习笔记(一)
  • 国产仪器进化论:“鲁般号”基于无人机的天线测试系统
  • 微软Entra新安全功能引发大规模账户锁定事件
  • 【Vue】组件基础
  • Linux系统下docker 安装 redis
  • Mybatis延迟加载、懒加载、二级缓存
  • 统计图表ECharts
  • 2025年世界职业院校技能大赛实施方案(意见稿)
  • 【单片机 C语言】单片机学习过程中常见C库函数(学习笔记)
  • 由Ai生成的Linux 入门到精通学习路径
  • 记录seatunnel排查重复数据的案例分析
  • ESP8266_ESP32 Smartconfig一键配网功能
  • qt 配置 mysql 驱动问题:Cannot load library qsqlmysql;QMYSQL driver not loaded