#P3553. 第2题-大模型训练MOE场景路由优化算法
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 2641
            Accepted: 404
            Difficulty: 4
            
          
          
          
                       所属公司 : 
                              华为
                                
            
                        
              时间 :2025年9月3日-国内-AI
                              
                      
          
 
- 
                        算法标签>模拟          
 
第2题-大模型训练MOE场景路由优化算法
思路与方法
给定 n 个专家,平均分布在 m 张 NPU 上(每张卡上一组,组内专家编号连续)。算法分三步:
- 
按组取代表 组大小为 g=n/m。对每一组,找到组内最大概率以及对应的专家编号,作为该组代表值。
 - 
选路由目标 NPU(选 p 个组) 将所有组按“代表概率”从大到小排序,取前 p 个组(对应的 NPU)。若 p>m,直接输出
error。 - 
在选定的 p 个组里选 k 个专家 将这 p 个组中的所有专家收集起来,按概率从大到小挑选前 k 个专家编号作为最终路由目标。若可选专家数 p⋅g<k,输出
error。 为了结果可复现,概率相同时按专家编号小的优先。 
最后把选出的 k 个专家编号按升序输出(空格分隔,行尾无空格)。
复杂度分析
- 计算每组最大值:遍历一次,O(n)。
 - 选出前 p 个组:对 m 个代表排序,O(mlogm)(或用大小为 p 的堆 O(mlogp))。
 - 在 p 个组里选出前 k 个专家:对 p⋅g 个元素排序,O((p⋅g)log(p⋅g))(或用堆 O((p⋅g)logk))。
 - 总体:在数据范围 n≤104 下,直接排序实现已足够高效,代码更简洁。
 
实现要点
- 
组索引:第 i 组覆盖的专家编号区间为 [i⋅g, (i+1)⋅g−1]。
 - 
排序键:
- 选组时:按 
(组代表概率 desc, 组索引 asc)稳定选择。 - 选专家时:按 
(概率 desc, 专家编号 asc)。 
 - 选组时:按 
 - 
输出:最终 k 个编号再升序打印。
 
参考实现
Python 版本
import sys
def main():
    data = sys.stdin.read().strip().split()
    if len(data) < 4:
        print("error")
        return
    it = iter(data)
    try:
        n = int(next(it)); m = int(next(it)); p = int(next(it)); k = int(next(it))
    except:
        print("error"); return
    # 读取 n 个概率
    probs = []
    for _ in range(n):
        try:
            probs.append(float(next(it)))
        except:
            print("error"); return
    # 基本校验
    if n % m != 0:
        print("error"); return
    if p > m:
        print("error"); return
    g = n // m  # 每组大小
    # 1) 计算每组代表(最大概率及其专家编号)
    group_repr = []  # (max_prob, group_id, expert_idx_of_max)
    for gi in range(m):
        L = gi * g
        R = L + g
        max_prob = -1.0
        max_idx = -1
        # 组内扫描找最大值;并用较小编号打破平局
        for idx in range(L, R):
            pr = probs[idx]
            if pr > max_prob or (abs(pr - max_prob) < 1e-18 and idx < max_idx):
                max_prob = pr
                max_idx = idx
        group_repr.append((max_prob, gi, max_idx))
    # 2) 选择前 p 个组(按代表概率降序;组索引升序打破平局)
    group_repr.sort(key=lambda x: (-x[0], x[1]))
    chosen_groups = set([gi for _, gi, _ in group_repr[:p]])
    # 3) 收集这些组的所有专家并选前 k 个(按概率降序,编号升序)
    pool = []
    for gi in chosen_groups:
        L = gi * g
        R = L + g
        for idx in range(L, R):
            pool.append((probs[idx], idx))
    if len(pool) < k:
        print("error"); return
    pool.sort(key=lambda x: (-x[0], x[1]))
    chosen = [idx for _, idx in pool[:k]]
    chosen.sort()
    print(" ".join(map(str, chosen)))
if __name__ == "__main__":
    main()
Java 版本
import java.io.*;
import java.util.*;
public class Main {
    static class GroupRep {
        double maxProb;
        int gid;
        int idxOfMax;
        GroupRep(double p, int g, int idx) { maxProb = p; gid = g; idxOfMax = idx; }
    }
    public static void main(String[] args) throws Exception {
        Scanner sc = new Scanner(System.in);
        String s1 = sc.next();
        if (s1 == null) { System.out.println("error"); return; }
        int n = Integer.parseInt(s1);
        int m = Integer.parseInt(sc.next());
        int p = Integer.parseInt(sc.next());
        int k = Integer.parseInt(sc.next());
        double[] prob = new double[n];
        for (int i = 0; i < n; i++) {
            String t = sc.next();
            if (t == null) { System.out.println("error"); return; }
            prob[i] = Double.parseDouble(t);
        }
        if (n % m != 0) { System.out.println("error"); return; }
        if (p > m) { System.out.println("error"); return; }
        int g = n / m;
        // 1) 计算每组代表
        ArrayList<GroupRep> reps = new ArrayList<>(m);
        for (int gi = 0; gi < m; gi++) {
            int L = gi * g, R = L + g;
            double maxP = -1.0;
            int maxIdx = -1;
            for (int idx = L; idx < R; idx++) {
                double pr = prob[idx];
                if (pr > maxP || (Math.abs(pr - maxP) < 1e-18 && idx < maxIdx)) {
                    maxP = pr; maxIdx = idx;
                }
            }
            reps.add(new GroupRep(maxP, gi, maxIdx));
        }
        // 2) 选择前 p 个组
        reps.sort((a, b) -> {
            if (a.maxProb == b.maxProb) return Integer.compare(a.gid, b.gid);
            return Double.compare(b.maxProb, a.maxProb);
        });
        boolean[] chosenGroup = new boolean[m];
        for (int i = 0; i < p; i++) chosenGroup[reps.get(i).gid] = true;
        // 3) 选专家
        ArrayList<int[]> pool = new ArrayList<>(p * g); // [idx, prob按排序用,不存prob避免装箱?这里保留prob]
        ArrayList<double[]> poolWithProb = new ArrayList<>(p * g);
        for (int gi = 0; gi < m; gi++) {
            if (!chosenGroup[gi]) continue;
            int L = gi * g, R = L + g;
            for (int idx = L; idx < R; idx++) {
                poolWithProb.add(new double[]{prob[idx], idx});
            }
        }
        if (poolWithProb.size() < k) { System.out.println("error"); return; }
        poolWithProb.sort((x, y) -> {
            int c = Double.compare(y[0], x[0]); // 概率降序
            if (c != 0) return c;
            return Integer.compare((int)x[1], (int)y[1]); // 编号升序
        });
        int[] ans = new int[k];
        for (int i = 0; i < k; i++) ans[i] = (int)poolWithProb.get(i)[1];
        Arrays.sort(ans);
        StringBuilder out = new StringBuilder();
        for (int i = 0; i < k; i++) {
            if (i > 0) out.append(' ');
            out.append(ans[i]);
        }
        System.out.println(out.toString());
    }
}
C++ 版本
#include <bits/stdc++.h>
using namespace std;
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m, p, k;
    if (!(cin >> n >> m >> p >> k)) {
        cout << "error\n";
        return 0;
    }
    vector<double> prob(n);
    for (int i = 0; i < n; ++i) {
        if (!(cin >> prob[i])) { cout << "error\n"; return 0; }
    }
    if (n % m != 0) { cout << "error\n"; return 0; }
    if (p > m) { cout << "error\n"; return 0; }
    int g = n / m;
    // 1) 每组代表(最大概率及其索引)
    struct Rep { double mx; int gid; int idx; };
    vector<Rep> reps; reps.reserve(m);
    for (int gi = 0; gi < m; ++gi) {
        int L = gi * g, R = L + g;
        double mx = -1.0; int idx = -1;
        for (int i = L; i < R; ++i) {
            double pr = prob[i];
            if (pr > mx || (fabs(pr - mx) < 1e-18 && i < idx)) {
                mx = pr; idx = i;
            }
        }
        reps.push_back({mx, gi, idx});
    }
    // 2) 选前 p 个组
    sort(reps.begin(), reps.end(), [](const Rep& a, const Rep& b){
        if (a.mx == b.mx) return a.gid < b.gid;          // 概率相同按组号升序
        return a.mx > b.mx;                               // 概率降序
    });
    vector<char> chosen(m, 0);
    for (int i = 0; i < p; ++i) chosen[reps[i].gid] = 1;
    // 3) 在选定组中选前 k 个专家
    vector<pair<double,int>> pool; pool.reserve(p * g);
    for (int gi = 0; gi < m; ++gi) if (chosen[gi]) {
        int L = gi * g, R = L + g;
        for (int i = L; i < R; ++i) pool.push_back({prob[i], i});
    }
    if ((int)pool.size() < k) { cout << "error\n"; return 0; }
    sort(pool.begin(), pool.end(), [](const auto& a, const auto& b){
        if (a.first == b.first) return a.second < b.second; // 编号升序
        return a.first > b.first;                           // 概率降序
    });
    vector<int> ans; ans.reserve(k);
    for (int i = 0; i < k; ++i) ans.push_back(pool[i].second);
    sort(ans.begin(), ans.end());
    for (int i = 0; i < k; ++i) {
        if (i) cout << ' ';
        cout << ans[i];
    }
    cout << '\n';
    return 0;
}
        题目内容
MOE 模型训练时,token 根据概率发送到 topk 个不同的专家进行计算。这些专家分布在多个 NPU 卡上。Device−Limitedr outing 算法将 token 的路由目标限制在 P 个 NPU 上,可以有效降低通信成本。具体的:
- 
把 n 个专家平均分配在 m 个 NPU 上,每个 NPU 上的专家为一个组;设 n 个专家的编号为 N=[0,1,2,…,n−1] ,同一个专家组上的专家编号是连续的;
 - 
每个专家对应一个概率,表示被路由到的可能性;用每个组中的最大概率作为本组代表,从所有组中选择概率最大的 p 个组,其所在的 NPU 即为路由目标限制 NPU ;
 - 
再从上述 p 个 NPU 对应的所有专家概率中选择 k 个最大的概率对应的专家编号作为最终路由目标。
 
试着编写一段程序,实现以上路由算法。
输入描述
第一行有 4 个处于区间 [1,10000] 之内的整数,第 1 个表示专家的个数 n ,第 2 个表示 NPU 个数 m ,第 3 个表示路由目标限制 NPU 个数 p ,第 4 个表示目标路由专家个数 k ;
第二行有 n 个处于区间 (0,1) 之内的浮点数,表示每个专家对应的概率值,这 n 个数对应的专家的编号为 [0,1,2,...,n−1] ;
输出描述
如果,n 不能被 m 整除或者获取不到 k 个专家编号,输出 error ;
否则,按照从小到大的顺序,输出 k 个专家编号,任意相邻两数之间有空格,最后一个数字(行尾没有空格)
样例1
输入
8 4 4 2
0.5 0.01 0.09 0.023 0.027 0.05 0.1 0.2
输出
0 7
说明
将专家分成 4 组,分别为:(1)0.5 0.01 (2)0.09 0.023 (3)0.027 0.05 (4)0.1 0.2
限定专家为 4 ,则 4 组都被选定,从选定的 4 组中,选择 2 个专家,分别是 0.5 和 0.2 对应的专家,对应的编号分别是 0 和 7 ,按照升序排个列,输出: 0 7
样例2
输入
8 4 5 2
0.3 0.04 0.06 0.45 0.05 0.01 0.03 0.06
输出
error
说明
NPU 一共只有 4 个,需要限定 5 个 NPU ,不满足条件,输出 error