当前位置: 首页 > news >正文

LSTM模型进行天气预测Pytorch版本

LSTM模型进行天气预测Pytorch版本


1-参考网址

  • Anaconda结合Pytorch使用参考

2-动手实践

1-创建Pytorch环境

# 1-Anacanda使用Python3.9
conda create -n LSTM3.9 python=3.9
conda activate LSTM3.9# 2-使用cudatoolkit=11.8
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
conda install pytorch torchvision torchaudio cudatoolkit=11.8 -c pytorch# 3-安装所需依赖包
pip install matplotlib # 4-查看GPU使用命令
nvidia-smi
watch -n 1 nvidia-smi

2-执行LSTM脚本

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader# 设置随机种子以确保结果可重复
torch.manual_seed(42)# 示例:加载和预处理 Jena Climate 数据集(假设数据已加载到一个 NumPy 数组中)
# 在实际应用中,你需要根据实际情况加载和处理数据
data = np.random.rand(1000, 14)  # 假设有 1000 个时间点,14 个特征# 数据归一化
data_mean = data.mean(axis=0)
data_std = data.std(axis=0)
data = (data - data_mean) / data_std# 定义时间序列数据集
class TimeSeriesDataset(Dataset):def __init__(self, data, seq_length):self.data = dataself.seq_length = seq_lengthdef __len__(self):return len(self.data) - self.seq_lengthdef __getitem__(self, idx):x = self.data[idx:idx + self.seq_length, :-1]  # 输入特征:除最后一个特征外的所有特征y = self.data[idx + self.seq_length, -1]  # 目标:最后一个特征作为预测目标return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)seq_length = 24  # 序列长度,例如过去 24 个小时的数据
dataset = TimeSeriesDataset(data, seq_length)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 定义 LSTM 模型
class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(LSTMModel, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.lstm(x)out = self.fc(out[:, -1, :])  # 使用最后一个时间步的输出进行预测return outinput_size = data.shape[1] - 1  # 输入特征数
hidden_size = 64  # 隐藏层单元数
output_size = 1  # 输出特征数(预测目标)
model = LSTMModel(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):model.train()train_loss = 0for x_batch, y_batch in train_loader:optimizer.zero_grad()outputs = model(x_batch)loss = criterion(outputs, y_batch.unsqueeze(1))loss.backward()optimizer.step()train_loss += loss.item() * x_batch.size(0)train_loss = train_loss / len(train_loader.dataset)model.eval()test_loss = 0with torch.no_grad():for x_batch, y_batch in test_loader:outputs = model(x_batch)loss = criterion(outputs, y_batch.unsqueeze(1))test_loss += loss.item() * x_batch.size(0)test_loss = test_loss / len(test_loader.dataset)print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')# 预测
model.eval()
with torch.no_grad():test_inputs = data[-seq_length:, :-1].astype(np.float32)test_inputs = torch.tensor(test_inputs).unsqueeze(0)predicted = model(test_inputs)predicted = predicted.item() * data_std[-1] + data_mean[-1]  # 反归一化# 打印预测结果
print(f'Predicted value: {predicted:.4f}')# 可视化结果(示例)
plt.figure(figsize=(10, 6))
plt.plot(data[-100:, -1] * data_std[-1] + data_mean[-1], label='True Values')
plt.plot(len(data) - 1, predicted, 'ro', label='Predicted Value')
plt.xlabel('Time')
plt.ylabel('Temperature')
plt.title('Temperature Prediction')
plt.legend()
plt.show()
http://www.xdnf.cn/news/653059.html

相关文章:

  • 索尼PS4模拟器shadPS4最新版 v0.9.0 提升PS4模拟器的兼容性
  • 【Linux】基础IO
  • 提问:鲜羊奶是解决育儿Bug的补丁吗?
  • mysql存储过程(if、case、begin...end、while、repeat、loop、cursor游标)的使用
  • 从0开始学习R语言--Day10--时间序列分析数据
  • 手机平板等设备租赁行业MDM方案解析
  • OpenCV计算机视觉实战(8)——图像滤波详解
  • vite常见面试问题
  • 新书速览|ASP.NET MVC高效构建Web应用
  • 精益数据分析(87/126):市场-产品契合度重构——现有产品寻找新市场的实战指南
  • springboot 微服务下部署AI服务
  • 2025年5月26日工作总结
  • 论文阅读:2024 arxiv Prompt Injection attack against LLM-integrated Applications
  • c#基础07(调试与异常捕捉)
  • [Git] 如何将已经执行的修改操作撤销
  • 力扣热题100之LRU缓存机制
  • 力扣 394.字符串解码
  • mysql-tpcc-mysql压测工具使用
  • 【Java工程师面试全攻略】Day2:Java集合框架面试全解析
  • 榕壹云物品回收系统实战案例:基于ThinkPHP+MySQL+UniApp的二手物品回收小程序开发与优化
  • 【运维】OpenWrt DNS重绑定保护配置指南:解决内网域名解析问题
  • 项目亮点 封装request请求模块
  • 2025年- H51-Lc159 --199. 二叉树的右视图(层序遍历,队列)--Java版
  • AI学习笔记二十八:使用ESP32 CAM和YOLOV5实现目标检测
  • 使用docker容器部署Elasticsearch和Kibana
  • Rk3568 Andorid 11 ,根据prop属性的值控制是否禁止u盘连接
  • 倚光科技在二元衍射面加工技术上的革新:引领光学元件制造新方向​
  • 拓扑光子混沌算法
  • 开源第三方库发展现状
  • 《软件工程》第 9 章 - 软件详细设计