LongT5: 针对长序列的高效文本到文本Transformer
摘要
最近的研究表明,增加输入长度或增大模型规模均能提升基于Transformer架构的神经模型性能。本文中,我们介绍了LongT5,这是一个同时探索输入长度与模型规模扩展效应新模型。具体而言,我们将长输入Transformer(如ETC)中的注意力机制理念与摘要预训练(如PEGASUS)策略相结合,融入可扩展的T5架构之中。由此诞生了一种新型注意力机制——瞬态全局(TGlobal)注意力,它借鉴了ETC的局部/全局注意力机制,但无需额外的辅助输入。我们在多项摘要生成和问答任务上实现了最先进的成果,并在这些任务上超越了原始T5模型的表现。我们已开源了模型架构、训练代码及预训练模型检查点。
1 引言
诸如BERT(Devlin等人,2019)及其变体(Liu等人,2019;Radford等人,2019;Raffel等人,2019a;Lewis等人,2020)的Transformer模型,在许多具有挑战性的自然语言处理任务中已取得顶尖成绩。此外,近期关于长输入Transformer的研究(Ainslie等人,2020;Zaheer等人,2020b;Beltagy等人,2020;Tay等人,2021)显示,提升Transformer能处理的输入长度可带来进一步的性能提升。同时,增大模型规模也被证实能在多项任务中促进性能提升(Kaplan等人,2020)。
本文中,我们提出了一种名为LongT5的新模型,旨在同时探究输入长度与模型规模扩展的影响。为此,我们将长输入Transformer的注意力机制和预训练理念整合进可扩展的T5(Raffel等人,2019a)模型架构中。如图1所示,所得模型在处理长序列输入需求的多个任务上达到了业界领先水平。
在注意力机制方面,我们设计了一种新型注意力机制,称为瞬态全局(TGlobal)注意力,它借鉴了ETC的局部/全局机制(Ainslie等人,2020年)。重要的是,TGlobal注意力无需ETC中额外的辅助输入,从而适应T5架构。ETC局部/全局机制的核心思想是在注意力机制中引入局部稀疏性,以降低处理长输入时的二次成本。具体而言,ETC仅允许输入中的token(称为长输入)关注其局部邻域,并通过一个称为全局记忆的辅助输入,使长输入中的token能够间接相互关注。此机制的一个缺点是,每个新问题都需要设计这一辅助全局输入。为了使其适应T5,我们的新TGlobal机制在每一注意力层动态合成这些全局token(作为输入中token组的聚合)。实验表明,该机制在相同输入长度下与完全注意力相比仅有小幅性能下降,但使模型能够扩展到更长的输入长度,带来显著的性能提升。
在预训练方面,我们采用了PEGASUS模型(Zhang等人,2019a)的预训练策略。该策略最初为抽象摘要设计,但在我们的实验中,我们发现它也能提升问答等其他任务的模型性能,因此我们将其应用于LongT5。其核心思想是从文档中掩码关键(主要)句子,并要求模型将其复现为单一字符串,如同生成摘要一般。
我们在多个摘要生成和问答任务上评估了LongT5(详见第4.2.1节和第4.3.1节对这些数据集的详细描述)。得益于输入长度和模型规模的扩展,我们在许多任务上取得了最先进的成果。
本研究的主要贡献包括:
- 一种新的Transformer架构——LongT5,能够同时扩展输入长度和模型规模。
- 一种新的注意力机制(TGlobal),模仿ETC的局部/全局机制,但可作为现有Transformer架构(如T5)中常规注意力的直接替代。
- 分析了普通T5和LongT5模型在输入长度和模型规模变化时的性能表现(将两者推至遇到内存问题前的最大长度),以理解性能与计算成本之间的权衡。
- 在arXiv、PubMed、BigPatent、MediaSum和TriviaQA数据集上取得了最先进的成果。对于Natural Questions,我们使用了与原始任务略有不同的表述,因此未宣称达到最先进水平。
- 我们开源了模型架构、训练代码及预训练模型检查点。
2. T5
T5(Raffel等人,2019a)是一种基于Transformer的文本到文本预训练语言模型,因其将基于文本的语言问题统一转换为文本到文本格式的框架及易于通过模型并行扩展参数数量(从60M到11B)而广受欢迎。凭借完全注意力Transformer,T5已成功应用于许多NLP任务,但这些任务仅需较短的输入序列。这是由于输入序列长度的二次计算增长导致内存消耗增加和训练时间延长。最近,Press等人(2021年)探索了在推理时扩展T5风格模型至比训练时更长的序列,但在训练期间如何扩展T5风格模型的输入序列长度仍待深入研究。
3. LongT5
3.1 LongT5架构:
我们扩展了原始T5编码器,采用全局-局部注意力稀疏模式(Ainslie等人,2020年;Zaheer等人,2020a)以处理长输入。对于本文报告的工作,我们使用了标准T5解码器,因为所有考虑的任务均需相对较短的输出序列长度。在架构上,T5与LongT5的主要区别在于注意力机制。我们为LongT5实验了两种注意力机制变体,如图2所示:(1)局部注意力和(2)瞬态全局注意力(TGlobal)。两种变体均保留了T5的多个特性:相对位置表示、支持示例打包及与T5检查点的兼容性。
3.1.1 局部注意力:
对于局部注意力,我们简单地将T5中的编码器自注意力操作替换为稀疏滑动窗口局部注意力操作,遵循ETC(Ainslie等人,2020年)的实现。具体而言,对于给定的局部半径r,此公式仅允许每个token关注其左右各r个token(见图2.a)。我们发现r=127在实践中已足够,其中rrr为左右邻域token的数量。局部注意力未引入任何新参数,并轻松适应示例打包所需的注意力掩码。对于给定的r,复杂度与输入序列长度l呈线性关系:O(l×r)O(l × r)O(l×r)。
3.1.2 瞬态全局注意力
为了使输入token在编码器的每一层中能够以比局部注意力的局部半径更长的范围相互交互,我们引入了瞬态全局注意力,作为对ETC全局-局部注意力机制的改进,采用“固定块”模式。具体而言,我们将输入序列划分为大小为k的块,并为每个块计算一个全局token,通过对块内每个token的嵌入进行求和(然后归一化)得到(见图2.b)。在计算注意力时,我们允许每个输入token不仅像局部注意力那样关注附近的token,还可以关注每个全局token。我们称这些全局token为“瞬态”的,因为与ETC式的全局-局部注意力模式不同,这些token是在每次注意力操作中动态构建(随后丢弃)的,无需决定哪些输入token应被视为“全局”。
TGlobal注意力仅引入了少量新参数:(1)T5风格的相对位置偏置,表示输入token所在块与每个全局token所在块之间的距离;(2)T5风格的层归一化参数,用于归一化每个全局token的嵌入。其余参数与T5相同,我们通过额外掩码输入token对其他示例的全局token的注意力来适应序列打包。我们发现块大小k=16在实践中已足够。因此,TGlobal注意力在局部注意力的基础上引入了l∗l/kl∗l/kl∗l/k个额外的注意力键值对(l个输入token,关注l/kl/kl/k个全局token;如图2.b中最右侧矩形所示),因此对于输入序列长度l,复杂度为O(l(r+l/k))O(l(r + l/k))O(l(r+l/k))。
3.2 PEGASUS 主要句子生成预训练
T5 采用 span corruption(片段掩码) 作为预训练目标,即:
- 连续的输入 token 片段 会被 掩码(mask token)替换,
- 然后模型学习 重建被掩码的 token。
虽然这种方法有效,但最新研究(Liu et al., 2019;Zhang et al., 2019b)表明:
- 精心设计的预测目标 能够 显著提升模型性能,
- 原因之一 是 让模型预测更具信息量的 token,可以 迫使模型学习更深层次的语义理解。
PEGASUS 的创新点
- 受此启发,我们探索了 屏蔽(masking)并生成(generating)文本中的主要句子(principle sentences)。
- 具体而言,我们采用 Zhang et al. (2019a) 提出的 “Gap Sentences Generation with Principle IndUniq” 策略,
- 该方法最初用于摘要任务的预训练(Summarization Pre-training)。
主要句子(Principle Sentences)的选择方法
按照 Zhang et al. (2019a) 的方法,我们进行以下操作:
- 根据 ROUGE-F1 分数(Lin, 2004),选取得分最高的 m 个句子(Principle Sentences)。
- 计算公式:
si=rouge(xi;D∖{xi};∀i)s_i = \text{rouge}(x_i; D \setminus \{x_i\}; \forall i) si=rouge(xi;D∖{xi};∀i)- 其中:
- iii 为 句子索引,
- DDD 为 文档中所有句子的集合,
- 每个句子 独立评分(Ind),
- 每个 n-gram 只计算一次(Uniq)。
- 其中:
4. 实验
4.1 配置
- LongT5 采用 JAX 和 Flaxformer 库实现。
- 与 T5.1.1 相同的设置,我们测试 3 种模型规模:
- Base 版(约 220M 参数)
- Large 版(约 770M 参数)
- XL 版(约 3B 参数)
- 采用 T5.1.1 的同款英文 SentencePiece 词表(包含 32000 个 token)。
- 所有实验采用:
- 批量大小(Batch Size):128,
- 优化器(Optimizer):Adafactor。
解码策略(Decoding Strategy)
- 所有实验均采用贪心解码(Greedy Decoding),而非束搜索(Beam Search),
- 即使在测试集上,我们也使用贪心解码,
- 这样做的原因是:
- 保持与开发集(dev setup)一致,
- 如果使用束搜索,结果可能会进一步提升。
4.1.1 预训练
- 我们对 LongT5 模型 进行了 100 万步(1M steps) 的预训练。
- 采用 输入序列长度 4096,输出序列长度 910。
- 学习率调度 方式与 T5 相同,采用 反平方根衰减(inverse square-root learning rate schedule),即:
1/max(step,warmupsteps)1/\sqrt{max(step,warm_upsteps)} 1/max(step,warmupsteps)
warm-up steps 设定为 10000。 - 与 T5.1.1 相同,我们在 C4 数据集(Raffel et al., 2019b) 上进行预训练,不使用 dropout。
- 预训练目标(Pre-training Objective)
- 如 3.2 节 所述,我们采用 PEGASUS 主要句子生成(Principle Sentences Generation) 作为预训练目标。
- 其配置与 Zhang et al. (2019a) 用于更大模型的设定相似,
- 唯一不同:屏蔽句子比例(masked sentence ratio)设为 0.2,而非 0.458。
- 在 5.3 节 我们会进行消融实验(ablation study),对比 Principle Sentences Generation 与 Span Corruption 之间的效果差异。
4.1.2 微调
- 所有任务的微调 采用:
- 固定学习率(constant learning rate):0.001
- dropout 率(dropout rate):0.1
- 摘要任务(Summarization tasks)
- 输入长度 设定为 4096、8192、16384
- 输出长度 设定为 512
- 问答任务(QA tasks)
- 输入长度 从 512 开始,并扩展到 36864
- 输出长度 设定为 128
4.2 摘要任务评测
我们选择摘要任务作为模型基准测试(benchmark),因为 摘要任务需要处理长上下文(long context)并进行文本生成(generative nature)。
4.2.1 数据集
LongT5 在以下 6 个数据集 上进行测试:
-
CNN / Daily Mail(Nallapati et al., 2016)
- 输入:CNN 和 Daily Mail 的新闻报道
- 目标:新闻报道的摘要(summary bullets)
-
PubMed(Cohan et al., 2018)
- 输入:PubMed 上的科学文献
- 目标:对应文献的摘要(abstract)
-
arXiv(Cohan et al., 2018)
- 类似于 PubMed,但数据来源于 arXiv
-
BigPatent(Sharma et al., 2019)
- 输入:美国专利文档
- 目标:对应专利的摘要
-
MediaSum(Zhu et al., 2021)
- 输入:CNN 和 NPR 的访谈记录
- 目标:对应访谈的主题和概要
-
Multi-News(Fabbri et al., 2019)
- 任务:对多个相关新闻文档进行总结,生成 人工撰写的摘要(human-written summary)
数据集统计信息(Table 1 数据集统计)
- 数据集包含的样本数(train/validation/test splits)
- 输入序列长度的均值、最大值、中位数和 90% 分位数
由于这些数据集的 输入文本较长,因此 需要能够处理长文本的模型。
- 我们纳入 CNN / Daily Mail 作为基准任务(benchmark task),主要目的是:
- 测试 TGlobal 注意力(TGlobal attention)对模型的影响,
- 尽管 CNN / Daily Mail 的输入长度比其他数据集短,但仍然能提供重要的参考。
4.2.2 结果
我们将 LongT5 与多个 最先进(top-performing) 方法进行对比,包括:
- BigBird-PEGASUS(Zaheer et al., 2020b)
- HAT-BART(Rohde et al., 2021)
- DANCER PEGASUS(Gidiotis and Tsoumakas, 2020)
- PRIMER(Xiao et al., 2021)
- TG-MultiSum(Cui and Hu, 2021)
- LED(Beltagy et al., 2020)
- BART(Zhu et al., 2021)应用版本
在这些比较中,我们采用了 常见的评估指标(common evaluation metrics):
- ROUGE-1
- ROUGE-2
- ROUGE-L
温馨提示:
阅读全文请访问"AI深语解构" LongT5: 针对长序列的高效文本到文本Transformer