#P3553. 第2题-大模型训练MOE场景路由优化算法
-
1000ms
Tried: 2737
Accepted: 416
Difficulty: 4
所属公司 :
华为
时间 :2025年9月3日-国内-AI
-
算法标签>模拟
第2题-大模型训练MOE场景路由优化算法
思路与方法
给定 n 个专家,平均分布在 m 张 NPU 上(每张卡上一组,组内专家编号连续)。算法分三步:
-
按组取代表 组大小为 g=n/m。对每一组,找到组内最大概率以及对应的专家编号,作为该组代表值。
-
选路由目标 NPU(选 p 个组) 将所有组按“代表概率”从大到小排序,取前 p 个组(对应的 NPU)。若 p>m,直接输出
error。 -
在选定的 p 个组里选 k 个专家 将这 p 个组中的所有专家收集起来,按概率从大到小挑选前 k 个专家编号作为最终路由目标。若可选专家数 p⋅g<k,输出
error。 为了结果可复现,概率相同时按专家编号小的优先。
最后把选出的 k 个专家编号按升序输出(空格分隔,行尾无空格)。
复杂度分析
- 计算每组最大值:遍历一次,O(n)。
- 选出前 p 个组:对 m 个代表排序,O(mlogm)(或用大小为 p 的堆 O(mlogp))。
- 在 p 个组里选出前 k 个专家:对 p⋅g 个元素排序,O((p⋅g)log(p⋅g))(或用堆 O((p⋅g)logk))。
- 总体:在数据范围 n≤104 下,直接排序实现已足够高效,代码更简洁。
实现要点
-
组索引:第 i 组覆盖的专家编号区间为 [i⋅g, (i+1)⋅g−1]。
-
排序键:
- 选组时:按
(组代表概率 desc, 组索引 asc)稳定选择。 - 选专家时:按
(概率 desc, 专家编号 asc)。
- 选组时:按
-
输出:最终 k 个编号再升序打印。
参考实现
Python 版本
import sys
def main():
data = sys.stdin.read().strip().split()
if len(data) < 4:
print("error")
return
it = iter(data)
try:
n = int(next(it)); m = int(next(it)); p = int(next(it)); k = int(next(it))
except:
print("error"); return
# 读取 n 个概率
probs = []
for _ in range(n):
try:
probs.append(float(next(it)))
except:
print("error"); return
# 基本校验
if n % m != 0:
print("error"); return
if p > m:
print("error"); return
g = n // m # 每组大小
# 1) 计算每组代表(最大概率及其专家编号)
group_repr = [] # (max_prob, group_id, expert_idx_of_max)
for gi in range(m):
L = gi * g
R = L + g
max_prob = -1.0
max_idx = -1
# 组内扫描找最大值;并用较小编号打破平局
for idx in range(L, R):
pr = probs[idx]
if pr > max_prob or (abs(pr - max_prob) < 1e-18 and idx < max_idx):
max_prob = pr
max_idx = idx
group_repr.append((max_prob, gi, max_idx))
# 2) 选择前 p 个组(按代表概率降序;组索引升序打破平局)
group_repr.sort(key=lambda x: (-x[0], x[1]))
chosen_groups = set([gi for _, gi, _ in group_repr[:p]])
# 3) 收集这些组的所有专家并选前 k 个(按概率降序,编号升序)
pool = []
for gi in chosen_groups:
L = gi * g
R = L + g
for idx in range(L, R):
pool.append((probs[idx], idx))
if len(pool) < k:
print("error"); return
pool.sort(key=lambda x: (-x[0], x[1]))
chosen = [idx for _, idx in pool[:k]]
chosen.sort()
print(" ".join(map(str, chosen)))
if __name__ == "__main__":
main()
Java 版本
import java.io.*;
import java.util.*;
public class Main {
static class GroupRep {
double maxProb;
int gid;
int idxOfMax;
GroupRep(double p, int g, int idx) { maxProb = p; gid = g; idxOfMax = idx; }
}
public static void main(String[] args) throws Exception {
Scanner sc = new Scanner(System.in);
String s1 = sc.next();
if (s1 == null) { System.out.println("error"); return; }
int n = Integer.parseInt(s1);
int m = Integer.parseInt(sc.next());
int p = Integer.parseInt(sc.next());
int k = Integer.parseInt(sc.next());
double[] prob = new double[n];
for (int i = 0; i < n; i++) {
String t = sc.next();
if (t == null) { System.out.println("error"); return; }
prob[i] = Double.parseDouble(t);
}
if (n % m != 0) { System.out.println("error"); return; }
if (p > m) { System.out.println("error"); return; }
int g = n / m;
// 1) 计算每组代表
ArrayList<GroupRep> reps = new ArrayList<>(m);
for (int gi = 0; gi < m; gi++) {
int L = gi * g, R = L + g;
double maxP = -1.0;
int maxIdx = -1;
for (int idx = L; idx < R; idx++) {
double pr = prob[idx];
if (pr > maxP || (Math.abs(pr - maxP) < 1e-18 && idx < maxIdx)) {
maxP = pr; maxIdx = idx;
}
}
reps.add(new GroupRep(maxP, gi, maxIdx));
}
// 2) 选择前 p 个组
reps.sort((a, b) -> {
if (a.maxProb == b.maxProb) return Integer.compare(a.gid, b.gid);
return Double.compare(b.maxProb, a.maxProb);
});
boolean[] chosenGroup = new boolean[m];
for (int i = 0; i < p; i++) chosenGroup[reps.get(i).gid] = true;
// 3) 选专家
ArrayList<int[]> pool = new ArrayList<>(p * g); // [idx, prob按排序用,不存prob避免装箱?这里保留prob]
ArrayList<double[]> poolWithProb = new ArrayList<>(p * g);
for (int gi = 0; gi < m; gi++) {
if (!chosenGroup[gi]) continue;
int L = gi * g, R = L + g;
for (int idx = L; idx < R; idx++) {
poolWithProb.add(new double[]{prob[idx], idx});
}
}
if (poolWithProb.size() < k) { System.out.println("error"); return; }
poolWithProb.sort((x, y) -> {
int c = Double.compare(y[0], x[0]); // 概率降序
if (c != 0) return c;
return Integer.compare((int)x[1], (int)y[1]); // 编号升序
});
int[] ans = new int[k];
for (int i = 0; i < k; i++) ans[i] = (int)poolWithProb.get(i)[1];
Arrays.sort(ans);
StringBuilder out = new StringBuilder();
for (int i = 0; i < k; i++) {
if (i > 0) out.append(' ');
out.append(ans[i]);
}
System.out.println(out.toString());
}
}
C++ 版本
#include <bits/stdc++.h>
using namespace std;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m, p, k;
if (!(cin >> n >> m >> p >> k)) {
cout << "error\n";
return 0;
}
vector<double> prob(n);
for (int i = 0; i < n; ++i) {
if (!(cin >> prob[i])) { cout << "error\n"; return 0; }
}
if (n % m != 0) { cout << "error\n"; return 0; }
if (p > m) { cout << "error\n"; return 0; }
int g = n / m;
// 1) 每组代表(最大概率及其索引)
struct Rep { double mx; int gid; int idx; };
vector<Rep> reps; reps.reserve(m);
for (int gi = 0; gi < m; ++gi) {
int L = gi * g, R = L + g;
double mx = -1.0; int idx = -1;
for (int i = L; i < R; ++i) {
double pr = prob[i];
if (pr > mx || (fabs(pr - mx) < 1e-18 && i < idx)) {
mx = pr; idx = i;
}
}
reps.push_back({mx, gi, idx});
}
// 2) 选前 p 个组
sort(reps.begin(), reps.end(), [](const Rep& a, const Rep& b){
if (a.mx == b.mx) return a.gid < b.gid; // 概率相同按组号升序
return a.mx > b.mx; // 概率降序
});
vector<char> chosen(m, 0);
for (int i = 0; i < p; ++i) chosen[reps[i].gid] = 1;
// 3) 在选定组中选前 k 个专家
vector<pair<double,int>> pool; pool.reserve(p * g);
for (int gi = 0; gi < m; ++gi) if (chosen[gi]) {
int L = gi * g, R = L + g;
for (int i = L; i < R; ++i) pool.push_back({prob[i], i});
}
if ((int)pool.size() < k) { cout << "error\n"; return 0; }
sort(pool.begin(), pool.end(), [](const auto& a, const auto& b){
if (a.first == b.first) return a.second < b.second; // 编号升序
return a.first > b.first; // 概率降序
});
vector<int> ans; ans.reserve(k);
for (int i = 0; i < k; ++i) ans.push_back(pool[i].second);
sort(ans.begin(), ans.end());
for (int i = 0; i < k; ++i) {
if (i) cout << ' ';
cout << ans[i];
}
cout << '\n';
return 0;
}
题目内容
MOE 模型训练时,token 根据概率发送到 topk 个不同的专家进行计算。这些专家分布在多个 NPU 卡上。Device−Limitedr outing 算法将 token 的路由目标限制在 P 个 NPU 上,可以有效降低通信成本。具体的:
-
把 n 个专家平均分配在 m 个 NPU 上,每个 NPU 上的专家为一个组;设 n 个专家的编号为 N=[0,1,2,…,n−1] ,同一个专家组上的专家编号是连续的;
-
每个专家对应一个概率,表示被路由到的可能性;用每个组中的最大概率作为本组代表,从所有组中选择概率最大的 p 个组,其所在的 NPU 即为路由目标限制 NPU ;
-
再从上述 p 个 NPU 对应的所有专家概率中选择 k 个最大的概率对应的专家编号作为最终路由目标。
试着编写一段程序,实现以上路由算法。
输入描述
第一行有 4 个处于区间 [1,10000] 之内的整数,第 1 个表示专家的个数 n ,第 2 个表示 NPU 个数 m ,第 3 个表示路由目标限制 NPU 个数 p ,第 4 个表示目标路由专家个数 k ;
第二行有 n 个处于区间 (0,1) 之内的浮点数,表示每个专家对应的概率值,这 n 个数对应的专家的编号为 [0,1,2,...,n−1] ;
输出描述
如果,n 不能被 m 整除或者获取不到 k 个专家编号,输出 error ;
否则,按照从小到大的顺序,输出 k 个专家编号,任意相邻两数之间有空格,最后一个数字(行尾没有空格)
样例1
输入
8 4 4 2
0.5 0.01 0.09 0.023 0.027 0.05 0.1 0.2
输出
0 7
说明
将专家分成 4 组,分别为:(1)0.5 0.01 (2)0.09 0.023 (3)0.027 0.05 (4)0.1 0.2
限定专家为 4 ,则 4 组都被选定,从选定的 4 组中,选择 2 个专家,分别是 0.5 和 0.2 对应的专家,对应的编号分别是 0 和 7 ,按照升序排个列,输出: 0 7
样例2
输入
8 4 5 2
0.3 0.04 0.06 0.45 0.05 0.01 0.03 0.06
输出
error
说明
NPU 一共只有 4 个,需要限定 5 个 NPU ,不满足条件,输出 error