当前位置: 首页 > news >正文

【Torch】nn.GRU算法详解

1. 输入输出

  1. 输入张量

    • 默认形状:(seq_len, batch_size, input_size)
    • batch_first=True(batch_size, seq_len, input_size)
    • 含义:序列长度 × 批大小 × 每步特征维度
  2. 可选初始隐状态

    • 形状:(num_layers * num_directions, batch_size, hidden_size)
    • 默认为全零张量。如果要自定义,需提供此形状的 h0
  3. 输出
    调用 output, h_n = gru(x, h0) 返回两部分:

    • output:所有时间步的隐藏状态序列
      • 形状:
        • 默认:(seq_len, batch_size, num_directions * hidden_size)
        • batch_first=True(batch_size, seq_len, num_directions * hidden_size)
      • 含义:每个时间步的隐藏状态,可以直接接全连接或其它后续层。
    • h_n:最后一个时间步的隐藏状态
      • 形状:(num_layers * num_directions, batch_size, hidden_size)
      • 含义:每一层(及方向)在序列末尾的隐藏状态,常用于初始化下一个序列或分类任务。

2. 构造函数参数详解

nn.GRU(input_size: int,hidden_size: int,num_layers: int = 1,bias: bool = True,batch_first: bool = False,dropout: float = 0.0,bidirectional: bool = False
)
参数类型含义
input_sizeint输入特征维度,即每步输入向量的大小。
hidden_sizeint隐状态(隐藏层)维度,也决定输出特征维度(单向时即 hidden_size)。
num_layersint堆叠的 GRU 层数(深度),默认为 1。
biasbool是否使用偏置;当为 False 时,所有线性变换均无 bias。
batch_firstbool是否将批量维放到第二维(True),默认序列维在最前(False)。
dropoutfloat除最后一层外,每层输出后使用的 Dropout 比例;仅在 num_layers>1 时生效。
bidirectionalbool是否使用双向 RNN;若 True,则隐状态和输出维度翻倍。

3. 输出含义详解

  • output

    • 大小:[..., num_directions * hidden_size]
    • 如果 bidirectional=Falsenum_directions=1;否则 =2
    • output[t, b, :](或在 batch_first 模式下 output[b, t, :])表示第 t 步第 b 个样本的隐藏状态。
  • h_n

    • 大小:(num_layers * num_directions, batch_size, hidden_size)
    • 维度索引含义:
      • 维度 0:层数 × 方向(例如 3 层双向时索引 0–5,对应层1正向、层1反向、层2正向…)
      • 维度 1:批内样本索引
      • 维度 2:隐藏状态向量

4. 使用注意事项

  1. batch_first 的选择

    • 若后续直接接全连接层、BatchNorm 等,更习惯 batch_first=True;否则可用默认格式节省一次转置。
  2. 双向与输出维度

    • bidirectional=True 时,output 的最后一维和 h_nhidden_size 均会翻倍,需要相应修改下游网络维度。
  3. Dropout 的生效条件

    • 只有在 num_layers > 1 并且 dropout > 0 时,才会在各层间插入 Dropout;单层时不会应用。
  4. 初始隐状态

    • 默认为零。若在两个连续序列之间保持状态(stateful RNN),可将上一次的 h_n 作为下一次的 h0
  5. PackedSequence

    • 对变长序列,可用 torch.nn.utils.rnn.pack_padded_sequence 输入,输出再用 pad_packed_sequence 恢复,对长短不一的序列批处理很有用。
  6. 性能与稳定性

    • GRU 相比 LSTM 参数更少、速度稍快,但有时在长期依赖或梯度流问题上略不如 LSTM。
    • 可在多层 RNN 之间加 LayerNorm 或 Residual 连接,提升深度模型的收敛和稳定性。

简单示例

import torch, torch.nn as nn# 定义单层单向 GRU
gru = nn.GRU(input_size=10, hidden_size=20, num_layers=2,batch_first=True, dropout=0.1, bidirectional=True)# 输入:batch=8, seq_len=15, features=10
x = torch.randn(8, 15, 10)# 默认 h0 为零
output, h_n = gru(x)
print(output.shape)  # (8, 15, 2*20)  双向,所以 hidden_size*2
print(h_n.shape)     # (2*2, 8, 20)  num_layers=2, num_directions=2
http://www.xdnf.cn/news/1066177.html

相关文章:

  • 前端跨域解决方案(7):Node中间件
  • 容器技术入门与Docker环境部署指南
  • asp.net core Razor动态语言编程代替asp.net .aspx更高级吗?
  • 如何在 Vue 应用中嵌入 ONLYOFFICE 编辑器
  • LED-Merging: 无需训练的模型合并框架,兼顾LLM安全和性能!!
  • WebSocket长连接在小程序中的实践:消息推送与断线重连机制设计
  • 运维打铁: Windows 服务器基础运维要点解析
  • 详解HarmonyOS NEXT仓颉开发语言中的全局弹窗
  • AI编程再突破,文心快码发布行业首个多模态、多智能体协同AI IDE
  • vue3整合element-plus
  • WebSocket快速入门
  • 卓易通是什么
  • 深度学习:PyTorch卷积神经网络(CNN)之图像入门
  • 【软考高级系统架构论文】论企业集成平台的理解与应用
  • Spring Boot 使用 ElasticSearch
  • 大数据时代UI前端的变革:从静态展示到动态交互
  • ISCSI存储
  • FreeRTOS 介绍、使用方法及应用场景
  • RabbitMQ从入门到实践:消息队列核心原理与典型应用场景
  • 跨域视角下强化学习重塑大模型推理:GURU框架与多领域推理新突破
  • 【论文阅读笔记】TransparentGS:当高斯溅射学会“看穿”玻璃,如何攻克透明物体重建难题?
  • 【破局痛点,赋能未来】领码 SPARK:铸就企业业务永续进化的智慧引擎—— 深度剖析持续演进之道,引领数字化新范式
  • 针对数据仓库方向的大数据算法工程师面试经验总结
  • 计算机网络通信技术与协议(九)————交换机技术
  • 前端手写题(一)
  • leetcode51.N皇后:回溯算法与冲突检测的核心逻辑
  • Linux——6.检测磁盘空间、处理数据文件
  • 【分布式技术】Bearer Token以及MAC Token深入理解
  • Python商务数据分析——Python 入门基础知识学习笔记
  • Node.js特训专栏-实战进阶:6. MVC架构在Express中的应用