当前位置: 首页 > news >正文

深度学习3.3 线性回归的简洁实现

步骤操作作用
前向计算net(X)计算预测值 y_hat = Xw + b
损失计算loss(y_hat, y)量化预测误差,驱动参数更新
反向传播l.backward()计算参数梯度
参数更新trainer.step()根据梯度调整参数,逼近最优解
梯度清零trainer.zero_grad()防止梯度累积(必须放在 backward() 之后,step() 之前)
训练监控loss(net(features), labels)评估模型整体性能,避免过拟合或欠拟合

3.3.1 生成数据集

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2ltrue_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

3.3.2 读取数据集

def load_array(data_arrays, batch_size, is_train=True):dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)batch_size = 10
data_iter = load_array((features, labels), batch_size)
next(iter(data_iter))

数据加载器 (DataLoader)
‌数据集封装‌:TensorDataset 将特征和标签包装为 PyTorch 数据集。‌
批量加载‌:DataLoader 按 batch_size=10 加载数据,训练时打乱数据 (shuffle=True)。

3.3.3 定义模型

from torch import nn
net = nn.Sequential(nn.Linear(2, 1))

3.3.4 初始化模型参数

net[0].weight.data.normal_(0, 0.01) # 权重初始化
net[0].bias.data.fill_(0) # 偏置初始化

3.3.5 定义损失函数

loss = nn.MSELoss() # 均方误差损失

3.3.6 定义优化算法

trainer = torch.optim.SGD(net.parameters(), lr=0.03)  # 随机梯度下降

3.3.7 训练

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)     # 前向计算损失trainer.zero_grad()      # 清零梯度l.backward()            # 反向传播trainer.step()          # 参数更新# 计算并输出整个训练集的损失l = loss(net(features), labels)print(f'epoch{epoch + 1}, loss{l:f}')

epoch1, loss0.000205
epoch2, loss0.000094
epoch3, loss0.000094

# 输出参数估计误差
w = net[0].weight.data
print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
b = net[0].bias.data
print(f'b的估计误差:{true_b - b}')

w的估计误差:tensor([5.9402e-04, 4.6015e-05])
b的估计误差:tensor([0.0001])

http://www.xdnf.cn/news/73261.html

相关文章:

  • 代码实战保险花销预测
  • AXOP38802: 400nA 超低功耗通用双通道运算放大器
  • JumpServer多用户VNC桌面配置指南:实现多端口远程访问
  • KDD2024 | BCGNN解读
  • 读文献先读图:韦恩图怎么看?
  • 第 2 篇:初探时间序列 - 可视化与基本概念
  • 【源码】【Java并发】【AQS】从ReentrantLock、Semaphore、CutDownLunch、CyclicBarrier看AQS源码
  • JFrog Artifactory 制品库命令行操作指南
  • Java虚拟机之GC收集器对比解读
  • 多线程初阶(1.2)
  • 爬虫学习——Item封装数据与Item Pipeline处理数据
  • 垂直机械硬盘与叠瓦机械硬盘的区别及数据恢复难度
  • Kubeflow 快速入门实战(三) - Qwen2.5 微调全流程
  • 影刀RPA - 简单易用且功能强大的自动化工具
  • mybatis plus 多条件查询注意查询条件顺序
  • 2025年渗透测试面试题总结-拷打题库09(题目+回答)
  • LangChain4j-第二篇 |实现声明式 AI 服务 AiService:简化 AI 集成新范式
  • Linux Wlan-四次握手(eapol)框架流程
  • Transformer到MoE:聚客AI大模型核心技术栈完全指南
  • 第一篇:从哲学到管理——实践论与矛盾论如何重塑企业思维
  • c++基础·列表初始化
  • Linux系统-cat命令/more命令/less命令
  • Kubernetes集群超配节点容量
  • MCP的发展历程
  • 批量创建同名文件夹并整理文件至对应文件夹
  • Day5-UFS总结
  • 基于vue框架的电脑配件网上商城18xsv(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。
  • aws文件存储服务——S3介绍使用代码集成
  • 第5章:MCP框架详解
  • Python 之 __file__ 变量导致打包 exe 后路径输出不一致的问题