#P3769. 第1题-节点树的简单路径
-
1000ms
Tried: 41
Accepted: 15
Difficulty: 5
所属公司 :
字节
时间 :2025年9月21日
-
算法标签>DFS
第1题-节点树的简单路径
解题思路
给定一棵以 1 为根的树,每个节点有权值 ai。对每个节点 i,要求统计在从根 1 到该节点 i 的简单路径上,有多少无序对 (x,y)(x=y)满足 ax=ay。
核心做法:根到当前节点的前缀统计 + DFS(栈模拟)
-
沿着树做一次从根出发的 DFS。维护一张“当前路径上各权值出现次数”的表
freq,以及“当前路径上的相等权值对总数”pairs。 -
当进入一个节点 u 时(把它加入根到当前的路径):
- 设该点的权值编号为
id,此权值在路径上已有次数为c=freq[id]。 - 新加入 u 会与这 c 个同值点各形成 1 个新对,所以
pairs += c,然后freq[id]++。 - 此时记录该节点答案
ans[u] = pairs(路径 1→u 上的总相等对数)。
- 设该点的权值编号为
-
当离开节点 u 时(从路径上移除):
- 先
freq[id]--,此时路径上还剩freq[id]个同值点;与刚才移除的 u 相关的配对数正是这freq[id]个,于是pairs -= freq[id]完成回溯。
- 先
为使计数结构高效,使用值域压缩将所有 ai 压成 1…m,用整型数组 freq[m] 统计即可(也可用哈希表,效果相近)。
用显式栈模拟 DFS,避免深递归爆栈;每个点“进栈一次、退栈一次”。
复杂度分析
- 每个节点仅“进入/退出”各一次,且每次更新是 O(1);
- 值域压缩耗时 O(nlogn)(排序),若用哈希可做到期望 O(n)。
- 总时间复杂度:O(nlogn)(含压缩)/ 期望 O(n)(哈希)。
- 额外空间:
freq、答案、栈和邻接表,均为 O(n)。
代码实现
Python
# 题意:对每个节点 i,统计从 1 到 i 的路径上,权值相等的无序点对数量
# 做法:DFS + 路径频次维护(值域压缩),显式栈模拟,进点加、出点撤销
import sys
def solve():
data = list(map(int, sys.stdin.buffer.read().split()))
it = iter(data)
n = next(it)
a = [0] * (n + 1)
for i in range(1, n + 1):
a[i] = next(it)
# 建图(无向)
g = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u = next(it); v = next(it)
g[u].append(v); g[v].append(u)
# 值域压缩
uniq = sorted(set(a[1:]))
idx = {v: i for i, v in enumerate(uniq)} # 0..m-1
b = [0] * (n + 1)
for i in range(1, n + 1):
b[i] = idx[a[i]]
m = len(uniq)
freq = [0] * m # 路径上各权值出现次数
ans = [0] * (n + 1) # 答案
pairs = 0 # 路径上相等权值对数(long long 级别)
# 显式栈:state=0 进点;state=1 退点
stack = [(1, 0, 0)] # (节点, 父亲, 状态)
while stack:
u, p, st = stack.pop()
if st == 0:
idu = b[u]
c = freq[idu]
pairs += c # 新点与已有同值点形成的对
freq[idu] = c + 1
ans[u] = pairs
stack.append((u, p, 1))
for v in g[u]:
if v != p:
stack.append((v, u, 0))
else:
idu = b[u]
freq[idu] -= 1
pairs -= freq[idu] # 撤销与该值剩余点形成的配对
print(' '.join(str(ans[i]) for i in range(1, n + 1)))
if __name__ == "__main__":
solve()
Java
// 题意:对每个节点 i,统计根到 i 的路径上相同权值的无序点对数量
// 做法:值域压缩 + 显式栈 DFS,路径频次与配对数同步维护
import java.io.*;
import java.util.*;
public class Main {
// —— 快速输入,适配 n <= 2e5 的规模 ——
static final class FastScanner {
private final InputStream in = System.in;
private final byte[] buffer = new byte[1 << 16];
private int ptr = 0, len = 0;
private int read() throws IOException {
if (ptr >= len) {
len = in.read(buffer);
ptr = 0;
if (len <= 0) return -1;
}
return buffer[ptr++];
}
int nextInt() throws IOException {
int c, sgn = 1, x = 0;
do { c = read(); } while (c <= ' ');
if (c == '-') { sgn = -1; c = read(); }
while (c > ' ') {
x = x * 10 + c - '0';
c = read();
}
return x * sgn;
}
}
public static void main(String[] args) throws Exception {
FastScanner fs = new FastScanner();
int n = fs.nextInt();
int[] a = new int[n + 1];
for (int i = 1; i <= n; i++) a[i] = fs.nextInt();
// 建图(无向)
ArrayList<Integer>[] g = new ArrayList[n + 1];
for (int i = 1; i <= n; i++) g[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = fs.nextInt(), v = fs.nextInt();
g[u].add(v); g[v].add(u);
}
// 值域压缩
int[] arr = new int[n];
for (int i = 1; i <= n; i++) arr[i - 1] = a[i];
Arrays.sort(arr);
int m = 0;
int[] uniq = new int[n];
for (int x : arr) {
if (m == 0 || x != uniq[m - 1]) uniq[m++] = x;
}
int[] b = new int[n + 1];
for (int i = 1; i <= n; i++) {
int id = Arrays.binarySearch(uniq, 0, m, a[i]);
b[i] = id; // 0..m-1
}
int[] freq = new int[m]; // 路径频次
long[] ans = new long[n + 1];
long pairs = 0L;
// 显式栈:state=0 进点;state=1 退点
int cap = 2 * n + 5;
int[] stNode = new int[cap];
int[] stPar = new int[cap];
int[] stType = new int[cap];
int top = 0;
stNode[top] = 1; stPar[top] = 0; stType[top] = 0; top++;
while (top > 0) {
top--;
int u = stNode[top], p = stPar[top], type = stType[top];
if (type == 0) {
int id = b[u];
int c = freq[id];
pairs += c; // 新点与已有同值点形成的新对
freq[id] = c + 1;
ans[u] = pairs;
// 回退标记
stNode[top] = u; stPar[top] = p; stType[top] = 1; top++;
for (int v : g[u]) if (v != p) {
stNode[top] = v; stPar[top] = u; stType[top] = 0; top++;
}
} else {
int id = b[u];
freq[id]--;
pairs -= freq[id]; // 撤销与该值剩余点形成的配对
}
}
StringBuilder sb = new StringBuilder();
for (int i = 1; i <= n; i++) {
if (i > 1) sb.append(' ');
sb.append(ans[i]);
}
System.out.println(sb.toString());
}
}
C++
// 题意:对每个节点 i,统计从 1 到 i 的路径上,权值相同的无序点对数量
// 做法:值域压缩 + 显式栈 DFS,进入时增加贡献,退出时撤销贡献
#include <bits/stdc++.h>
using namespace std;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
if (!(cin >> n)) return 0;
vector<long long> a(n + 1);
for (int i = 1; i <= n; ++i) cin >> a[i];
// 建图
vector<vector<int>> g(n + 1);
for (int i = 0; i < n - 1; ++i) {
int u, v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
// 值域压缩
vector<long long> all(a.begin() + 1, a.end());
sort(all.begin(), all.end());
all.erase(unique(all.begin(), all.end()), all.end());
int m = (int)all.size();
vector<int> b(n + 1);
for (int i = 1; i <= n; ++i) {
b[i] = int(lower_bound(all.begin(), all.end(), a[i]) - all.begin()); // 0..m-1
}
vector<int> freq(m, 0); // 路径频次
vector<long long> ans(n + 1, 0);
long long pairs = 0;
// 显式栈:(u,p,state) state=0 进点;1 退点
vector<tuple<int,int,int>> st;
st.reserve(2 * n + 5);
st.emplace_back(1, 0, 0);
while (!st.empty()) {
auto [u, p, stt] = st.back(); st.pop_back();
if (stt == 0) {
int id = b[u];
int c = freq[id];
pairs += c; // 新增与已有同值点的配对
freq[id] = c + 1;
ans[u] = pairs;
st.emplace_back(u, p, 1);
for (int v : g[u]) if (v != p) st.emplace_back(v, u, 0);
} else {
int id = b[u];
freq[id]--;
pairs -= freq[id]; // 撤销该节点贡献
}
}
for (int i = 1; i <= n; ++i) {
if (i > 1) cout << ' ';
cout << ans[i];
}
cout << '\n';
return 0;
}
题目内容
Bingbong 有一棵 n 个节点的树,编号为 1~ n ,节点 i 的权值为 ai 。
现在你需要计算对于任意的节点 i∈∣1,n∣ ,有多少对 (x,y)(x=y) ,满足 x,y 两个节点均在 1→i 的简单路径上,且 ax=ay 。
输入描述
第一行一个数 n(2≦n≦2×105) ,表示节点总数。
第二行 n 个整数,表示节点 i 的权值 ai(1≦n≦109) 。
接下来 n−1 行,每行 2 个整数 u,v(1≦u,v≦n) ,表示当前无向边连接 u,v 两个节点。
保证输入是一棵树。
输出描述
输出包含一行,共 n 个整数,每个整数之同以空格隔开,含义如题所示。
样例1
输入
4
1 1 2 2
1 2
2 3
3 4
输出
0 1 1 2