当前位置: 首页 > ds >正文

python打卡day39@浙大疏锦行

知识点回顾

  1. 图像数据的格式:灰度和彩色数据
  2. 模型的定义
  3. 显存占用的4种地方
    1. 模型参数+梯度参数
    2. 优化器参数
    3. 数据批量所占显存
    4. 神经元输出中间状态
  4. batchisize和训练的关系

 1. 图像数据格式
- 灰度图像 :单通道,像素值范围通常0-255,形状为(H, W)或(1, H, W)
- 彩色图像 :三通道(RGB),形状为(3, H, W)或(H, W, 3)
 2. 模型定义要点
- 由多个神经网络层组成(卷积层、全连接层等)
- 每层包含可训练参数(权重和偏置)
- 需要定义前向传播(forward)逻辑
3. 显存占用的4个主要部分

a. 模型参数+梯度参数

# 以PyTorch为例,每个参数会占用:
显存 = 参数数量 × 4字节(float32)
梯度占用相同大小的显存

b. 优化器参数

# 例如Adam优化器会为每个参数保存:
- 一阶动量(m)
- 二阶动量(v)
# 显存占用 = 参数数量 × 4 × 2

c. 数据批量所占显存

显存 = batch_size × 单样本数据量 × 4字节

d. 神经元输出中间状态

# 前向传播时各层的输出需要保存
# 用于反向传播计算梯度
显存 ≈ Σ(各层输出张量大小 × batch_size × 4)

4.batchisize和训练的关系

import torch
import torch.nn as nn
from torch.utils.data import DataLoader# 1. 定义简单模型
class SimpleModel(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(10, 2)  # 10维输入到2维输出# 2. 创建模拟数据集
data = torch.randn(1000, 10)        # 1000个样本
targets = torch.randint(0, 2, (1000,)) 
dataset = torch.utils.data.TensorDataset(data, targets)# 3. 不同batch_size的影响对比
for batch_size in [1, 32, 1024]:loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)model = SimpleModel()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 训练监控print(f"\nBatch Size: {batch_size}")for epoch in range(3):for x, y in loader:optimizer.zero_grad()output = model(x)loss = nn.CrossEntropyLoss()(output, y)loss.backward()optimizer.step()print(f"Epoch {epoch} Loss: {loss.item():.4f}", end='\r')

关键点说明:
1. batch_size=1 (随机梯度下降):
   
   - 每个样本更新一次参数
   - 损失波动剧烈(高方差)
   - 适合在线学习场景
2. batch_size=32 (常用值):
   
   - 平衡了梯度稳定性和计算效率
   - 损失曲线相对平滑
3. batch_size=1024 (大批量):
   
   - 梯度方向最稳定
   - 需要更大的学习率(可用线性缩放规则)
   - 可能需梯度累积(若显存不足)

http://www.xdnf.cn/news/9976.html

相关文章:

  • vite配置一个css插件
  • MySQL字段为什么要求定义为not null ?
  • 约瑟夫问题
  • insightface==0.7.3 编译失败
  • 从时钟精度看晶振频率稳定度的重要性
  • 12-后端Web实战(登录认证)
  • 实验设计与分析(第6版,Montgomery)第5章析因设计引导5.7节思考题5.4 R语言解题
  • Linux文件操作、文件夹操作
  • 【前端】使用grid布局封装断点式进度条
  • Flannel 支持的后端
  • 交集、差集、反选
  • 蓝牙和wifi相关的杂项内容总结
  • Executors面试题
  • apptrace 的优势以及对 App 的价值
  • 【Stable Diffusion 1.5 】在 Unet 中每个 Cross Attention 块中的张量变化过程
  • 磁盘管理无法删除卷,虚拟磁盘管理器:不支持该请求
  • Attention Is All You Need论文阅读笔记
  • Wirtinger Flow算法的matlab实现和python实现
  • 【前端】Twemoji(Twitter Emoji)
  • RV1126-OPENCV Mat理解
  • 某东 h5st第8个参数 指纹加密纯算解析
  • 模型微调之对齐微调KTO
  • MySQL的binlog有有几种录入格式分别有什么区别 ?
  • VSCode的下载与安装(2025亲测有效)
  • LLaMaFactory 微调QwenCoder模型
  • Windows 中禁止在桌面放置文件以保持桌面整洁
  • 深入详解编译与链接:翻译环境和运行环境,翻译环境:预编译+编译+汇编+链接,运行环境
  • does not provide an export named ‘getActiveHead‘
  • 集成均衡功能电池保护芯片在大功率移动电源的应用,创芯微CM1341-DAT、杰华特JW3312、赛微微电CW1244、中颖SH366006
  • 从Homebrew找到openssl.cnf文件并拷贝到Go项目下使用