当前位置: 首页 > java >正文

softmax回归的从零开始实现

  • softmax的组成

  1. 对每个项求幂(使用exp);
  2. 对每一行求和(小批量中每个样本是一行),得到每个样本的规范化常数;
  3. 将每一行除以其规范化常数,确保结果的和为1。
  • softmax的表达式

在这里插入图片描述
分母或规范化常数,有时也称为配分函数(其对数称为对数-配分函数)。

  • 为什么使用 softmax?

多分类问题:将模型输出解释为类别概率。
梯度优化:softmax 与交叉熵损失结合时,梯度计算更简单高效。
可解释性:输出概率直接反映模型对各类别的置信度。

  • 一个完整的softmax 回归模型?

def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

模型结构解析

  1. 输入处理:X.reshape((-1, W.shape[0]))
  • X:输入图像张量,形状为 [batch_size, 1, 28, 28](批量大小 × 通道数 × 高度 × 宽度)。
  • reshape:将图像展平为二维张量,形状变为 [batch_size, 784](W.shape[0]=784)。
  • -1:自动计算批量大小,适应不同的输入批次。
  1. 线性变换:torch.matmul(…, W) + b矩阵乘法:
  • X_flattened @ W,其中:
  • X_flattened 形状:[batch_size, 784]。
    
  • W 形状:[784, 10](784 个输入特征 → 10 个输出类别)。
    
  • 结果形状:[batch_size, 10]。
    
  • 添加偏置:b 形状为 [10],通过广播机制加到每个样本上。
  1. 概率转换:softmax(…)
  • 将线性变换的输出(logits)转换为概率分布,确保:
  • 每个元素 ∈ [0, 1]。
    
  • 每行元素和为 1。2. 
    

数学公式
在这里插入图片描述

  • 为什么用交叉熵?

直观意义:惩罚模型对真实标签的低置信度预测。例如,若模型对真实标签的预测概率为 0.1(接近错误),损失会很大;若为 0.9(接近正确),损失会很小。
数学意义:交叉熵是两个概率分布(预测分布 y_hat 和真实分布 y)之间差异的度量,真实分布可视为 “one-hot 编码”(如 y=0 对应 [1,0,0])。

  • 解释预测类别提取?

if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)

输入检查:
若 y_hat 是多维张量(如 [batch_size, num_classes])且类别数 > 1,则进行处理。
argmax(axis=1):
对每行(每个样本),取概率最大的索引作为预测类别。
y_hat = torch.tensor([[0.1, 0.3, 0.6], # 预测类别为2(索引从0开始) [0.3, 0.6, 0.1]]) # 预测类别为1 print(y_hat.argmax(axis=1) ) # 输出:tensor([2, 1])

  • 预测类别与真实标签的对比?

准确率是分类问题最基本的评估指标,它衡量模型预测正确的比例。

y_hat = torch.tensor([[0.1, 0.3, 0.6],  # 样本1的预测概率分布[0.3, 0.2, 0.5]])  # 样本2的预测概率分布y = torch.tensor([0, 2])  # 真实标签:样本1属于类别0,样本2属于类别2
  1. 第一个样本的预测分析
    预测概率:[0.1, 0.3, 0.6]
    类别 0 概率:0.1
    类别 1 概率:0.3
    类别 2 概率:0.6(最大值)
    预测类别:概率最大的索引 → 类别 2(索引从 0 开始)
    真实标签:0
    结果:预测错误 ❌
  2. 第二个样本的预测分析
    预测概率:[0.3, 0.2, 0.5]
    类别 0 概率:0.3
    类别 1 概率:0.2
    类别 2 概率:0.5(最大值)
    预测类别:概率最大的索引 → 类别 2
    真实标签:2
    结果:预测正确 ✅
  3. 准确率的计算
    在这里插入图片描述
  • accuracy函数和evaluate_accuracy函数两者的区别和联系?

accuracy函数:
核心特点
输入:单个批次的预测结果(概率矩阵或类别索引)和真实标签。
输出:该批次中预测正确的样本数量(例如:15个样本预测正确)。
用途:计算小批量数据的准确率基数。

 def accuracy(y_hat, y):"""计算单个批次的预测正确数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)  # 将概率矩阵转为类别索引cmp = y_hat.type(y.dtype) == y    # 逐元素比较预测结果与真实标签return float(cmp.type(y.dtype).sum())  # 返回正确预测的样本数(如15.0)

evaluate_accuracy 函数
核心特点
输入:模型和整个数据集的迭代器。
输出:整个数据集的准确率(比例)(例如:87.5%)。
依赖:调用 accuracy 计算每个批次的正确数,再通过 Accumulator 累加所有批次的结果。

def evaluate_accuracy(net, data_iter):"""计算整个数据集上的准确率"""if isinstance(net, torch.nn.Module):net.eval()  # 设置模型为评估模式metric = Accumulator(2)  # 创建累加器(正确数、总数)with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())  # 累加每批次的正确数和总数return metric[0] / metric[1]  # 返回全局准确率(如0.875)
功能accuracyevaluate_accuracy
作用范围单个批次的数据整个数据集(所有批次)
输入参数预测结果 y_hat 和真实标签 y模型 net 和数据迭代器 data_iter
返回值正确预测的样本数(如 15.0)准确率(如 0.875)
核心逻辑比较预测与标签,统计正确数遍历所有批次,累计正确数和总数,计算比例
是否需要 Accumulator是(需要累计多批次结果)
是否设置模型模式是(自动设置为评估模式)

协作关系
这两个函数通常配合使用:

  • accuracy 负责计算单个批次的正确预测数。
  • evaluate_accuracy 调用 accuracy 处理每个批次,并累计结果得到全局准确率。

为什么需要这种设计?

  1. 模块化设计
  • accuracy 专注于单个批次的计算,逻辑简单纯粹。
  • evaluate_accuracy 专注于跨批次的累计,复用 accuracy 的逻辑。
  1. 内存效率
  • 无需存储所有批次的预测结果,只需维护两个累加值(正确数和总数)。
  • 适合处理大规模数据集(如 ImageNet 有 100 万样本)。
  1. 灵活性
  • accuracy 可单独用于调试或分析特定批次的预测结果。
  • evaluate_accuracy 可用于验证集、测试集或不同模型的比较。

总结

  • accuracy 是微观层面的工具,计算单个批次的正确预测数。
  • evaluate_accuracy 是宏观层面的工具,基于 accuracy 计算整个数据集的准确率。
    两者结合实现了高效、模块化的模型评估流程。
  • 损失计算与梯度更新的两种模式

优化器类型损失处理梯度更新逻辑适用场景
PyTorch 内置优化器l.mean().backward()updater.zero_grad() + step()标准训练流程,自动处理批量
自定义优化器l.sum().backward()手动传入批量大小 X.shape[0]教学演示,理解优化器原理

为什么区分处理?
内置优化器(如torch.optim.SGD)默认基于批量均值计算梯度,而自定义优化器可能需要总损失和批量大小来调整学习率(如 lr =learning_rate / batch_size)。

  • 完整代码

"""
文件名: 3.6 softmax回归的从零开始实现
作者: 墨尘
日期: 2025/7/11
项目名: dl_env
备注: 
"""
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
from IPython.display import display
import matplotlib.pyplot as plt
import platform
import matplotlib.font_manager as fm
"""softmax函数"""def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制"""定义模型"""
"""
这个 net 函数实现了一个完整的 softmax 回归模型,它:
将 28×28 的图像展平为 784 维向量。
通过线性变换(矩阵乘法 + 偏置)计算每个类别的得分。
使用 softmax 将得分转换为概率分布。
"""def net(X):# return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b) 逐行详解# 概率转换:softmax(...)# 将线性变换的输出(logits)转换为概率分布,确保:# 每个元素 ∈ [0, 1]。# 每行元素和为 1。return softmax(# 矩阵乘法:X_flattened @ W,其中:# X_flattened 形状:[batch_size, 784]。# W 形状:[784, 10](784 个输入特征 → 10 个输出类别)。# 结果形状:[batch_size, 10]。# 添加偏置:b 形状为 [10],通过广播机制加到每个样本上。torch.matmul(  # 线性变换:torch.matmul(..., W) + b# X:输入图像张量,形状为 [batch_size, 1, 28, 28](批量大小 × 通道数 × 高度 × 宽度)。# reshape:将图像展平为二维张量,形状变为 [batch_size, 784](W.shape[0]=784)。# -1:自动计算批量大小,适应不同的输入批次。X.reshape((-1, W.shape[0])),  # 输入处理:X.reshape((-1, W.shape[0]))W) + b)"""交叉熵损失函数
用于衡量模型预测概率分布与真实标签之间的差异
"""def cross_entropy(y_hat, y):# 对提取的概率取自然对数(torch.log),再取负数(-)。# 概率越接近 1,负对数越小(损失越小);# 概率越接近 0,负对数越大(损失越大)。return - torch.log(# 提取真实标签对应的预测概率y_hat[range(len(y_hat)), y])"""分类精度
用于计算模型预测的准确率(Accuracy),即预测正确的样本数占总样本数的比例
"""def accuracy(y_hat, y):  # @save"""预测类别提取"""# 详解:解释预测类别提取?if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)  # 将概率矩阵转换为类别索引cmp = y_hat.type(y.dtype) == y  # 比较预测类别与真实标签return float(cmp.type(y.dtype).sum())  # 统计正确预测的数量# 用于评估模型在给定数据集上准确率(Accuracy)的函数
# 灵活的统计量累加器
"""这个类用于动态累计多个变量,特别适合在深度学习中统计批次数据的指标(如准确率、损失值)。"""
"""metric = Accumulator(2)  # 创建两个变量的累加器
# 第一次累加:正确1个,总数2个
metric.add(1, 2)  # metric.data → [1.0, 2.0]
# 第二次累加:正确2个,总数3个
metric.add(2, 3)  # metric.data → [3.0, 5.0]
# 获取统计结果
print(f"总正确数: {metric[0]}")  # 3.0
print(f"总样本数: {metric[1]}")  # 5.0
print(f"准确率: {metric[0]/metric[1]}")  # 0.6"""class Accumulator:  # @save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * n  # 初始化n个累加器变量(如正确数、总数)def add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]  # 累加新值def reset(self):self.data = [0.0] * len(self.data)  # 重置为0def __getitem__(self, idx):return self.data[idx]  # 支持索引访问(如metric[0]获取第一个变量)"""
evaluate_accuracy 函数:模型准确率评估
这个函数利用 Accumulator 遍历整个数据集,计算模型的全局准确率。
"""
def evaluate_accuracy(net, data_iter):  # @save"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数  # 创建两个变量的累加器with torch.no_grad():  # 禁用梯度计算,加速评估for X, y in data_iter:  # 遍历每个批次# 计算当前批次的正确预测数和样本总数metric.add(accuracy(net(X), y), y.numel())# 返回全局准确率 = 总正确数 / 总样本数return metric[0] / metric[1]
"""绘制数据的实用程序类Animator
这个 Animator 类是一个用于动态更新图表的工具,特别适合在训练过程中实时可视化损失、准确率等指标的变化。
"""
"""
Animator 类通过以下方式实现动态图表:
初始化配置:设置图表标题、坐标轴标签、图例等。
数据添加:通过 add(x, y) 方法持续添加新数据点。
实时更新:每次添加数据后,自动刷新图表并清除旧输出。
"""
import matplotlib.pyplot as plt
import timeimport matplotlib
matplotlib.use('TkAgg')  # 强制设置后端
import matplotlib.pyplot as plt
import matplotlib.font_manager as fmclass Animator:"""实时绘制训练曲线的工具类(适配PyCharm)"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(10, 6)):# 配置中文字体self._setup_fonts()self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes]# 设置坐标轴配置函数self.config_axes = lambda: self._set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmts# 开启交互模式plt.ion()self.fig.show()  # 初始化显示窗口def _setup_fonts(self):chinese_fonts = ['SimHei', 'Microsoft YaHei', 'SimSun']available_fonts = [f.name for f in fm.fontManager.ttflist]found_font = next((f for f in chinese_fonts if f in available_fonts), None)if found_font:plt.rcParams["font.family"] = found_fontprint(f"已设置中文字体: {found_font}")else:print("警告: 未找到可用中文字体")plt.rcParams['axes.unicode_minus'] = Falsedef _set_axes(self, ax, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):ax.set_xlabel(xlabel)ax.set_ylabel(ylabel)ax.set_xscale(xscale)ax.set_yscale(yscale)ax.set_xlim(xlim)ax.set_ylim(ylim)if legend:ax.legend(legend)ax.grid(True)def add(self, x, y):# 处理输入数据if not hasattr(y, "__len__"):y = [y]n = len(y)x = [x] * n if not hasattr(x, "__len__") else x# 初始化数据存储if not self.X:self.X = [[] for _ in range(n)]self.Y = [[] for _ in range(n)]# 添加数据点for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)# 绘制并刷新try:self.axes[0].cla()for x_list, y_list, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x_list, y_list, fmt, linewidth=1.5)self.config_axes()self.fig.canvas.draw()  # 刷新画布self.fig.canvas.flush_events()  # 处理事件except Exception as e:print(f"绘图错误: {e}")def show(self):"""训练结束后保持窗口不关闭"""#plt.ion()  # 确保交互模式是开启的(可选)#self.fig.canvas.draw()  # 最后刷新一次plt.ioff()  # 关闭交互模式(关键)plt.show(block=True)  # 阻塞窗口,直到手动关闭(关键)"""训练"""
# net:待训练的模型(如之前定义的 softmax 回归模型)。
# train_iter:训练数据集的迭代器(DataLoader),用于批量加载数据。
# loss:损失函数(如交叉熵损失)。
# updater:参数更新器(如梯度下降优化器),负责更新模型参数(W 和 b)。
def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式#对于包含 Dropout、BatchNorm 等层的模型,训练模式会启用随机化操作(如 Dropout 随机丢弃神经元),确保模型正常训练if isinstance(net, torch.nn.Module):net.train()# 存储:[训练损失总和, 正确预测数, 总样本数]"""Accumulator(3) 用于累计整个迭代周期的三个关键指标:所有批次的损失总和(metric[0])。所有批次的正确预测数(metric[1])。所有批次的总样本数(metric[2])。
"""metric = Accumulator(3)for X, y in train_iter:#前向传播:计算预测值y_hat = net(X)# 计算损失l = loss(y_hat, y)# 根据优化器类型更新参数if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数   # 情况1:使用PyTorch内置优化器(如SGD)updater.zero_grad()   # 清零梯度l.mean().backward() # 计算梯度(损失取均值,适应批量梯度下降)updater.step()  # 更新参数else:# 使用定制的优化器和损失函数  # 情况2:使用自定义优化器(如手动实现的SGD)l.sum().backward()   #计算梯度(损失取总和,自定义优化器需手动处理批量大小)updater(X.shape[0]) # 传入批量大小,更新参数# 累计指标metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度"""返回值:第一个值:平均训练损失(总损失 ÷ 总样本数)。第二个值:训练准确率(正确预测数 ÷ 总样本数)。"""return metric[0] / metric[2], metric[1] / metric[2]"""这个 train_ch3 函数实现了一个完整的模型训练流程,包括多轮迭代训练、实时可视化和结果验证。"""
# net:待训练的模型(如 softmax 回归网络)。
# train_iter:训练数据集迭代器(批量加载训练数据)。
# test_iter:测试数据集迭代器(评估模型泛化能力)。
# loss:损失函数(如交叉熵损失)。
# num_epochs:训练轮数(整个数据集被遍历的次数)。
# updater:优化器(如 SGD、Adam 等,用于更新模型参数)。
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save"""训练模型(定义见第3章)"""# Animator:动态更新图表的工具(之前讨论过),用于实时显示训练损失、训练准确率和测试准确率。# 图表配置:# x 轴:训练轮数(epoch)。# y 轴范围:0.3~0.9(确保曲线显示在合理区间)。# 图例:显示三条曲线(训练损失、训练准确率、测试准确率)。animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])# 多轮训练循环# 每轮训练:# train_epoch_ch3:完成一个完整轮次的训练,返回 (train_loss, train_acc)。# evaluate_accuracy:在测试集上评估模型准确率。# animator.add:将当前轮次的三个指标(训练损失、训练准确率、测试准确率)添加到图表中,触发可视化更新。for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, train_metrics + (test_acc,))#     训练结果验证# 断言检查:确保训练结果符合预期(防止严重错误):# 训练损失应小于 0.5。# 训练准确率应在 0.7~1.0 之间。# 测试准确率应在 0.7~1.0 之间。train_loss, train_acc = train_metricsassert train_loss < 0.5, train_lossassert train_acc <= 1 and train_acc > 0.7, train_accassert test_acc <= 1 and test_acc > 0.7, test_acc# 训练完全结束后,保持窗口#animator.show()lr = 0.1
# updater 函数定义了一个自定义优化器,用于手动实现参数更新(如随机梯度下降)
def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)"""预测"""
#predict_ch3 函数用于可视化模型在测试集上的预测结果,通过对比真实标签和预测标签,直观展示模型的分类效果。
# net:训练好的模型(如之前定义的 softmax 回归模型)。
# test_iter:测试数据集的迭代器(DataLoader),用于加载测试样本。
# n:要可视化的样本数量(默认展示 6 个样本)。def predict_ch3(net, test_iter, n=6):  # @save"""预测标签(适配PyCharm环境)"""# 1. 只取第一个批次的测试数据for X, y in test_iter:break  # 跳出循环,保留X(图像)和y(真实标签)# 2. 转换标签为文本(如0→"T恤")trues = d2l.get_fashion_mnist_labels(y)  # 真实标签文本preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))  # 预测标签文本# 3. 生成标题(真实标签 + 预测标签,换行显示)titles = [f"{true}\n{pred}" for true, pred in zip(trues, preds)]# 4. 显示图像(适配PyCharm)d2l.show_images(X[0:n].reshape((n, 28, 28)),  # 提取前n张图,调整形状为(n,28,28)1, n,  # 布局:1行n列titles=titles[0:n]  # 对应前n张图的标题)# 关键:PyCharm中保持图像窗口不关闭plt.show(block=True)  # 阻塞窗口,手动关闭后才继续执行
if __name__ == '__main__':batch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)# 初始化模型参数num_inputs = 784num_outputs = 10# 在softmax回归中,我们的输出与类别一样多。 因为我们的数据集有10个类别,# 所以网络输出维度为10。 因此,权重将构成一个的矩阵, 偏置将构成一个的行向量。# 与线性回归一样,我们将使用正态分布初始化我们的权重W,偏置b初始化为0。W = torch.normal(0, 0.01,  # 正态分布,均值为0,标准差0.01size=(num_inputs, num_outputs),# 权重矩阵的形状 在 Fashion-MNIST 分类中,num_inputs=784(28×28 像素),num_outputs=10(10 个类别),则 W 的形状为 [784, 10]。requires_grad=True)  # 启用自动微分(Autograd),PyTorch 会跟踪 W 上的所有运算,以便后续计算梯度b = torch.zeros(num_outputs, requires_grad=True)# 简要回顾一下sum运算符如何沿着张量中的特定维度工作X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])print(X.sum(0, keepdim=True), X.sum(1, keepdim=True))"""实现softmax由三个步骤组成:对每个项求幂(使用exp);对每一行求和(小批量中每个样本是一行),得到每个样本的规范化常数;将每一行除以其规范化常数,确保结果的和为1。"""# 从正态分布(均值 0,标准差 1)随机采样,生成形状为 (2, 5) 的张量。X = torch.normal(0, 1, (2, 5))X_prob = softmax(X)print(X_prob)# 概率值,每行和为 1:print(X_prob.sum(1))#  定义损失函数# 从预测概率矩阵中提取真实标签对应的概率,为计算交叉熵损失做准备。y = torch.tensor([0, 2])  # 两个样本的真实标签:类别0和类别2# 预测概率矩阵"""y_hat 含义:
第 1 行 [0.1, 0.3, 0.6]:对第 1 个样本的预测,类别 0 概率 0.1,类别 1 概率 0.3,类别 2 概率 0.6。
第 2 行 [0.3, 0.2, 0.5]:对第 2 个样本的预测,类别 0 概率 0.3,类别 1 概率 0.2,类别 2 概率 0.5。
"""y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])"""y_hat[[0, 1], [0, 2]]  # 等价写法第 0 行第 0 列 → y_hat[0, 0] = 0.1第 1 行第 2 列 → y_hat[1, 2] = 0.5"""print(y_hat[[0, 1], y])# 实现交叉熵损失函数cross_entropy(y_hat, y)"""分类精度"""# 计算单个批次的正确预测数 accuracy(y_hat, y)# 计算单个批次数据的准确率(比例)print(accuracy(y_hat, y) / len(y))# 基于 accuracy 计算整个数据集的准确率。print(evaluate_accuracy(net, test_iter))#训练num_epochs = 10train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)# 预测predict_ch3(net, test_iter)
  • 实验结果:

在这里插入图片描述
在这里插入图片描述

tensor([[5., 7., 9.]]) tensor([[ 6.],
[15.]])
tensor([[0.3271, 0.0393, 0.0964, 0.3759, 0.1613],
[0.1693, 0.0101, 0.5148, 0.0452, 0.2607]])
tensor([1.0000, 1.0000])
tensor([0.1000, 0.5000])
0.5
0.0652
已设置中文字体: SimHei

  • 整体逻辑流程:

该代码实现了一个完整的softmax回归模型,用于解决Fashion-MNIST数据集的图像分类任务(10类衣物分类)。整体逻辑可分为**数据准备、核心组件定义、模型训练、结果可视化**四个部分,各部分衔接紧密,形成从数据到预测的完整流程。### 一、数据准备与环境配置
代码首先导入必要的库(PyTorch、数据加载工具、可视化工具等),并配置中文字体显示(确保训练曲线和预测结果的中文标签正常显示)。  
通过`d2l.load_data_fashion_mnist(batch_size)`加载Fashion-MNIST数据集,返回训练集迭代器`train_iter`和测试集迭代器`test_iter`,用于批量加载28×28的灰度图像及对应标签。### 二、核心组件定义
这部分是模型的“骨架”,包括softmax函数、模型结构、损失函数、评估指标等核心功能。#### 1. softmax函数:将输出转换为概率分布
softmax函数的作用是将模型输出的“原始得分(logits)”转换为概率分布(每个类别概率∈[0,1],且所有类别概率和为1)。  
实现逻辑:  
- 对每个得分做指数运算(`torch.exp(X)`),避免负数影响;  
- 对每行(每个样本)的指数结果求和(`partition`),作为归一化常数;  
- 每个元素除以该行的归一化常数(`X_exp / partition`),得到概率分布。  #### 2. 模型定义(net函数):softmax回归的前向传播
模型功能:将输入图像转换为10类别的概率分布。  
实现逻辑:  
- 图像展平:将28×28的输入图像(形状`[batch_size, 1, 28, 28]`)展平为784维向量(`X.reshape((-1, 784))`);  
- 线性变换:通过权重矩阵`W`(形状`[784, 10]`)和偏置`b`(形状`[10]`)计算每个类别的原始得分(`torch.matmul(展平向量, W) + b`);  
- 概率转换:用softmax函数将原始得分转换为10类别的概率分布(形状`[batch_size, 10]`)。  #### 3. 损失函数:交叉熵损失
用于衡量模型预测概率与真实标签的差异,是训练的“优化目标”。  
实现逻辑(`cross_entropy函数`):  
- 从预测概率矩阵`y_hat`中,提取真实标签`y`对应的概率(`y_hat[range(len(y_hat)), y]`);  
- 对该概率取负对数(`-torch.log(...)`),概率越接近1,损失越小(符合“预测越准,损失越低”的直观逻辑)。  #### 4. 评估指标:分类准确率
用于衡量模型的预测效果(正确预测样本数占总样本数的比例)。  
- `accuracy函数`:比较模型预测的类别(`y_hat.argmax(axis=1)`,取概率最大的类别)与真实标签`y`,统计正确数;  
- `Accumulator类`:累加多个批次的指标(如“总正确数”“总样本数”),方便计算全局准确率;  
- `evaluate_accuracy函数`:利用上述工具遍历整个数据集(训练集或测试集),计算模型在该数据集上的整体准确率。  #### 5. 可视化工具:Animator类
用于实时可视化训练过程,动态展示三个关键指标的变化:训练损失、训练准确率、测试准确率。  
功能:每轮训练后更新图表,支持中文显示,适配PyCharm环境(确保图像窗口不闪退)。  ### 三、模型训练
通过“前向传播计算损失→反向传播更新参数→多轮迭代优化”的流程,让模型从“随机猜测”逐步收敛到“准确分类”。#### 1. 参数初始化
- 权重`W`:从均值0、标准差0.01的正态分布中随机初始化(形状`[784, 10]`);  
- 偏置`b`:初始化为0(形状`[10]`);  
- 学习率`lr=0.1`,用于控制参数更新的步长。  #### 2. 训练核心函数
- `train_epoch_ch3函数`:实现一个训练周期(遍历一次训练集):  - 遍历`train_iter`中的每个批次数据`(X, y)`;  - 前向传播:计算预测概率`y_hat = net(X)`;  - 计算损失:`l = cross_entropy(y_hat, y)`;  - 反向传播:通过梯度下降(`updater`)更新`W``b`(清零梯度→计算梯度→更新参数);  - 累加指标:记录该批次的总损失、正确数、总样本数。  - `train_ch3函数`:控制多轮训练(`num_epochs=10`):  - 每轮调用`train_epoch_ch3`完成训练,得到该轮的训练损失和训练准确率;  - 调用`evaluate_accuracy`计算模型在测试集上的准确率;  - 用`Animator`将三个指标(训练损失、训练准确率、测试准确率)实时可视化;  - 训练结束后验证指标合理性(如损失<0.5,准确率>0.7)。  ### 四、预测与结果展示
训练完成后,通过`predict_ch3函数`验证模型的实际分类效果:  
- 从测试集取少量样本(默认6个),用训练好的`net`模型预测类别;  
- 可视化样本图像、真实标签(如“T恤”)和预测标签(如“T恤”或“衬衫”);  
- 直观对比模型预测与真实结果,检验分类效果。  ### 整体流程总结
1. 数据输入:加载Fashion-MNIST数据集,批量读取图像和标签;  
2. 模型计算:通过softmax回归将图像映射为10类概率分布;  
3. 损失与优化:用交叉熵损失衡量误差,通过梯度下降更新参数(`W``b`);  
4. 评估与可视化:实时跟踪训练/测试指标,训练后展示具体样本的预测结果;  
5. 目标:让模型在10类衣物分类任务上达到较高准确率(最终测试准确率通常>0.8)。  整个代码从零实现了softmax回归的核心逻辑,涵盖了“模型定义-损失计算-参数优化-评估可视化”的完整深度学习流程。
http://www.xdnf.cn/news/15133.html

相关文章:

  • php的原生类
  • 《棒球规则介绍》领队和主教练谁说了算·棒球1号位
  • Express实现定时任务
  • PBR渲染
  • 软件开发那些基础事儿:需求、模型与生命周期
  • 大模型在卵巢癌预测及诊疗方案制定中的应用研究
  • 河南专升本2026年练习题、真题和2000题每日一节
  • 分割网络Segformer
  • 【B题解题思路】2025APMCM亚太杯中文赛B题解题思路+可运行代码参考(无偿分享)
  • 设计模式(结构型)-适配器模式
  • c++——浅拷贝和深拷贝、浅赋值和深赋值
  • 基于强化学习的智能推荐系统优化实践
  • c/c++拷贝函数
  • 字节豆包又一个新功能,超级实用,4 种玩法,你肯定用得上!(建议收藏)
  • 力扣_二叉搜索树_python版本
  • 聚焦数据资源建设与应用,浙江省质科院赴景联文科技调研交流
  • 【龙泽科技】新能源汽车维护与动力蓄电池检测仿真教学软件【吉利几何G6】
  • Elasticsearch 滚动(Scroll)用法、使用场景及与扫描(Scan)的区别
  • DIDCTF-蓝帽杯
  • 【经典面经】C++新特性 TCP完整收发数据 TLS1.2 TLS1.3
  • 【C++】全套数据结构算法-线性表讲解(1)
  • Anaconda及Conda介绍及使用
  • 注意力机制十问
  • 简单记录一下Debug的折磨历程
  • 汽车级MCU选型新方向:eVTOL垂桨控制监控芯片的替代选型技术分析
  • 巨人网络持续加强AI工业化管线,Lovart国内版有望协同互补
  • UI前端大数据可视化实战技巧:如何利用数据故事化提升用户参与度?
  • 云暴露面分析完整指南
  • Qt:布局管理器Layout
  • [Meetily后端框架] 多模型-Pydantic AI 代理-统一抽象 | SQLite管理