当前位置: 首页 > web >正文

Reason-ModernColBERT论文速览:内存受限设置下深度对比学习批量大小的扩展

一、引言

论文《Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup》主要探讨了在内存受限环境下,如何通过梯度缓存技术扩大对比学习的批量大小。对比学习是一种有效的表示学习方法,它通过将相关的数据点在嵌入空间中拉近,不相关的数据点推远来学习数据的表示。尽管已有研究表明,使用大量负样本的批量对比损失能够提升表示学习的质量,但这种方法需要将整个批量数据及其激活值存储在GPU内存中,限制了批量大小。

二、相关工作

论文回顾了对比学习和深度网络内存优化的相关研究。对比学习最初用于概率语言模型,后来被Word2Vec用于学习词嵌入。最近的研究则利用对比学习进行无监督预训练和监督训练密集检索器。深度网络内存优化方面,已有技术包括梯度检查点方法和可逆激活函数等,但这些方法在对比学习中的有效性尚未得到确认。

三、方法论

论文提出了梯度缓存技术,该技术通过将对比损失的反向传播过程分为两部分:从损失到表示,以及从表示到模型参数。通过预先计算表示的梯度并存储在缓存中,可以打破编码器参数更新中的数据依赖性,从而实现将大批次的梯度更新分解为多个小子批次更新。

3.1 预备知识

论文定义了对比损失的数学表达式,并指出每个求和项依赖于整个批次,这要求所有数据都必须适应内存。

3.2 计算分析

论文分析了对比损失的计算及其梯度,指出反向传播可以分为两部分:从损失到表示,以及从表示到模型参数。通过观察到,给定前一部分,后一部分在批量样本之间是独立的。

3.3 梯度缓存技术

论文详细描述了梯度缓存技术的四个步骤:

  1. 无图前向:为每个批次实例运行额外的编码器前向传递以获取表示,不构建计算图。

  2. 表示梯度计算和缓存:基于步骤1中的表示计算对比损失,并构造相应的计算图,然后运行反向传递以填充每个表示的梯度。

  3. 子批次梯度累积:逐个子批次运行编码器前向传递以计算表示,并构建相应的计算图,然后从缓存中取该子批次的表示梯度,并运行反向传播。

  4. 优化:处理完所有子批次后,执行优化器步骤以更新模型参数。

3.4 多GPU训练

论文讨论了多GPU训练时如何通过跨GPU通信来计算所有示例的梯度,使用all-gather操作使所有GPU都能获得所有表示。

四、实验

论文通过在密集段落检索器(DPR)中实现梯度缓存技术,并在Natural Question数据集上评估不同方法的顶级命中准确性,来验证该方法的有效性。实验结果表明,梯度缓存方法在单个消费级GPU上能够以约20%的运行时间增加,再现以往需要多个专业GPU训练的最新模型。

4.1 检索准确性

论文比较了DPR参考系统、顺序更新、梯度累积和梯度缓存系统的性能。梯度缓存方法在Top-5、Top-20和Top-100命中准确性方面均优于DPR参考系统,尤其是在将批量大小增加到512时,性能有所提升。

4.2 训练速度

论文比较了梯度缓存和累积方法的训练速度,发现梯度缓存方法可以稳定地扩展到更大的批量更新,并且只增加了20%的表示预计算时间。

五、扩展到深度距离函数

论文讨论了如何将梯度缓存技术扩展到深度距离函数,通过引入额外的距离梯度缓存来实现。

六、结论

论文总结了梯度缓存技术能够在资源受限的硬件上保持对比学习的准确性,并允许研究人员在不被GPU内存限制的情况下推进研究。

核心技术汇总表在这里插入图片描述

http://www.xdnf.cn/news/8720.html

相关文章:

  • IDA插件 MIPSROP的安装和使用方法
  • 电子人的分水岭-FPGA模电和数电
  • 大模型智能体入门扫盲——基于camel的概述
  • 嵌入式<style>设计模式
  • DeepSeek 赋能数字农业:从智慧种植到产业升级的全链条革新
  • 可编程运动控制器行业2025数据分析报告
  • CodeBuddy实现图片水印添加工具
  • Ntfs!ReadIndexBuffer函数分析之根目录读取索引缓冲区的一个例子
  • STM32 USART串口通信
  • Nginx-详解(二)
  • SOC-ESP32S3部分:11-任务创建
  • 事务处理与事务隔离
  • uni-app(5):Vue3语法基础上
  • Eigen 直线拟合/曲线拟合/圆拟合/椭圆拟合
  • Kotlin MultiPlatform 跨平台版本的记账 App
  • 39-居住证管理系统(小程序)
  • NRM:快速切换 npm 镜像源的管理工具指南
  • C/C++---隐式显式转换
  • World of Warcraft [CLASSIC] 80 Hunter [Grandel] VS Onyxia
  • 什么是深度学习中的层次分类问题?
  • C++静态成员变量
  • 使用 AWK 分析 CSV 文件中的数据模式
  • C++ --- string
  • 【MPC控制 - 从ACC到自动驾驶】车辆纵向动力学建模与离散化:MPC的“数字蓝图”
  • JavaScripts 中parseInt的作用
  • uniapp-商城-67-shop(3-品牌信息显示,弹窗显示完整品牌信息,弹窗拨打电话、地图定位)
  • CMSIS-NN:2.神经网络到CMSIS-NN的转换
  • 基于亚博K210开发板——lvgl 图形化实验
  • AI 笔记 - 模型优化 - 注意力机制在目标检测上的使用
  • Nat Rev Genet | 如果DNA序列能“说话”?深度学习S2E(序列2表达)模型正在听懂基因组的调控秘密!