当前位置: 首页 > web >正文

7.4-Creating data loaders for an instruction dataset

Chapter 7-Fine-tuning to follow instructions

7.4-Creating data loaders for an instruction dataset

  • 我们只需将InstructionDataset对象和custom_collate_fn函数接入 PyTorch 数据加载器

  • 使用以下代码来初始化设备信息

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Note:
    # Uncommenting the following lines will allow the code to run on Apple Silicon chips, if applicable,
    # which is much faster than on an Apple CPU (as measured on an M3 MacBook Air).
    # However, the resulting loss values may be slightly different.#if torch.cuda.is_available():
    #    device = torch.device("cuda")
    #elif torch.backends.mps.is_available():
    #    device = torch.device("mps")
    #else:
    #    device = torch.device("cpu")print("Device:", device)"""输出"""
    Device: cuda
    

    custom_collate_fn函数中的device参数和allowed_max_length预先设定为变量device1024。这样在后续调用customized_collate_fn时,就不需要再手动传入这两个参数的值了。

    from functools import partialcustomized_collate_fn = partial(custom_collate_fn,device=device,allowed_max_length=1024
    )
    

    接下来,我们设置数据加载器,但是这次,我们将使用我们的自定义排序函数进行批处理过程。

    from torch.utils.data import DataLoadernum_workers = 0
    batch_size = 8torch.manual_seed(123)train_dataset = InstructionDataset(train_data, tokenizer)
    train_loader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=True,drop_last=True,num_workers=num_workers
    )val_dataset = InstructionDataset(val_data, tokenizer)
    val_loader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=False,drop_last=False,num_workers=num_workers
    )test_dataset = InstructionDataset(test_data, tokenizer)
    test_loader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=False,drop_last=False,num_workers=num_workers
    )
    

    让我们看看input 和target批次的维度是什么样的

    print("Train loader:")
    for inputs, targets in train_loader:print(inputs.shape, targets.shape)"""输出"""
    Train loader:
    torch.Size([8, 61]) torch.Size([8, 61])
    torch.Size([8, 76]) torch.Size([8, 76])
    ......
    torch.Size([8, 69]) torch.Size([8, 69])
    

    根据上面的输出,我们可以看到,所有批次的批次大小为8,但长度不同,第一个[8,61]表示,batchsize为8,在当前批次中,每个训练示例中的token数量为61。让我们通过打印“input”批处理中第一个训练示例的内容来仔细检查输入是否包含与tokenID 50256对应的“<|endoftext|>”填充token

    print(inputs[0])"""输出"""
    tensor([21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,257,  2882,   326, 20431, 32543,   262,  2581,    13,   198,   198,21017, 46486,    25,   198, 30003,  6525,   262,  6827,  1262,   257,985,   576,    13,   198,   198, 21017, 23412,    25,   198,   464,5156,   318,   845, 13779,    13,   198,   198, 21017, 18261,    25,198,   464,  5156,   318,   355, 13779,   355,   257,  4936,    13,50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],device='cuda:0')
    

    同样,我们仔细检查target是否包含-100占位符标记

    print(target[0])"""输出"""
    tensor([  318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,   257,2882,   326, 20431, 32543,   262,  2581,    13,   198,   198, 21017,46486,    25,   198, 30003,  6525,   262,  6827,  1262,   257,   985,576,    13,   198,   198, 21017, 23412,    25,   198,   464,  5156,318,   845, 13779,    13,   198,   198, 21017, 18261,    25,   198,464,  5156,   318,   355, 13779,   355,   257,  4936,    13, 50256,-100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],device='cuda:0')
    

http://www.xdnf.cn/news/10539.html

相关文章:

  • debian12.9或ubuntu,vagrant离线安装插件vagrant-libvirt,20250601
  • 第二章支线四 ·响应圣坛:媒体查询与移动适配
  • Efficient Combination of
  • FastAPI MCP 快速入门教程
  • c++第四章练习题
  • spining-lidar的电机和激光雷达体(lidar-imu)之间的标定
  • java servlet: context-path的作用
  • powershell7.5@.net环境@pwsh7.5在部分windows10系统下的运行问题
  • Java实现中文姓名转拼音生成用户信息并写入文件
  • Lighttpd CGI配置:404错误排查实录
  • Python 中的继承机制:从基础到高级应用
  • SRE 基础知识:在站点可靠性工程中可以期待什么
  • Bootstrap 5学习教程,从入门到精通,Bootstrap 5 入门简介(1)
  • 【js逆向_AES】全国二手房指数数据爬取
  • 《关于有序推动绿电直连发展有关事项的通知》核心内容
  • Flannel MAC地址冲突导致Pod 跨节点通信异常
  • 6.运算放大器—增益带宽积(六)
  • __STDC_VERSION__
  • 路由策略与路由控制
  • Linux系统配置网络优先级
  • Wavelib 库的核心属性、方法
  • Sa-Token 同端登录冲突检测实战
  • 箱式不确定集
  • Baklib加速企业AI数据治理实践
  • AtCoder Beginner Contest 399题目翻译
  • 前端面经 响应式布局
  • 2023ICPC杭州题解
  • 文档核心结构优化(程序C++...)
  • TensorFlow深度学习实战(19)——受限玻尔兹曼机
  • seq2seq 视频截图