当前位置: 首页 > java >正文

Qwen2.5-VL - 多模态旋转位置嵌入(Multimodal Rotary Position Embedding, MRoPE)

Qwen2.5-VL - 多模态旋转位置嵌入(Multimodal Rotary Position Embedding, MRoPE)

flyfish

多模态旋转位置嵌入(Multimodal Rotary Position Embedding, MRoPE) 是 Qwen2-VL 及 Qwen2.5-VL 模型中用于处理多模态输入的关键技术,它通过扩展传统 RoPE(Rotary Position Embedding),实现了对文本、图像和视频等不同模态数据的统一位置编码。

Multimodal:多模态(指支持文本、图像、视频等多种数据类型)。
Rotary Position:旋转位置(源于旋转位置编码技术 RoPE,通过三角函数实现位置信息的编码)。
Embedding:嵌入(在深度学习中,指将位置信息映射为模型可处理的向量表示)。

传统RoPE在处理文本时只能理解一维的序列顺序,就像看一部没有时间轴的纪录片,面对视频和图像时完全无法感知时间流动和空间布局——比如两段帧率不同的视频,传统方法会把“10秒20帧”和“5秒20帧”都当作简单的帧编号序列,根本分不清动作快慢;面对图片时,也无法区分“左上角的猫”和“右下角的球”的空间位置。而MRoPE的核心创新在于给AI构建了一套“三维时空坐标系”:在时间维度上,它用second_per_grid_ts参数将视频帧绑定真实时间,比如10秒20帧的视频每帧代表0.5秒,时间ID按0.5、1.0秒递增,5秒20帧的视频每帧代表0.25秒,时间ID按0.25、0.5秒递增,让AI能感知事件的真实节奏;在空间维度上,它把图像/视频帧划分为类似棋盘格的网格,用h_indexw_index生成每个块的(高度,宽度)坐标,比如2×2网格中左上角块对应(0,0),右上角对应(0,1),使模型能识别空间位置;在多模态衔接上,它确保视觉和文本的位置ID连续,比如视觉部分ID到100,文本就从101开始,让AI理解文字与画面的对应关系。

Qwen2.5的MRoPE进一步实现了绝对时间对齐,不再依赖帧编号,而是按视频实际时长计算时间ID——比如3秒的视频,不管采样3帧还是6帧,第1秒对应ID=1,第2秒对应ID=2,时间ID间隔会根据采样帧自动调整(3帧间隔1,6帧间隔0.5),代码中的time_tensor = expanded_range * second_per_grid_t * 2里,second_per_grid_t就是每秒的时间ID增量,2是放大倍数,让AI更敏感捕捉时间差异。这种设计让MRoPE在实际应用中展现出强大能力:在视频问答中,能通过时间ID精准定位“球员射门在第5秒”;在图文文档解析时,给图片“左侧柱状图”标上(0,0)坐标,让文字“左侧”与之关联;在动态手势识别中,通过时间ID间隔区分“1秒1帧的缓慢挥手”和“1秒4帧的快速挥手”。

MRoPE就像AI的“时空翻译器”,将视频的时间先后、图像的上下左右、文字的段落顺序,全部转化为“时间-高度-宽度”的三维坐标语言,让多模态模型不仅能“看到”信息,还能理解信息间的时空逻辑——这就好比人类看电影时能同时把握剧情的时间线、画面的空间布局和台词的前后关联。

MRoPE算法

传统RoPE

1. 旋转操作的复数表示

对于位置 m m m处的向量 x m ∈ R d x_m \in \mathbb{R}^d xmRd,将其拆分为两个维度为 d / 2 d/2 d/2的子向量 x m ( 1 ) x_m^{(1)} xm(1) x m ( 2 ) x_m^{(2)} xm(2),RoPE的旋转操作可表示为:
RoPE ( x m , m ) = [ x m ( 1 ) cos ⁡ ( m θ ) − x m ( 2 ) sin ⁡ ( m θ ) x m ( 2 ) cos ⁡ ( m θ ) + x m ( 1 ) sin ⁡ ( m θ ) ] \text{RoPE}(x_m, m) = \begin{bmatrix} x_m^{(1)} \cos(m\theta) - x_m^{(2)} \sin(m\theta) \\ x_m^{(2)} \cos(m\theta) + x_m^{(1)} \sin(m\theta) \end{bmatrix} RoPE(xm,m)=[xm(1)cos(mθ)xm(2)sin(mθ)xm(2)cos(mθ)+xm(1)sin(mθ)]
其中, θ = { θ 1 , θ 2 , … , θ d / 2 } \theta = \{\theta_1, \theta_2, \ldots, \theta_{d/2}\} θ={θ1,θ2,,θd/2}是一组可学习的频率参数。

2. 点积形式

RoPE通过旋转操作保持了位置感知的点积性质:
RoPE ( q m , m ) ⋅ RoPE ( k n , n ) = q m ⋅ k n cos ⁡ ( ( m − n ) θ ) + ( q m ⋅ k ~ n ) sin ⁡ ( ( m − n ) θ ) \text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n) = q_m \cdot k_n \cos((m-n)\theta) + (q_m \cdot \tilde{k}_n) \sin((m-n)\theta) RoPE(qm,m)RoPE(kn,n)=qmkncos((mn)θ)+(qmk~n)sin((mn)θ)
其中, k ~ n \tilde{k}_n k~n k n k_n kn的特定排列。

MRoPE的三维扩展公式

1. 三维位置编码分解

MRoPE将位置信息分解为时间 t t t、高度 h h h、宽度 w w w三个维度的旋转操作:
MRoPE ( x , t , h , w ) = RoPE t ( x ) ⊕ RoPE h ( x ) ⊕ RoPE w ( x ) \text{MRoPE}(x, t, h, w) = \text{RoPE}_t(x) \oplus \text{RoPE}_h(x) \oplus \text{RoPE}_w(x) MRoPE(x,t,h,w)=RoPEt(x)RoPEh(x)RoPEw(x)
其中, ⊕ \oplus 表示三个维度的旋转操作的组合,通常通过张量拼接或加权求和实现。

2. 时间维度的绝对编码

Qwen2.5-VL引入绝对时间编码,将实际时间间隔映射为位置ID:
时间ID ( i ) = t 0 + t 1 − t 0 N ⋅ i ⋅ s \text{时间ID}(i) = t_0 + \frac{t_1 - t_0}{N} \cdot i \cdot s 时间ID(i)=t0+Nt1t0is
其中:

  • t 0 t_0 t0 t 1 t_1 t1为视频的起始和结束时间,
  • N N N为总帧数,
  • i i i为当前帧索引,
  • s s s为可学习的缩放因子(代码中对应second_per_grid_t * 2)。

三维位置ID的生成公式

1. 时间网格的位置ID计算

代码中的时间位置ID生成对应公式:
时间ID ( t ) = t ⋅ Δ t ⋅ s \text{时间ID}(t) = t \cdot \Delta t \cdot s 时间ID(t)=tΔts
其中:

  • t t t为时间块索引,
  • Δ t \Delta t Δt为每个时间块的秒数(second_per_grid_t),
  • s s s为缩放因子(代码中为2)。
2. 空间网格的位置ID计算

高度和宽度的位置ID通过网格索引生成:
高度ID ( h ) = h ( h = 0 , 1 , … , H − 1 ) \text{高度ID}(h) = h \quad (h = 0, 1, \ldots, H-1) 高度ID(h)=h(h=0,1,,H1)
宽度ID ( w ) = w ( w = 0 , 1 , … , W − 1 ) \text{宽度ID}(w) = w \quad (w = 0, 1, \ldots, W-1) 宽度ID(w)=w(w=0,1,,W1)
其中, H H H W W W分别为高度和宽度方向的网格数。

多模态融合的位置连续性公式

1. 视觉与文本的位置衔接

文本部分的起始位置ID为视觉部分的最大位置ID加1:
文本起始ID = max ⁡ ( 视觉时间ID , 视觉高度ID , 视觉宽度ID ) + 1 \text{文本起始ID} = \max(\text{视觉时间ID}, \text{视觉高度ID}, \text{视觉宽度ID}) + 1 文本起始ID=max(视觉时间ID,视觉高度ID,视觉宽度ID)+1

2. 整体位置ID序列

对于包含视觉和文本的混合序列,位置ID序列可表示为:
位置ID = [ 视觉ID 1 , 视觉ID 2 , … , 视觉ID M , 文本起始ID , 文本起始ID + 1 , … ] \text{位置ID} = [\text{视觉ID}_1, \text{视觉ID}_2, \ldots, \text{视觉ID}_M, \text{文本起始ID}, \text{文本起始ID}+1, \ldots] 位置ID=[视觉ID1,视觉ID2,,视觉IDM,文本起始ID,文本起始ID+1,]
其中, M M M为视觉token的数量。

基于大语言模型中的图像和视频的时间、高度和宽度维度,计算三维旋转位置编码索引。

原理说明:每个嵌入序列包含视觉嵌入和文本嵌入,或仅包含文本嵌入。对于纯文本嵌入序列,旋转位置嵌入与现代大语言模型相同。示例:input_ids: [T T T T T],这里T代表文本。时间位置ID: [0, 1, 2, 3, 4]高度位置ID: [0, 1, 2, 3, 4]宽度位置ID: [0, 1, 2, 3, 4]对于视觉和文本混合嵌入序列,我们为视觉部分计算三维旋转位置嵌入,为文本部分计算一维旋转位置嵌入。示例:时间维度(Temporal):3个时间块,表示视频在时间上的不同片段。高度维度(Height):2个高度块,垂直划分每一帧。宽度维度(Width):2个宽度块,水平划分每一帧。我们还有一些重要参数:fps(每秒帧数):视频的帧率,设为1。这意味着每秒处理一帧。tokens_per_second:这是一个关键参数。它决定了概念上一秒视频间隔内包含多少"时间步"或"时间token"。在这种情况下,我们每秒有25个token。因此,视频的每一秒将由25个不同的时间点表示。它本质上定义了时间粒度。temporal_patch_size:构成一个时间块的帧数。这里是2帧。interval:时间位置ID的步长,计算为tokens_per_second * temporal_patch_size / fps。在这种情况下,25 * 2 / 1 = 50。这意味着每个时间块的时间位置ID相差50。input_ids: [V V V V V V V V V V V V T T T T T],这里V代表视觉。视觉时间位置ID: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]视觉高度位置ID: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]视觉宽度位置ID: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]文本时间位置ID: [101, 102, 103, 104, 105]文本高度位置ID: [101, 102, 103, 104, 105]文本宽度位置ID: [101, 102, 103, 104, 105]这里我们将文本的起始位置ID计算为视觉最大位置ID加1。参数:input_ids (`torch.LongTensor`,形状为`(batch_size, sequence_length)`):输入序列在词汇表中的token索引。如果提供了注意力掩码,填充token将被忽略。image_grid_thw (`torch.LongTensor`,形状为`(num_images, 3)`,可选):大语言模型中每个图像特征的时间、高度和宽度维度。video_grid_thw (`torch.LongTensor`,形状为`(num_videos, 3)`,可选):大语言模型中每个视频特征的时间、高度和宽度维度。second_per_grid_ts (`torch.Tensor`,形状为`(num_videos)`,可选):3D位置ID中每个时间网格的时间间隔(以秒为单位)。attention_mask (`torch.Tensor`,形状为`(batch_size, sequence_length)`,可选):用于避免在填充token索引上执行注意力计算的掩码。掩码值选择为`[0, 1]`:- 1表示token**未被掩码**,- 0表示token**被掩码**。返回:position_ids (`torch.LongTensor`,形状为`(3, batch_size, sequence_length)`)mrope_position_deltas (`torch.Tensor`,形状为`(batch_size)`)
def get_rope_index_25(spatial_merge_size: Optional[int] = 2,input_ids: Optional[torch.LongTensor] = None,image_grid_thw: Optional[torch.LongTensor] = None,video_grid_thw: Optional[torch.LongTensor] = None,second_per_grid_ts: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:image_token_id = 151655video_token_id = 151656vision_start_token_id = 151652mrope_position_deltas = []if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):total_input_ids = input_idsif attention_mask is None:attention_mask = torch.ones_like(total_input_ids)position_ids = torch.ones(3,input_ids.shape[0],input_ids.shape[1],dtype=input_ids.dtype,device=input_ids.device,)image_index, video_index = 0, 0attention_mask = attention_mask.to(total_input_ids.device)for i, input_ids in enumerate(total_input_ids):input_ids = input_ids[attention_mask[i] == 1]image_nums, video_nums = 0, 0vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)vision_tokens = input_ids[vision_start_indices + 1]image_nums = (vision_tokens == image_token_id).sum()video_nums = (vision_tokens == video_token_id).sum()input_tokens = input_ids.tolist()llm_pos_ids_list: list = []st = 0remain_images, remain_videos = image_nums, video_numsfor _ in range(image_nums + video_nums):if image_token_id in input_tokens and remain_images > 0:ed_image = input_tokens.index(image_token_id, st)else:ed_image = len(input_tokens) + 1if video_token_id in input_tokens and remain_videos > 0:ed_video = input_tokens.index(video_token_id, st)else:ed_video = len(input_tokens) + 1if ed_image < ed_video:t, h, w = (image_grid_thw[image_index][0],image_grid_thw[image_index][1],image_grid_thw[image_index][2],)second_per_grid_t = 0image_index += 1remain_images -= 1ed = ed_imageelse:t, h, w = (video_grid_thw[video_index][0],video_grid_thw[video_index][1],video_grid_thw[video_index][2],)if second_per_grid_ts is not None:second_per_grid_t = second_per_grid_ts[video_index]else:second_per_grid_t = 1.0video_index += 1remain_videos -= 1ed = ed_videollm_grid_t, llm_grid_h, llm_grid_w = (t.item(),h.item() // spatial_merge_size,w.item() // spatial_merge_size,)text_len = ed - stst_idx = (llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0)llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)range_tensor = torch.arange(llm_grid_t).view(-1, 1)expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)time_tensor = expanded_range * second_per_grid_t * 2time_tensor_long = time_tensor.long()t_index = time_tensor_long.flatten()h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten())w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten())llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)st = ed + llm_grid_t * llm_grid_h * llm_grid_wif st < len(input_tokens):st_idx = (llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0)text_len = len(input_tokens) - stllm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)return position_ids, mrope_position_deltaselse:if attention_mask is not None:position_ids = attention_mask.long().cumsum(-1) - 1position_ids.masked_fill_(attention_mask == 0, 1)position_ids = (position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device))max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]else:position_ids = (torch.arange(input_ids.shape[1], device=input_ids.device).view(1, 1, -1).expand(3, input_ids.shape[0], -1))mrope_position_deltas = torch.zeros([input_ids.shape[0], 1],device=input_ids.device,dtype=input_ids.dtype,)return position_ids, mrope_position_deltas
http://www.xdnf.cn/news/12068.html

相关文章:

  • 计算机操作系统知识点总结②
  • 天机学堂(我的课表)
  • winform下DevExpress中datagridview中数据批量保存不上
  • 【python深度学习】Day 44 预训练模型
  • 安装 Nginx
  • 一则systemctl service诡异问题
  • GAN模式崩塌难题:成因分析与多维度解决方案
  • stripe支付测试,ngrok无法使用?免费vscode端口转发,轻松简单!
  • 第八部分:第四节 - 列表渲染与条件渲染:根据订单显示不同内容
  • [Java 基础]类,面向对象的蓝图
  • Windows 下载、安装、配置和使用Node
  • BUU MISC(持续更新)
  • Java 中实现线程的创建和启动
  • [ACM MM 2024]Lite-Mind:Towards Efficient and Robust Brain Representation
  • MySQL对数据库用户的操作
  • VS Code开发项目,配置ESlint自动修复脚本
  • 高防CDN有用吗?它的防护效果怎么样?
  • 1. 数据库基础
  • 卫星的“太空陀螺”:反作用轮如何精准控制姿态?
  • 蓝桥云课ROS一键配置teb教程更新-250604
  • 嵌入式就业难不难?
  • 【趣味Html】第11课:动态闪烁发光粒子五角星
  • 力扣刷题Day 70:在排序数组中查找元素的第一个和最后一个位置(34)
  • Visual Studio 2022 在 Windows 11 添加资源时崩溃问题分析与解决方案
  • [Linux] Linux GPIO应用编程深度解析与实践指南(代码示例)
  • JAVA实战开源项目:医院药品管理系统 (Vue+SpringBoot) 附源码
  • 数组1 day7
  • zabbix 6 监控 docker 容器
  • Linux 库文件的查看和管理
  • 解决 Java 项目中 “zip END header not found“ 错误