#P3676. 第4题-树上的红色联通块
-
1000ms
Tried: 21
Accepted: 3
Difficulty: 9
所属公司 :
美团
时间 :2025年9月13日-算法岗
-
算法标签>树上差分
第4题-树上的红色联通块
思路
计算连通块数量有一个经典的公式:
连通块数量=点数−边数在这个问题中,我们要求的是红色连通块的数量,所以公式可以相应地调整为:
红色连通块数量=红色结点数量−红色边数这里的“红色边”指的是连接两个红色结点的边。
因此,问题的核心就转化为了如何高效地求出所有操作结束后,树上总共有多少个不同的红色结点和多少条不同的红色边。
对于树上的路径操作问题,一个常见的技巧是树上差分,结合最近公共祖先 (LCA) 来实现。
-
统计红色边的数量 一条边 (x,parent(x)) 会被染红,当且仅当它至少被一次操作的路径所覆盖。一条从 u 到 v 的路径会覆盖边 (x,parent(x)),当且仅当 u 和 v 中一个点在 x 的子树中,另一个点不在。
我们可以使用一个差分数组
d来记录每条边被路径覆盖的次数。对于一次操作 (u,v),我们找到它们的最近公共祖先 l=lca(u,v)。路径 u→v 可以拆分为 u→l 和 v→l。- 路径 u→l 覆盖了从 u 到 l 的所有边。
- 路径 v→l 覆盖了从 v 到 l 的所有边。
我们可以通过在端点打标记的方式来一次性更新整条路径。具体操作如下:
d[u]++d[v]++d[l] -= 2
在处理完所有 q 次查询后,我们从叶子结点向上进行一次 DFS,计算差分数组的前缀和(在这里是子树和)。对于每个结点 x,执行
d[x] += sum(d[child]),其中child是 x 的所有子节点。 完成 DFS 后,d[x]的值就表示边 (x,parent(x)) 被所有路径覆盖的总次数。如果d[x] > 0,说明这条边是红色的。我们遍历所有非根结点,统计d[x] > 0的数量,即可得到红色边的总数。 -
统计红色结点的数量 一个结点 x 会被染红,当且仅当它至少被一次操作的路径所覆盖。一个结点被路径覆盖,有两种情况:
- 路径穿过该结点。例如,路径的一个端点在 x 的子树中,另一个端点不在。
- 该结点是路径的最高点,即路径的LCA。
第一种情况的数量,就是我们刚刚为边计算的
d[x]。d[x]表示有多少条路径穿过了边 (x,parent(x)),这些路径也必然穿过了结点 x。第二种情况,我们需要单独统计。我们可以使用另一个数组
lca_cnt,对于每次查询 (u,v),我们计算出 l=lca(u,v),然后执行lca_cnt[l]++。因此,一个结点 x 被路径覆盖的总次数就是
d[x] + lca_cnt[x]。如果这个值大于 0,说明结点 x 是红色的。我们遍历所有结点,统计d[x] + lca_cnt[x] > 0的数量,即可得到红色结点的总数。注意: 对于
(u, u)这样的查询,路径只包含一个点。按照上述d数组的更新方法d[u]+=2, d[lca(u,u)]-=2,d[u]没有变化。但是lca_cnt[u]会加一,所以d[u]+lca_cnt[u]的值会大于0,这正确地将点u标记为红色。
算法步骤总结:
-
通过一次 DFS 预处理出每个结点的深度
dep、父结点fa,为计算 LCA 做准备。 -
使用倍增法预处理出 LCA 查询需要的数据结构。
-
初始化差分数组
d和 LCA 计数数组lca_cnt为 0。 -
对于 q 次操作 (u,v),计算 l=lca(u,v),并更新数组:
d[u]++,d[v]++,d[l]-=2,lca_cnt[l]++。 -
执行一次 DFS,自底向上计算
d数组的子树和。 -
遍历所有结点,计算红色结点数
red_nodes和红色边数red_edges。- 如果
d[i] + lca_cnt[i] > 0,则red_nodes++。 - 如果
i不是根结点且d[i] > 0,则red_edges++。
- 如果
-
最终答案为
red_nodes - red_edges。
C++
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int MAXN = 200005;
const int LOGN = 18; // log2(200005) is approx 17.6
vector<int> adj[MAXN];
int parent[MAXN][LOGN];
int depth[MAXN];
int d[MAXN];
int lca_cnt[MAXN];
int n, q;
// 预处理DFS,计算深度和父节点
void dfs_pre(int u, int p, int dth) {
depth[u] = dth;
parent[u][0] = p;
for (int v : adj[u]) {
if (v != p) {
dfs_pre(v, u, dth + 1);
}
}
}
// 预处理LCA的倍增数组
void preprocess_lca() {
dfs_pre(1, 0, 0); // 假设1是根节点,父节点为0,深度为0
for (int j = 1; j < LOGN; ++j) {
for (int i = 1; i <= n; ++i) {
if (parent[i][j - 1] != 0) {
parent[i][j] = parent[parent[i][j - 1]][j - 1];
}
}
}
}
// 查询LCA
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);
for (int j = LOGN - 1; j >= 0; --j) {
if (parent[u][j] != 0 && depth[parent[u][j]] >= depth[v]) {
u = parent[u][j];
}
}
if (u == v) return u;
for (int j = LOGN - 1; j >= 0; --j) {
if (parent[u][j] != 0 && parent[v][j] != 0 && parent[u][j] != parent[v][j]) {
u = parent[u][j];
v = parent[v][j];
}
}
return parent[u][0];
}
// DFS计算差分数组的子树和
void dfs_sum(int u, int p) {
for (int v : adj[u]) {
if (v != p) {
dfs_sum(v, u);
d[u] += d[v];
}
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cin >> n >> q;
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
preprocess_lca();
for (int i = 0; i < q; ++i) {
int u, v;
cin >> u >> v;
int l = lca(u, v);
d[u]++;
d[v]++;
d[l] -= 2;
lca_cnt[l]++;
}
dfs_sum(1, 0);
long long red_nodes = 0;
long long red_edges = 0;
for (int i = 1; i <= n; ++i) {
// 如果节点被覆盖次数 > 0,则为红色节点
if (d[i] + lca_cnt[i] > 0) {
red_nodes++;
}
// 如果边(i, parent[i][0])被覆盖次数 > 0,则为红色边
if (i != 1 && d[i] > 0) { // 根节点没有父边
red_edges++;
}
}
cout << red_nodes - red_edges << endl;
return 0;
}
Python
import sys
# 增加递归深度限制以防树过深
sys.setrecursionlimit(200005 * 2)
def solve():
n, q = map(int, sys.stdin.readline().split())
adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u, v = map(int, sys.stdin.readline().split())
adj[u].append(v)
adj[v].append(u)
LOGN = (n + 1).bit_length()
parent = [[0] * LOGN for _ in range(n + 1)]
depth = [-1] * (n + 1)
# 预处理DFS,计算深度和父节点
q_dfs = [(1, 0, 0)] # 使用栈代替递归进行DFS
depth[0] = -1
head = 0
while head < len(q_dfs):
u, p, d = q_dfs[head]
head += 1
depth[u] = d
parent[u][0] = p
for v in adj[u]:
if v != p:
q_dfs.append((v, u, d + 1))
# 预处理LCA的倍增数组
for j in range(1, LOGN):
for i in range(1, n + 1):
if parent[i][j - 1] != 0:
parent[i][j] = parent[parent[i][j - 1]][j - 1]
# 查询LCA
def lca(u, v):
if depth[u] < depth[v]:
u, v = v, u
for j in range(LOGN - 1, -1, -1):
if parent[u][j] != 0 and depth[parent[u][j]] >= depth[v]:
u = parent[u][j]
if u == v:
return u
for j in range(LOGN - 1, -1, -1):
if parent[u][j] != 0 and parent[v][j] != 0 and parent[u][j] != parent[v][j]:
u = parent[u][j]
v = parent[v][j]
return parent[u][0]
d = [0] * (n + 1)
lca_cnt = [0] * (n + 1)
for _ in range(q):
u, v = map(int, sys.stdin.readline().split())
l = lca(u, v)
d[u] += 1
d[v] += 1
d[l] -= 2
lca_cnt[l] += 1
# DFS计算差分数组的子树和
order = [item[0] for item in sorted(enumerate(depth), key=lambda x: x[1], reverse=True) if item[0] != 0]
for u in order:
p = parent[u][0]
if p != 0:
d[p] += d[u]
red_nodes = 0
red_edges = 0
for i in range(1, n + 1):
# 如果节点被覆盖次数 > 0,则为红色节点
if d[i] + lca_cnt[i] > 0:
red_nodes += 1
# 如果边(i, parent[i][0])被覆盖次数 > 0,则为红色边
if i != 1 and d[i] > 0:
red_edges += 1
print(red_nodes - red_edges)
solve()
Java
import java.io.*;
import java.util.*;
public class Main {
static int n, q;
static ArrayList<Integer>[] adj;
static int[][] parent;
static int[] depth;
static int[] d;
static int[] lca_cnt;
static int LOGN;
// 预处理DFS,计算深度和父节点
static void dfs_pre(int u, int p, int dth) {
depth[u] = dth;
parent[u][0] = p;
for (int v : adj[u]) {
if (v != p) {
dfs_pre(v, u, dth + 1);
}
}
}
// 预处理LCA的倍增数组
static void preprocess_lca() {
depth = new int[n + 1];
LOGN = (int) (Math.log(n) / Math.log(2)) + 1;
parent = new int[n + 1][LOGN];
dfs_pre(1, 0, 0); // 假设1是根节点
for (int j = 1; j < LOGN; ++j) {
for (int i = 1; i <= n; ++i) {
if (parent[i][j - 1] != 0) {
parent[i][j] = parent[parent[i][j - 1]][j - 1];
}
}
}
}
// 查询LCA
static int lca(int u, int v) {
if (depth[u] < depth[v]) {
int temp = u;
u = v;
v = temp;
}
for (int j = LOGN - 1; j >= 0; --j) {
if (parent[u][j] != 0 && depth[parent[u][j]] >= depth[v]) {
u = parent[u][j];
}
}
if (u == v) return u;
for (int j = LOGN - 1; j >= 0; --j) {
if (parent[u][j] != 0 && parent[v][j] != 0 && parent[u][j] != parent[v][j]) {
u = parent[u][j];
v = parent[v][j];
}
}
return parent[u][0];
}
// DFS计算差分数组的子树和
static void dfs_sum(int u, int p) {
for (int v : adj[u]) {
if (v != p) {
dfs_sum(v, u);
d[u] += d[v];
}
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
n = Integer.parseInt(st.nextToken());
q = Integer.parseInt(st.nextToken());
adj = new ArrayList[n + 1];
for (int i = 0; i <= n; i++) {
adj[i] = new ArrayList<>();
}
for (int i = 0; i < n - 1; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
adj[u].add(v);
adj[v].add(u);
}
preprocess_lca();
d = new int[n + 1];
lca_cnt = new int[n + 1];
for (int i = 0; i < q; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
int l = lca(u, v);
d[u]++;
d[v]++;
d[l] -= 2;
lca_cnt[l]++;
}
dfs_sum(1, 0);
long red_nodes = 0;
long red_edges = 0;
for (int i = 1; i <= n; i++) {
// 如果节点被覆盖次数 > 0,则为红色节点
if (d[i] + lca_cnt[i] > 0) {
red_nodes++;
}
// 如果边(i, parent[i][0])被覆盖次数 > 0,则为红色边
if (i != 1 && d[i] > 0) {
red_edges++;
}
}
System.out.println(red_nodes - red_edges);
}
}
题目内容
小美拿到一棵 n 个结点的 树,初始都是白色,q 次操作。
给定 u,v ,把 u 到 v 的简单路径上的所有点染红。
请你输出树上最后有多少个红色连通块。
【名词解释】 树:指这样的一张图,其由 n 个节点和 n−1 条边构成,其上的任意两个点都连通且不存在环。
简单路径:在图上由若干顶点构成的序列,序列中顶点互不重复,且相邻顶点有边相连;路径长度为其中边的数量。
某一颜色的连通块:也称连通分量,满足,
-
是原图的一个子图;
-
连通块内的任意两个顶点之间都存在路径相连,且路径上的点也在连通块内;
-
连通块内所有顶点的颜色均为目标颜色;
-
是极大的,即不能再通过添加原图中的其他顶点而依旧保持连通性;
单独的点也构成一个连通块。连通块的大小即为连通块中顶点的数量。
输入描述
第一行两个整数 n,q(1≤n,q≤2×105) 表示结点总数和操作次数。
接下来 n−1 行,每行两个整数 u,v(1≤u,v≤n,u=v) ,表示结点 u 和 v 之间存在一条无向边,输入保证是一棵树。接下来 q 行,每行给出两个整数u,v(1≤u,υ≤n) ,表示将此简单路径上的所有结点染成红色。
输出描述
一个整数,最后当前树上存在多少个红色连通块。
样例1
输入
5 3
1 2
1 3
1 4
5 3
3 3
4 3
1 5
输出
1