#P2957. 第1题-AI算法训练中的动态优先级经验回放
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 1666
            Accepted: 138
            Difficulty: 6
            
          
          
          
                       所属公司 : 
                              华为
                                
            
                        
              时间 :2025年5月14日-暑期实习
                              
                      
          
 
- 
                        算法标签>有序集合          
 
第1题-AI算法训练中的动态优先级经验回放
题解
题面描述
有一个经验池,需要支持三种操作,共有 N 次操作:
- 
插入经验
格式:+ id score
向池中插入一个编号为 id 的经验,初始优先级为 score。 - 
提取 TopK 经验
格式:- K
从当前经验池中提取优先级最高的 K 个经验的 id(按优先级降序,如果优先级相同按 id 升序);提取后,这些经验暂时从池中移除。如果此时池中剩余经验不足 K 个,则输出-1。 - 
更新优先级
格式:= id newScore
将编号为 id 的经验的优先级更新为 newScore;更新后,之前所有被提取(但尚未真正“消耗”)的经验都要重新回到池中,且优先级为最新值。 
若所有操作中没有任何一次提取操作,则输出 null。若输入行数与 N 不符,也输出 null。
思路
- 
核心数据结构
- 使用一个全局哈希表 
prio[id]存储每个经验的最新优先级。 - 使用一个最大堆 
pq,其中每个节点存储三元组(<当前优先级>, -id, <版本号>),版本号用于延迟删除(见下)。 - 使用一个缓冲区 
buffer保存每次提取出的、尚未真正“消耗”的经验节点。 
 - 使用一个全局哈希表 
 - 
延迟删除与版本号
- 每当插入或更新经验时,给该经验在哈希表中写入新优先级,并增加一个“版本号”标识。
 - 向堆中 push 时一并存入当前版本号。
 - 从堆中 pop 时,如果该节点的版本号与哈希表中的版本号不一致,说明这是过时条目,直接丢弃,继续 pop。
 
 - 
提取操作
- 循环从堆中 pop 合法节点,收集 K 个;同时把它们放入 
buffer。 - 如果不足 K,输出 
-1;否则输出收集到的 id 列表。 
 - 循环从堆中 pop 合法节点,收集 K 个;同时把它们放入 
 - 
更新操作
- 写入新的优先级与版本号到哈希表,并向堆中 push 新节点。
 - 重置缓冲区:将 
buffer中所有节点重新 push 回堆中并清空buffer,以实现“所有提取的经验回归池子”。 
 - 
复杂度分析
- 插入/更新:O(logM),其中 M 为当前堆大小。
 - 提取 K:最坏 O((K+V)logM),其中 V 是因延迟删除弹出的过时节点数,但均摊下仍然可接受。
 - 总体上 NlogN 级别,可在 N≤105 下通过。
 
 
C++
#include <bits/stdc++.h>
using namespace std;
// 全局版本号,每次插入或更新时自增
long long global_version = 0;
// 经验的最新优先级与版本号映射
unordered_map<int, pair<long long, long long>> prio;
// 最大堆节点:{优先级, -id, 版本号}
struct Node {
    long long score;
    int neg_id;
    long long version;
    bool operator<(const Node& o) const {
        if (score != o.score) return score < o.score; // 分数降序
        return neg_id < o.neg_id; // 同分 id 升序(neg_id 较大 => id 较小)
    }
};
priority_queue<Node> pq;
vector<Node> buffer;  // 存放提取出的节点,等待可能的恢复
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int N;
    if (!(cin >> N)) {
        cout << "null\n";
        return 0;
    }
    string op;
    bool hasExtract = false;
    for (int i = 0; i < N; i++) {
        if (!(cin >> op)) {
            cout << "null\n";
            return 0;
        }
        if (op == "+") {
            int id; long long score;
            if (!(cin >> id >> score)) { cout << "null\n"; return 0; }
            global_version++;
            prio[id] = {score, global_version};
            pq.push({score, -id, global_version});
        }
        else if (op == "-") {
            int K;
            if (!(cin >> K)) { cout << "null\n"; return 0; }
            hasExtract = true;
            // 先判断可用数量:总数 - 本轮已提取数
            size_t available = prio.size();
            if (buffer.size() > available) { cout << "null\n"; return 0; } // 防御
            available -= buffer.size();
            if ((size_t)K > available) {
                cout << "-1\n";
                continue; // 不改变任何状态
            }
            vector<int> res;
            res.reserve(K);
            // 提取恰好 K 个
            while ((int)res.size() < K) {
                if (pq.empty()) { cout << "null\n"; return 0; } // 按理不会发生
                Node top = pq.top(); pq.pop();
                int id = -top.neg_id;
                auto it = prio.find(id);
                // 懒删除:版本不匹配直接丢弃
                if (it == prio.end() || it->second.second != top.version) continue;
                // 有效条目:加入结果,并放入 buffer(表示本轮已取,不可再取)
                res.push_back(id);
                buffer.push_back(top);
            }
            for (int j = 0; j < K; j++) {
                cout << res[j] << (j+1<K ? ' ' : '\n');
            }
        }
        else if (op == "=") {
            int id; long long newScore;
            if (!(cin >> id >> newScore)) { cout << "null\n"; return 0; }
            global_version++;
            prio[id] = {newScore, global_version};
            pq.push({newScore, -id, global_version});
            // 恢复所有 buffer 中的节点(本轮提取全部回池)
            for (auto &nd : buffer) pq.push(nd);
            buffer.clear();
        }
        else {
            cout << "null\n";
            return 0;
        }
    }
    if (!hasExtract) {
        cout << "null\n";
    }
    return 0;
}
Python
import sys
import heapq
def main():
    data = sys.stdin.read().strip().split()
    if not data:
        print("null")
        return
    it = iter(data)
    try:
        N = int(next(it))
    except:
        print("null")
        return
    # prio: id -> (score, version)
    prio = {}
    pq = []  # 小根堆存 (-score, id, version)
    buffer = []  # 本轮已提取的节点等待回池
    global_version = 0
    has_extract = False
    for _ in range(N):
        try:
            op = next(it)
        except StopIteration:
            print("null")
            return
        if op == "+":
            # 插入
            try:
                id_ = int(next(it))
                score = int(next(it))
            except:
                print("null"); return
            global_version += 1
            prio[id_] = (score, global_version)
            heapq.heappush(pq, (-score, id_, global_version))
        elif op == "-":
            # 提取
            try:
                K = int(next(it))
            except:
                print("null"); return
            has_extract = True
            available = len(prio) - len(buffer)
            if K > available:
                print("-1")
                continue
            res = []
            while len(res) < K:
                if not pq:
                    print("null"); return
                score_neg, id_, ver = heapq.heappop(pq)
                s, v = prio.get(id_, (None, None))
                if v != ver:
                    continue  # 已过期
                res.append(id_)
                buffer.append((score_neg, id_, ver))
            print(" ".join(map(str, res)))
        elif op == "=":
            # 更新
            try:
                id_ = int(next(it))
                new_score = int(next(it))
            except:
                print("null"); return
            global_version += 1
            prio[id_] = (new_score, global_version)
            heapq.heappush(pq, (-new_score, id_, global_version))
            # 回池
            for node in buffer:
                heapq.heappush(pq, node)
            buffer.clear()
        else:
            print("null")
            return
    if not has_extract:
        print("null")
        return
if __name__ == "__main__":
    main()
Java
import java.io.*;
import java.util.*;
class Node {
    long score;
    int id;
    long version;
    Node(long s, int i, long v) {
        score = s; id = i; version = v;
    }
}
public class Main {
    public static void main(String[] args) throws IOException {
        FastScanner fs = new FastScanner(System.in);
        Integer NObj = fs.nextIntNullable();
        if (NObj == null) { System.out.println("null"); return; }
        int N = NObj;
        Map<Integer, long[]> prio = new HashMap<>(); // id -> {score, version}
        Comparator<Node> cmp = (a, b) -> {
            if (a.score != b.score) return Long.compare(b.score, a.score);
            return Integer.compare(a.id, b.id);
        };
        PriorityQueue<Node> pq = new PriorityQueue<>(cmp);
        List<Node> buffer = new ArrayList<>();
        long globalVersion = 0;
        boolean hasExtract = false;
        for (int opCnt = 0; opCnt < N; opCnt++) {
            String op = fs.nextToken();
            if (op == null) { System.out.println("null"); return; }
            if (op.equals("+")) {
                Integer id = fs.nextIntNullable();
                Long sc = fs.nextLongNullable();
                if (id == null || sc == null) { System.out.println("null"); return; }
                globalVersion++;
                prio.put(id, new long[]{sc, globalVersion});
                pq.add(new Node(sc, id, globalVersion));
            } else if (op.equals("-")) {
                Integer kObj = fs.nextIntNullable();
                if (kObj == null) { System.out.println("null"); return; }
                int K = kObj;
                hasExtract = true;
                int available = prio.size() - buffer.size();
                if (K > available) {
                    System.out.println("-1");
                    continue;
                }
                List<Integer> res = new ArrayList<>();
                while (res.size() < K) {
                    if (pq.isEmpty()) { System.out.println("null"); return; }
                    Node top = pq.poll();
                    long[] cur = prio.get(top.id);
                    if (cur == null || cur[1] != top.version) continue; // 过期
                    res.add(top.id);
                    buffer.add(top);
                }
                for (int i = 0; i < res.size(); i++) {
                    if (i > 0) System.out.print(" ");
                    System.out.print(res.get(i));
                }
                System.out.println();
            } else if (op.equals("=")) {
                Integer id = fs.nextIntNullable();
                Long ns = fs.nextLongNullable();
                if (id == null || ns == null) { System.out.println("null"); return; }
                globalVersion++;
                prio.put(id, new long[]{ns, globalVersion});
                pq.add(new Node(ns, id, globalVersion));
                // 回池
                for (Node nd : buffer) pq.add(nd);
                buffer.clear();
            } else {
                System.out.println("null");
                return;
            }
        }
        if (!hasExtract) System.out.println("null");
    }
    // 快速输入
    static class FastScanner {
        private final InputStream in;
        private final byte[] buf = new byte[1 << 16];
        private int ptr = 0, len = 0;
        FastScanner(InputStream is) { in = is; }
        private int read() throws IOException {
            if (ptr >= len) {
                len = in.read(buf);
                ptr = 0;
                if (len <= 0) return -1;
            }
            return buf[ptr++];
        }
        String nextToken() throws IOException {
            StringBuilder sb = new StringBuilder();
            int c;
            do { c = read(); } while (c != -1 && c <= ' ');
            if (c == -1) return null;
            while (c != -1 && c > ' ') {
                sb.append((char) c);
                c = read();
            }
            return sb.toString();
        }
        Integer nextIntNullable() throws IOException {
            String t = nextToken(); if (t == null) return null;
            return Integer.parseInt(t);
        }
        Long nextLongNullable() throws IOException {
            String t = nextToken(); if (t == null) return null;
            return Long.parseLong(t);
        }
    }
}
        题目内容
AI算力资源宝贵,在AI算法训练中,经验回放机制是通过存储和重用过去的经验数据来提高算法训练效率,以节省AI算力资源和提升算法训练迭代速度。为了进一步优化算法,我们使用优先级经验回放,根据每个经验的TD误差(TemporalDifferenceError)动态调整其采样优先级。你需要设计一个高效的数据结构和算法来支持以下操作:
插入经验:将新经验加入经验池子,并赋予初始优先级。
提取TopK经验:取出优先级最高的K个经验用于算法训练。提取经验操作后,经验池子中就少了已经提取出去的经验。若经验池中剩余的经验个数小于K,则返回−1。
更新优先级:根据训练后的TD误差,更新指定经验的优先级。每次更新优先级操作后,之前提取出来的经验又都回到经验池子中,即经验池子包含了所有插入过的经验,且优先级为最新更新后的优先级。
给定N个按时间顺序发生的操作(插入/提取/更新),请输出每次提取操作的结果。
输入描述
第一行为整数N,表示操作总数。
接下来N行,每行表示一个操作:
insert id score:插入ID为id的经验,初始优先级为score。insert操作用+表示。
update id newScore:将ID为id的经验优先级更新为newScore。update操作用=表示。
extract k:提取当前优先级最高的k个经验id(按经验的优先级降序排列,若优先级相同则按id升序排列)。extract操作用−表示。
约束条件: 1≤N≤105
1≤id≤105(保证同一id不会被重复插入)
0<score,newScore≤109
1≤K≤10001
输出描述
对每个extract操作,按顺序输出提取的id列表,用空格分隔。若有多次extract操作的输出,则后面的extract操作在前一次输出结果之后换行输出。注意:若extract操作的返回为−1,则输出−1。
若所有操作记录中没有extract操作,则返回null。
若入参的操作总数N和实际的操作行数不匹配,则返回null
样例1
输入
7
+ 1 5
+ 2 10
+ 3 7
- 2
= 3 20
- 1
- 1
输出
2 3
3
2
说明
初始经验池:[(1,5),(2,10),(3,7)] 提取Top2:2(10)→3(7)→输出2 3
更新后经验池:[(1,5),(2,10),(3,20)] 提取Top1:3(20)→输出3
经验池中剩余:[(1,5),(2,10)] 提取Top1:2(10)→输出2
样例2
输入
3
+ 1 5
+ 2 10
- 4
输出
-1
说明
插入了2个经验值,最后要提取4个经验值;池子中的经验值不够数量,故返回−1