当前位置: 首页 > ds >正文

基于pytorch.nn模块实现线性模型

课程:b站up 跟李沐学AI 的系列课程 《动手学深度学习》

笔记:本笔记基于以上课程所写,有疑问的地方,可评论区留言,或自行观看b站视频

完整代码

import torch
from torch import nn, optim
from torch.nn import Sequential, Linear
from torch.utils import data# ========== 数据生成与模型定义部分 ==========# 生成真实参数(用于模拟数据生成)
true_w = torch.tensor([2, -3.4])  # 真实权重向量(2维输入)
true_b = 4.2  # 真实偏置项# 定义合成数据生成函数(模拟真实数据分布)
def synthetic_data(w, b, num_examples):"""生成线性关系数据集 y = Xw + b + 高斯噪声参数:w -- 真实权重向量 (2,)b -- 真实偏置 (标量)num_examples -- 样本数量返回:X -- 特征矩阵 (num_examples, 2)y -- 标签向量 (num_examples, 1)"""# 生成特征矩阵:从标准正态分布采样(均值0,方差1)X = torch.normal(0, 1, (num_examples, len(w)))# 计算线性关系结果y = torch.matmul(X, w) + b# 添加高斯噪声(模拟测量误差,标准差0.01)y += torch.normal(0, 0.01, y.shape)# 调整标签形状为列向量return X, y.reshape((-1, 1))# 生成1000个样本的合成数据集
features, labels = synthetic_data(true_w, true_b, 1000)
print("features shape : ", features.shape)  # (1000, 2)
print("labels shape : ", labels.shape)  # (1000, 1)
print('features:', features[:3], '\nlabel:', labels[:3])# 创建数据加载器(批量处理)
batch_size = 10
# 将特征和标签组合为数据集,打乱顺序后分批加载
data_iter = data.DataLoader(dataset=data.TensorDataset(features, labels),  # 数据集包装batch_size=batch_size,  # 每批10个样本shuffle=True  # 训练前打乱顺序
)# ========== 模型定义与初始化 ==========# 定义线性回归模型(使用Sequential容器)
net = Sequential(Linear(2, 1))  # 输入2维,输出1维的线性层# 参数初始化(重要!影响模型收敛)
# 权重初始化:均值0,标准差0.01的正态分布
net[0].weight.data.normal_(0, 0.01)
# 偏置初始化:全零
net[0].bias.data.fill_(0)# ========== 训练配置 ==========# 超参数设置
lr = 0.03  # 学习率(控制参数更新步长)
num_epochs = 3  # 训练轮数(遍历整个数据集的次数)# 损失函数:均方误差(MSE)
loss = nn.MSELoss()# 优化器:随机梯度下降(SGD)
optimizer = optim.SGD(params=net.parameters(),  # 需优化的参数lr=lr  # 学习率
)# ========== 训练循环 ==========for epoch in range(num_epochs):# 遍历每个批次的数据for X, y in data_iter:# 前向传播:计算预测值y_hat = net(X)# 计算当前批次的损失l = loss(y_hat, y)# 反向传播准备optimizer.zero_grad()  # 清空梯度(防止梯度累积)l.backward()  # 自动计算梯度# 参数更新optimizer.step()  # 根据梯度更新参数# 每轮结束后计算整个数据集的损失(验证效果)with torch.no_grad():  # 禁用梯度计算(节省内存)# 计算所有样本的预测值y_hat_all = net(features)# 计算平均损失total_loss = loss(y_hat_all, labels)print(f'epoch {epoch + 1}, loss {total_loss.item():.4f}')# ========== 结果验证 ==========# 提取训练得到的参数
train_w = net[0].weight.data  # 权重参数
train_b = net[0].bias.data  # 偏置参数# 打印训练结果与真实参数对比
print("训练得到的权重:", train_w, "\n真实权重:", true_w)
print("训练得到的偏置:", train_b, "\n真实偏置:", true_b)

输出结果

features shape :  torch.Size([1000, 2])
labels shape :  torch.Size([1000, 1])
features: tensor([[-0.7123,  0.7392],[ 0.9638, -1.7860],[ 1.6612, -0.6634]]) 
label: tensor([[ 0.2535],[12.2078],[ 9.7877]])
epoch 1, loss 0.0002
epoch 2, loss 0.0001
epoch 3, loss 0.0001
训练得到的权重: tensor([[ 1.9998, -3.4003]]) 
真实权重: tensor([ 2.0000, -3.4000])
训练得到的偏置: tensor([4.2002]) 
真实偏置: 4.2Process finished with exit code 0

http://www.xdnf.cn/news/15122.html

相关文章:

  • c语言中的数组II
  • OpenCV图片操作100例:从入门到精通指南(4)
  • (C++)任务管理系统(正式版)(迭代器)(list列表基础教程)(STL基础知识)
  • Android-重学kotlin(协程源码第一阶段)新学习总结
  • STM32-看门狗
  • (5)机器学习小白入门 YOLOv:数据需求与图像不足应对策略
  • qml加载html以及交互
  • Qt去噪面板搭建
  • Flutter基础(前端教程⑦-Http和卡片)
  • 【EGSR2025】材质+扩散模型+神经网络相关论文整理随笔(二)
  • 图片的拍摄创建日期怎么改?保护好图片信息安全的好方法
  • 李宏毅NLP-9-语音转换
  • 全球发展币GDEV:从中国出发,走向全球的数字发展合作蓝图
  • 本地Qwen中医问诊小程序系统开发
  • kubernetes存储入门
  • Flutter编译安卓应用时遇到的compileDebugJavaWithJavac和compileDebugKotlin版本不匹配的问题
  • 【c++学习记录】状态模式,实现一个登陆功能
  • huggingface笔记:文本生成Text generation
  • WinUI3入门16:Order自定义排序
  • WouoUI-Page移植
  • 一个vue项目的基本构成
  • 实时音视频通过UDP打洞实现P2P优先通信
  • 方法论汇总
  • ACE-Step:AI音乐生成基础模型
  • 【python】 time_str = time_str.strip() 与 time_str = str(time_str).strip() 的区别
  • Mac安装Docker(使用orbstack代替)
  • 云原生详解:构建现代化应用的未来
  • 【Node.js】文本与 pdf 的相互转换
  • eslint扁平化配置
  • 牛市来临之际,如何用期权抢占反弹先机?