#P4476. 第3题-Prompt上下文信息精简:找出二叉树中的最大值子树
-
1000ms
Tried: 29
Accepted: 5
Difficulty: 7
所属公司 :
华为
时间 :2025年11月19日-AI方向
-
算法标签>动态规划
第3题-Prompt上下文信息精简:找出二叉树中的最大值子树
解题思路
题目里“子树”的定义:
-
可以在一棵树中选择任意一个结点作为“根”;
-
对于这个根的左右子树,我们可以选择保留或裁掉:
- 如果某个子树对总和的贡献 ≤ 0,就可以整个裁掉(不保留这一支);
- 只保留“贡献为正”的子树分支;
-
目标:找到和最大的“裁剪后子树”,并输出这棵裁剪后的子树(仍然用完全二叉树层序数组表示,被裁掉的分支用
null,最后去掉末尾多余的null)。
这本质上就是“在树上找最大权连通子图(允许剪枝)”,非常经典的写法是 树形 DP + 可选子树:
一、数组表示的二叉树结构
与之前相同,输入是一棵“完全二叉树”的一维数组:
- 根:下标
0 - 左子:
2 * i + 1 - 右子:
2 * i + 2 null表示该位置没有结点(空)
我们仍然先解析出:
vals[i]:结点值(对于null,值随便填,反正不用)valid[i]:该位置是否是一个真实结点(非null)
二、核心 DP:允许裁剪子树的最大和
对每个真实结点 i,定义:
dp[i]= 以i为根,在允许删掉任意“贡献 ≤ 0 的子树”的前提下, 能得到的 最大子树和
状态转移(自底向上,从 n-1 到 0):
-
如果
i不是有效结点(valid[i] == false),- 我们不以它为根建树,
dp[i]记为 0 即可(不会被真正使用)。
- 我们不以它为根建树,
-
如果
i是有效结点:-
左子下标:
l = 2*i + 1 -
右子下标:
r = 2*i + 2 -
左子贡献:
- 如果
l在范围内且valid[l] == true,则可用dp[l] - 否则为 0
- 如果
-
右子贡献同理
-
当前根的最佳和:
$$dp[i] = vals[i] + \max(0, \text{leftDp}) + \max(0, \text{rightDp}) $$ -
这里的
max(0, dp[child])正是“如果这个子树贡献为正,就保留;否则就整个裁掉”的含义。
-
计算完 dp[i] 后,我们用它去更新全局最大值:
-
维护:
bestSum:当前所有结点中最大的dp[i]bestRoot:达到bestSum的下标i
最终:
bestRoot就是那棵“最大值裁剪子树”的根- 和为
bestSum
注意:如果整棵树都是负数, 这个 DP 仍然会选出某个单个结点(值最大的那个)作为答案。
三、如何还原“被裁剪后的子树”结构
有了 bestRoot 和整棵树的 dp[],要构造输出数组:
-
从
bestRoot开始做 BFS(广度优先),队列中保存:(原数组下标 originalIndex, 新树下标 newIndex)- 新树也用完全二叉树规则:左子
2*newIndex+1,右子2*newIndex+2
-
对于队头
(oi, ni):-
确保结果数组
res的长度大于ni,不够就用null(None)补齐 -
设置
res[ni] = vals[oi] -
然后处理原树中
oi的左右孩子:-
child = 2 * oi + 1(左)或2 * oi + 2(右) -
条件:
child在数组范围内valid[child] == true(是真实结点)- 且
dp[child] > 0(说明这棵子树被保留下来)
-
满足条件则将
(child, 2*ni+1 或 2*ni+2)入队
-
-
如果
dp[child] <= 0,表示这棵子树被“剪掉”,在新树对应位置将保持为null。
-
-
BFS 完成后,
res中有若干null:- 中间的
null是被剪掉的子树位置,必须保留 - 末尾连续的
null是“完全二叉树填充”的冗余,需要按题意 从尾部删掉
- 中间的
这样就得到最终输出
四、小结
和“必须保留所有后代”的版本不同,这里:
- DP 转移多了
max(0, 子树和),允许剪掉坏分支; - 仍然是 O(n) 的一次自底向上的 DP;
- 再配合 BFS + 剪枝条件
dp[child] > 0来还原树结构。
复杂度分析
设数组长度为 n:
- 计算
dp[i](自底向上):每个下标访问一次,O(n) - BFS 还原最大子树:每个被保留结点最多被访问一次,O(n)
- 去掉末尾多余的
null:O(n)
总体:
-
时间复杂度:O(n)
-
空间复杂度:O(n)
dp数组 O(n)valid/vals/ 队列 / 结果数组 O(n)
代码实现
Python
import sys
from ast import literal_eval
from collections import deque
def max_pruned_subtree(arr):
n = len(arr)
if n == 0:
return []
# 标记每个位置是否为真实结点(非 None)
valid = [x is not None for x in arr]
# dp[i]:以 i 为根,允许裁掉贡献不为正的子树后,能得到的最大子树和
dp = [0] * n
best_sum = None # 全局最大值
best_root = -1 # 对应的根下标
# 自底向上 DP
for i in range(n - 1, -1, -1):
if not valid[i]:
dp[i] = 0
continue
left = 2 * i + 1
right = 2 * i + 2
left_dp = dp[left] if left < n and valid[left] else 0
right_dp = dp[right] if right < n and valid[right] else 0
# 允许裁掉贡献不为正的子树
cur = arr[i]
if left_dp > 0:
cur += left_dp
if right_dp > 0:
cur += right_dp
dp[i] = cur
if best_sum is None or cur > best_sum:
best_sum = cur
best_root = i
if best_root == -1:
return []
# BFS 构造被裁剪后的最大子树(完全二叉树形式)
res = []
q = deque()
# 队列元素:(原数组下标, 新树下标)
q.append((best_root, 0))
while q:
oi, ni = q.popleft()
# 保证 res 长度足够
while len(res) <= ni:
res.append(None)
res[ni] = arr[oi]
left = 2 * oi + 1
right = 2 * oi + 2
# 左子树保留条件:存在、为真实结点、dp > 0
if left < n and valid[left] and dp[left] > 0:
q.append((left, 2 * ni + 1))
# 右子树保留条件:存在、为真实结点、dp > 0
if right < n and valid[right] and dp[right] > 0:
q.append((right, 2 * ni + 2))
# 去掉末尾多余的 None
while res and res[-1] is None:
res.pop()
return res
def main():
s = sys.stdin.readline().strip()
if not s:
return
# 将输入中的 'null' 替换为 Python 的 None,便于 literal_eval 解析
s = s.replace('null', 'None')
arr = literal_eval(s) # 得到包含 int 和 None 的列表
ans = max_pruned_subtree(arr)
# 输出格式:[1,-1,null,2,...]
out = '[' + ','.join('null' if x is None else str(x) for x in ans) + ']'
print(out)
if __name__ == "__main__":
main()
Java
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
public class Main {
// 允许裁掉贡献不为正子树的最大子树
private static List<String> maxPrunedSubtree(int[] vals, boolean[] valid) {
int n = vals.length;
List<String> res = new ArrayList<>();
if (n == 0) return res;
long[] dp = new long[n]; // dp[i]:以 i 为根的最大裁剪子树和
long bestSum = Long.MIN_VALUE;
int bestRoot = -1;
// 自底向上 DP
for (int i = n - 1; i >= 0; i--) {
if (!valid[i]) {
dp[i] = 0;
continue;
}
int left = 2 * i + 1;
int right = 2 * i + 2;
long leftDp = (left < n && valid[left]) ? dp[left] : 0;
long rightDp = (right < n && valid[right]) ? dp[right] : 0;
long cur = vals[i];
if (leftDp > 0) cur += leftDp;
if (rightDp > 0) cur += rightDp;
dp[i] = cur;
if (cur > bestSum) {
bestSum = cur;
bestRoot = i;
}
}
if (bestRoot == -1) {
return res;
}
// BFS 构造裁剪后的子树
Queue<Integer> qOrig = new LinkedList<>(); // 原树下标队列
Queue<Integer> qNew = new LinkedList<>(); // 新树下标队列
qOrig.offer(bestRoot);
qNew.offer(0);
while (!qOrig.isEmpty()) {
int oi = qOrig.poll();
int ni = qNew.poll();
// 保证 res 长度足够,用 "null" 占位
while (res.size() <= ni) {
res.add("null");
}
res.set(ni, String.valueOf(vals[oi]));
int left = 2 * oi + 1;
int right = 2 * oi + 2;
// 左子树:存在且有效且 dp > 0 才保留
if (left < n && valid[left] && dp[left] > 0) {
qOrig.offer(left);
qNew.offer(2 * ni + 1);
}
// 右子树同理
if (right < n && valid[right] && dp[right] > 0) {
qOrig.offer(right);
qNew.offer(2 * ni + 2);
}
}
// 去掉末尾多余的 "null"
int last = res.size() - 1;
while (last >= 0 && "null".equals(res.get(last))) {
res.remove(last);
last--;
}
return res;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String s = br.readLine();
if (s == null || s.trim().isEmpty()) {
return;
}
s = s.trim();
// 去掉首尾的 [ ]
if (s.startsWith("[")) s = s.substring(1);
if (s.endsWith("]")) s = s.substring(0, s.length() - 1);
s = s.trim();
if (s.isEmpty()) {
System.out.println("[]");
return;
}
String[] parts = s.split(",");
int n = parts.length;
int[] vals = new int[n];
boolean[] valid = new boolean[n];
for (int i = 0; i < n; i++) {
String p = parts[i].trim();
if (p.equals("null") || p.equals("None") || p.length() == 0) {
valid[i] = false;
} else {
valid[i] = true;
vals[i] = Integer.parseInt(p);
}
}
List<String> ans = maxPrunedSubtree(vals, valid);
// 输出:[a,b,null,...]
StringBuilder sb = new StringBuilder();
sb.append("[");
for (int i = 0; i < ans.size(); i++) {
if (i > 0) sb.append(",");
sb.append(ans.get(i));
}
sb.append("]");
System.out.println(sb.toString());
}
}
C++
#include <bits/stdc++.h>
using namespace std;
// 允许裁剪非正贡献子树的最大子树
vector<string> maxPrunedSubtree(const vector<long long> &vals, const vector<bool> &valid) {
int n = (int)vals.size();
vector<string> res;
if (n == 0) return res;
vector<long long> dp(n, 0); // dp[i]:以 i 为根的最大裁剪子树和
long long bestSum = LLONG_MIN;
int bestRoot = -1;
// 自底向上 DP
for (int i = n - 1; i >= 0; --i) {
if (!valid[i]) {
dp[i] = 0;
continue;
}
int left = 2 * i + 1;
int right = 2 * i + 2;
long long leftDp = (left < n && valid[left]) ? dp[left] : 0;
long long rightDp = (right < n && valid[right]) ? dp[right] : 0;
long long cur = vals[i];
if (leftDp > 0) cur += leftDp;
if (rightDp > 0) cur += rightDp;
dp[i] = cur;
if (cur > bestSum) {
bestSum = cur;
bestRoot = i;
}
}
if (bestRoot == -1) {
return res;
}
// BFS 构造裁剪后的子树(完全二叉树形式)
queue<int> qOrig, qNew;
qOrig.push(bestRoot);
qNew.push(0);
while (!qOrig.empty()) {
int oi = qOrig.front(); qOrig.pop();
int ni = qNew.front(); qNew.pop();
// 确保 res 长度足够,用 "null" 填充
while ((int)res.size() <= ni) {
res.push_back("null");
}
res[ni] = to_string(vals[oi]);
int left = 2 * oi + 1;
int right = 2 * oi + 2;
// 左子树:存在且有效且 dp > 0 才保留
if (left < n && valid[left] && dp[left] > 0) {
qOrig.push(left);
qNew.push(2 * ni + 1);
}
// 右子树同理
if (right < n && valid[right] && dp[right] > 0) {
qOrig.push(right);
qNew.push(2 * ni + 2);
}
}
// 去掉末尾多余的 "null"
int last = (int)res.size() - 1;
while (last >= 0 && res[last] == "null") {
res.pop_back();
--last;
}
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
string s;
if (!getline(cin, s)) {
return 0;
}
// 去掉首尾空白
while (!s.empty() && isspace((unsigned char)s.back())) s.pop_back();
int pos = 0;
while (pos < (int)s.size() && isspace((unsigned char)s[pos])) pos++;
s = s.substr(pos);
if (s.empty()) {
return 0;
}
// 去掉首尾的 [ ]
if (!s.empty() && s.front() == '[') s.erase(s.begin());
if (!s.empty() && s.back() == ']') s.pop_back();
// 再清理首尾空白
while (!s.empty() && isspace((unsigned char)s.back())) s.pop_back();
while (!s.empty() && isspace((unsigned char)s.front())) s.erase(s.begin());
if (s.empty()) {
cout << "[]\n";
return 0;
}
// 按逗号切分
vector<long long> vals;
vector<bool> valid;
string token;
stringstream ss(s);
while (getline(ss, token, ',')) {
int l = 0, r = (int)token.size() - 1;
while (l <= r && isspace((unsigned char)token[l])) l++;
while (r >= l && isspace((unsigned char)token[r])) r--;
if (l > r) {
// 空串,视为 null
valid.push_back(false);
vals.push_back(0);
continue;
}
string t = token.substr(l, r - l + 1);
if (t == "null" || t == "None") {
valid.push_back(false);
vals.push_back(0);
} else {
valid.push_back(true);
long long v = stoll(t);
vals.push_back(v);
}
}
vector<string> ans = maxPrunedSubtree(vals, valid);
// 输出为 [a,b,c] 格式
cout << "[";
for (int i = 0; i < (int)ans.size(); ++i) {
if (i > 0) cout << ",";
cout << ans[i];
}
cout << "]\n";
return 0;
}
题目内容
描述: Prompt 应用面临的一个首要问题就是 Token 的长度和精确度问题,如何精简 Prompt 的 token 长度一直是大模型应用中的难题。假设 Prompt 的 token 序列是一颗二叉树,给定这样一颗二叉树,该二叉树的每个节点都有一个值,可以是正负值,也可以是 0 ,请返回该二又树的最大值子树。每颗子树的值为该子树所有节点值的和。
注意:
输入和输出数据的格式要求:(1)二叉树是完全二叉树;(2)二叉树节点数据是通过宽度优先搜索遍历获取;(3)遍历出的二叉树节点数据是以一维数组的形式存储。(4)如果一颗二叉树的左节点不存在,就以 null 补齐。
举例:如果节点 A 和 B 是兄弟节点,它们两个的父节点是 C ,A 无子节点,B 有子节点 D 和 E ,那么这棵树的数组为 tree=[C,A,B,null,null,D,E];如果 B 只有左子节点 D ,则 tree=[C,A,B,null,null,D] ; 如果 B 只有右子节点 E ,则 tree=[C,A,B,null,null,null,E];
示例1

输入:[3,2,5]
输出: [3,2,5]
示例2

输入: [−5,−1,3,null,null,4,7]
输出: [3,4,7]
示例3

输入: [5,−1,3,null,null,4,7]
输出: [5,null,3,null,null,4,7]
输入描述
输入:二叉树是一颗完全二叉树,节点数据是通过宽度优先搜索遍历的,以一维数组结构表示,null 代表为空的叶子节点。
以示例 2 为例,输入为 [−5,−1,3,null,null,4,7],−1 节点虽然是叶子节点,但在完全二叉树中需要明确它的两个子节点,这两个子节点为 null 。
输出描述
输出:输出最大值子树,也以宽度优先搜索完全二叉树后的数组结构表示的方式作为输出。
以示例 2 为例,由于 [−5,−1,3] 该子树为负值,不应当与子树 [3,4,7] 合一起,所以 [3,4,7] 是最大子树。
以示例 3 为例,最大值子树是 [5,null,3,null,null,4,7],根节点 5 的所有左子树节点用 null 补 齐。
样例1
输入
[-5,-1,3,null,null,4,7]
输出
[3,4,7]
说明
最大子树 max−sub−tree 是 [3,4,7]
样例2
输入
[-1,null,1,null,null,-1,-1,null,null,null,null,2,1,-3,-1,null,null,null,null,null,null,null,null,2,1,3,8]
输出
[1,-1,null,2,1,null,null,2,1,3,8]
说明
最大子树 max−sub−tree 是 [1,−1,null,2,1,null,null,2,1,3,8]