当前位置: 首页 > news >正文

lesson04-简单回归案例实战(理论+代码)

理解线性回归及梯度下降优化

引言

在机器学习的基础课程中,我们经常遇到的一个重要概念就是线性回归。今天,我们将深入探讨这一主题,并通过具体的例子来了解如何利用梯度下降方法对模型进行优化。

线性回归简介

线性回归是一种统计方法,用于确定两个变量之间的关系。简单来说,如果我们有一个自变量 XX 和因变量 YY,线性回归可以帮助我们找到一条最佳拟合直线,这条直线可以用公式 Y=WX+bY=WX+b 来表示,其中 WW 是权重,bb 是偏置。

损失函数

为了评估模型的好坏,我们需要定义一个损失函数。对于线性回归而言,通常使用平方误差作为损失函数,即 loss=(WX+b−y)2loss=(WX+b−y)2。

梯度下降优化

梯度下降是一种迭代优化算法,用来最小化损失函数。每次迭代过程中,我们会更新参数 WW 的值,具体更新规则为 w′=w−lr×∇loss/∇ww′=w−lr×∇loss/∇w,这里的 lrlr 表示学习率,控制着每一步调整的幅度。

迭代优化

通过不断调整 WW 和 bb 的值,使得损失函数逐渐减小,直到达到局部或全局最小值点。这个过程需要多次迭代计算,直至满足预设的停止条件为止。

下一课时预告

接下来的一课时,我们将一起探索著名的MNIST手写数字识别任务,敬请期待!

结语

感谢大家的关注与支持,希望今天的分享能够加深您对线性回归以及梯度下降算法的理解。让我们共同期待下一节课的到来吧!

实战代码

import numpy as np# y = wx + b
def compute_error_for_line_given_points(b, w, points):totalError = 0for i in range(0, len(points)):x = points[i, 0]y = points[i, 1]totalError += (y - (w * x + b)) ** 2return totalError / float(len(points))def step_gradient(b_current, w_current, points, learningRate):b_gradient = 0w_gradient = 0N = float(len(points))for i in range(0, len(points)):x = points[i, 0]y = points[i, 1]b_gradient += -(2/N) * (y - ((w_current * x) + b_current))w_gradient += -(2/N) * x * (y - ((w_current * x) + b_current))new_b = b_current - (learningRate * b_gradient)new_m = w_current - (learningRate * w_gradient)return [new_b, new_m]def gradient_descent_runner(points, starting_b, starting_m, learning_rate, num_iterations):b = starting_bm = starting_mfor i in range(num_iterations):b, m = step_gradient(b, m, np.array(points), learning_rate)return [b, m]def run():points = np.genfromtxt("data.csv", delimiter=",")learning_rate = 0.0001initial_b = 0 # initial y-intercept guessinitial_m = 0 # initial slope guessnum_iterations = 1000print("Starting gradient descent at b = {0}, m = {1}, error = {2}".format(initial_b, initial_m,compute_error_for_line_given_points(initial_b, initial_m, points)))print("Running...")[b, m] = gradient_descent_runner(points, initial_b, initial_m, learning_rate, num_iterations)print("After {0} iterations b = {1}, m = {2}, error = {3}".format(num_iterations, b, m,compute_error_for_line_given_points(b, m, points)))if __name__ == '__main__':run()

🧠 一、代码概述

这段代码的主要目的是:

  • 使用一个简单的线性模型:y = mx + b
  • 给定一个二维数据集 data.csv,其中每行有两个值:x 和 y
  • 使用梯度下降算法迭代地更新 m 和 b,使得预测的 y 尽可能接近真实值
  • 最终输出经过多次迭代后的最优 m 和 b 值,并计算最终误差

📁 二、文件结构说明

  1. 导入库

    import numpy as np
    • 引入 NumPy 库,用于高效的数值计算和数组操作。
  2. 函数定义

    • compute_error_for_line_given_points(b, w, points)
      计算当前直线的平均平方误差(MSE)
    • step_gradient(b_current, w_current, points, learningRate)
      执行一次梯度下降步骤,返回更新后的 b 和 m
    • gradient_descent_runner(points, starting_b, starting_m, learning_rate, num_iterations)
      迭代运行梯度下降过程
    • run()
      主函数,加载数据、调用训练函数、打印结果
  3. 主程序入口

if __name__ == '__main__':run()

 

📌 三、函数详解

1. compute_error_for_line_given_points(b, w, points)

功能:

计算当前模型参数下的均方误差(Mean Squared Error, MSE)

公式:

MSE=1N∑i=1N(yi−(wxi+b))2MSE=N1​i=1∑N​(yi​−(wxi​+b))2

参数:
  • b: 当前截距(bias / y-intercept)
  • w: 当前斜率(weight / slope)
  • points: 数据点集合,是一个二维数组,每行表示一个 (x, y) 点
返回值:
  • 平均误差值(越小越好)

2. step_gradient(b_current, w_current, points, learningRate)

功能:

执行一次梯度下降步骤,根据当前的 bm 更新它们的值。

核心公式(梯度下降更新规则):

b′=b−η⋅∂MSE∂bb′=b−η⋅∂b∂MSE​

m′=m−η⋅∂MSE∂mm′=m−η⋅∂m∂MSE​

其中:

  • ηη 是学习率(learning rate)
  • 梯度是通过对损失函数分别对 b 和 m 求导得到的
导数推导:

∂MSE∂b=2N∑i=1N(yi−(mxi+b))⋅(−1)∂b∂MSE​=N2​i=1∑N​(yi​−(mxi​+b))⋅(−1)

∂MSE∂m=2N∑i=1N(yi−(mxi+b))⋅(−xi)∂m∂MSE​=N2​i=1∑N​(yi​−(mxi​+b))⋅(−xi​)

你在代码中实现了这两个梯度的累加。

返回值:
  • [new_b, new_m]:更新后的模型参数

3. gradient_descent_runner(...)

功能:

循环执行 step_gradient 多次,完成完整的梯度下降过程。

参数:
  • points: 数据集
  • starting_bstarting_m: 初始参数
  • learning_rate: 学习率
  • num_iterations: 迭代次数
输出:
  • 最终的 b 和 m

4. run()

功能:
  • 加载 CSV 数据文件
  • 设置初始参数
  • 调用梯度下降函数进行训练
  • 打印训练前后误差和参数变化

输出结果展示

这表明经过 1000 次迭代后,模型已经基本收敛。

http://www.xdnf.cn/news/705025.html

相关文章:

  • C#·常用快捷键
  • 论文笔记:DreamDiffusion
  • DeepSeek进阶教程:实时数据分析与自动化决策系统
  • Web攻防-SQL注入增删改查盲注延时布尔报错有无回显错误处理
  • 【论文阅读】《PEACE: Empowering Geologic Map Holistic Understanding with MLLMs》
  • 模块化集成建筑(MiC建筑):颠覆传统的未来建造革命
  • 基于本地化大模型的智能编程助手全栈实践:从模型部署到IDE深度集成学习心得
  • 51c视觉~3D~合集3
  • 【SpringBoot】零基础全面解析SpringBoot配置文件
  • sass基础语法
  • Vite打包优化实践:从分包到性能提升
  • 自学嵌入式 day 25 - 系统编程 标准io 缓冲区 文件io
  • git+svn+sourcetree客户端下载和安装教程
  • DeepSeek R1开源模型的技术突破与AI产业格局的重构
  • Nacos | 三种方式的配置中心,整合Springboot3.x + yaml文件完成 0错误 自动刷新(亲测无误)
  • 单片机——keil5
  • WSL 开发环境搭建指南:Java 11 + 中间件全家桶安装实战
  • STM32开发全解析:从环境搭建到项目实战的技术文档撰写指南
  • 代谢组数据分析(二十五):代谢组与蛋白质组数据分析的异同
  • day13 leetcode-hot100-23(链表2)
  • xLSTM技术介绍
  • 技术文档写作大纲
  • JWT 不对外,Session ID 对外:构建安全可控的微服务认证架构
  • jenkins集成gitlab实现自动构建
  • 力扣-最长回文子串
  • 【课堂笔记】EM算法
  • stm32cube ide如何将工具链替换成arm-none-eabi-gcc
  • 零基础设计模式——结构型模式 - 代理模式
  • 安装flash-attention失败的终极解决方案(WINDOWS环境)
  • 按照状态实现自定义排序的方法