#P2796. 第4题-最长的路径
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 5
            Accepted: 2
            Difficulty: 8
            
          
          
          
                       所属公司 : 
                              美团
                                
            
                        
              时间 :2025年4月5日-算法岗
                              
                      
          
 
- 
                        算法标签>树          
 
第4题-最长的路径
题目分析
在一棵树中,给定两个点 x 和 y,要求计算这两点之间经过的最长简单路径的权值和。树上任意两个节点之间只有一条简单路径,因此我们可以从路径的性质出发,思考如何求解。
问题转化
通过题意可以得到以下两个关键点:
- 树的最长路径问题:树上两个节点之间的最长路径就是其路径权值的和。
 - 路径中经过两个指定点:给定两个节点 x 和 y,我们要求的是一条路径,路径上经过这两个节点,且路径的长度最大。
 
因此,我们的任务可以分为两部分:
- 计算树上任意两个节点之间的距离。
 - 计算经过两个节点的最长路径,这涉及到如何处理从 x 到 y 路径上的其他节点。
 
核心算法
- 
LCA(最低公共祖先)算法:可以使用二进制提升(Binary Lifting)算法来预处理 LCA 查询。预处理的时间复杂度为
O(n log n),查询的时间复杂度为O(log n)。 - 
DFS处理子树信息:在 DFS 中,我们维护以下信息:
- 根节点到 u 的路径。
 - 节点 u 的子树中的最大路径。
 - 节点 u 的子树中的第二大路径。
 - 节点 u 向上延伸的最大路径。
 
 - 
查询:对于每个查询,首先通过 LCA 获得两点之间的路径长度,然后分别计算路径两端的子树最大路径和其他路径的最大扩展。
 
时间复杂度分析
- 预处理阶段:DFS 遍历树的时间复杂度为 
O(n),LCA 的二进制提升预处理的时间复杂度为O(n log n)。 - 查询阶段:每个查询的 LCA 查询时间为 
O(log n),计算路径上其他节点的最长路径需要常数时间。因此每个查询的时间复杂度为O(log n)。 
整体时间复杂度为 O(n log n + m log n),其中 n 为节点数,m 为查询数。
代码实现
Python 代码
import sys
sys.setrecursionlimit(1 << 20)          # 递归深度上限
input = sys.stdin.readline
# ---------- 读入 ----------
n, m = map(int, input().split())
adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
    u, v, w = map(int, input().split())
    adj[u].append((v, w))
    adj[v].append((u, w))
LOG = (n).bit_length()                  # ⌈log2 n⌉
# ---------- 一些全局数组 ----------
depth  = [0]*(n + 1)                    # 深度
dist   = [0]*(n + 1)                    # 根到该点的距离
dp     = [[0]*LOG for _ in range(n + 1)]# 2^k 祖先
mx1    = [[0, -1] for _ in range(n + 1)]# [最长向下链, 次长向下链]
mx2    = [0]*(n + 1)                    # 向上的最长链(不在子树内)
# ---------- 1. 建立深度、距离和 1 级祖先 ----------
def dfs_parent(u: int, p: int) -> None:
    dp[u][0] = p
    for v, w in adj[u]:
        if v == p:
            continue
        depth[v]   = depth[u] + 1
        dist[v]    = dist[u]  + w
        dfs_parent(v, u)
dfs_parent(1, 0)
# ---------- 2. 倍增表 ----------
for k in range(1, LOG):
    for v in range(1, n + 1):
        dp[v][k] = dp[dp[v][k-1]][k-1]
# ---------- 3. 求每个结点向下的两条最长链 ----------
def dfs_down(u: int, p: int) -> None:
    best1, best2 = 0, -1
    for v, w in adj[u]:
        if v == p:
            continue
        dfs_down(v, u)
        cand = mx1[v][0] + w
        if cand > best1:
            best2, best1 = best1, cand
        elif cand > best2:
            best2 = cand
    mx1[u][0], mx1[u][1] = best1, best2
dfs_down(1, 0)
# ---------- 4. 求每个结点向上的最长链 ----------
def dfs_up(u: int, p: int, up_val: int) -> None:
    mx2[u] = up_val
    # 预处理本结点所有孩子的最长向下链的前两大
    top1, top2, child_top1 = -1, -1, -1
    for v, _ in adj[u]:
        if v == p:
            continue
        val = mx1[v][0]
        if val > top1:
            top2, top1, child_top1 = top1, val, v
        elif val > top2:
            top2 = val
    for v, w in adj[u]:
        if v == p:
            continue
        # 经过父方向的候选
        cand = up_val + w
        # 经过兄弟子树的候选
        if v == child_top1 and top2 != -1:
            cand = max(cand, top2 + w)
        elif v != child_top1 and top1 != -1:
            cand = max(cand, top1 + w)
        dfs_up(v, u, cand)
dfs_up(1, 0, 0)
# ---------- 5. 最近公共祖先 ----------
def lca(u: int, v: int) -> int:
    if depth[u] < depth[v]:
        u, v = v, u
    diff = depth[u] - depth[v]
    for k in range(LOG - 1, -1, -1):
        if diff >> k & 1:
            u = dp[u][k]
    if u == v:
        return u
    for k in range(LOG - 1, -1, -1):
        if dp[u][k] and dp[u][k] != dp[v][k]:
            u = dp[u][k]
            v = dp[v][k]
    return dp[u][0]
# ---------- 6. 处理询问 ----------
out_lines = []
for _ in range(m):
    x, y = map(int, input().split())
    if x == y:                          # 经过同一个点
        down1, down2 = mx1[x]
        down2 = max(down2, 0)           # 可能没有次大链
        if mx2[x] == 0:                 # 根节点或向上链为空
            ans = down1 + down2
        else:
            ans = max(mx2[x] + down1, down1 + down2)
    else:                               # 经过两点
        z   = lca(x, y)
        dis = dist[x] + dist[y] - 2*dist[z]
        ans = dis + mx1[x][0] + mx1[y][0]
    out_lines.append(str(ans))
sys.stdout.write("\n".join(out_lines))
C++
#include<bits/stdc++.h>
using namespace std;
int n, m;
const int N = 1e5 + 5;
vector<array<int, 2>> e[N];
int sz[N], mxlen2[N], len[N];
array<int, 2> mxlen1[N];
int dp[N][20];
int depth[N];
bool vise[N];
void DFS(int k) {
  //预处理出DP数组
  vise[k] = true;
  for (int i = 0; dp[k][i]; i++) {
    dp[k][i + 1] = dp[dp[k][i]][i];
  }
  for (auto [to, m] : e[k]) {
    //求解直接公共祖先;
    if (vise[to])continue;
    depth[to] = depth[k] + 1;
    dp[to][0] = k;
    DFS(to);
  }
  return;
}
int lca(int u, int v) {
  if (depth[u] < depth[v])swap(u, v);
  //弹节点
  int k = log2(depth[u] - depth[v]);
  for (int i = k; i >= 0; i--) {
    if (depth[dp[u][i]] >= depth[v])u = dp[u][i];
  }
  if (u == v)return u;
  //查询
  k = log2(depth[u]);
  for (int i = k; i >= 0; i--) {
    if (dp[u][i] == dp[v][i])continue;
    u = dp[u][i];
    v = dp[v][i];
  }
  return dp[u][0];
}
void dfs0(int u, int fa) {
  vector<int> vals;
  mxlen1[u][0] = 0;
  mxlen1[u][1] = -1;// 最大值和次大值
  for (auto [v, w] : e[u]) {
    if (v == fa)continue;
    len[v] = len[u] + w;
    dfs0(v, u);
    vals.push_back(mxlen1[v][0]+w);
  }
  int sz = vals.size();
  if(sz==0)return ;
  sort(vals.begin(), vals.end());
  mxlen1[u][0] = vals[sz-1];
  if(sz >= 2)mxlen1[u][1] = vals[sz - 2];
}
void dfs1(int u, int fa, int up_mx) {
  mxlen2[u] = up_mx;
  vector<int> vals;
  for (auto [v, w] : e[u]) {
    if (v == fa)continue;
    vals.push_back(mxlen1[v][0]);
  }
  sort(vals.begin(), vals.end());
  int sz = vals.size();
  for (auto [v, w] : e[u]) {
    if (v == fa)continue;
    int mx = up_mx + w;
    for (int i = sz - 1; i >= max(sz - 2, 0); i--) {
      if (vals[i] != mxlen1[v][0]) {
        mx = max(mx, vals[i] + w);
      }
    }
    dfs1(v, u, mx);
  }
}
signed main () {
  cin >> n >> m;
  for (int i = 1; i <= n - 1; i++) {
    int u, v, c;
    cin >> u >> v >> c;
    e[u].push_back({v, c});
    e[v].push_back({u, c});
  }
  DFS(1);
  dfs0(1, 0);
  dfs1(1, 0, 0);
  for (int i = 0; i < m; i++) {
    int x, y;
    cin >> x >> y;
    int dis = len[x] + len[y] - 2 * len[lca(x, y)];
    int ans = 0;
    if(x==y){
      if(mxlen2[x] == 0){
        ans = mxlen1[x][0] + (mxlen1[x][1] == -1 ? 0 : mxlen1[x][1]);
      }
      else{
        ans = max(mxlen2[x] + mxlen1[x][0],mxlen1[x][0] + mxlen1[x][1]);
      }
    }
    else{
      ans = dis + mxlen1[x][0] + mxlen1[y][0];
    }
    cout<<ans<<endl;
  }
  return 0;
}
Java代码
import java.util.*;
public class Main {
    static int n, m;
    static final int N = (int) 1e5 + 5;
    static List<int[]>[] e = new ArrayList[N];
    static int[] sz = new int[N], mxlen2 = new int[N], len = new int[N];
    static int[][] mxlen1 = new int[N][2];
    static int[][] dp = new int[N][20];
    static int[] depth = new int[N];
    static boolean[] vise = new boolean[N];
    // 深度优先搜索 (DFS) 计算祖先数组
    static void DFS(int k) {
        vise[k] = true;
        for (int i = 0; dp[k][i] != 0; i++) {
            dp[k][i + 1] = dp[dp[k][i]][i];
        }
        for (int[] edge : e[k]) {
            int to = edge[0], m = edge[1];
            if (vise[to]) continue;
            depth[to] = depth[k] + 1;
            dp[to][0] = k;
            DFS(to);
        }
    }
    // 计算最近公共祖先 (LCA)
    static int lca(int u, int v) {
        if (depth[u] < depth[v]) {
            int temp = u;
            u = v;
            v = temp;
        }
        int k = (int) Math.floor(Math.log(depth[u] - depth[v]) / Math.log(2));
        for (int i = k; i >= 0; i--) {
            if (depth[dp[u][i]] >= depth[v]) {
                u = dp[u][i];
            }
        }
        if (u == v) return u;
        k = (int) Math.floor(Math.log(depth[u]) / Math.log(2));
        for (int i = k; i >= 0; i--) {
            if (dp[u][i] != dp[v][i]) {
                u = dp[u][i];
                v = dp[v][i];
            }
        }
        return dp[u][0];
    }
    // 计算每个结点向下的最大长度和次大长度
    static void dfs0(int u, int fa) {
        List<Integer> vals = new ArrayList<>();
        mxlen1[u][0] = 0;
        mxlen1[u][1] = -1; // 最大值和次大值
        for (int[] edge : e[u]) {
            int v = edge[0], w = edge[1];
            if (v == fa) continue;
            len[v] = len[u] + w;
            dfs0(v, u);
            vals.add(mxlen1[v][0] + w);
        }
        int sz = vals.size();
        if (sz == 0) return;
        Collections.sort(vals);
        mxlen1[u][0] = vals.get(sz - 1);
        if (sz >= 2) mxlen1[u][1] = vals.get(sz - 2);
    }
    // 计算每个结点向上的最大长度
    static void dfs1(int u, int fa, int up_mx) {
        mxlen2[u] = up_mx;
        List<Integer> vals = new ArrayList<>();
        for (int[] edge : e[u]) {
            int v = edge[0], w = edge[1];
            if (v == fa) continue;
            vals.add(mxlen1[v][0]);
        }
        Collections.sort(vals);
        int sz = vals.size();
        for (int[] edge : e[u]) {
            int v = edge[0], w = edge[1];
            if (v == fa) continue;
            int mx = up_mx + w;
            for (int i = sz - 1; i >= Math.max(sz - 2, 0); i--) {
                if (vals.get(i) != mxlen1[v][0]) {
                    mx = Math.max(mx, vals.get(i) + w);
                }
            }
            dfs1(v, u, mx);
        }
    }
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        n = scanner.nextInt();
        m = scanner.nextInt();
        // 初始化邻接表
        for (int i = 1; i <= n; i++) {
            e[i] = new ArrayList<>();
        }
        // 输入树的边
        for (int i = 1; i <= n - 1; i++) {
            int u = scanner.nextInt();
            int v = scanner.nextInt();
            int c = scanner.nextInt();
            e[u].add(new int[]{v, c});
            e[v].add(new int[]{u, c});
        }
        DFS(1);
        dfs0(1, 0);
        dfs1(1, 0, 0);
        // 处理查询
        for (int i = 0; i < m; i++) {
            int x = scanner.nextInt();
            int y = scanner.nextInt();
            int dis = len[x] + len[y] - 2 * len[lca(x, y)];
            int ans = 0;
            if (x == y) {
                if (mxlen2[x] == 0) {
                    ans = mxlen1[x][0] + (mxlen1[x][1] == -1 ? 0 : mxlen1[x][1]);
                } else {
                    ans = Math.max(mxlen2[x] + mxlen1[x][0], mxlen1[x][0] + mxlen1[x][1]);
                }
            } else {
                ans = dis + mxlen1[x][0] + mxlen1[y][0];
            }
            System.out.println(ans);
        }
        scanner.close();
    }
}
        题目内容
游游很喜欢树,这一天他在研究树上的路径,他遇到了一个难题,现在邀请你帮助他解决该问题。
在一棵树上,两个点并不一定能确定一条链,但是可以找到一条经过这两个点最长的一条链。
你有一棵n个点的树,树上每条边都有一个权值,定义一条简单路径的长度为这条简单路径上的边权和,对于给定的两个点x,y,你需要回答在树上经过这
两个点的最长简单路径是多少。
树上的路径从节点u到节点v的简单路径定义为从节点u出发,以节点v为终点,随意在树上走,不经过重复的点和边走出来的序列。可以证明,在树上,任意两个节点间有且仅有一条简单路径。
输入描述
第一行两个数n,m(1≤n,m≤105)。
接下来n−1行,每行3个数ui,vi,di(1≤ui,vi≤n,1≤di≤10),表示树的第i条边。,
接下来m行,每行2个数x,y,表示一次询问。
输出描述
共m行,每行一个整数ans,表示你的答案。
样例1
输入
4 4
1 2 1
1 3 2
1 4 1
2 1
4 3
1 4
2 4
输出
3
3 
3 
2