当前位置: 首页 > news >正文

动手学深度学习——线性回归 + 基础优化算法

# matplotlib inline
import random
import torch
from d2l import torch as d2ldef synthetic_data(w, b, num_examples):  """生成 y = Xw + b + 噪声。"""X = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(X, w) + by += torch.normal(0, 0.01, y.shape)return X, y.reshape((-1, 1))true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)print('features:', features[0], '\nlabel:', labels[0])
d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(),labels.detach().numpy(), 1);

X = torch.normal(0, 1, (num_examples, len(w)))

创建一个形状为 (num_examples, len(w)) 的张量 X,其中每个元素服从均值为 0、标准差为 1 的正态分布。


y = torch.matmul(X, w) + b
计算目标值 y,即线性回归模型的预测结果。

数学形式y_i = \sum_{j=1}^d X_{ij} w_j + b


y += torch.normal(0, 0.01, y.shape)
在真实值上添加少量噪声,使数据更接近真实场景。


return X, y.reshape((-1, 1)
返回特征矩阵 X 和目标值 y,其中 y 被调整为二维列向量,形状为 (num_examples, 1)

解释

  • reshape((-1, 1)):将 y 重塑为一列。
    • -1 表示自动推导维度大小,这里会推导为 num_examples
    • 例如,如果 num_examples=5y 的最终形状就是 (5, 1)

true_w = torch.tensor([2, -3.4])
true_b = 4.2

features, labels = synthetic_data(true_w, true_b, 1000)

定义线性模型的真实权重和偏置项。

调用自定义函数 synthetic_data 生成模拟训练数据。


d2l.set_figsize()
设置画布的大小。

d2l.plt.scatter(features[:, (1)].detach().numpy(),
labels.detach().numpy(), 1)
绘制第二个特征(x₂)与标签(y)之间的关系图,用散点图展示数据分布。

逐步解析

(1) features[:, (1)]

  • 取所有样本的第 2 个特征列(索引为 1)。
  • 结果是形状为 (1000,) 的张量。

(2) .detach().numpy()

  • detach():从计算图中分离张量,阻止 PyTorch 跟踪梯度。
  • .numpy():将张量转换为 NumPy 数组,Matplotlib 需要 NumPy 格式才能绘图。

(3) labels.detach().numpy()

  • 标签数据也需要转换为 NumPy 格式。
  • 原始 labels 形状为 (1000, 1),Matplotlib 会自动展平。

(4) scatter(x, y, 1)

  • scatter 用于绘制散点图。
  • 参数:
    1. x:横坐标,这里是第 2 个特征。
    2. y:纵坐标,这里是标签值。
    3. 1:每个点的大小,数值越小点越小。
def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))# 这些样本是随机读取的,没有特定的顺序random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]

def data_iter(batch_size, features, labels):
定义了一个数据迭代器生成函数,它将数据集分批(mini-batch)返回,用于模型训练。

在深度学习中,通常不会一次性将全部数据送入模型,而是分成若干批次(mini-batch)逐步训练,提高效率并利用 GPU 并行计算。


num_examples = len(features)

len(features) 返回特征矩阵中的样本数量。


indices = list(range(num_examples))
range(num_examples) 生成 0num_examples-1 的整数序列。

再用 list() 转换为列表。


random.shuffle(indices)
将索引列表 随机打乱,保证每次迭代时小批量数据顺序不同。

random.shuffle()原地操作,不会返回新列表。

原因
如果数据始终按原顺序输入模型,模型可能学到数据顺序模式,导致泛化能力差。
随机打乱样本顺序是**随机梯度下降法(SGD)**的必要步骤之一。


for i in range(0, num_examples, batch_size):
batch_indices = torch.tensor(
indices[i: min(i + batch_size, num_examples)])
yield features[batch_indices], labels[batch_indices]

batch_size 为步长,从 0 开始遍历整个数据集。

  • i + batch_size

    • 计算当前批次的结束位置。
  • min(i + batch_size, num_examples)

    • i + batch_size 与总样本数的较小值,防止越界。
  • indices[i : min(...)]

    • 切片操作,获取当前批次对应的样本索引。
  • torch.tensor()

    • 将 Python 列表转换为 PyTorch 张量,便于后续张量索引。
    • yield:将函数变为生成器(generator),每次调用只返回一个小批量数据,而不是一次性返回所有数据。
    • features[batch_indices]:根据批次索引提取当前批次的特征。
    • labels[batch_indices]:提取对应的标签。

PyTorch 提供了现成的数据加载工具 DataLoader,作用和此函数类似:

from torch.utils.data import DataLoader, TensorDataset

dataset = TensorDataset(features, labels)
data_iter = DataLoader(dataset, batch_size=3, shuffle=True)

dataset = TensorDataset(features, labels)
把多个第一维长度相同的张量(这里是 featureslabels)打包成一个可索引的数据集对象。

data_iter = DataLoader(dataset, batch_size=3, shuffle=True)
datasetmini-batch 迭代产出,并根据需要随机打乱样本。

  • dataset:数据集对象(如 TensorDataset 或自定义 Dataset)。
  • batch_size=3:每个 batch 3 个样本。len(data_iter)ceil(N/3)(若 drop_last=True 则是 floor(N/3))。
  • shuffle=True:每个 epoch 前随机打乱数据(底层是 RandomSampler)。

http://www.xdnf.cn/news/1451017.html

相关文章:

  • 服务器异常负载排查手册 · 隐蔽进程篇
  • Android AI客户端开发(语音与大模型部署)面试题大全
  • Tomcat 服务器全方位指南:安装、配置、部署与实战优化
  • Sentinel 与 Feign 整合详解:实现服务调用的流量防护
  • Clang 编译器:下载安装指南与实用快捷键全解析
  • C++类和对象(上):从设计图到摩天大楼的构建艺术
  • 蔚来汽车前制动器设计及热性能分析cad+三维图+设计说明书
  • MySQL SM4 UDF 安装与使用
  • 【计算机网络(自顶向下方法 第7版)】第一章 计算机网络概述
  • 《D (R,O) Grasp:跨机械手灵巧抓取的机器人 - 物体交互统一表示》论文解读
  • 实战演练(二):结合路由与状态管理,构建一个小型博客前台
  • Java基础知识点汇总(五)
  • 修订版!Uniapp从Vue3编译到安卓环境踩坑记录
  • 新手向:AI IDE+AI 辅助编程
  • 开源视频剪辑工具推荐
  • 经典资金安全案例分享:支付系统开发的血泪教训
  • Hadoop(七)
  • 数说故事 | 2025年运动相机数据报告,深挖主流品牌运营策略及行业趋势​
  • HarmonyOS路由导航方案演进:HMRouter基于Navigation封装,使用更方便
  • 【软考架构】嵌入式系统及软件
  • Shadcn UI – 开发者首选的高性能、高定制化 React 组件库
  • Flutter之riverpod状态管理详解
  • 第1章 Jenkins概述与架构
  • ⸢ 肆 ⸥ ⤳ 默认安全:安全建设方案 ➭ b.安全资产建设
  • HTTP性能优化
  • Rust 文件操作终极实战指南:从基础读写到进阶锁控,一文搞定所有 IO 场景
  • 设计模式3 创建模式之Singleton模式
  • 大数据工程师认证推荐项目:基于Spark+Django的学生创业分析可视化系统技术价值解析
  • 基于 EasyExcel + 线程池 解决 POI 导出时的内存溢出与超时问题
  • 如何简单理解状态机、流程图和时序图