#P2796. 第4题-最长的路径
-
1000ms
Tried: 5
Accepted: 2
Difficulty: 8
所属公司 :
美团
时间 :2025年4月5日-算法岗
-
算法标签>树
第4题-最长的路径
题目分析
在一棵树中,给定两个点 x 和 y,要求计算这两点之间经过的最长简单路径的权值和。树上任意两个节点之间只有一条简单路径,因此我们可以从路径的性质出发,思考如何求解。
问题转化
通过题意可以得到以下两个关键点:
- 树的最长路径问题:树上两个节点之间的最长路径就是其路径权值的和。
- 路径中经过两个指定点:给定两个节点 x 和 y,我们要求的是一条路径,路径上经过这两个节点,且路径的长度最大。
因此,我们的任务可以分为两部分:
- 计算树上任意两个节点之间的距离。
- 计算经过两个节点的最长路径,这涉及到如何处理从 x 到 y 路径上的其他节点。
核心算法
-
LCA(最低公共祖先)算法:可以使用二进制提升(Binary Lifting)算法来预处理 LCA 查询。预处理的时间复杂度为
O(n log n),查询的时间复杂度为O(log n)。 -
DFS处理子树信息:在 DFS 中,我们维护以下信息:
- 根节点到 u 的路径。
- 节点 u 的子树中的最大路径。
- 节点 u 的子树中的第二大路径。
- 节点 u 向上延伸的最大路径。
-
查询:对于每个查询,首先通过 LCA 获得两点之间的路径长度,然后分别计算路径两端的子树最大路径和其他路径的最大扩展。
时间复杂度分析
- 预处理阶段:DFS 遍历树的时间复杂度为
O(n),LCA 的二进制提升预处理的时间复杂度为O(n log n)。 - 查询阶段:每个查询的 LCA 查询时间为
O(log n),计算路径上其他节点的最长路径需要常数时间。因此每个查询的时间复杂度为O(log n)。
整体时间复杂度为 O(n log n + m log n),其中 n 为节点数,m 为查询数。
代码实现
Python 代码
import sys
sys.setrecursionlimit(1 << 20) # 递归深度上限
input = sys.stdin.readline
# ---------- 读入 ----------
n, m = map(int, input().split())
adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u, v, w = map(int, input().split())
adj[u].append((v, w))
adj[v].append((u, w))
LOG = (n).bit_length() # ⌈log2 n⌉
# ---------- 一些全局数组 ----------
depth = [0]*(n + 1) # 深度
dist = [0]*(n + 1) # 根到该点的距离
dp = [[0]*LOG for _ in range(n + 1)]# 2^k 祖先
mx1 = [[0, -1] for _ in range(n + 1)]# [最长向下链, 次长向下链]
mx2 = [0]*(n + 1) # 向上的最长链(不在子树内)
# ---------- 1. 建立深度、距离和 1 级祖先 ----------
def dfs_parent(u: int, p: int) -> None:
dp[u][0] = p
for v, w in adj[u]:
if v == p:
continue
depth[v] = depth[u] + 1
dist[v] = dist[u] + w
dfs_parent(v, u)
dfs_parent(1, 0)
# ---------- 2. 倍增表 ----------
for k in range(1, LOG):
for v in range(1, n + 1):
dp[v][k] = dp[dp[v][k-1]][k-1]
# ---------- 3. 求每个结点向下的两条最长链 ----------
def dfs_down(u: int, p: int) -> None:
best1, best2 = 0, -1
for v, w in adj[u]:
if v == p:
continue
dfs_down(v, u)
cand = mx1[v][0] + w
if cand > best1:
best2, best1 = best1, cand
elif cand > best2:
best2 = cand
mx1[u][0], mx1[u][1] = best1, best2
dfs_down(1, 0)
# ---------- 4. 求每个结点向上的最长链 ----------
def dfs_up(u: int, p: int, up_val: int) -> None:
mx2[u] = up_val
# 预处理本结点所有孩子的最长向下链的前两大
top1, top2, child_top1 = -1, -1, -1
for v, _ in adj[u]:
if v == p:
continue
val = mx1[v][0]
if val > top1:
top2, top1, child_top1 = top1, val, v
elif val > top2:
top2 = val
for v, w in adj[u]:
if v == p:
continue
# 经过父方向的候选
cand = up_val + w
# 经过兄弟子树的候选
if v == child_top1 and top2 != -1:
cand = max(cand, top2 + w)
elif v != child_top1 and top1 != -1:
cand = max(cand, top1 + w)
dfs_up(v, u, cand)
dfs_up(1, 0, 0)
# ---------- 5. 最近公共祖先 ----------
def lca(u: int, v: int) -> int:
if depth[u] < depth[v]:
u, v = v, u
diff = depth[u] - depth[v]
for k in range(LOG - 1, -1, -1):
if diff >> k & 1:
u = dp[u][k]
if u == v:
return u
for k in range(LOG - 1, -1, -1):
if dp[u][k] and dp[u][k] != dp[v][k]:
u = dp[u][k]
v = dp[v][k]
return dp[u][0]
# ---------- 6. 处理询问 ----------
out_lines = []
for _ in range(m):
x, y = map(int, input().split())
if x == y: # 经过同一个点
down1, down2 = mx1[x]
down2 = max(down2, 0) # 可能没有次大链
if mx2[x] == 0: # 根节点或向上链为空
ans = down1 + down2
else:
ans = max(mx2[x] + down1, down1 + down2)
else: # 经过两点
z = lca(x, y)
dis = dist[x] + dist[y] - 2*dist[z]
ans = dis + mx1[x][0] + mx1[y][0]
out_lines.append(str(ans))
sys.stdout.write("\n".join(out_lines))
C++
#include<bits/stdc++.h>
using namespace std;
int n, m;
const int N = 1e5 + 5;
vector<array<int, 2>> e[N];
int sz[N], mxlen2[N], len[N];
array<int, 2> mxlen1[N];
int dp[N][20];
int depth[N];
bool vise[N];
void DFS(int k) {
//预处理出DP数组
vise[k] = true;
for (int i = 0; dp[k][i]; i++) {
dp[k][i + 1] = dp[dp[k][i]][i];
}
for (auto [to, m] : e[k]) {
//求解直接公共祖先;
if (vise[to])continue;
depth[to] = depth[k] + 1;
dp[to][0] = k;
DFS(to);
}
return;
}
int lca(int u, int v) {
if (depth[u] < depth[v])swap(u, v);
//弹节点
int k = log2(depth[u] - depth[v]);
for (int i = k; i >= 0; i--) {
if (depth[dp[u][i]] >= depth[v])u = dp[u][i];
}
if (u == v)return u;
//查询
k = log2(depth[u]);
for (int i = k; i >= 0; i--) {
if (dp[u][i] == dp[v][i])continue;
u = dp[u][i];
v = dp[v][i];
}
return dp[u][0];
}
void dfs0(int u, int fa) {
vector<int> vals;
mxlen1[u][0] = 0;
mxlen1[u][1] = -1;// 最大值和次大值
for (auto [v, w] : e[u]) {
if (v == fa)continue;
len[v] = len[u] + w;
dfs0(v, u);
vals.push_back(mxlen1[v][0]+w);
}
int sz = vals.size();
if(sz==0)return ;
sort(vals.begin(), vals.end());
mxlen1[u][0] = vals[sz-1];
if(sz >= 2)mxlen1[u][1] = vals[sz - 2];
}
void dfs1(int u, int fa, int up_mx) {
mxlen2[u] = up_mx;
vector<int> vals;
for (auto [v, w] : e[u]) {
if (v == fa)continue;
vals.push_back(mxlen1[v][0]);
}
sort(vals.begin(), vals.end());
int sz = vals.size();
for (auto [v, w] : e[u]) {
if (v == fa)continue;
int mx = up_mx + w;
for (int i = sz - 1; i >= max(sz - 2, 0); i--) {
if (vals[i] != mxlen1[v][0]) {
mx = max(mx, vals[i] + w);
}
}
dfs1(v, u, mx);
}
}
signed main () {
cin >> n >> m;
for (int i = 1; i <= n - 1; i++) {
int u, v, c;
cin >> u >> v >> c;
e[u].push_back({v, c});
e[v].push_back({u, c});
}
DFS(1);
dfs0(1, 0);
dfs1(1, 0, 0);
for (int i = 0; i < m; i++) {
int x, y;
cin >> x >> y;
int dis = len[x] + len[y] - 2 * len[lca(x, y)];
int ans = 0;
if(x==y){
if(mxlen2[x] == 0){
ans = mxlen1[x][0] + (mxlen1[x][1] == -1 ? 0 : mxlen1[x][1]);
}
else{
ans = max(mxlen2[x] + mxlen1[x][0],mxlen1[x][0] + mxlen1[x][1]);
}
}
else{
ans = dis + mxlen1[x][0] + mxlen1[y][0];
}
cout<<ans<<endl;
}
return 0;
}
Java代码
import java.util.*;
public class Main {
static int n, m;
static final int N = (int) 1e5 + 5;
static List<int[]>[] e = new ArrayList[N];
static int[] sz = new int[N], mxlen2 = new int[N], len = new int[N];
static int[][] mxlen1 = new int[N][2];
static int[][] dp = new int[N][20];
static int[] depth = new int[N];
static boolean[] vise = new boolean[N];
// 深度优先搜索 (DFS) 计算祖先数组
static void DFS(int k) {
vise[k] = true;
for (int i = 0; dp[k][i] != 0; i++) {
dp[k][i + 1] = dp[dp[k][i]][i];
}
for (int[] edge : e[k]) {
int to = edge[0], m = edge[1];
if (vise[to]) continue;
depth[to] = depth[k] + 1;
dp[to][0] = k;
DFS(to);
}
}
// 计算最近公共祖先 (LCA)
static int lca(int u, int v) {
if (depth[u] < depth[v]) {
int temp = u;
u = v;
v = temp;
}
int k = (int) Math.floor(Math.log(depth[u] - depth[v]) / Math.log(2));
for (int i = k; i >= 0; i--) {
if (depth[dp[u][i]] >= depth[v]) {
u = dp[u][i];
}
}
if (u == v) return u;
k = (int) Math.floor(Math.log(depth[u]) / Math.log(2));
for (int i = k; i >= 0; i--) {
if (dp[u][i] != dp[v][i]) {
u = dp[u][i];
v = dp[v][i];
}
}
return dp[u][0];
}
// 计算每个结点向下的最大长度和次大长度
static void dfs0(int u, int fa) {
List<Integer> vals = new ArrayList<>();
mxlen1[u][0] = 0;
mxlen1[u][1] = -1; // 最大值和次大值
for (int[] edge : e[u]) {
int v = edge[0], w = edge[1];
if (v == fa) continue;
len[v] = len[u] + w;
dfs0(v, u);
vals.add(mxlen1[v][0] + w);
}
int sz = vals.size();
if (sz == 0) return;
Collections.sort(vals);
mxlen1[u][0] = vals.get(sz - 1);
if (sz >= 2) mxlen1[u][1] = vals.get(sz - 2);
}
// 计算每个结点向上的最大长度
static void dfs1(int u, int fa, int up_mx) {
mxlen2[u] = up_mx;
List<Integer> vals = new ArrayList<>();
for (int[] edge : e[u]) {
int v = edge[0], w = edge[1];
if (v == fa) continue;
vals.add(mxlen1[v][0]);
}
Collections.sort(vals);
int sz = vals.size();
for (int[] edge : e[u]) {
int v = edge[0], w = edge[1];
if (v == fa) continue;
int mx = up_mx + w;
for (int i = sz - 1; i >= Math.max(sz - 2, 0); i--) {
if (vals.get(i) != mxlen1[v][0]) {
mx = Math.max(mx, vals.get(i) + w);
}
}
dfs1(v, u, mx);
}
}
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
n = scanner.nextInt();
m = scanner.nextInt();
// 初始化邻接表
for (int i = 1; i <= n; i++) {
e[i] = new ArrayList<>();
}
// 输入树的边
for (int i = 1; i <= n - 1; i++) {
int u = scanner.nextInt();
int v = scanner.nextInt();
int c = scanner.nextInt();
e[u].add(new int[]{v, c});
e[v].add(new int[]{u, c});
}
DFS(1);
dfs0(1, 0);
dfs1(1, 0, 0);
// 处理查询
for (int i = 0; i < m; i++) {
int x = scanner.nextInt();
int y = scanner.nextInt();
int dis = len[x] + len[y] - 2 * len[lca(x, y)];
int ans = 0;
if (x == y) {
if (mxlen2[x] == 0) {
ans = mxlen1[x][0] + (mxlen1[x][1] == -1 ? 0 : mxlen1[x][1]);
} else {
ans = Math.max(mxlen2[x] + mxlen1[x][0], mxlen1[x][0] + mxlen1[x][1]);
}
} else {
ans = dis + mxlen1[x][0] + mxlen1[y][0];
}
System.out.println(ans);
}
scanner.close();
}
}
题目内容
游游很喜欢树,这一天他在研究树上的路径,他遇到了一个难题,现在邀请你帮助他解决该问题。
在一棵树上,两个点并不一定能确定一条链,但是可以找到一条经过这两个点最长的一条链。
你有一棵n个点的树,树上每条边都有一个权值,定义一条简单路径的长度为这条简单路径上的边权和,对于给定的两个点x,y,你需要回答在树上经过这
两个点的最长简单路径是多少。
树上的路径从节点u到节点v的简单路径定义为从节点u出发,以节点v为终点,随意在树上走,不经过重复的点和边走出来的序列。可以证明,在树上,任意两个节点间有且仅有一条简单路径。
输入描述
第一行两个数n,m(1≤n,m≤105)。
接下来n−1行,每行3个数ui,vi,di(1≤ui,vi≤n,1≤di≤10),表示树的第i条边。,
接下来m行,每行2个数x,y,表示一次询问。
输出描述
共m行,每行一个整数ans,表示你的答案。
样例1
输入
4 4
1 2 1
1 3 2
1 4 1
2 1
4 3
1 4
2 4
输出
3
3
3
2