当前位置: 首页 > ops >正文

7.3 Organizing data into training batches

Chapter 7-Fine-tuning to follow instructions

7.3 Organizing data into training batches

  • 下一步是构建训练批次

    定义一种方法,以确保我们的模型在微调过程中接收格式化的训练数据。如下图所示,我们通过几个步骤处理此数据集批处理。

  • 首先,我们实现一个“InstructionDataset”类,该类对数据集中的所有输入进行预标记,类似于第6章中的“SpamDataset”

    import torch
    from torch.utils.data import Datasetclass InstructionDataset(Dataset):def __init__(self, data, tokenizer):self.data = data# Pre-tokenize textsself.encoded_texts = []for entry in data:instruction_plus_input = format_input(entry)response_text = f"\n\n### Response:\n{entry['output']}"full_text = instruction_plus_input + response_textself.encoded_texts.append(tokenizer.encode(full_text))def __getitem__(self, index):return self.encoded_texts[index]def __len__(self):return len(self.data)
    

    与第 6 章类似,我们希望批量收集多个训练示例以加速训练;这需要将所有输入填充到相似的长度,使用 ‘<|endftext|>’ 标记作为填充标记。

    import tiktoken
    tokenizer = tiktoken.get_encoding("gpt2")print(tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}))"""输出"""
    [50256]
    
  • 我们通过开发一个可以传递给数据加载器的自定义排序函数来采用更复杂的方法。这个自定义排序函数将每个批次中的训练示例填充到相同的长度,同时允许不同批次具有不同的长度。这种方法 minimizesunnecessary 填充,只扩展序列以匹配每个批次中最长的一个,而不是整个数据集,如下图所示

    def custom_collate_draft_1(batch,pad_token_id=50256,device="cpu"
    ):# Find the longest sequence in the batch# and increase the max length by +1, which will add one extra# padding token belowbatch_max_length = max(len(item)+1 for item in batch)# Pad and prepare inputsinputs_lst = []for item in batch:new_item = item.copy()# Add an <|endoftext|> tokennew_item += [pad_token_id]# Pad sequences to batch_max_lengthpadded = (new_item + [pad_token_id] *(batch_max_length - len(new_item)))# Via padded[:-1], we remove the extra padded token# that has been added via the +1 setting in batch_max_length# (the extra padding token will be relevant in later codes)inputs = torch.tensor(padded[:-1])inputs_lst.append(inputs)# Convert list of inputs to tensor and transfer to target deviceinputs_tensor = torch.stack(inputs_lst).to(device)return inputs_tensor
    

    我们实现的custom_collate_draft_1被设计为集成到PyTorch DataLoader中,但它也可以作为一个独立的工具。在这里,我们独立地使用它来测试和验证它是否按预期运行。让我们在三个不同的输入上尝试一下,我们希望将它们组装成一个批处理,其中每个示例都被填充到相同的长度

    inputs_1 = [0, 1, 2, 3, 4]
    inputs_2 = [5, 6]
    inputs_3 = [7, 8, 9]batch = (inputs_1,inputs_2,inputs_3
    )print(custom_collate_draft_1(batch))"""输出"""
    tensor([[    0,     1,     2,     3,     4],[    5,     6, 50256, 50256, 50256],[    7,     8,     9, 50256, 50256]])
    

  • 到目前为止我们只将输入返回给 LLM;但是,对于 LLM 训练,我们还需要目标值,与预先训练 LLM 类似,目标是向右移动 1 个位置的输入,因此 LLM 学会预测下一个令牌。

    以下更新的排序函数从输入token ID生成目标tokenID:

    def custom_collate_draft_2(batch,pad_token_id=50256,device="cpu"
    ):# Find the longest sequence in the batchbatch_max_length = max(len(item)+1 for item in batch)# Pad and prepare inputsinputs_lst, targets_lst = [], []for item in batch:new_item = item.copy()# Add an <|endoftext|> tokennew_item += [pad_token_id]# Pad sequences to max_lengthpadded = (new_item + [pad_token_id] *(batch_max_length - len(new_item)))inputs = torch.tensor(padded[:-1])  # Truncate the last token for inputstargets = torch.tensor(padded[1:])  # Shift +1 to the right for targetsinputs_lst.append(inputs)targets_lst.append(targets)# Convert list of inputs to tensor and transfer to target deviceinputs_tensor = torch.stack(inputs_lst).to(device)targets_tensor = torch.stack(targets_lst).to(device)return inputs_tensor, targets_tensor
    
    inputs, targets = custom_collate_draft_2(batch)
    print('inputs:\n', inputs)
    print('targets:\n', targets)"""输出"""
    inputs:tensor([[    0,     1,     2,     3,     4],[    5,     6, 50256, 50256, 50256],[    7,     8,     9, 50256, 50256]])targets:tensor([[    1,     2,     3,     4, 50256],[    6, 50256, 50256, 50256, 50256],[    8,     9, 50256, 50256, 50256]])
    
  • 接下来,我们引入一个ignore_index值,用一个新值替换所有填充tokenID;这个ignore_index的目的是我们可以忽略损失函数中的填充值(稍后会详细介绍)

    具体来说,这意味着我们将与’50256’对应的令牌ID替换为’-100’,如下所示

    此外,我们还引入了“允许的最大长度”(allowed_max_length)这一参数,以便在需要时对样本长度加以限制。如果您打算使用长度超过 GPT-2 模型所支持的 1024 个词元上下文大小的自有数据集,这个参数就会派上用场 。

    inputs, targets = custom_collate_fn(batch)
    print('inputs:\n', inputs)
    print('targets:\n', targets)"""输出"""
    inputs:tensor([[    0,     1,     2,     3,     4],[    5,     6, 50256, 50256, 50256],[    7,     8,     9, 50256, 50256]])targets:tensor([[    1,     2,     3,     4, 50256],[    6, 50256,  -100,  -100,  -100],[    8,     9, 50256,  -100,  -100]])
    
  • 修改后的排序函数按预期工作,通过插入令牌ID-100来更改目标列表。这种调整背后的逻辑是什么?让我们探索一下这种修改的潜在目的。

    为了说明的目的,让我们假设我们有一个小分类任务,有 2 个类标签,0 和 1,类似于第 6 章、如果我们有以下 logits 值(模型最后一层的输出),我们计算以下损失

    logits_1 = torch.tensor([[-1.0, 1.0],  # 1st training example[-0.5, 1.5]]  # 2nd training example
    )
    targets_1 = torch.tensor([0, 1])loss_1 = torch.nn.functional.cross_entropy(logits_1, targets_1)
    print(loss_1)"""输出"""
    tensor(1.1269)
    

    现在,如预期的那样,再添加一个训练示例将影响损失

    logits_2 = torch.tensor([[-1.0, 1.0],[-0.5, 1.5],[-0.5, 1.5]]  # New 3rd training example
    )
    targets_2 = torch.tensor([0, 1, 1])loss_2 = torch.nn.functional.cross_entropy(logits_2, targets_2)
    print(loss_2)"""输出"""
    tensor(0.7936)
    

    让我们看看如果我们将其中一个示例的类标签替换为-100会发生什么

    targets_3 = torch.tensor([0, 1, -100])loss_3 = torch.nn.functional.cross_entropy(logits_2, targets_3)
    print(loss_3)
    print("loss_1 == loss_3:", loss_1 == loss_3)"""输出"""
    tensor(1.1269)
    loss_1 == loss_3: tensor(True)
    

    我们可以看到,3个训练示例的结果损失与从2个训练示例计算的损失相同,这表明交叉熵损失函数忽略了带有 -100 标签的训练示例,默认情况下PyTorch有 cross_entropy(..., ignore_index=-100) 设置来忽略与标签 -100 对应的示例,利用 -100 ignore_index 能忽略用于将训练示例填充到等长的批次中额外的文本结束(填充)令牌,但我们不想忽略文本结束(填充)令牌(50256)的第一个实例,因其能在响应完成时向LLM发出信号 (说白了就是标注每句话的最后一个字符)。

    在实践中,掩盖与指令对应的目标token ID也很常见,如下图所示


7.4-Creating data loaders for an instruction dataset

http://www.xdnf.cn/news/9882.html

相关文章:

  • 易路 iBuilder:解构企业 AI 落地困境,重构智能体时代生产力范式
  • 顶刊SCS | 基于视觉语言大模型推理分割的建筑足迹尺度功能分类, 样本数据和代码已开源!
  • QNAP MEMOS 域名访问 SSL(Lucky)
  • 广州邮科高频开关电源:以创新科技赋能通信能源绿色未来
  • 工控机安装lubuntu系统
  • Med-R1论文阅读理解-1
  • 我的3种AI写作节奏搭配模型,适合不同类型写作者
  • 企业级Spring MVC高级主题与实用技术讲解
  • 互联网大厂Java求职面试:云原生微服务架构设计与AI大模型集成实战
  • 页面输入数据的表格字段(如 Web 表单或表格控件)与后台数据库进行交互时常用的两种方式
  • 第十三篇:MySQL 运维自动化与可观测性建设实践指南
  • 一句话开发Chrome摸鱼插件
  • @Docker Compose 部署 Pushgateway
  • Idea 配置 Maven 环境
  • YC-8002型综合变配电监控自动化系统
  • Pytorch Geometric官方例程pytorch_geometric/examples/link_pred.py环境安装教程及图数据集制作
  • MES管理系统:Java+Vue,含源码与文档,实现生产过程实时监控、调度与优化,提升制造企业效能
  • MySql(七)
  • 深入浅出:使用DeepSeek开发小程序的完整指南
  • Express教程【003】:Express获取查询参数
  • 软件测试|FIT故障注入测试工具——ISO 26262合规下的智能汽车安全验证引擎
  • 题目 3293: 蓝桥杯2024年第十五届决赛真题-数位翻转
  • 编程技能:格式化打印01,vsprintf 函数族简介
  • 相机--双目立体相机
  • iOS 集成网易云信IM
  • Edge浏览器怎样开启兼容模式
  • t014-项目申报管理系统 【springBoot 含源码】
  • 推荐3个优秀wordpress主题
  • Electron-vite【实战】MD 编辑器 -- 文件列表(含右键快捷菜单,重命名文件,删除本地文件,打开本地目录等)
  • 基于分布式状态机的集装箱智能道口软件架构方法