#P3596. 第3题-最小差值
-
1000ms
Tried: 13
Accepted: 6
Difficulty: 6
所属公司 :
阿里
时间 :2025年9月6日-阿里淘天
-
算法标签>二分动态规划
第3题-最小差值
题目理解
给定一棵有 n 个节点的树,节点 i 初值为 si。树的“稳态度”
C(s)=(u,v)∈Emax∣su−sv∣允许把至多 k 个节点的取值改为任意整数,问最小能把稳态度降到多少。
关键等价
若把某个节点的值改掉,则它与所有相邻边的差值都可以被我们“调到很小”(甚至为 0)。 因此,当我们尝试把目标稳态度限制为某个阈值 L 时:
- 对于边 e=(u,v),若 ∣su−sv∣≤L 称为好边,不需要处理;
- 若 ∣su−sv∣>L 称为坏边,则必须修改它的至少一个端点,才能让该边的差值不超过 L。
于是,问题转化为:
在“坏边”构成的边集上,是否存在一个顶点覆盖(选点集与每条坏边相邻)大小 ≤k?
树上求最小顶点覆盖可以用树形 DP 一次线性计算;再通过对 L 二分即可得到答案最小值。
解题思路
算法框架
-
预处理每条边的权值 we=∣su−sv∣,并取 W=maxwe。
-
二分答案 L∈[0,W]。对每个候选 L:
- 把 (we>L) 的边视为坏边;
- 在整棵树上做 DP 计算覆盖所有坏边的最小选点数 cover(L);
- 若 cover(L)≤k,说明该 L 可行,右边界收缩;否则左边界增大。
-
输出最小可行的 L。
树形 DP 设计
将树根设为 1(任意根皆可)。对每个点 u 定义:
dp[u][0]:在以 u 为根的子树中,不选 u 时覆盖所有坏边的最小选点数;dp[u][1]:在以 u 为根的子树中,选择 u 时覆盖所有坏边的最小选点数。
转移(设 v 是 u 的儿子,边 (u,v) 的权为 w):
-
若 w>L(坏边)
dp[u][0] += dp[v][1](不选 u 必须选 v 才能覆盖 (u,v));dp[u][1] += min(dp[v][0], dp[v][1])(选了 u,该边已被覆盖,儿子任意)。
-
若 w≤L(好边)
- 两种状态都与这条边无关:
dp[u][0] += min(dp[v][0], dp[v][1]),dp[u][1] += min(dp[v][0], dp[v][1])。
- 两种状态都与这条边无关:
-
初值:
dp[u][1]额外 +1(因为选择了 u)。
根的答案是 min(dp[1][0], dp[1][1])。拿它与 k 比较判断 L 是否可行。
复杂度分析
- 一次 DP 为 O(n);
- L 的取值最多到 maxwe(本题 si≤100,所以 W≤100),二分为 O(logW)。 总复杂度 O(nlogW),空间 O(n)。
参考实现
Python
import sys
sys.setrecursionlimit(1 << 20)
def dfs(u, p, L, adj):
dp0, dp1 = 0, 1 # dp1 初值 +1,表示选择 u
for v, w in adj[u]:
if v == p:
continue
c0, c1 = dfs(v, u, L, adj)
if w > L: # 坏边
dp0 += c1 # 不选 u 必须选 v
dp1 += min(c0, c1) # 选 u 则随意
else: # 好边
t = min(c0, c1)
dp0 += t
dp1 += t
return dp0, dp1
def check(L, adj, n, k):
dp0, dp1 = dfs(1, 0, L, adj)
return min(dp0, dp1) <= k
def main():
data = list(map(int, sys.stdin.read().strip().split()))
it = iter(data)
n = next(it); k = next(it)
s = [0] + [next(it) for _ in range(n)]
adj = [[] for _ in range(n + 1)]
maxw = 0
for _ in range(n - 1):
u = next(it); v = next(it)
w = abs(s[u] - s[v])
maxw = max(maxw, w)
adj[u].append((v, w))
adj[v].append((u, w))
# 二分 L
lo, hi = 0, maxw
while lo < hi:
mid = (lo + hi) // 2
if check(mid, adj, n, k):
hi = mid
else:
lo = mid + 1
print(lo)
if __name__ == "__main__":
main()
Java 实现
import java.io.*;
import java.util.*;
public class Main {
static class Edge {
int to, w;
Edge(int t, int w){ this.to=t; this.w=w; }
}
static List<Edge>[] g;
static int n, k;
static int[] dfs(int u, int p, int L){
int dp0 = 0; // 不选 u
int dp1 = 1; // 选 u(计入 1)
for (Edge e: g[u]) {
int v = e.to, w = e.w;
if (v == p) continue;
int[] cv = dfs(v, u, L);
if (w > L) { // 坏边
dp0 += cv[1]; // 不选 u -> 必须选 v
dp1 += Math.min(cv[0], cv[1]); // 选 u -> 随意
} else { // 好边
int t = Math.min(cv[0], cv[1]);
dp0 += t;
dp1 += t;
}
}
return new int[]{dp0, dp1};
}
static boolean check(int L){
int[] res = dfs(1, 0, L);
return Math.min(res[0], res[1]) <= k;
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
st = new StringTokenizer(br.readLine());
n = Integer.parseInt(st.nextToken());
k = Integer.parseInt(st.nextToken());
st = new StringTokenizer(br.readLine());
int[] s = new int[n+1];
for (int i=1;i<=n;i++) s[i] = Integer.parseInt(st.nextToken());
g = new ArrayList[n+1];
for (int i=1;i<=n;i++) g[i] = new ArrayList<>();
int maxw = 0;
for (int i=0;i<n-1;i++){
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
int w = Math.abs(s[u]-s[v]);
maxw = Math.max(maxw, w);
g[u].add(new Edge(v, w));
g[v].add(new Edge(u, w));
}
int lo = 0, hi = maxw;
while (lo < hi){
int mid = (lo + hi) >>> 1;
if (check(mid)) hi = mid;
else lo = mid + 1;
}
System.out.println(lo);
}
}
C++
#include <bits/stdc++.h>
using namespace std;
struct Edge{ int to, w; };
int n, k;
vector<vector<Edge>> g;
pair<int,int> dfs(int u, int p, int L){
int dp0 = 0; // 不选 u
int dp1 = 1; // 选 u
for (auto &e : g[u]){
int v = e.to, w = e.w;
if (v == p) continue;
auto cv = dfs(v, u, L);
if (w > L){
dp0 += cv.second; // 不选 u -> 必选 v
dp1 += min(cv.first, cv.second); // 选 u -> 随意
}else{
int t = min(cv.first, cv.second); // 好边
dp0 += t;
dp1 += t;
}
}
return {dp0, dp1};
}
bool check(int L){
auto res = dfs(1, 0, L);
return min(res.first, res.second) <= k;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> k;
vector<int> s(n+1);
for (int i=1;i<=n;i++) cin >> s[i];
g.assign(n+1, {});
int maxw = 0;
for (int i=0;i<n-1;i++){
int u, v; cin >> u >> v;
int w = abs(s[u] - s[v]);
maxw = max(maxw, w);
g[u].push_back({v, w});
g[v].push_back({u, w});
}
int lo = 0, hi = maxw;
while (lo < hi){
int mid = (lo + hi) / 2;
if (check(mid)) hi = mid;
else lo = mid + 1;
}
cout << lo << "\n";
return 0;
}
题目内容
有一棵 n 个节点的树,其中每个节点 i 初始有一个整数值 si 。树上有 n−1 条边。定义树的稳定度为:
C(s)=(u,v)∈Emax∣su−sv∣Levko 可以修改至多 k 个节点的取值(每个节点的新值可以是任意整数),请问修改后最小的稳定度是多少。
输入描述
第一行两个整数 n,k(1≤n≤1000,0≤k≤n−1) 。
第二行 n 个整数 s1,s2,…,sn(0≤si≤100) 。
接下来 n−1 行,每行两个整数 u,v,表示一条边。
输出描述
输出一个整数,表示最小稳定度。
样例1
输入
5 2
4 7 4 7 4
1 2
2 3
3 4
4 5
输出
0
样例2
输入
6 3
1 2 3 7 8 9
1 2
2 3
3 4
4 5
5 6
输出
1