当前位置: 首页 > ds >正文

自回归建模模型(AR)

参考网站:什么是自回归模型 | IBM

基本说明:

AR是一种强大的最常用于时间序列分析和预测的机器学习技术,使用时间序列先前时间步长的一个或者多个值来创建回归模型。

用同一变数例如x的之前各期,亦即x1至xt-1来预测本期xt的表现,并假设它们为一线性关系。因为这是从回归分析中的线性回归发展而来,只是不用x预测y,而是用x预测 x(自己);所以叫做自回归。


   自回归模型假设当前时刻的数据仅依赖于历史时刻的数据,通过条件概率分解序列的联合分布:                             ​​​​​​​        ​​​​​​​        ​​​​​​​  

其中: c是常数项;被假设为平均数等于0,标准差等于的随机误差值;被假设为对于任何的t都不变。文字叙述为:X的期望值等于一个或数个落后期的线性组合,加常数项,加随机误差。


生成过程

  1. 逐步预测:每次基于已生成的部分序列预测下一个元素(如GPT通过上文预测下一个词)。

  2. 迭代采样:通过随机采样(如从softmax分布中采样)或贪婪搜索生成新元素,并将新元素反馈到模型中以继续生成后续内容。

  3. 典型架构:Transformer的解码器(如GPT)或因果卷积网络(如WaveNet),通过掩码机制确保仅依赖历史信息。


实际应用中的主要限制

  1. 计算效率问题

    • 序列长度限制:生成长度为$N$的序列需$N$次前向计算,导致延迟高(如长文本生成)。

    • 内存瓶颈:Transformer的注意力机制内存消耗随序列长度平方增长($O(N^2)$)。

  2. 长程依赖建模困难

    • 尽管Transformer优于RNN,但远距离依赖仍可能因注意力权重分散或梯度消失而失效(如生成连贯的长文档)。

  3. 误差累积与暴露偏差

    • 训练-测试不一致:训练时使用真实历史数据(Teacher Forcing),而测试时依赖模型自身生成的历史,错误会逐步累积(Exposure Bias)。

    • 模式坍塌:倾向于生成高频但低多样性的内容(如重复短语)。

  4. 可控生成挑战

    难以精确控制生成内容的属性(如情感、风格),需额外引入约束或后处理。

改进技术手段

  1. 效率优化

    • 稀疏注意力:如Longformer的局部+全局注意力、Reformer的局部敏感哈希(LSH)注意力,将复杂度降至$O(N\log N)$。

    • 分块生成:将序列分段处理(如Image Transformer对图像分块)。

    • 模型蒸馏:训练小型化模型(如DistilGPT-2)保持性能的同时减少计算量。

  2. 长序列建模改进

    • 记忆机制:如Transformer-XL通过循环记忆模块保留跨段信息。

    • 递归结构:将Transformer与RNN结合(如Compressive Transformer)增强长程记忆。

  3. 缓解误差累积

    • 计划采样(Scheduled Sampling):逐步混合训练时的真实输入与模型生成输入。

    • 强化学习:通过策略梯度(如RLHF)直接优化生成序列的整体质量。

  4. 可控生成技术

    • 条件控制:在输入中嵌入控制信号(如CTRL模型的领域控制前缀)。

    • 解码约束:束搜索(Beam Search)中引入禁止重复n-gram等规则。

    • 能量模型:如GeDi通过辅助模型引导生成方向。
    • 并行化生成

      • 非自回归模型(NAR):如Mask-Predict通过迭代掩码预测实现并行解码(牺牲部分质量换取速度)。

      • 半自回归:部分步骤并行化(如Blockwise Parallel Decoding)。


自回归和回归区别:

特性自回归模型 (AR)非自回归模型 (NAR)
生成方式逐步生成,严格顺序依赖并行生成,一步预测所有位置
速度慢(需$O(N)$次前向计算)快(仅需$O(1)$次前向计算)
质量高质量,上下文连贯可能因独立性假设降低连贯性
训练目标最大化似然$P(x_t|x_{<t})$直接建模$P(x_{1:T}|c)$(c为条件)
典型模型GPT、Transformer-DecoderBART、T5、Masked-LM
应用场景文本生成、音乐生成机器翻译、文本摘要(需快速场景)

# 自回归生成(顺序)
for t in range(T):x_t = model(x_<t)  # 依赖历史# 非自回归生成(并行)
x_1:T = model(c)       # 直接输出全部序列

应用场景

1. 自回归模型
  • 自然语言生成:GPT-3的故事创作、ChatGPT的对话生成。

  • 时间序列预测:股票价格预测(ARIMA)、天气建模。

  • 语音合成:WaveNet生成逼真语音波形。

  • 代码生成:GitHub Copilot的代码补全。

2. 非自回归模型
  • 机器翻译:Google的NAT(Non-Autoregressive Translation)。

  • 文本摘要:快速生成摘要(如BART的并行解码)。

  • 图像生成:部分扩散模型的并行去噪步骤。


代码示例:

import torch
import torch.nn as nnclass ARModel(nn.Module):def __init__(self, vocab_size, hidden_size):super().__init__()self.embed = nn.Embedding(vocab_size, hidden_size)self.rnn = nn.LSTM(hidden_size, hidden_size)self.head = nn.Linear(hidden_size, vocab_size)def forward(self, x):# x: [seq_len, batch_size]x = self.embed(x)  # [seq_len, batch_size, hidden_size]outputs, _ = self.rnn(x)return self.head(outputs)  # [seq_len, batch_size, vocab_size]# 生成示例(贪婪搜索)
def generate_ar(model, start_token, max_len):tokens = [start_token]for _ in range(max_len):logits = model(torch.tensor([tokens[-1]]))  # 预测下一步next_token = logits.argmax(-1).item()       # 贪婪选择tokens.append(next_token)return tokens

http://www.xdnf.cn/news/9687.html

相关文章:

  • C++进阶--C++11(03)
  • 一种字典树的Python实现
  • 什么是数字化转型,如何系统性重构业务逻辑
  • Android 构建系统中常见的 .mk 文件及其作用
  • 涨薪技术|0到1学会性能测试第88课-Web_service_call函数
  • Spring AI Alibaba 发布企业级 MCP 分布式部署方案
  • LeetCode 169:多数元素 - 摩尔投票法的精妙解法
  • 【freertos-kernel】queue(发送)
  • # Python 语音助手本地的ollama实现
  • nt!MmMapViewInSystemCache函数分析PointerPte的填充
  • AD/DA HAL库API
  • 内容中台的构建基础是什么?
  • King3399(ubuntu文件系统)iic(i2c)功能测试
  • MP4视频文件播放Demo(附源码)
  • 头歌之动手学人工智能-Pytorch 之autograd
  • 算法 Arrays.sort()函数自定义排序(Comparator 接口)
  • [网页五子棋][匹配模块]服务器开发、用户管理器(创建匹配请求/响应对象、处理连接成功、处理下线)
  • 根据jvm源码剖析类加载机制
  • Python爬虫实战:研究Tornado框架相关技术
  • [Vue组件]半环进度显示器
  • 小猴子摆玩具
  • 计算机网络第一章计算机网络概述(竟成)
  • 小白成长之路-Linux操作系统-进程管理
  • 【机器人编程基础】python中的常用数据类型
  • ElasticSearch查询指定时间内出现的次数/2秒内出现的次数
  • 我们来学mysql -- 输出一份“数据备份还原”sh脚本
  • 手写字魔法消除1:数据集说明(含下载链接)
  • Kruskal算法剖析与py/cpp/Java语言实现
  • linux中基础IO(上)
  • 浅谈 JavaScript 性能优化