#P3769. 第1题-节点树的简单路径
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 35
            Accepted: 12
            Difficulty: 5
            
          
          
          
                       所属公司 : 
                              字节
                                
            
                        
              时间 :2025年9月21日
                              
                      
          
 
- 
                        算法标签>DFS          
 
第1题-节点树的简单路径
解题思路
给定一棵以 1 为根的树,每个节点有权值 ai。对每个节点 i,要求统计在从根 1 到该节点 i 的简单路径上,有多少无序对 (x,y)(x=y)满足 ax=ay。
核心做法:根到当前节点的前缀统计 + DFS(栈模拟)
- 
沿着树做一次从根出发的 DFS。维护一张“当前路径上各权值出现次数”的表
freq,以及“当前路径上的相等权值对总数”pairs。 - 
当进入一个节点 u 时(把它加入根到当前的路径):
- 设该点的权值编号为 
id,此权值在路径上已有次数为c=freq[id]。 - 新加入 u 会与这 c 个同值点各形成 1 个新对,所以 
pairs += c,然后freq[id]++。 - 此时记录该节点答案 
ans[u] = pairs(路径 1→u 上的总相等对数)。 
 - 设该点的权值编号为 
 - 
当离开节点 u 时(从路径上移除):
- 先 
freq[id]--,此时路径上还剩freq[id]个同值点;与刚才移除的 u 相关的配对数正是这freq[id]个,于是pairs -= freq[id]完成回溯。 
 - 先 
 
为使计数结构高效,使用值域压缩将所有 ai 压成 1…m,用整型数组 freq[m] 统计即可(也可用哈希表,效果相近)。
用显式栈模拟 DFS,避免深递归爆栈;每个点“进栈一次、退栈一次”。
复杂度分析
- 每个节点仅“进入/退出”各一次,且每次更新是 O(1);
 - 值域压缩耗时 O(nlogn)(排序),若用哈希可做到期望 O(n)。
 - 总时间复杂度:O(nlogn)(含压缩)/ 期望 O(n)(哈希)。
 - 额外空间:
freq、答案、栈和邻接表,均为 O(n)。 
代码实现
Python
# 题意:对每个节点 i,统计从 1 到 i 的路径上,权值相等的无序点对数量
# 做法:DFS + 路径频次维护(值域压缩),显式栈模拟,进点加、出点撤销
import sys
def solve():
    data = list(map(int, sys.stdin.buffer.read().split()))
    it = iter(data)
    n = next(it)
    a = [0] * (n + 1)
    for i in range(1, n + 1):
        a[i] = next(it)
    # 建图(无向)
    g = [[] for _ in range(n + 1)]
    for _ in range(n - 1):
        u = next(it); v = next(it)
        g[u].append(v); g[v].append(u)
    # 值域压缩
    uniq = sorted(set(a[1:]))
    idx = {v: i for i, v in enumerate(uniq)}  # 0..m-1
    b = [0] * (n + 1)
    for i in range(1, n + 1):
        b[i] = idx[a[i]]
    m = len(uniq)
    freq = [0] * m               # 路径上各权值出现次数
    ans = [0] * (n + 1)          # 答案
    pairs = 0                    # 路径上相等权值对数(long long 级别)
    # 显式栈:state=0 进点;state=1 退点
    stack = [(1, 0, 0)]          # (节点, 父亲, 状态)
    while stack:
        u, p, st = stack.pop()
        if st == 0:
            idu = b[u]
            c = freq[idu]
            pairs += c            # 新点与已有同值点形成的对
            freq[idu] = c + 1
            ans[u] = pairs
            stack.append((u, p, 1))
            for v in g[u]:
                if v != p:
                    stack.append((v, u, 0))
        else:
            idu = b[u]
            freq[idu] -= 1
            pairs -= freq[idu]    # 撤销与该值剩余点形成的配对
    print(' '.join(str(ans[i]) for i in range(1, n + 1)))
if __name__ == "__main__":
    solve()
Java
// 题意:对每个节点 i,统计根到 i 的路径上相同权值的无序点对数量
// 做法:值域压缩 + 显式栈 DFS,路径频次与配对数同步维护
import java.io.*;
import java.util.*;
public class Main {
    // —— 快速输入,适配 n <= 2e5 的规模 ——
    static final class FastScanner {
        private final InputStream in = System.in;
        private final byte[] buffer = new byte[1 << 16];
        private int ptr = 0, len = 0;
        private int read() throws IOException {
            if (ptr >= len) {
                len = in.read(buffer);
                ptr = 0;
                if (len <= 0) return -1;
            }
            return buffer[ptr++];
        }
        int nextInt() throws IOException {
            int c, sgn = 1, x = 0;
            do { c = read(); } while (c <= ' ');
            if (c == '-') { sgn = -1; c = read(); }
            while (c > ' ') {
                x = x * 10 + c - '0';
                c = read();
            }
            return x * sgn;
        }
    }
    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner();
        int n = fs.nextInt();
        int[] a = new int[n + 1];
        for (int i = 1; i <= n; i++) a[i] = fs.nextInt();
        // 建图(无向)
        ArrayList<Integer>[] g = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++) g[i] = new ArrayList<>();
        for (int i = 0; i < n - 1; i++) {
            int u = fs.nextInt(), v = fs.nextInt();
            g[u].add(v); g[v].add(u);
        }
        // 值域压缩
        int[] arr = new int[n];
        for (int i = 1; i <= n; i++) arr[i - 1] = a[i];
        Arrays.sort(arr);
        int m = 0;
        int[] uniq = new int[n];
        for (int x : arr) {
            if (m == 0 || x != uniq[m - 1]) uniq[m++] = x;
        }
        int[] b = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            int id = Arrays.binarySearch(uniq, 0, m, a[i]);
            b[i] = id; // 0..m-1
        }
        int[] freq = new int[m];          // 路径频次
        long[] ans = new long[n + 1];
        long pairs = 0L;
        // 显式栈:state=0 进点;state=1 退点
        int cap = 2 * n + 5;
        int[] stNode = new int[cap];
        int[] stPar  = new int[cap];
        int[] stType = new int[cap];
        int top = 0;
        stNode[top] = 1; stPar[top] = 0; stType[top] = 0; top++;
        while (top > 0) {
            top--;
            int u = stNode[top], p = stPar[top], type = stType[top];
            if (type == 0) {
                int id = b[u];
                int c = freq[id];
                pairs += c;        // 新点与已有同值点形成的新对
                freq[id] = c + 1;
                ans[u] = pairs;
                // 回退标记
                stNode[top] = u; stPar[top] = p; stType[top] = 1; top++;
                for (int v : g[u]) if (v != p) {
                    stNode[top] = v; stPar[top] = u; stType[top] = 0; top++;
                }
            } else {
                int id = b[u];
                freq[id]--;
                pairs -= freq[id]; // 撤销与该值剩余点形成的配对
            }
        }
        StringBuilder sb = new StringBuilder();
        for (int i = 1; i <= n; i++) {
            if (i > 1) sb.append(' ');
            sb.append(ans[i]);
        }
        System.out.println(sb.toString());
    }
}
C++
// 题意:对每个节点 i,统计从 1 到 i 的路径上,权值相同的无序点对数量
// 做法:值域压缩 + 显式栈 DFS,进入时增加贡献,退出时撤销贡献
#include <bits/stdc++.h>
using namespace std;
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n;
    if (!(cin >> n)) return 0;
    vector<long long> a(n + 1);
    for (int i = 1; i <= n; ++i) cin >> a[i];
    // 建图
    vector<vector<int>> g(n + 1);
    for (int i = 0; i < n - 1; ++i) {
        int u, v; cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    // 值域压缩
    vector<long long> all(a.begin() + 1, a.end());
    sort(all.begin(), all.end());
    all.erase(unique(all.begin(), all.end()), all.end());
    int m = (int)all.size();
    vector<int> b(n + 1);
    for (int i = 1; i <= n; ++i) {
        b[i] = int(lower_bound(all.begin(), all.end(), a[i]) - all.begin()); // 0..m-1
    }
    vector<int> freq(m, 0);    // 路径频次
    vector<long long> ans(n + 1, 0);
    long long pairs = 0;
    // 显式栈:(u,p,state) state=0 进点;1 退点
    vector<tuple<int,int,int>> st;
    st.reserve(2 * n + 5);
    st.emplace_back(1, 0, 0);
    while (!st.empty()) {
        auto [u, p, stt] = st.back(); st.pop_back();
        if (stt == 0) {
            int id = b[u];
            int c = freq[id];
            pairs += c;          // 新增与已有同值点的配对
            freq[id] = c + 1;
            ans[u] = pairs;
            st.emplace_back(u, p, 1);
            for (int v : g[u]) if (v != p) st.emplace_back(v, u, 0);
        } else {
            int id = b[u];
            freq[id]--;
            pairs -= freq[id];   // 撤销该节点贡献
        }
    }
    for (int i = 1; i <= n; ++i) {
        if (i > 1) cout << ' ';
        cout << ans[i];
    }
    cout << '\n';
    return 0;
}
        题目内容
Bingbong 有一棵 n 个节点的树,编号为 1~ n ,节点 i 的权值为 ai 。
现在你需要计算对于任意的节点 i∈∣1,n∣ ,有多少对 (x,y)(x=y) ,满足 x,y 两个节点均在 1→i 的简单路径上,且 ax=ay 。
输入描述
第一行一个数 n(2≦n≦2×105) ,表示节点总数。
第二行 n 个整数,表示节点 i 的权值 ai(1≦n≦109) 。
接下来 n−1 行,每行 2 个整数 u,v(1≦u,v≦n) ,表示当前无向边连接 u,v 两个节点。
保证输入是一棵树。
输出描述
输出包含一行,共 n 个整数,每个整数之同以空格隔开,含义如题所示。
样例1
输入
4
1 1 2 2
1 2
2 3
3 4
输出
0 1 1 2