解题思路
本题使用的算法是矩阵乘法、数值稳定 softmax 和排序选 Top−S。
先对每个 token 计算它到每个专家的门控分数:
score[i][e]=Xi×Wg,e
然后对同一个 token 的 E 个专家分数做数值稳定的 softmax,得到归一化门控权重 gate[i][e]。
题目内容
混合专家模型(Mixture of Experts, MoE)是一种有效的增强模型的参数数量同时保持计算量的方式。传统Token选择专家(Token Choice)存在负载不均衡问题:热门专家过载,冷门专家闲置。Expert Choice Routing(ECR)反转选择关系:每个专家主动选择Top-S个Token,保持天然平衡。
要求不调用pytorch、tensorflow等库实现Expert Choice Routing机制。给定一批token表示、门控权重和专家变换,完成:
- 计算专家门控分数与归一化权重,权重使用数值稳定softmax方法进行归一化
- 每个专家独立选择Top-S个token,分数相同时优先选择序号小的token
- 对选中token执行专家变换并加权聚合
- 输出最终表示,注意未被任何专家选中处理的token置为长度d的0向量
输入描述
第一行:四个正整数 N d E S
- N:token数量(序列长度)
- d:特征维度
- E:专家数量
- S:每个专家选择的token数(容量),S≤N
接下来:
- N 行,每行 d 个浮点数:输入token矩阵 X
- d 行,每行 E 个浮点数:门控权重矩阵 Wg(转置存储,即每行对应输入维度,每列对应专家)
- E 个块,每块:
- d 行,每行 d 个浮点数:专家 e 的变换矩阵 We
输出描述
输出表示 Y,N 行,每行 d 个浮点数,保留2位小数
注意输入格式违反约束要求,则返回为0
样例1
输入
4 3 2 3
1.0 0.5 0.2
0.8 0.3 0.9
0.2 0.7 0.4
0.5 0.1 0.6
0.3 0.7
0.5 0.2
0.2 0.1
1.0 0.0 0.0
0.0 1.0 0.0
0.0 0.0 1.0
0.5 0.0 0.0
0.0 0.5 0.0
0.0 0.0 0.5
输出
0.28 0.14 0.06
0.59 0.22 0.66
0.11 0.38 0.22
0.37 0.07 0.44
说明
输入 N=4 d=3 E=2 S=3
X=[1.0 0.5 0.2; 0.8 0.3 0.9; 0.2 0.7 0.4; 0.5 0.1 0.6]
Wg=[0.3 0.7; 0.5 0.2; 0.2 0.1]
We=(按专家顺序的两个 3×3 矩阵)
第一步:计算专家分数矩阵(近似),依次计算每个token X[i]⋅Wg
- token 0:[1.0∗0.3+0.5∗0.5+0.2∗0.2, 1.0∗0.7+0.5∗0.2+0.2∗0.1]=[0.59,0.82],softmax之后 [0.442752,0.557248]
- token 1:[0.8∗0.3+0.3∗0.5+0.9∗0.2, 0.8∗0.7+0.3∗0.2+0.9∗0.1]=[0.57,0.71],softmax之后 [0.465057,0.534943]
- token 2:[0.2∗0.3+0.7∗0.5+0.4∗0.2, 0.2∗0.7+0.7∗0.2+0.4∗0.1]=[0.49,0.32],softmax之后 [0.542398,0.457602]
- token 3:[0.5∗0.3+0.1∗0.5+0.6∗0.2, 0.5∗0.7+0.1∗0.2+0.6∗0.1]=[0.32,0.43],softmax之后 [0.472528,0.527472]
第二步:根据步骤1的得分,Top-S选择
- Expert 0:[0,1,1,1](选择token 1, 2, 3)
- Expert 1:[1,1,0,1](选择token 0, 1, 3)
第三步:聚合计算
- token 0 被Expert 0选中,输出 =0.557248∗X[0]We[1]=[0.28 0.14 0.06]
- token 1 被Expert 0, 1选中,输出 =0.465057∗X[1]We[0]+0.534943∗X[1]We[1]=[0.59 0.22 0.66]
- token 2 被Expert 0选中,输出 =0.542398∗X[2]We[0]=[0.11 0.38 0.22]
- token 3 被Expert 0, 1选中,输出 =0.472528∗X[2]We[0]+0.527472∗X[2]We[1]=[0.37 0.07 0.44]

样例2
输入
3 2 2 0
1.0 0.5
0.5 1.0
0.2 0.3
0.1 0.2
0.8 0.6
0.2 0.4
0.9 0.1
0.1 0.8
0.0 0.1
输出
0
说明
输入 N=3, d=2, E=2, S=0,其中输入第一行的S违反了正整数约束