当前位置: 首页 > web >正文

【PyTorch】深度学习实践——第二章:线性模型

参考:刘二老师的《PyTorch深度学习实践》完结合集

本章实现了一个简单的线性回归模型,用于学习输入x和输出y之间的线性关系(y=w*x)。

一、代码细节

1.数据准备

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
  • 定义了训练数据,x和y之间显然是y=2x的关系,只是我们自己知道计算机不知道。

2.模型定义

def forward(x):return x * w
  • 非常简单的线性模型,只有一个权重参数w

3.损失函数

def loss(x, y):y_pred = forward(x)return (y_pred - y) * (y_pred - y)
  • 使用均方误差(MSE)作为损失函数

4.训练循环

for w in np.arange(0.0, 4.1, 0.1):print('w=',w)l_sum = 0for x_val, y_val in zip(x_data, y_data):y_pred_val = forward(x_val)loss_val = loss(x_val, y_val)l_sum += loss_valprint("MSE", l_sum / 3)w_list.append(w)mse_list.append(l_sum / 3)
  • 遍历w的可能值(0.0到4.0,步长0.1)
  • 对每个w值,计算在所有训练数据上的总损失
  • 计算并存储平均MSE

5.可视化

plt.plot(w_list, mse_list)
plt.xlabel('w')
plt.ylabel('Loss')
plt.show()

6.找最优解

min_mse = min(mse_list)
optimal_w = w_list[mse_list.index(min_mse)]
print(f"\nOptimal weight: {optimal_w:.1f} (MSE = {min_mse:.2f})")

二、完整代码

import numpy as np
import matplotlib.pyplot as plt# 训练数据
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]# 前向传播函数
def forward(x):return x * w# 损失函数
def loss(x, y):y_pred = forward(x)return (y_pred - y) * (y_pred - y)# 存储权重和对应的MSE值
w_list = []
mse_list = []# 遍历不同的权重值
for w in np.arange(0.0, 4.1, 0.1):print("w =", w)l_sum = 0  # 累计损失# 计算当前权重下的预测值和损失for x_val, y_val in zip(x_data, y_data):y_pred_val = forward(x_val)loss_val = loss(x_val, y_val)l_sum += loss_valprint("\t", x_val, y_val, y_pred_val, loss_val)# 计算并存储平均MSEprint("MSE:", l_sum / 3)w_list.append(w)mse_list.append(l_sum / 3)# 可视化结果
plt.plot(w_list, mse_list)
plt.title('Loss for different weights')
plt.xlabel('w')
plt.ylabel('Loss')
plt.show()# 找到最优权重
min_mse = min(mse_list)
optimal_w = w_list[mse_list.index(min_mse)]
print(f"\nOptimal weight: {optimal_w:.1f} (MSE = {min_mse:.2f})")
http://www.xdnf.cn/news/5875.html

相关文章:

  • 【数据结构】——栈和队列OJ
  • python酒店健身俱乐部管理系统
  • iPaaS 集成平台如何解决供应链响应速度问题?
  • Spring AI 开发本地deepseek对话快速上手笔记
  • 07_Java中的锁
  • 系统平衡与企业挑战
  • Tomcat与纯 Java Socket 实现远程通信的区别
  • 中国人工智能智能体研究报告
  • Linux的文件查找与压缩
  • 关于cleanRL Q-learning
  • Java集合框架详解与使用场景示例
  • MySQL 5.7在CentOS 7.9系统下的安装(下)——给MySQL设置密码
  • Android NDK 高版本交叉编译:为何无需配置 FLAGS 和 INCLUDES
  • org.slf4j.MDC介绍-笔记
  • 集成DHTMLX 预订排期调度组件实践指南:如何实现后端数据格式转换
  • web 自动化之 yaml 数据/日志/截图
  • Boundary Attention Constrained Zero-Shot Layout-To-Image Generation
  • 配置hadoop集群-启动集群
  • apache2的默认html修改
  • 【前端三剑客】Ajax技术实现前端开发
  • ETL 数据集成平台与数据仓库的关系及 ETL 工具推荐
  • 前端流行框架Vue3教程:15. 组件事件
  • kafka----初步安装与配置
  • PROFIBUS DP转ModbusTCP网关模块于污水处理系统的成功应用案例解读​
  • C++中的各式类型转换
  • 序列化和反序列化(hadoop)
  • RabbitMQ发布订阅模式深度解析与实践指南
  • 解决 CentOS 7 镜像源无法访问的问题
  • 爬虫请求频率应控制在多少合适?
  • cocos creator 3.8 下的 2D 改动