#P3589. 第3题-祖先节点
-
1000ms
Tried: 23
Accepted: 6
Difficulty: 8
所属公司 :
美团
时间 :2025年9月6日-算法岗
-
算法标签>树上背包
第3题-祖先节点
解题思路
关键观察
在以 u 为根的子树中,若选择了结点 u,则它的整个子树都不能再选任何结点;若不选 u,则可以在它的各个子树中独立选择,且不同子树之间互不影响(不同子树的结点不可能互为祖先)。
于是可做树形背包 DP:
-
记 dp[u][k]:在结点 u 的子树中,选出恰好 k 个两两不成祖孙关系的点时的最大权值和(不可行记为 −∞),并规定 dp[u][0]=0。
-
转移:
-
初始
cur = [0](表示已合并 0 个孩子时的方案)。 -
依次把每个孩子 v 的数组 dp[v] 与
$$ \text{cur}[i+j] = \max(\text{cur}[i+j],\ \text{cur}[i] + dp[v][j]). $$cur做“背包卷积式”的合并:这对应“不选 u”且在各子树中独立选取。
-
合并完所有孩子后,再考虑“选 u”这一种方案:此时在 u 的子树内只能选 u 自身,得到
dp[u][1]=max(dp[u][1], au).对于 k≥2,“选 u”无法贡献,因为子树内其它点都不能选。
-
-
最终答案:f(k)=dp[1][k]。若为 −∞,输出 −1。
复杂度分析
- 设合并过程为“多子树多项式卷积(取 max+ 加法)”,总状态数为 ∑usubsize(u)≤n2。
- 时间复杂度:O(n2)(总背包合并)。
- 空间复杂度:实现一(Python 递归返回 dp):O(n) 额外空间(加上邻接表);实现二(C++/Java 存每个结点 dp):最坏 O(n2),对 n≤5000 可承受。
- 由于每个测试文件所有 n 之和 ≤5000,整体可轻松通过。
代码实现
Python
import sys
sys.setrecursionlimit(1 << 20)
def solve():
data = list(map(int, sys.stdin.buffer.read().split()))
it = iter(data)
T = next(it)
NEG = -10**30 # 负无穷的替代值,远小于最大可能答案
out = []
for _ in range(T):
n = next(it)
a = [0] * (n + 1)
for i in range(1, n + 1):
a[i] = next(it)
g = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u = next(it); v = next(it)
g[u].append(v); g[v].append(u)
# 合并两个 dp 数组,返回新数组
def merge(A, B):
la, lb = len(A), len(B)
C = [NEG] * (la + lb - 1)
for i in range(la):
if A[i] <= NEG // 2: # 不可行状态跳过
continue
ai = A[i]
for j in range(lb):
bj = B[j]
if bj <= NEG // 2:
continue
val = ai + bj
if val > C[i + j]:
C[i + j] = val
return C
# DFS 返回以 u 为根子树的 dp 数组(dp[k])
def dfs(u, p):
cur = [0] # 只选 0 个的最佳和为 0
for v in g[u]:
if v == p:
continue
dv = dfs(v, u)
cur = merge(cur, dv)
# 选 u 的方案:仅能贡献 1 个
if len(cur) < 2:
cur.append(NEG)
if a[u] > cur[1]:
cur[1] = a[u]
return cur
dp_root = dfs(1, 0)
ans = []
for k in range(1, n + 1):
if k >= len(dp_root) or dp_root[k] <= NEG // 2:
ans.append("-1")
else:
ans.append(str(dp_root[k]))
out.append(" ".join(ans))
print("\n".join(out))
if __name__ == "__main__":
solve()
C++
#include <bits/stdc++.h>
using namespace std;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
const long long NEG = -(1LL << 60); // 负无穷替代
int T;
if (!(cin >> T)) return 0;
while (T--) {
int n; cin >> n;
vector<long long> a(n + 1);
for (int i = 1; i <= n; ++i) cin >> a[i];
vector<vector<int>> g(n + 1);
for (int i = 0; i < n - 1; ++i) {
int u, v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
// 迭代 DFS:求父亲与后序顺序
vector<int> par(n + 1, 0), order;
order.reserve(n);
stack<int> st;
st.push(1); par[1] = -1;
while (!st.empty()) {
int u = st.top(); st.pop();
order.push_back(u);
for (int v : g[u]) if (v != par[u]) {
par[v] = u; st.push(v);
}
}
reverse(order.begin(), order.end()); // 后序
// 每个结点一份 dp 数组
vector<vector<long long>> dp(n + 1);
auto merge = [&](const vector<long long>& A, const vector<long long>& B) {
int la = (int)A.size(), lb = (int)B.size();
vector<long long> C(la + lb - 1, NEG);
for (int i = 0; i < la; ++i) if (A[i] > NEG / 2) {
for (int j = 0; j < lb; ++j) if (B[j] > NEG / 2) {
C[i + j] = max(C[i + j], A[i] + B[j]);
}
}
return C;
};
for (int u : order) {
vector<long long> cur(1, 0); // 只选 0 个
for (int v : g[u]) if (par[v] == u) {
cur = merge(cur, dp[v]);
}
if ((int)cur.size() < 2) cur.resize(2, NEG);
cur[1] = max(cur[1], a[u]); // 选 u
dp[u].swap(cur);
}
// 输出 f(1..n)
for (int k = 1; k <= n; ++k) {
long long val = (k < (int)dp[1].size() ? dp[1][k] : NEG);
if (k > 1) cout << ' ';
if (val <= NEG / 2) cout << -1;
else cout << val;
}
cout << "\n";
}
return 0;
}
Java
import java.io.*;
import java.util.*;
public class Main {
static final long NEG = -(1L << 60);
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
st = new StringTokenizer(br.readLine());
int T = Integer.parseInt(st.nextToken());
StringBuilder out = new StringBuilder();
while (T-- > 0) {
int n = Integer.parseInt(br.readLine().trim());
long[] a = new long[n + 1];
st = new StringTokenizer(br.readLine());
for (int i = 1; i <= n; i++) a[i] = Long.parseLong(st.nextToken());
ArrayList<Integer>[] g = new ArrayList[n + 1];
for (int i = 1; i <= n; i++) g[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
g[u].add(v); g[v].add(u);
}
// 迭代 DFS:父亲 + 后序序列
int[] par = new int[n + 1];
Arrays.fill(par, 0);
int[] order = new int[n];
int idx = 0;
Deque<Integer> stack = new ArrayDeque<>();
stack.push(1); par[1] = -1;
while (!stack.isEmpty()) {
int u = stack.pop();
order[idx++] = u;
for (int v : g[u]) if (v != par[u]) { par[v] = u; stack.push(v); }
}
// 反转得到后序
for (int i = 0, j = n - 1; i < j; i++, j--) {
int tmp = order[i]; order[i] = order[j]; order[j] = tmp;
}
// dp[u] 为长度可变的 long[],dp[u][k] 表示在 u 子树选 k 个的最优和
long[][] dp = new long[n + 1][];
for (int t = 0; t < n; t++) {
int u = order[t];
long[] cur = new long[]{0L}; // 只选 0 个
for (int v : g[u]) if (par[v] == u) {
cur = merge(cur, dp[v]);
}
if (cur.length < 2) {
long[] nx = new long[2];
Arrays.fill(nx, NEG);
nx[0] = cur[0];
cur = nx;
}
cur[1] = Math.max(cur[1], a[u]); // 选 u
dp[u] = cur;
}
// 输出 f(1..n)
for (int k = 1; k <= n; k++) {
long val = (k < dp[1].length) ? dp[1][k] : NEG;
if (k > 1) out.append(' ');
out.append(val <= NEG / 2 ? -1 : val);
}
out.append('\n');
}
System.out.print(out.toString());
}
// 将两段 dp 做背包式合并
static long[] merge(long[] A, long[] B) {
int la = A.length, lb = B.length;
long[] C = new long[la + lb - 1];
Arrays.fill(C, NEG);
for (int i = 0; i < la; i++) {
long ai = A[i];
if (ai <= NEG / 2) continue; // 不可行状态
for (int j = 0; j < lb; j++) {
long bj = B[j];
if (bj <= NEG / 2) continue;
long val = ai + bj;
if (val > C[i + j]) C[i + j] = val;
}
}
return C;
}
}
题目内容
给定一棵树,根节点为 1 ,其中第 u 个节点有点权 au ,定义 f(k) 为:
选择树上 k 个互不相同的节点。你需要保证这 k 个节点两两不成祖先-子孙关系。f(k) 为在所有可能的选择方案里最大的点权和。如果不存在任何一种合法的选择方案,则 f(k)=−1 。
计算 f(1),f(2),…,f(n) 。
【名词解释】
祖先节点:在一棵以 u 为根的树中,若点 x 在 u 到 v 的简单路径上,且 x=v 则称 x 是 v 的祖先节点。根节点没有祖先节点。
输入描述
每个测试文件均包含多组测试数据,第一行输入一个整数 T(1≦T≦1000) 代表数据组数,每组测试数据描述如下:
第一行输入一个正整数 n(1≦n≦5000),代表树中的节点数量。
第二行输入 n 个正整数 a1,a2,...,an(1≦ai≦109) 表示每个节点的点权。
接下来 n−1 行,第 i 行插人两个正整数 ui,vi(1≦ui,vi≦n) ,表示树上一条从节点 ui 到 vi 的边。输入保证是一棵合法的树。
除此之外,保证单个测试文件的 n 之和不超过 5000 。
输出描述
输出一行 n 个整数,其中第 i 个整数代表 f(i) 。
样例1
输入
2
1
114514
4
5 3 3 3
1 2
1 3
4 1
输出
114514
5 6 9 -1