当前位置: 首页 > ds >正文

LLM基础6_在未标记数据上进行预训练

基于GitHub项目:https://github.com/datawhalechina/llms-from-scratch-cn

  • 预训练:让模型大量阅读各种文本

  • 微调:教模型完成特定任务(如问答)

评估文本生成模型

使用GPT生成文本

工具:

  • tiktoken:文本→数字ID(就像给每个单词编号)

  • generate_text_simple:让模型"接着上一个词说"

如何衡量模型好坏?

交叉熵损失 (Cross-Entropy Loss):衡量模型预测的下一个词有多不准,值越小越好(理想值接近0)

困惑度 (Perplexity):更直观的评分:"模型平均要在多少词里猜中正确答案",越小越好(理想值接近1)

模型训练(关键步骤)

1.优化器:模型的学习教练

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

2.训练循环

for epoch in range(100):  # 学100遍for batch in train_loader:# 1. 清空上次"学到的错误"optimizer.zero_grad() # 2. 尝试预测下一个词logits = model(batch) # 3. 计算错误程度loss = cross_entropy(logits, targets)# 4. 找出改进方向loss.backward()# 5. 调整模型参数optimizer.step()

3.进度监控

if epoch % 10 == 0:print(f"Epoch {epoch}: Loss={loss:.2f}, Perplexity={torch.exp(loss):.0f}")print(generate_text("Once upon a time"))  # 看生成效果
概念作用
预训练让模型学会语言规律
交叉熵损失衡量预测错误程度
困惑度直观评估模型水平
批次训练分段处理大数据

 训练大型语言模型(LLM)

核心训练函数解析

def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,eval_freq, eval_iter, start_context):# ...训练循环...

学习循环(核心)

让大模型反复学习:

for epoch in range(10):  # 学10遍数据for batch in train_loader:  # 每次做一道题optimizer.zero_grad()   # 清空上次错误loss = calc_loss_batch()  # 计算错误程度loss.backward()         # 分析错在哪optimizer.step()        # 改正错误

定期评估:

if global_step % 5 == 0:  # 每5步train_loss, val_loss = evaluate_model()  # 评估print(f"训练错误率: {train_loss:.3f}, 考试错误率: {val_loss:.3f}")

展示:

generate_and_print_sample(model, start_context="Every effort moves you"  # 前文提示
)
# 输出:Every effort moves you,. (初期) → 后期变成完整句子

关键问题:为什么过拟合?

问题解决方法
训练数据太少用更大数据集
训练时间太长早停机制
模型太复杂简化模型或正则化

 解码策略

核心目标:控制文本生成的“随机性”或“多样性”

最简单的文本生成方法 generate_text_simple

  1. 模型根据当前输入序列预测下一个词(标记)的概率分布。

  2. 总是选择概率最高的那个词 (torch.argmax)。

问题是什么?

  • 太死板: 同一个开头,模型每次生成的文本一模一样

  • 缺乏多样性: 没有惊喜,没有创意,不适合写故事、对话等需要变化的场景。

解决方案:引入“采样”和“控制”
我们不总是选最确定的那一个,而是根据概率分布随机选一个,但要用技巧控制这个随机过程,让它既有趣又合理。

1. 温度缩放 (Temperature Scaling):调节概率分布的“软硬”程度

原理: 在计算 Softmax 概率 之前,把模型的原始输出分数 (logits) 除以一个参数 T (温度)。

        probs = torch.softmax(logits / T, dim=-1)

温度 (T) 的作用:

  • T = 1: 原始概率分布。这是我们通常使用的“正常”温度。
  • T > 1 (高温): 加热!让概率分布变得更平滑、均匀

        效果: 高概率的词优势变小,低概率的词机会变大。生成文本更随机、更丰富、更有创意,但也可能包含更多错误或不合理内容。

        比喻: 像把冰块融化成了水,各种可能性更容易流动混合。

  • T < 1 (低温): 降温!让概率分布变得更尖锐、集中

        效果: 高概率的词优势更大,低概率的词几乎没机会。生成文本更保守、更可预测、更接近“总是选最高概率”的模式,多样性降低。

        比喻: 像把水冻成了冰,只有最坚固(概率最高)的结构能存在。

  • T = 0: 极端低温!直接退化成 torch.argmax总是选概率最高的词,完全没有随机性。
if temperature > 0.0:logits = logits / temperature  # 应用温度缩放probs = torch.softmax(logits, dim=-1)  # 计算新概率next_token_id = torch.multinomial(probs, num_samples=1)  # 根据新概率分布随机采样一个词
else: # T=0, 退化成贪心next_token_id = torch.argmax(logits, dim=-1, keepdim=True)

2.Top-K 采样 (Top-K Sampling):限制候选池的大小 

  • 为什么需要它? 高温 (T > 1) 虽然能增加多样性,但也可能让一些概率极低、完全不合适的词被偶然选中(比如输入是“Every effort moves you”,输出可能抽到“pizza”)。

  • 原理: 在应用温度缩放之前(或之后),我们只考虑概率最高的前 K 个词,其他词的概率直接设为零(或负无穷),然后再重新计算概率分布并进行采样。

  • 步骤:

    1. 获取原始 logits。

    2. 找出 logits 值最大的前 K 个 (torch.topk(logits, k)

    3. 创建一个新的 logits 张量:只有这 Top-K 个词保留原值,其他词的值都设成一个非常小的数(比如 -float('inf'))

top_logits, _ = torch.topk(logits, top_k)  # 找到前K个最大的logits
min_topk_logit = top_logits[:, -1]  # 第K大的logit值,作为阈值
# 创建新logits:小于阈值的都变成负无穷,大于等于阈值的保留原值
new_logits = torch.where(logits < min_topk_logit, torch.tensor(float('-inf')).to(logits.device), logits)

    #对这个新的 logits 张量应用温度缩放 (如果需要) 和 Softmax,得到新的概率分布。

    #从这个新的、只包含 Top-K 词 的概率分布中采样 (torch.multinomial)。

    Top-K 采样效果:保证了采样只在相对合理的词中进行。与高温 (T > 1结合使用效果最佳:在合理的范围内增加多样性,大大降低了生成“pizza”这种荒谬词的概率。

    (K 的选择很重要:K 大 (如 50, 100):候选池大,多样性高。K 小 (如 10, 20):候选池小,文本更保守、更集中。K = 1:退化成 torch.argmax (贪心)。)

    升级版的 generate 函数,加入温度和 Top-K 控制

    def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):for _ in range(max_new_tokens):idx_cond = idx[:, -context_size:]  # 取最后 context_size 个词作为当前上下文with torch.no_grad():logits = model(idx_cond)  # 模型预测logits = logits[:, -1, :]     # 只取最后一个时间步的logits (预测下一个词)# 1. 应用 Top-K 采样 (如果指定了 top_k)if top_k is not None:top_logits, _ = torch.topk(logits, top_k)min_topk_val = top_logits[:, -1]  # 当前批次中每个样本的第K大logit值logits = torch.where(logits < min_topk_val[:, None], # 比较并屏蔽torch.tensor(float('-inf')).to(logits.device),logits)# 2. 应用温度缩放 (如果 temperature > 0)if temperature > 0.0:logits = logits / temperatureprobs = torch.softmax(logits, dim=-1)next_token_id = torch.multinomial(probs, num_samples=1)else:  # temperature=0, 退化成贪心搜索next_token_id = torch.argmax(logits, dim=-1, keepdim=True)# 将新生成的词添加到序列中idx = torch.cat((idx, next_token_id), dim=1)return idx

    使用示例

    torch.manual_seed(123)  # 设置随机种子保证结果可复现
    input_text = "Every effort moves you"
    token_ids = text_to_token_ids(input_text, tokenizer)output_ids = generate(model=model,idx=token_ids,max_new_tokens=20,    # 生成20个新词context_size=1024,     # 模型能看的上下文长度temperature=1.5,      # 用较高的温度增加多样性top_k=50              # 只在概率最高的前50个词里采样
    )output_text = token_ids_to_text(output_ids, tokenizer)
    print(output_text)
    # 示例输出: Every effort moves you know began to my surprise to the end it was such a laugh that there: "sweet of an
    1. 为什么需要解码策略? 打破贪心搜索 (argmax) 的单调性,增加文本生成的多样性趣味性

    2. 温度缩放 (T):

      • T > 1增加随机性,概率分布更平滑。适合创意写作。

      • T < 1减少随机性,概率分布更尖锐。适合需要准确、保守的场合。

      • T = 0退化成贪心搜索,总是选概率最高的词。

    3. Top-K 采样 (K):

      • 限制采样只在概率最高的前 K 个词中进行。

      • 防止高温时选择到完全不合适的低概率词。

      • K 控制候选池大小,平衡多样性合理性

    http://www.xdnf.cn/news/13475.html

    相关文章:

  • HTML盒子模型
  • 1.一起学习仓颉-编译环境,ide,输出hello,world
  • GitLab Web 界面创建分支后pathspec ... did not match any file(s)
  • MNIST数据集上朴素贝叶斯分类器(MATLAB)
  • 扁平表+递归拼树思想
  • cf2117E
  • 【Pandas】pandas DataFrame interpolate
  • echarts 数据大屏(无UI设计 极简洁版)
  • [2025CVPR]DeepVideo-R1:基于难度感知回归GRPO的视频强化微调框架详解
  • 黄晓军所长:造血干细胞移植后晚期效应及患者健康相关生存质量
  • SQL进阶之旅 Day 23:事务隔离级别与性能优化
  • CentOS 安装Python 3教程
  • 38 C 语言字符串搜索与分割函数详解:strchr、strrchr、strpbrk、strstr、strcspn、strtok
  • 现代汽车在巴黎和得克萨斯州宣传其混合动力汽车为「两全其美之选」
  • CppCon 2015 学习:Extreme Type Safety with Opaque Typedefs
  • 从走线到互连:优化高速信号路径设计的快速指南
  • vue 监听页面滚动
  • carla与ros坐标变换
  • iOS 抖音首页头部滑动标签的实现
  • 【DAY45】 Tensorboard使用介绍
  • 《高等数学》(同济大学·第7版)第三章第五节“函数的极值与最大值最小值“
  • github.com 链接127.0.0.1
  • 征程 6E/M|如何解决量化部署时 mul 与 bool 类型数据交互的问题
  • 《为什么 String 是 final 的?Java 字符串池机制全面解析》
  • MySql简述
  • 基于GeoTools求解GeoTIFF的最大最小值方法
  • 搞了两天的win7批处理脚本问题
  • SaaS(软件即服务)和 PaaS(平台即服务)的定义及区别(服务对象不同、管理责任边界、典型应用场景)
  • GO自带日志库log包解释
  • 【二】12.关于中断