当前位置: 首页 > news >正文

【人工智障生成日记1】从零开始训练本地小语言模型


🎯 从零开始训练本地小语言模型:MiniGPT + TinyStories(4090Ti)

🧭 项目背景

本项目旨在以学习为目的,从头构建一个完整的本地语言模型训练管线。目标是:

  • ✅ 不依赖外部云计算
  • ✅ 完全本地运行(RTX 4090Ti)
  • ✅ 从零构建数据加载、模型结构、训练与推理逻辑
  • ✅ 阶段性掌握 LLM 微调与部署的关键技能

🛠️ 开发环境

项目配置
操作系统Windows 10
GPUNVIDIA RTX 4090Ti
CUDA 驱动版本 12.1(cu121
Python 版本3.10
虚拟环境.venv310(指定 Python 3.10)

📦 项目结构

toy-transformer/
├── data_loader.py        # 加载 TinyStories 数据集
├── model.py              # MiniGPT 模型实现
├── train_resume.py       # 支持断点训练的主循环
├── generate.py           # 推理与生成函数
├── checkpoint_latest.pth # 自动保存的训练权重
├── .venv310/             # 虚拟环境

🧠 技术路线

1. 数据加载

  • 使用 HuggingFace datasets 加载 TinyStories
  • Tokenizer 使用 GPT-2 默认分词器
  • 启用 paddingtruncation,统一 max_length=128

2. 模型构建

  • 自定义实现 MiniGPT

    • 小型 Transformer(Embedding + 多层 Self-Attention + Linear head)
    • 使用 GPT-2 的 vocab
    • 无 pretraining,全从零学起

3. 模型训练

  • 使用 torch.nn.CrossEntropyLoss,忽略 pad_token_id
  • 优化器为 AdamW
  • 使用 PyTorch AMP (torch.amp.autocast) 启用混合精度
  • 使用 GradScaler 动态控制精度
  • 使用 tqdm 进度条显示训练状态
  • 支持自动保存断点(checkpoint_latest.pth

4. 推理逻辑

  • 自定义 generate_text() 实现逐 token 自回归生成
  • 使用 softmax + sampling,多轮测试发现 collapse 问题

✅ 成果展示

训练状态

  • 成功训练了 10000 steps,loss 降至 0.05 以下
  • 支持断点恢复训练,训练速度约 25 it/s

在这里插入图片描述

推理结果(初步)

在这里插入图片描述

Prompt: "Once upon a time"
Output: time time time time time time...

Prompt: "Deng Chao is"
Output: is is is is is is is is is...

🚨 模型已出现 token collapse 问题(复读)——为后续结构调整和防过拟合提供重要依据。


🚧 存在问题与后续优化

问题原因计划
文本输出高度重复模型 collapse,过拟合高频 token使用 dropout,top-k sampling,减少训练步数
模型容量不足结构太小,表达力差增加 embedding 和层数;或微调现成模型
学不到语义没有预训练语料带来的 inductive bias尝试 distilgpt2 + TinyStories 微调

📅 今日工作小结

  • ✅ 成功建立 GPU 环境并启用 CUDA(Python 3.10 + PyTorch cu121)
  • ✅ 完整实现数据加载、tokenizer、模型结构与训练管线
  • ✅ 训练了 10000 steps,保存了断点模型
  • ✅ 实现推理接口并测试多个 prompt
  • ✅ 发现模型陷入 token collapse,为后续优化积累经验

🚀 明日计划

  • 替换 softmax 采样为 top-k + temperature 采样,缓解 collapse
  • 加入 dropout / layernorm 防止模型坍缩
  • 尝试使用 distilgpt2 微调,比较从零训练 vs 微调效果
  • 训练可视化(tensorboard / wandb)

🧨 本地语言模型训练踩坑记录(2025.05)


🧱 环境搭建相关

❌ 坑 1:安装了 PyTorch,但无法使用 CUDA

  • 表现torch.cuda.is_available() 返回 False,训练跑在 CPU 上,GPU 利用率为 0%

  • 原因:初始环境为 Python 3.13,而 PyTorch GPU 版本不支持该版本

  • 解决方案

    1. 安装 Python 3.10(py -3.10

    2. 使用 python3.10 -m venv .venv310 创建虚拟环境

    3. 使用官方源安装支持 CUDA 的 PyTorch:

      pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
      

🧪 数据处理与训练相关

❌ 坑 2:模型在 GPU 上训练速度不升反降

  • 表现:CUDA 激活后训练反而更慢

  • 原因:虽然模型 to(device),但输入数据没有显式 .to(device)

  • 解决方案

    • 使用:

      input_ids = batch["input_ids"].to(device)
      

      而不是:

      batch["input_ids"].to(device)  # ⚠ 无效!
      

❌ 坑 3:tqdm 报错 IProgress not found

  • 表现:使用 from tqdm.notebook import tqdm 报错

  • 解决方案

    • 快速替换为:

      from tqdm import tqdm
      
    • 或安装依赖:

      pip install ipywidgets
      jupyter nbextension enable --py widgetsnbextension
      

📦 模型训练相关

❌ 坑 4:训练 loss 降不下来 or 降到 0.0000 太快

  • 表现:训练 1 个 epoch 后 loss ≈ 0,后续 epoch 训练跳过

  • 原因:训练步数被 step_count >= max_steps 提前终止,epoch 实际未执行

  • 解决方案

    • 使用 total_step 替代 step_count 并每轮累加
    • 或改为基于 max_epochs 控制训练轮数

❌ 坑 5:训练后模型生成“词语复读机”(collapse)

  • 表现:生成 output 全是 "time time time...""is is is..."

  • 原因

    • 模型太小,表达能力差
    • 学习率太大或步数太多导致过拟合高频 token
  • 解决方案

    • 启用 dropout 正则
    • 使用 top-k + temperature 控制采样策略
    • 更换为 distilgpt2 微调方案或扩大学习语料

🔐 安全性提示

❌ 坑 6:PyTorch 警告 torch.load() 存在安全隐患

  • 表现:加载 checkpoint 时出现 FutureWarning: weights_only=False

  • 解决方案(建议但非必须):

    • 明确添加参数:

      torch.load(checkpoint_path, weights_only=True)
      

http://www.xdnf.cn/news/584281.html

相关文章:

  • 【无标题】西门子S7-1500PLC与西门子V90 PN伺服通讯控制项目程序项目程序,共有8轴,编码器信号直接输入到变频器内。
  • 系统架构设计(十八):ATAM
  • 《棒球百科》棒球运动规则·棒球1号位
  • 【竖排繁体识别】如何将竖排繁体图片文字识别转横排繁体,转横排简体导出文本文档,基于WPF和腾讯OCR的实现方案
  • 免费轻量便携截图 录屏 OCR 翻译四合一!提升办公效率
  • 解决weman框架redis报错:Class “llluminatelRedis\RedisManager“ not found
  • 【Java高阶面经:数据库篇】18、分布式事务:如何在分库分表中实现高性能与一致性?
  • 零基础设计模式——第二部分:创建型模式 - 原型模式
  • HCIP(广域网)
  • Normalized Blind Deconvolution论文阅读
  • UART串口两种连接方式
  • 笔记本6GB本地可跑的图生视频项目(FramePack)
  • EtpBot:安卓自动化脚本开发神器
  • 了解Android studio 初学者零基础推荐(2)
  • 正则表达式篇
  • element ui 表格实现单选
  • v2.0 技术篇目录-研究生如何选择编程技术
  • iOS工厂模式
  • uniapp-商城-65-shop(1-品牌信息显示,将数据库信息同步到vuex的state)
  • 如何构建一个简单的AI Agent(极简指南)
  • 深度学习入门到实战:用PyTorch打通数学、张量与模型训练全链路​
  • 使用 A2A Python SDK 实现 CurrencyAgent
  • 开闭原则 (Open/Closed Principle, OCP)
  • leetcode hot100刷题日记——10.螺旋矩阵
  • day33 python深度学习入门
  • jmeter登录接口生成一批token并写入csv文件
  • 浪潮Inspur服务器产品线概述
  • 【paddle】常见的数学运算
  • Ubuntu 22.04上升级npm版本
  • 升级node@22后运行npm install报错 distutils not found