当前位置: 首页 > news >正文

动学学深度学习03-线性神经网络

动学学深度学习pytorch

参考地址:https://zh.d2l.ai/

文章目录

  • 动学学深度学习pytorch
    • 1-第03章-线性神经网络
      • 1. 线性回归
        • 1.1 什么是线性回归?
        • 1.2 如何表示线性回归的预测公式?
      • 2. 损失函数
        • 2.1 什么是损失函数?
        • 2.2 如何表示整个训练集的平均损失?
      • 3. 解析解
        • 3.1 什么是解析解?
        • 3.2 解析解的局限性是什么?
      • 4. 小批量随机梯度下降(SGD)
        • 4.1 SGD一次更新步骤的伪代码?
        • 4.2 超参数有哪些?
      • 5. 从线性回归到单层神经网络
        • 5.1 如何把线性回归视为神经网络?
        • 5.2 全连接层的定义?
      • 6. Softmax 回归(多类分类)
        • 6.1 什么是 Softmax 函数?
        • 6.2 交叉熵损失如何定义?
      • 7. 从零实现 vs 简洁实现(PyTorch)
        • 7.1 从零实现的关键步骤?
        • 7.2 简洁实现用到的 PyTorch 高级 API?
      • 8. 矢量化加速
        • 8.1 为什么矢量化重要?
        • 8.2 示例:向量相加
      • 9. 信息论视角的交叉熵
        • 9.1 熵 H(P) 的公式?
        • 9.2 交叉熵 H(P,Q) 的公式?
      • 10. Fashion-MNIST 数据集
        • 10.1 数据集规模?
        • 10.2 如何用 PyTorch 加载?


1-第03章-线性神经网络

1. 线性回归

1.1 什么是线性回归?

线性回归(linear regression)是一种经典的统计学习方法,用于建立自变量 x 和因变量 y 之间的线性关系模型,即假设 y 可以表示为 x 中元素的加权和,再加上一个偏置项 b。

1.2 如何表示线性回归的预测公式?

对于单样本特征向量 x∈ℝᵈ,预测值 ŷ 的向量化公式为
y^=w⊤x+b\hat y = w^\top x + by^=wx+b
其中 w 为权重向量,b 为偏置标量。


2. 损失函数

2.1 什么是损失函数?

损失函数(loss function)用来量化模型预测值与真实值之间的差距。在线性回归中,最常用的损失函数是平方误差(均方误差):

l(i)(w,b)=12(y^(i)−y(i))2l^{(i)}(w,b)=\frac12\bigl(\hat y^{(i)}-y^{(i)}\bigr)^2l(i)(w,b)=21(y^(i)y(i))2

2.2 如何表示整个训练集的平均损失?

L(w,b)=1n∑i=1nl(i)(w,b)=12n∑i=1n(w⊤x(i)+b−y(i))2L(w,b)=\frac1n\sum_{i=1}^n l^{(i)}(w,b)=\frac1{2n}\sum_{i=1}^n\bigl(w^\top x^{(i)}+b-y^{(i)}\bigr)^2L(w,b)=n1i=1nl(i)(w,b)=2n1i=1n(wx(i)+by(i))2


3. 解析解

3.1 什么是解析解?

解析解(analytical solution)指通过数学公式一次性求得的参数最优值,无需迭代优化。
线性回归的解析解为
w∗=(X⊤X)−1X⊤yw^*=(X^\top X)^{-1}X^\top yw=(XX)1Xy
其中 X 为 n×d 的设计矩阵,y 为 n×1 的标签向量。

3.2 解析解的局限性是什么?

仅适用于能写成闭合形式的问题;对大规模数据或复杂模型(如深度网络)难以直接应用。


4. 小批量随机梯度下降(SGD)

4.1 SGD一次更新步骤的伪代码?
  1. 随机采样小批量 B(大小 |B|)。
  2. 计算平均梯度:
    g←1∣B∣∑i∈B∇w,bl(i)(w,b)g \leftarrow \frac1{|B|}\sum_{i\in B}\nabla_{w,b}\,l^{(i)}(w,b)gB1iBw,bl(i)(w,b)
  3. 更新参数:
    w←w−ηgw,b←b−ηgbw \leftarrow w - \eta\,g_w,\quad b \leftarrow b - \eta\,g_bwwηgw,bbηgb
4.2 超参数有哪些?
  • 学习率 η
  • 批量大小 |B|(batch size)

5. 从线性回归到单层神经网络

5.1 如何把线性回归视为神经网络?

把输入视为输入层(d 个神经元),计算层为单个全连接(Dense)层,权重矩阵 W∈ℝ^{1×d},偏置 b∈ℝ,因此网络层数为 1。

5.2 全连接层的定义?

每个输入特征都与每个输出单元相连,计算为
oj=∑ixiwij+bjo_j = \sum_i x_i w_{ij} + b_joj=ixiwij+bj


6. Softmax 回归(多类分类)

6.1 什么是 Softmax 函数?

把未规范化的 logit 向量 o∈ℝ^q 映射为概率分布:

y^j=eoj∑k=1qeok\hat y_j = \frac{e^{o_j}}{\sum_{k=1}^q e^{o_k}}y^j=k=1qeokeoj

保证 0 ≤ ŷ_j ≤ 1 且 Σ_j ŷ_j = 1。

6.2 交叉熵损失如何定义?

对于独热标签 y 和预测概率 ŷ:

l(y,y^)=−∑j=1qyjlog⁡y^jl(y,\hat y)=-\sum_{j=1}^q y_j\log\hat y_jl(y,y^)=j=1qyjlogy^j


7. 从零实现 vs 简洁实现(PyTorch)

7.1 从零实现的关键步骤?
  1. 生成/加载数据
  2. 初始化参数
  3. 定义模型 linreg(X,w,b)net(X)
  4. 定义损失 squared_loss / cross_entropy
  5. 定义优化器 sgd
  6. 循环 for epoch / for batch 训练
7.2 简洁实现用到的 PyTorch 高级 API?
  • nn.Sequential(nn.Linear(...)) 定义模型
  • nn.MSELoss() / nn.CrossEntropyLoss() 定义损失
  • torch.optim.SGD 定义优化器
  • data.DataLoader 构建高效数据迭代器

8. 矢量化加速

8.1 为什么矢量化重要?

利用 GPU/CPU 的并行矩阵运算,避免 Python for-loop,可带来数量级加速。

8.2 示例:向量相加
# 慢
c = torch.zeros(n)
for i in range(n):c[i] = a[i] + b[i]# 快
d = a + b

9. 信息论视角的交叉熵

9.1 熵 H§ 的公式?

H(P)=−∑jP(j)log⁡P(j)H(P)=-\sum_j P(j)\log P(j)H(P)=jP(j)logP(j)
表示真实分布 P 的不确定性。

9.2 交叉熵 H(P,Q) 的公式?

H(P,Q)=−∑jP(j)log⁡Q(j)H(P,Q)=-\sum_j P(j)\log Q(j)H(P,Q)=jP(j)logQ(j)
衡量用模型分布 Q 编码真实分布 P 所需的平均比特数。模型越准,H(P,Q) 越接近 H§。


10. Fashion-MNIST 数据集

10.1 数据集规模?
  • 训练集:60 000 张 28×28 灰度图
  • 测试集:10 000 张
  • 10 个类别:T-shirt、Trouser、Pullover、Dress、Coat、Sandal、Shirt、Sneaker、Bag、Ankle boot
10.2 如何用 PyTorch 加载?
trans = transforms.ToTensor()
train_ds = torchvision.datasets.FashionMNIST(root='../data', train=True, transform=trans, download=True)
train_iter = torch.utils.data.DataLoader(train_ds, batch_size=256, shuffle=True)

http://www.xdnf.cn/news/1338319.html

相关文章:

  • hadoop-3.3.6和hbase-2.4.13
  • Linux下Docker版本升级保姆攻略
  • 数据结构之排序大全(4)
  • LLaVA-3D,Video-3D LLM,VG-LLM,SPAR论文解读
  • WebSocket通信:sockjs与stomp.js的完美搭档
  • 【问题思考】为什么需要文件后缀?(gemini完成)
  • Web3 的发展挑战:技术、监管与生态的多重困境
  • 机器学习聚类算法
  • 什么是默克尔树
  • 缓存与Redis
  • C++---辗转相除法
  • 2025-08-21 Python进阶1——控制流语句
  • 【网络运维】Shell:变量数值计算
  • SQL-leetcode—3451. 查找无效的 IP 地址
  • 从vue2到vue3
  • C++STL-stack和queue的使用及底层实现
  • 基于单片机教室照明灯控制系统
  • Jenkins+GitLab在CentOS7上的自动化部署方案
  • 新手向:Python 3.12 新特性实战详解
  • PAT 1076 Forwards on Weibo
  • linux 差分升级简介
  • git增加ignore文件
  • 健康常识查询系统|基于java和小程序的健康常识查询系统设计与实现(源码+数据库+文档)
  • UEM终端防御一体化
  • 2026 济南玉米及淀粉深加工展:从原料到创新产品的完整解决方案
  • AI Agent与LLM区别
  • Jmeter接口测试之文件上传
  • QT的项目pro qmake编译
  • 【51单片机学习】AT24C02(I2C)、DS18B20(单总线)、LCD1602(液晶显示屏)
  • Prompt魔法:提示词工程与ChatGPT行业应用读书笔记:提示词设计全能指南