当前位置: 首页 > news >正文

Qwen3中的MoE是如何平衡专家负载的?

近年来,混合专家(Mixture of Experts, MoE)架构因其在扩展模型容量的同时保持计算效率的潜力,在大型语言模型领域受到了广泛关注。Qwen3系列模型也采用了MoE架构,通过稀疏激活特定的“专家”网络来处理不同的输入。然而,MoE模型的一个核心挑战在于如何确保各个专家之间的负载均衡,避免某些专家过载而另一些专家空闲。本文将基于Qwen3的开源代码,深入分析其负载均衡损失函数(load_balancing_loss_func)的设计与实现。

在这里插入图片描述

MoE与负载均衡的重要性

在MoE模型中,一个门控网络(Gating Network)决定将每个输入token路由到哪些专家进行处理。理想情况下,我们希望所有专家都能得到充分利用,并且每个专家都能学到独特的知识。如果路由机制出现偏差,导致大部分token被路由到少数几个专家,就会出现以下问题:

  • 1.专家过载与资源浪费:少数专家计算压力过大,而其他专家则处于空闲状态,导致计算资源利用不均。
  • 2.训练不稳定:不均衡的负载可能导致训练过程不稳定,模型难以收敛。
  • 3.模型性能下降:专家未能充分特化,模型整体性能可能受损。

在这里插入图片描述

因此,引入一个辅助的负载均衡损失函数至关重要,它能够惩罚不均衡的路由行为,鼓励token在专家间均匀分布。

Qwen3中的负载均衡机制借鉴了Switch Transformer论文[1]中的公式(4)至(6):
在这里插入图片描述

辅助损失函数的目标是使每个专家的token分配比例(tokens_per_expert)和路由概率(router_prob_per_expert)尽可能均匀。

让我们一起剖析load_balancing_loss_func函数的具体实现:


def load_balancing_loss_func(gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],num_experts: Optional[int] = None,top_k=2,attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the lossfunction presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing betweenexperts is too unbalanced.Args:gate_logits:Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors ofshape [batch_size X sequence_length, num_experts].num_experts:Number of expertstop_k:The number of experts to route per-token, can be also interpreted as the `top-k` routingparameter.attention_mask (`torch.Tensor`, *optional*):The attention_mask used in forward functionshape [batch_size X sequence_length] if not None.Returns:The auxiliary loss."""if gate_logits is None or not isinstance(gate_logits, tuple):return 0if isinstance(gate_logits, tuple):compute_device = gate_logits[0].deviceconcatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)if attention_mask is None:# Compute the percentage of tokens routed to each expertstokens_per_expert = torch.mean(expert_mask.float(), dim=0)# Compute the average probability of routing to these expertsrouter_prob_per_expert = torch.mean(routing_weights, dim=0)else:batch_size, sequence_length = attention_mask.shapenum_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)# Compute the mask that masks all padding tokens as 0 with the same shape of expert_maskexpert_attention_mask = (attention_mask[None, :, :, None, None].expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)).reshape(-1, top_k, num_experts).to(compute_device))# Compute the percentage of tokens routed to each expertstokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expertrouter_per_expert_attention_mask = (attention_mask[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(compute_device))# Compute the average probability of routing to these expertsrouter_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts

该函数接收门控网络输出的gate_logits、专家总数num_experts、每个token选择的专家数量top_k,以及可选的attention_mask(用于处理padding)。
具体的步骤为:

1.门控输出(Gate Logits)

if gate_logits is None or not isinstance(gate_logits, tuple):
return 0

if isinstance(gate_logits, tuple):compute_device = gate_logits[0].deviceconcatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)

模型中可能有多层MoE模块,gate_logits是一个元组,包含了每一层MoE的门控输出。这里首先将所有层的gate_logits在第0维度(通常是token维度)上拼接起来。这意味着负载均衡损失是跨所有MoE层、所有token统一计算的。

2.计算路由权重和选择专家

routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

使用softmax将gate_logits转换为路由权重(routing_weights),表示每个token分配到各专家的概率。
再根据routing_weights,为每个token选出概率最高的top_k个专家。最后对selected_experts进行one-hot编码,生成一个掩码,标记了哪些专家被选中。

3.计算核心指标

这是负载均衡损失计算的核心。

# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0
)# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (attention_mask[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(compute_device)
)

tokens_per_expert反映了每个专家在所有top_k选择中被选中的“密度”或“比例”。如果一个专家被频繁选中(即使不是首选),其对应的值会较高。在Qwen3的实现中,如果attention_mask存在,会先根据attention_mask过滤掉padding token的贡献。
router_prob_per_expert表示门控网络分配给每个专家的平均概率值。同样,如果attention_mask存在,会排除padding token的影响。

4.计算最终损失

overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))return overall_loss * num_experts

根据公式 Loss = N ⋅ ∑ f i ⋅ P i \text{Loss} = N \cdot \sum f_i \cdot P_i Loss=NfiPi,计算tokens_per_expert和 router_prob_per_expert的点积,并乘以专家数量num_experts,得到最终的辅助损失。
这个计算方式旨在同时鼓励:

  • 1.专家被选中的频率(tokens_per_expert)应该均衡。
  • 2.门控网络对所选专家的置信度(router_prob_per_expert)应该高。

通过将这两者相乘,如果一个专家被频繁选中但门控概率很低,或者门控概率很高但很少被选中,损失都会相应调整。
该损失值越大,说明专家路由越不均衡,模型会通过优化减少这种不均衡。

总结:

Qwen3通过实现Switch Transformer思想的负载均衡损失函数,有效地解决了MoE架构中的专家负载不均问题。该函数通过统计每个专家接收到的token比例以及门控网络分配给各专家的平均概率,构建了一个惩罚项。
这个惩罚项被加到模型的总损失中,在训练过程中引导门控网络学习更均衡的路由策略。这不仅保证了计算资源的有效利用,也促进了各个专家的特化学习,最终有助于提升模型的整体性能和训练稳定性。理解这一机制对于深入掌握和应用MoE大模型至关重要。

[1] Fedus W, Zoph B, Shazeer N. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity[J]. Journal of Machine Learning Research, 2022, 23(120): 1-39.

http://www.xdnf.cn/news/347545.html

相关文章:

  • 跨线程和跨进程通信还有多种方式对比
  • JS 下载data:image/png;base64, 图片
  • 告别手动输入密码:基于SSHPass的自动化文件传输实践告别手动输入密码:基于SSHPass的自动化文件传输实践
  • Marin说PCB之器件的3D数模匹配失效案例
  • 在微程序控制器中,各概念之间的详细关系
  • IEEE出版|2025年物联网、数据科学与先进计算国际学术会议(IDSAC2025)
  • MyBatis 动态 SQL 完整笔记
  • 深泽多层电路在PCB行业中属于什么水平
  • laravel 使用异步队列,context带的上下文造成反序列化出问题
  • sql server限制用户只能访问特定表
  • PWN基础-ROP技术-ret2syscall-64位程序栈溢出利用
  • el-table合并单元
  • 【基础知识】李雅普诺夫方程与李雅普诺夫函数
  • 985高校查重率“隐性阈值”:低于5%可能被重点审查!
  • 从艾米・阿尔文看 CTO 的多面特质与成长路径
  • 英皇娱乐X乐华娱乐携手造星!“英皇乐华青少年艺人培训班”正式启动!
  • 深度学习-159-综述之混合专家模型和推理模型以及工作流和智能体的概念
  • Elastic:如何构建由 AI 驱动的数字客户体验策略
  • 计算机网络-LDP工作过程详解
  • 代码随想录算法训练营第60期第三十天打卡
  • C++之set和map的运用
  • MySQL 数据库
  • AI人工智能在交通物流领域的应用
  • web 自动化之 Selenium 元素定位和浏览器操作
  • 探索 C++ 在行业应用与技术融合中的核心价值
  • Baklib构建AI就绪知识管理体系
  • 湖北理元理律师事务所的企业债务重组实践:挽救实体经济的法律处方
  • B站pwn教程笔记-8
  • 验证码(笔记)
  • IndoorLink 新一代旗舰电子讲解器,四大革新技术开启破冰之旅