#P3287. 第2题-网络整改
-
1000ms
Tried: 515
Accepted: 118
Difficulty: 4
所属公司 :
华为
时间 :2025年6月11日-暑期实习
-
算法标签>动态规划
第2题-网络整改
题解
题目描述
给定一棵以节点 1 为根的树型网络,包含 n 台设备(节点编号 1 到 n)。网络中任意两节点通过边相连,最后没有子节点的称为“边缘设备”。希望移除尽可能少的节点,使得剩下网络中所有边缘设备到根设备的距离都相同。输出最少需要移除的节点数。
思路
- 先从根节点 1 做一次 BFS/DFS,计算每个节点到根的初始距离 depth[v]。
- 设定一个目标距离 H,希望所有保留后的边缘节点深度都为 H。
- 对于每个节点 v,定义状态
- 若 depth[v]>H,则 dp[v][H]=−∞(此节点深度已超出目标,无法保留)。
- 当 depth[v]=H 时,节点 v 必须成为边缘节点,故保留它本身计为 1。
- 当 depth[v]<H 时,节点 v 必须至少保留一个子节点路径以达到深度 H,因此累加所有能达成的子树的最大保留节点数。
- 根节点的 dp[1][H] 即是在目标深度 H 下可保留的最大节点数。遍历所有可能的 H(即树的最大深度范围),选取 maxHdp[1][H],则最少移除数为n−Hmaxdp[1][H].
C++
#include <bits/stdc++.h>
using namespace std;
const int INF = 1e9;
// 全局变量
int n;
vector<vector<int>> adj;
vector<int> depth;
vector<vector<int>> children;
int maxDepth;
// 计算每个节点深度并构建子树
void dfsDepth(int u, int p) {
for (int v : adj[u]) {
if (v == p) continue;
depth[v] = depth[u] + 1;
maxDepth = max(maxDepth, depth[v]);
children[u].push_back(v);
dfsDepth(v, u);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n;
adj.assign(n+1, {});
for (int i = 0; i < n-1; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
depth.assign(n+1, 0);
children.assign(n+1, {});
maxDepth = 0;
dfsDepth(1, 0);
// dp[v][h]: 子树 v 在目标叶深度 h 时最大保留节点数
// 为节省空间,用滚动数组:prev[h], cur[h]
vector<int> best(n+1, -INF), nxt;
int answer = 0;
// 对每个候选深度 h 从 0 到 maxDepth
for (int h = 0; h <= maxDepth; h++) {
// 自底向上后序遍历:我们可以用一次栈模拟,也可按节点编号逆序(因为深度越大后序肯定处理先)
// 这里简单地按深度从大到小分层遍历
vector<vector<int>> byDepth(maxDepth+1);
for (int v = 1; v <= n; v++) {
byDepth[depth[v]].push_back(v);
}
best.assign(n+1, -INF);
// 从最大深度层到 0 层
for (int d = maxDepth; d >= 0; d--) {
for (int v : byDepth[d]) {
if (depth[v] > h) {
best[v] = -INF;
} else if (depth[v] == h) {
// 变为叶子
best[v] = 1;
} else {
int sum = 0;
for (int u : children[v]) {
if (best[u] > 0) sum += best[u];
}
if (sum > 0) best[v] = sum + 1;
else best[v] = -INF;
}
}
}
answer = max(answer, best[1]);
}
// 最少移除数 = 总数 - 最大保留数
cout << (n - answer) << "\n";
return 0;
}
Python
import sys
sys.setrecursionlimit(10000)
n = int(sys.stdin.readline())
adj = [[] for _ in range(n+1)]
for _ in range(n-1):
u, v = map(int, sys.stdin.readline().split())
adj[u].append(v)
adj[v].append(u)
depth = [0]*(n+1)
children = [[] for _ in range(n+1)]
max_depth = 0
def dfs(u, p):
global max_depth
for v in adj[u]:
if v == p: continue
depth[v] = depth[u] + 1
max_depth = max(max_depth, depth[v])
children[u].append(v)
dfs(v, u)
dfs(1, 0)
# dp[v][h] 用滚动数组 best[v] 存储当前 h 的值
answer = 0
for h in range(max_depth+1):
# 按深度分层
by_depth = [[] for _ in range(max_depth+1)]
for v in range(1, n+1):
by_depth[depth[v]].append(v)
best = [-10**9]*(n+1)
for d in range(max_depth, -1, -1):
for v in by_depth[d]:
if depth[v] > h:
best[v] = -10**9
elif depth[v] == h:
best[v] = 1
else:
s = sum(best[u] for u in children[v] if best[u] > 0)
best[v] = s + 1 if s > 0 else -10**9
answer = max(answer, best[1])
print(n - answer)
Java
import java.io.*;
import java.util.*;
public class Main {
static int n;
static List<List<Integer>> adj;
static int[] depth;
static List<List<Integer>> children;
static int maxDepth = 0;
static void dfs(int u, int p) {
for (int v : adj.get(u)) {
if (v == p) continue;
depth[v] = depth[u] + 1;
maxDepth = Math.max(maxDepth, depth[v]);
children.get(u).add(v);
dfs(v, u);
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
n = Integer.parseInt(br.readLine());
adj = new ArrayList<>();
for (int i = 0; i <= n; i++) adj.add(new ArrayList<>());
for (int i = 0; i < n-1; i++) {
StringTokenizer st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken()), v = Integer.parseInt(st.nextToken());
adj.get(u).add(v);
adj.get(v).add(u);
}
depth = new int[n+1];
children = new ArrayList<>();
for (int i = 0; i <= n; i++) children.add(new ArrayList<>());
dfs(1, 0);
int answer = 0;
for (int h = 0; h <= maxDepth; h++) {
List<List<Integer>> byDepth = new ArrayList<>();
for (int i = 0; i <= maxDepth; i++) byDepth.add(new ArrayList<>());
for (int v = 1; v <= n; v++) {
byDepth.get(depth[v]).add(v);
}
int[] best = new int[n+1];
Arrays.fill(best, Integer.MIN_VALUE / 2);
for (int d = maxDepth; d >= 0; d--) {
for (int v : byDepth.get(d)) {
if (depth[v] > h) {
best[v] = Integer.MIN_VALUE / 2;
} else if (depth[v] == h) {
best[v] = 1;
} else {
int sum = 0;
for (int u : children.get(v)) {
if (best[u] > 0) sum += best[u];
}
best[v] = (sum > 0 ? sum + 1 : Integer.MIN_VALUE / 2);
}
}
}
answer = Math.max(answer, best[1]);
}
System.out.println(n - answer);
}
}
题目内容

在一个树形的网络拓扑中,有 n 台设备,编号 1 到 n ,其中我们固定 1 为根设备,如上图:根设备下可下挂多台设备(如设备编号 2、3 ),以此类推每一台设备下都可能下挂1台或者多台设备,最后没有下挂设备的设备成为边缘设备(如设备 3、5、6、7 )。
现在我们希望对网络进行整改,将组网中的部分设备移除,使得所有的边缘设备到根设备的距离相同,请你计算下最少需要移除多少台设备。
如上图:我们只需要移除 3 号和 5 号设备,可以使得剩下的所有边缘设备( 6、7 )到根设备的距离相同。
注:整个网络是单个连通的树型组网且没有环
输入描述
用例第一行为一个整数 n(3≤n≤5000) ,代表网络设备数目。
接下来 n 行每行包含两个整数 u,v(1≤u,v≤n,u=v) ,代表设备 u 与设备 v 相连接(注意仅代表链接关系,不表明确父子关系)。
注:我们保证每个设备的编号都小于等于 n ,且不重复;n 个网络设备,必然有 n−1 条连接。
输出描述
输出最少移除多少台设备,可以使得剩下的所有边缘设备到根设备距离都相同。
样例1
输入
7
1 2
1 3
2 4
2 5
4 6
4 7
输出
2
说明
如题目实例图中:我们移除 3 号和 5 号 2 台设备,可以使得剩下的所有边缘设备(6/7)到根设备的距离相同。
样例2
输入
5
4 1
2 4
5 1
5 3
输出
0
说明
该用例中的树形图为如下,可见不需要移除任何设备就满足边缘设备(2和3)到根设备1的距离都相等。

样例3
输入
7
1 2
2 3
3 4
1 5
1 6
1 7
输出
2
说明
该用例移除设备 4 后,再移除设备 3 即可。
