#P4305. 第3题-子树与节点对的距离和
-
1000ms
Tried: 2
Accepted: 1
Difficulty: 8
所属公司 :
米哈游
时间 :2025年10月26日
-
算法标签>euler 序前缀和
第3题-子树与节点对的距离和
解题思路
给定一棵以 1 为根的树。对任意节点 u,记其子树为 S(u)。每次询问要计算
v,w∈S(u),v<w∑dist(v,w)其中 dist 为树上最短路径边数。
关键算法与核心想法
-
边贡献法 一对点的距离等于其路径上经过的边数。对子树 S(u) 内的一条父子边 (p,x),切断后得到两个连通块:大小分别为 ∣A∣=size[x]、∣B∣=size[u]−size[x]。所有跨越这条边的点对恰有 ∣A∣⋅∣B∣ 个,每个对距离贡献 1。 因而
$$\text{ans}[u] \;=\; \sum_{x\in S(u),\,x\neq u}\text{size}[x]\cdot\big(\text{size}[u]-\text{size}[x]\big). $$ -
Euler 序 + 前缀和 设整棵树以 1 为根,做一次 DFS,得到:
tin[u]、tout[u]:子树区间 [tin[u],tout[u]];size[u]:以根为 1 的全局子树大小;euler[]:按进入时间的节点序列。 则 S(u) 的所有节点在euler上是一段连续区间。
记
$$S_1(u)=\sum_{x\in S(u)}\text{size}[x],\qquad S_2(u)=\sum_{x\in S(u)}\text{size}[x]^2 . $$由上式化简可得一个非常简洁的公式(把 x=u 的项并入再相减):
$$\boxed{\;\text{ans}[u] = \text{size}[u]\cdot S_1(u) - S_2(u)\;} $$因为
$$S_1(u)=P_1[tout[u]]-P_1[tin[u]-1],\quad S_2(u)=P_2[tout[u]]-P_2[tin[u]-1]. $$euler上区间可用前缀和 P1,P2 在 O(1) 求得: -
实现方法
- 建图、以 1 为根做一次 DFS(可递归或迭代)求
size/tin/tout/euler。 - 在
euler序上构建两组前缀和prefSize与prefSq。 - 每次询问节点 u 即按上式 ans[u]=size[u]⋅S1(u)−S2(u) 计算。
- 结果可能达到 O(n2) 量级,需使用 64 位整型。
- 建图、以 1 为根做一次 DFS(可递归或迭代)求
复杂度分析
- 预处理(一次 DFS + 两个前缀和):O(n)。
- 每次查询:O(1)。
- 总时间复杂度:O(n+m)。
- 空间复杂度:O(n)(存图与若干数组)。
代码实现
Python
import sys
sys.setrecursionlimit(1 << 25)
# 功能函数:返回每个查询节点的答案列表
def solve_tree_pair_sum(n, edges, queries):
g = [[] for _ in range(n + 1)]
for u, v in edges:
g[u].append(v)
g[v].append(u)
tin = [0] * (n + 1)
tout = [0] * (n + 1)
size = [0] * (n + 1)
euler = []
timer = [0]
# 递归 DFS,计算进入时间、子树大小以及 euler 序
def dfs(u, p):
timer[0] += 1
tin[u] = timer[0]
euler.append(u)
for v in g[u]:
if v == p:
continue
dfs(v, u)
size[u] += size[v]
size[u] += 1
tout[u] = timer[0]
dfs(1, 0)
# 在 euler 序上做前缀和
prefSize = [0] * (n + 1)
prefSq = [0] * (n + 1)
for i in range(1, n + 1):
x = euler[i - 1]
prefSize[i] = prefSize[i - 1] + size[x]
prefSq[i] = prefSq[i - 1] + size[x] * size[x]
# 查询
ans = []
for u in queries:
s1 = prefSize[tout[u]] - prefSize[tin[u] - 1]
s2 = prefSq[tout[u]] - prefSq[tin[u] - 1]
res = size[u] * s1 - s2
ans.append(str(res))
return ans
def main():
data = list(map(int, sys.stdin.buffer.read().split()))
it = iter(data)
n = next(it); m = next(it)
edges = [(next(it), next(it)) for _ in range(n - 1)]
queries = [next(it) for _ in range(m)]
out = solve_tree_pair_sum(n, edges, queries)
sys.stdout.write("\n".join(out))
if __name__ == "__main__":
main()
Java
import java.io.*;
import java.util.*;
/* ACM 风格,主类名必须为 Main */
public class Main {
// 功能函数:计算所有查询答案
static List<Long> solve(int n, List<int[]> edges, int[] queries) {
List<Integer>[] g = new ArrayList[n + 1];
for (int i = 1; i <= n; i++) g[i] = new ArrayList<>();
for (int[] e : edges) {
g[e[0]].add(e[1]);
g[e[1]].add(e[0]);
}
int[] tin = new int[n + 1];
int[] tout = new int[n + 1];
long[] size = new long[n + 1];
int[] parent = new int[n + 1];
int[] iter = new int[n + 1];
int[] order = new int[n]; // euler 序(进入时间)
int timer = 0;
// 迭代 DFS,计算 tin/tout 与 euler 序
Deque<Integer> st = new ArrayDeque<>();
st.push(1);
parent[1] = 0;
while (!st.isEmpty()) {
int u = st.peek();
if (tin[u] == 0) { // 第一次到达
tin[u] = ++timer;
order[timer - 1] = u;
}
if (iter[u] < g[u].size()) {
int v = g[u].get(iter[u]++);
if (v == parent[u]) continue;
parent[v] = u;
st.push(v);
} else { // 退出结点
tout[u] = timer;
st.pop();
}
}
// 自底向上计算 size
for (int i = n - 1; i >= 0; i--) {
int u = order[i];
long s = 1;
for (int v : g[u]) if (parent[v] == u) s += size[v];
size[u] = s;
}
// euler 序上的前缀和
long[] prefSize = new long[n + 1];
long[] prefSq = new long[n + 1];
for (int i = 1; i <= n; i++) {
int x = order[i - 1];
prefSize[i] = prefSize[i - 1] + size[x];
prefSq[i] = prefSq[i - 1] + size[x] * size[x];
}
// 回答查询
List<Long> ans = new ArrayList<>();
for (int u : queries) {
long s1 = prefSize[tout[u]] - prefSize[tin[u] - 1];
long s2 = prefSq[tout[u]] - prefSq[tin[u] - 1];
long res = size[u] * s1 - s2;
ans.add(res);
}
return ans;
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());
int m = Integer.parseInt(st.nextToken());
List<int[]> edges = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
edges.add(new int[]{u, v});
}
int[] queries = new int[m];
for (int i = 0; i < m; i++) {
queries[i] = Integer.parseInt(br.readLine().trim());
}
List<Long> ans = solve(n, edges, queries);
StringBuilder sb = new StringBuilder();
for (long x : ans) sb.append(x).append('\n');
System.out.print(sb.toString());
}
}
C++
#include <bits/stdc++.h>
using namespace std;
// 功能函数:返回每个查询的答案
vector<long long> solve(int n, const vector<pair<int,int>>& edges, const vector<int>& queries) {
vector<vector<int>> g(n + 1);
for (auto e : edges) {
g[e.first].push_back(e.second);
g[e.second].push_back(e.first);
}
vector<int> tin(n + 1), tout(n + 1), parent(n + 1, 0), iterIdx(n + 1, 0);
vector<int> euler; euler.reserve(n);
int timer = 0;
// 迭代 DFS,生成 tin/tout 与 euler 序
vector<int> st; st.push_back(1);
while (!st.empty()) {
int u = st.back();
if (tin[u] == 0) { // 首次到达
tin[u] = ++timer;
euler.push_back(u);
}
if (iterIdx[u] < (int)g[u].size()) {
int v = g[u][iterIdx[u]++];
if (v == parent[u]) continue;
parent[v] = u;
st.push_back(v);
} else {
tout[u] = timer;
st.pop_back();
}
}
// 自底向上计算 size
vector<long long> sz(n + 1, 0);
for (int i = n - 1; i >= 0; --i) {
int u = euler[i];
long long s = 1;
for (int v : g[u]) if (parent[v] == u) s += sz[v];
sz[u] = s;
}
// euler 序前缀和
vector<long long> prefSize(n + 1, 0), prefSq(n + 1, 0);
for (int i = 1; i <= n; ++i) {
int x = euler[i - 1];
prefSize[i] = prefSize[i - 1] + sz[x];
prefSq[i] = prefSq[i - 1] + sz[x] * sz[x];
}
// 查询
vector<long long> ans;
ans.reserve(queries.size());
for (int u : queries) {
long long s1 = prefSize[tout[u]] - prefSize[tin[u] - 1];
long long s2 = prefSq[tout[u]] - prefSq[tin[u] - 1];
long long res = sz[u] * s1 - s2;
ans.push_back(res);
}
return ans;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m;
if (!(cin >> n >> m)) return 0;
vector<pair<int,int>> edges;
edges.reserve(n - 1);
for (int i = 0; i < n - 1; ++i) {
int u, v; cin >> u >> v;
edges.emplace_back(u, v);
}
vector<int> qs(m);
for (int i = 0; i < m; ++i) cin >> qs[i];
vector<long long> ans = solve(n, edges, qs);
for (auto x : ans) cout << x << "\n";
return 0;
}
题目内容
给定一棵节点数为 n 的,树的根节点为 1 。
对树中的任意节点 u ,定义其子树为以为根的所有节点集合,记为 S(u) 。
现有 m 次查询,每次查询给定一个节点 u ,请你计算子树 S(u) 中所有节点对 (v,w) 之间的距离和,
$\sum_{v, w \in S(u), v<w} \operatorname{dist}(v, w),$
其中 dist(v,w) 表示节点 v 与 w 之间的距离。
【名词解释】
-
树:树是一种连接无环的无向图。
-
子树:子树指给定节点及其所有后代节点组成的连通子图。
-
距离:距离表示树中两节点之间的边数最短路径长度。
输入描述
第一行输入两个整数 n 和 m(1≦n,m≦2×105),分别表示树的节点数量与查询次数。
接下来 n−1 行,每行输入两个整数 ui,vi(1≦ui,vi≦n;ui=vi),表示树的一条边。
接下来 m 行,每行输入一个整数 u(1≦u≦n) ,表示一次查询的节点编号。
输出描述
对于每次查询,在一行上输出一个整数,表示对应节点子树中所有节点对距离和。
样例1
输入
5 3
1 2
1 3
3 4
3 5
1
3
4
输出
18
4
0
说明
在这个样例中:
-
节点 1 的子树为 {1,2,3,4,5},其所有 10 对距离之和为 18 ;
-
节点 3 的子树为 {3,4,5},距离和 1+1+2=4 ;
-
节点 4 的子树仅为自身,贡献距离和 0 .
样例2
输入
3 2
1 2
2 3
2
3
输出
1
0
说明
在这个样例中:
- 节点 2 的子树为 {2,3},距离和 1 ;
- 节点 3 的子树为 {3},距离和 0 .