#P3287. 第2题-网络整改
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 505
            Accepted: 116
            Difficulty: 4
            
          
          
          
                       所属公司 : 
                              华为
                                
            
                        
              时间 :2025年6月11日-暑期实习
                              
                      
          
 
- 
                        算法标签>动态规划          
 
第2题-网络整改
题解
题目描述
给定一棵以节点 1 为根的树型网络,包含 n 台设备(节点编号 1 到 n)。网络中任意两节点通过边相连,最后没有子节点的称为“边缘设备”。希望移除尽可能少的节点,使得剩下网络中所有边缘设备到根设备的距离都相同。输出最少需要移除的节点数。
思路
- 先从根节点 1 做一次 BFS/DFS,计算每个节点到根的初始距离 depth[v]。
 - 设定一个目标距离 H,希望所有保留后的边缘节点深度都为 H。
 - 对于每个节点 v,定义状态
- 若 depth[v]>H,则 dp[v][H]=−∞(此节点深度已超出目标,无法保留)。
 - 当 depth[v]=H 时,节点 v 必须成为边缘节点,故保留它本身计为 1。
 - 当 depth[v]<H 时,节点 v 必须至少保留一个子节点路径以达到深度 H,因此累加所有能达成的子树的最大保留节点数。
 
 - 根节点的 dp[1][H] 即是在目标深度 H 下可保留的最大节点数。遍历所有可能的 H(即树的最大深度范围),选取 maxHdp[1][H],则最少移除数为n−Hmaxdp[1][H].
 
C++
#include <bits/stdc++.h>
using namespace std;
const int INF = 1e9;
// 全局变量
int n;
vector<vector<int>> adj;
vector<int> depth;
vector<vector<int>> children;
int maxDepth;
// 计算每个节点深度并构建子树
void dfsDepth(int u, int p) {
    for (int v : adj[u]) {
        if (v == p) continue;
        depth[v] = depth[u] + 1;
        maxDepth = max(maxDepth, depth[v]);
        children[u].push_back(v);
        dfsDepth(v, u);
    }
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n;
    adj.assign(n+1, {});
    for (int i = 0; i < n-1; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    depth.assign(n+1, 0);
    children.assign(n+1, {});
    maxDepth = 0;
    dfsDepth(1, 0);
    // dp[v][h]: 子树 v 在目标叶深度 h 时最大保留节点数
    // 为节省空间,用滚动数组:prev[h], cur[h]
    vector<int> best(n+1, -INF), nxt;
    int answer = 0;
    // 对每个候选深度 h 从 0 到 maxDepth
    for (int h = 0; h <= maxDepth; h++) {
        // 自底向上后序遍历:我们可以用一次栈模拟,也可按节点编号逆序(因为深度越大后序肯定处理先)
        // 这里简单地按深度从大到小分层遍历
        vector<vector<int>> byDepth(maxDepth+1);
        for (int v = 1; v <= n; v++) {
            byDepth[depth[v]].push_back(v);
        }
        best.assign(n+1, -INF);
        // 从最大深度层到 0 层
        for (int d = maxDepth; d >= 0; d--) {
            for (int v : byDepth[d]) {
                if (depth[v] > h) {
                    best[v] = -INF;
                } else if (depth[v] == h) {
                    // 变为叶子
                    best[v] = 1;
                } else {
                    int sum = 0;
                    for (int u : children[v]) {
                        if (best[u] > 0) sum += best[u];
                    }
                    if (sum > 0) best[v] = sum + 1;
                    else best[v] = -INF;
                }
            }
        }
        answer = max(answer, best[1]);
    }
    // 最少移除数 = 总数 - 最大保留数
    cout << (n - answer) << "\n";
    return 0;
}
Python
import sys
sys.setrecursionlimit(10000)
n = int(sys.stdin.readline())
adj = [[] for _ in range(n+1)]
for _ in range(n-1):
    u, v = map(int, sys.stdin.readline().split())
    adj[u].append(v)
    adj[v].append(u)
depth = [0]*(n+1)
children = [[] for _ in range(n+1)]
max_depth = 0
def dfs(u, p):
    global max_depth
    for v in adj[u]:
        if v == p: continue
        depth[v] = depth[u] + 1
        max_depth = max(max_depth, depth[v])
        children[u].append(v)
        dfs(v, u)
dfs(1, 0)
# dp[v][h] 用滚动数组 best[v] 存储当前 h 的值
answer = 0
for h in range(max_depth+1):
    # 按深度分层
    by_depth = [[] for _ in range(max_depth+1)]
    for v in range(1, n+1):
        by_depth[depth[v]].append(v)
    best = [-10**9]*(n+1)
    for d in range(max_depth, -1, -1):
        for v in by_depth[d]:
            if depth[v] > h:
                best[v] = -10**9
            elif depth[v] == h:
                best[v] = 1
            else:
                s = sum(best[u] for u in children[v] if best[u] > 0)
                best[v] = s + 1 if s > 0 else -10**9
    answer = max(answer, best[1])
print(n - answer)
Java
import java.io.*;
import java.util.*;
public class Main {
    static int n;
    static List<List<Integer>> adj;
    static int[] depth;
    static List<List<Integer>> children;
    static int maxDepth = 0;
    static void dfs(int u, int p) {
        for (int v : adj.get(u)) {
            if (v == p) continue;
            depth[v] = depth[u] + 1;
            maxDepth = Math.max(maxDepth, depth[v]);
            children.get(u).add(v);
            dfs(v, u);
        }
    }
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(br.readLine());
        adj = new ArrayList<>();
        for (int i = 0; i <= n; i++) adj.add(new ArrayList<>());
        for (int i = 0; i < n-1; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken()), v = Integer.parseInt(st.nextToken());
            adj.get(u).add(v);
            adj.get(v).add(u);
        }
        depth = new int[n+1];
        children = new ArrayList<>();
        for (int i = 0; i <= n; i++) children.add(new ArrayList<>());
        dfs(1, 0);
        int answer = 0;
        for (int h = 0; h <= maxDepth; h++) {
            List<List<Integer>> byDepth = new ArrayList<>();
            for (int i = 0; i <= maxDepth; i++) byDepth.add(new ArrayList<>());
            for (int v = 1; v <= n; v++) {
                byDepth.get(depth[v]).add(v);
            }
            int[] best = new int[n+1];
            Arrays.fill(best, Integer.MIN_VALUE / 2);
            for (int d = maxDepth; d >= 0; d--) {
                for (int v : byDepth.get(d)) {
                    if (depth[v] > h) {
                        best[v] = Integer.MIN_VALUE / 2;
                    } else if (depth[v] == h) {
                        best[v] = 1;
                    } else {
                        int sum = 0;
                        for (int u : children.get(v)) {
                            if (best[u] > 0) sum += best[u];
                        }
                        best[v] = (sum > 0 ? sum + 1 : Integer.MIN_VALUE / 2);
                    }
                }
            }
            answer = Math.max(answer, best[1]);
        }
        System.out.println(n - answer);
    }
}
        题目内容

在一个树形的网络拓扑中,有 n 台设备,编号 1 到 n ,其中我们固定 1 为根设备,如上图:根设备下可下挂多台设备(如设备编号 2、3 ),以此类推每一台设备下都可能下挂1台或者多台设备,最后没有下挂设备的设备成为边缘设备(如设备 3、5、6、7 )。
现在我们希望对网络进行整改,将组网中的部分设备移除,使得所有的边缘设备到根设备的距离相同,请你计算下最少需要移除多少台设备。
如上图:我们只需要移除 3 号和 5 号设备,可以使得剩下的所有边缘设备( 6、7 )到根设备的距离相同。
注:整个网络是单个连通的树型组网且没有环
输入描述
用例第一行为一个整数 n(3≤n≤5000) ,代表网络设备数目。
接下来 n 行每行包含两个整数 u,v(1≤u,v≤n,u=v) ,代表设备 u 与设备 v 相连接(注意仅代表链接关系,不表明确父子关系)。
注:我们保证每个设备的编号都小于等于 n ,且不重复;n 个网络设备,必然有 n−1 条连接。
输出描述
输出最少移除多少台设备,可以使得剩下的所有边缘设备到根设备距离都相同。
样例1
输入
7
1 2
1 3
2 4
2 5
4 6
4 7
输出
2
说明
如题目实例图中:我们移除 3 号和 5 号 2 台设备,可以使得剩下的所有边缘设备(6/7)到根设备的距离相同。
样例2
输入
5
4 1
2 4
5 1
5 3
输出
0
说明
该用例中的树形图为如下,可见不需要移除任何设备就满足边缘设备(2和3)到根设备1的距离都相等。

样例3
输入
7
1 2
2 3
3 4
1 5
1 6
1 7
输出
2
说明
该用例移除设备 4 后,再移除设备 3 即可。
