当前位置: 首页 > ai >正文

pytorch LSTM 结构详解

最近项目用到了LSTM ,但是对LSTM 的输入输出不是很理解,对此,我详细查找了lstm 的资料

import torch.nn as nnclass LSTMModel(nn.Module):def __init__(self, input_size=1, hidden_size=50, num_layers=2):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, 1)  # 1 表示预测输出变量为1def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)out, _ = self.lstm(x, (h0, c0))out = self.fc(out[:, -1, :])return out # out 形状为(batch_size,1)
  • input_size=1:输入特征的维度,适用于单变量时间序列。

  • hidden_size=50:LSTM 隐藏层的维度,决定了模型的记忆能力。

  • num_layers=2:堆叠的 LSTM 层数,增加层数可以提升模型的表达能力。

  • batch_first=True:指定输入和输出的张量形状为 (batch_size, seq_len, input_size)

  • self.fc:一个全连接层,将 LSTM 的输出映射到最终的预测值。

  • batch_size 表示批次、seq_len 表示窗口大小、input_size 表示输入尺寸,单变量输入为1 ,多变量要基于个数变化

  • 初始化隐藏状态和细胞状态

    • h0c0 分别表示初始的隐藏状态和细胞状态,形状为 (num_layers, batch_size, hidden_size)

    • 在每次前向传播时,初始化为零张量。

  • LSTM 层处理

    • self.lstm(x, (h0, c0)):将输入 x 和初始状态传入 LSTM 层,输出 out 和新的状态。

    • out 的形状为 (batch_size, seq_len, hidden_size),包含了每个时间步的输出。

  • 全连接层映射

    • out[:, -1, :]:提取序列中最后一个时间步的输出。

    • self.fc(...):将提取的输出通过全连接层,得到最终的预测结果。

http://www.xdnf.cn/news/8319.html

相关文章:

  • PR-2014《The MinMax K-Means clustering algorithm》
  • HTML5的新语义化标签
  • 腾讯地图WebServiceAPI提供基于HTTPS/HTTP协议的数据接口
  • JAVA:Kafka 存储接口详解与实践样例
  • 练习小项目7:天气状态切换器②
  • 机器学习中的维度、过拟合、降维
  • 从制造到智造:猎板PCB的技术实践与产业价值重构
  • 攻防世界 - MISCall
  • JMeter-SSE响应数据自动化
  • SVN被锁定解决svn is already locked
  • 青少年编程与数学 02-020 C#程序设计基础 02课题、开发环境
  • FME入门系列教程7-基于FME的ArcGIS空间数据互操作技术研究与实践
  • 线程封装与互斥
  • 使用OpenSSL生成根证书并自签署证书
  • OpenCV入门
  • React组件(二):常见属性和函数
  • React从基础入门到高级实战:React 基础入门 - 简介与开发环境搭建
  • React从基础入门到高级实战:React 基础入门 - 列表渲染与条件渲染
  • 初始Flask框架
  • C++之STL--string
  • 移远三款主流5G模块RM500U,RM520N,RG200U比较
  • 电脑风扇转速不正常的原因
  • Flask框架
  • DeepSeek 赋能智能电网:从技术革新到全场景应用实践
  • Android 直接通过 app_process 启动的应用如何使用 Context
  • 3362. 零数组变换 III
  • PPP 拨号失败:ATD*99***1# ... failed
  • 610Hz!无惧环境光新薄膜!ROG全新电竞显示器亮相2025台北电脑展
  • 第七部分:第二节 - 在 Node.js 中连接和操作 MySQL:厨房与仓库的沟通渠道
  • STM32:深度解析RS-485总线与SP3485芯片