当前位置: 首页 > ds >正文

深度学习3.6 softmax回归的从零开始实现

本章节引入3.5的数据集

import torch
from IPython import display
from d2l import torch as d2lbatch_size = 256 #迭代器批量
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.6.1 初始化模型参数

num_inputs = 784 # 权重矩阵长度
num_outputs = 10 # 类别数量
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True) # 权重矩阵
b = torch.zeros(num_outputs, requires_grad=True) # 偏置

图像尺寸28*28像素
‌权重W‌:从均值为0、标准差0.01的正态分布采样,形状 [784, 10]。
‌偏置b‌:初始化为全0,形状 [10]。
‌梯度追踪‌:requires_grad=True 启用自动微分。

3.6.2 定义softmax操作

def softmax(X):X_exp = torch.exp(X) # 处理计算自然指数函数e的幂(GPU计算效率高)partition = X_exp.sum(1, keepdim=True) # 0:列,1:行,计算为x行1列张量return X_exp / partition # 归一化-概率[[1/3,2/3],[3/7,4/7]]X = torch.normal(0, 1, (2, 5)) # torch.normal 用于生成服从‌正态分布(高斯分布)‌的随机数张量,支持多种参数形式(均值,标准差,(形状))
X_prob = softmax(X) # 概率
X_prob, X_prob.sum(1) # 概率和=1

在这里插入图片描述

3.6.3 定义模型

def net(X):a1 = X.reshape((-1, W.shape[0])) # 保持[*,len(W)]a2 = torch.matmul(a1, W) # torch.matmul矩阵乘法return softmax(a2 + b) # 返回对应概率

展平输入:X.reshape((-1, 784))(将 [batch_size,1,28,28] 转为 [batch_size,784])。
线性变换:XW+b(输出 [batch_size,10])。
Softmax归一化:得到每个类别的概率分布。

3.6.4 定义损失函数

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]

tensor([0.1000, 0.5000])
高级索引 : 索引列表会按‌位置配对‌,从y_hat中提取特定位置的元素
‌第一个元素‌:y_hat[0行, y[0]=0列] → 0.1
‌第二个元素‌:y_hat[1行, y[1]=2列] → 0.5

def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)

tensor([2.3026, 0.6931])

3.6.5 分类精度

def accuracy(y_hat, y):if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())
http://www.xdnf.cn/news/1228.html

相关文章:

  • ULVAC VTR-350MERH紧凑型真空蒸发器Compact Vacuum Evaporator 含电路图,安装手,工艺规范,操作工序说明
  • 【漫话机器学习系列】217.监督式深度学习的核心法则(Supervised Deep Learning Rule Of Thumb)
  • 数据结构与算法-顺序表应用
  • MySQL_MCP_Server_pro接入cherry_studio实现大模型操作数据库
  • 进阶篇 第 5 篇:现代预测方法 - Prophet 与机器学习特征工程
  • Linux 系统监控进阶:htop 命令详解与高效运维
  • 算法基础_数据结构【KMP + Trie 树 + 并查集 】
  • sql server tempdb库的字符集和用户库字符集不一样
  • 大模型时代下的人工智能专业就业:机遇与挑战并存
  • U535982 J-A 小梦的AB交换 题解
  • 【springsecurity oauth2授权中心】自定义登录页和授权确认页 P2
  • [Android]豆包爱学v4.5.0小学到研究生 题目Ai解析
  • qt调用deepseek的API开发(附带源码)
  • IPoIB驱动接收路径深度解析:从数据包到协议栈
  • 全本地化智能数字人
  • Java 性能优化:如何在资源受限的环境下实现高效运行?
  • Apache PDFBox
  • 【延迟双删】简单解析
  • 基于无障碍跳过广告-基于节点跳过广告
  • 比特币三种扩容路径Nubit、Babylon、Bitlayer分析
  • spark和Hadoop的之间的对比和联系
  • VMware Workstation 10.0.0 完整安装与激活指南零配置
  • [贪心_3] 摆动序列 | 最长递增子序列
  • 植被参数遥感反演技术革命!AI+Python支持向量机/随机森林/神经网络/CNN/LSTM/迁移学习在植被参数反演中的实战应用与优化
  • ESM 内功心法:化解 require 中的夺命一击!
  • 用语言模型训练出图像生成和理解能力:Liquid 框架 论文速读
  • 从零开始创建MCP Server实战指南
  • 描述城市出行需求模式的复杂网络视角:大规模起点-目的地需求网络的图论分析
  • 牛客算法题目刷——链表总结
  • 软考高级信息系统项目管理师的【干系人参与度评估矩阵】详解