#P2794. 第3题-最长的路径
          
                        
                                    
                      
        
              - 
          
          
                      2000ms
            
          
                      Tried: 49
            Accepted: 6
            Difficulty: 8
            
          
          
          
                       所属公司 : 
                              美团
                                
            
                        
              时间 :2025年4月5日-开发岗
                              
                      
          
 
- 
                        算法标签>树          
 
第3题-最长的路径
题目分析
在一棵树中,给定两个点 x 和 y,要求计算这两点之间经过的最长简单路径的权值和。树上任意两个节点之间只有一条简单路径,因此我们可以从路径的性质出发,思考如何求解。
问题转化
通过题意可以得到以下两个关键点:
- 树的最长路径问题:树上两个节点之间的最长路径就是其路径权值的和。
 - 路径中经过两个指定点:给定两个节点 x 和 y,我们要求的是一条路径,路径上经过这两个节点,且路径的长度最大。
 
因此,我们的任务可以分为两部分:
- 计算树上任意两个节点之间的距离。
 - 计算经过两个节点的最长路径,这涉及到如何处理从 x 到 y 路径上的其他节点。
 
核心算法
- 
LCA(最低公共祖先)算法:可以使用二进制提升(Binary Lifting)算法来预处理 LCA 查询。预处理的时间复杂度为
O(n log n),查询的时间复杂度为O(log n)。 - 
DFS处理子树信息:在 DFS 中,我们维护以下信息:
- 根节点到 u 的路径。
 - 节点 u 的子树中的最大路径。
 - 节点 u 的子树中的第二大路径。
 - 节点 u 向上延伸的最大路径。
 
 - 
查询:对于每个查询,首先通过 LCA 获得两点之间的路径长度,然后分别计算路径两端的子树最大路径和其他路径的最大扩展。
 
时间复杂度分析
- 预处理阶段:DFS 遍历树的时间复杂度为 
O(n),LCA 的二进制提升预处理的时间复杂度为O(n log n)。 - 查询阶段:每个查询的 LCA 查询时间为 
O(log n),计算路径上其他节点的最长路径需要常数时间。因此每个查询的时间复杂度为O(log n)。 
整体时间复杂度为 O(n log n + m log n),其中 n 为节点数,m 为查询数。
代码实现
Python 代码
import sys
sys.setrecursionlimit(1 << 20)          # 递归深度上限
input = sys.stdin.readline
# ---------- 读入 ----------
n, m = map(int, input().split())
adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
    u, v, w = map(int, input().split())
    adj[u].append((v, w))
    adj[v].append((u, w))
LOG = (n).bit_length()                  # ⌈log2 n⌉
# ---------- 一些全局数组 ----------
depth  = [0]*(n + 1)                    # 深度
dist   = [0]*(n + 1)                    # 根到该点的距离
dp     = [[0]*LOG for _ in range(n + 1)]# 2^k 祖先
mx1    = [[0, -1] for _ in range(n + 1)]# [最长向下链, 次长向下链]
mx2    = [0]*(n + 1)                    # 向上的最长链(不在子树内)
# ---------- 1. 建立深度、距离和 1 级祖先 ----------
def dfs_parent(u: int, p: int) -> None:
    dp[u][0] = p
    for v, w in adj[u]:
        if v == p:
            continue
        depth[v]   = depth[u] + 1
        dist[v]    = dist[u]  + w
        dfs_parent(v, u)
dfs_parent(1, 0)
# ---------- 2. 倍增表 ----------
for k in range(1, LOG):
    for v in range(1, n + 1):
        dp[v][k] = dp[dp[v][k-1]][k-1]
# ---------- 3. 求每个结点向下的两条最长链 ----------
def dfs_down(u: int, p: int) -> None:
    best1, best2 = 0, -1
    for v, w in adj[u]:
        if v == p:
            continue
        dfs_down(v, u)
        cand = mx1[v][0] + w
        if cand > best1:
            best2, best1 = best1, cand
        elif cand > best2:
            best2 = cand
    mx1[u][0], mx1[u][1] = best1, best2
dfs_down(1, 0)
# ---------- 4. 求每个结点向上的最长链 ----------
def dfs_up(u: int, p: int, up_val: int) -> None:
    mx2[u] = up_val
    # 预处理本结点所有孩子的最长向下链的前两大
    top1, top2, child_top1 = -1, -1, -1
    for v, _ in adj[u]:
        if v == p:
            continue
        val = mx1[v][0]
        if val > top1:
            top2, top1, child_top1 = top1, val, v
        elif val > top2:
            top2 = val
    for v, w in adj[u]:
        if v == p:
            continue
        # 经过父方向的候选
        cand = up_val + w
        # 经过兄弟子树的候选
        if v == child_top1 and top2 != -1:
            cand = max(cand, top2 + w)
        elif v != child_top1 and top1 != -1:
            cand = max(cand, top1 + w)
        dfs_up(v, u, cand)
dfs_up(1, 0, 0)
# ---------- 5. 最近公共祖先 ----------
def lca(u: int, v: int) -> int:
    if depth[u] < depth[v]:
        u, v = v, u
    diff = depth[u] - depth[v]
    for k in range(LOG - 1, -1, -1):
        if diff >> k & 1:
            u = dp[u][k]
    if u == v:
        return u
    for k in range(LOG - 1, -1, -1):
        if dp[u][k] and dp[u][k] != dp[v][k]:
            u = dp[u][k]
            v = dp[v][k]
    return dp[u][0]
# ---------- 6. 处理询问 ----------
out_lines = []
for _ in range(m):
    x, y = map(int, input().split())
    z = lca(x, y)
    dis = dist[x] + dist[y] - 2*dist[z]
    if x == y:                          # 经过同一个点
        down1, down2 = mx1[x]
        down2 = max(down2, 0)           # 可能没有次大链
        if mx2[x] == 0:                 # 根节点或向上链为空
            ans = down1 + down2
        else:
            ans = max(mx2[x] + down1, down1 + down2)
    elif z == x or z == y:             # 两点在一条链上
        if depth[x] > depth[y]:        # x 在 y 的下方
            x, y = y, x
        ans = dis + mx2[x] + mx1[y][0]
        for v, w in adj[x]:
            if v == dp[x][0]:
                continue
            if lca(v, y) == x:        # y 不在 v 的子树内
                ans = max(ans, dis + mx1[v][0] + w + mx1[y][0])
    else:                             # 两点不在同一条链上
        ans = dis + mx1[x][0] + mx1[y][0]
    out_lines.append(str(ans))
sys.stdout.write("\n".join(out_lines))
C++
#include<bits/stdc++.h>
using namespace std;
int n, m;
const int N = 1e5 + 5;
const int MAXLOG = 18;
vector<array<int, 2>> e[N];
int mx2[N], dist[N];
array<int, 2> mx1[N];
int dp[N][20];
int depth[N];
void dfs(int u, int parent) {
  dp[u][0] = parent;
  depth[u] = depth[parent] + 1;
  for (auto [v,w] : e[u]) {
      if (v == parent) continue; // 不往回走
      dfs(v, u);
  }
}
void build_lca(int n) {
  for (int k = 1; k <= MAXLOG; k++) {
      for (int u = 1; u <= n; u++) {
          int ancestor = dp[u][k - 1];      
          dp[u][k] = dp[ancestor][k - 1];   
      }
  }
}
int lca(int u, int v) {
  if (depth[u] < depth[v]) swap(u, v); 
  int diff = depth[u] - depth[v];
  for (int k = 0; k <= MAXLOG; k++) {
      if (diff & (1 << k)) {
          u = dp[u][k];
      }
  }
  if (u == v) return u;
  for (int k = MAXLOG; k >= 0; k--) {
      if (dp[u][k] != dp[v][k]) {
          u = dp[u][k];
          v = dp[v][k];
      }
  }
  return dp[u][0];
}
void dfs0(int u, int fa) {
  vector<int> vals;
  mx1[u][0] = 0;
  mx1[u][1] = -1;// 最大值和次大值
  for (auto [v, w] : e[u]) {
    if (v == fa)continue;
    dist[v] = dist[u] + w;
    dfs0(v, u);
    vals.push_back(mx1[v][0] + w);
  }
  int sz = vals.size();
  if (sz == 0)return ;
  sort(vals.begin(), vals.end());
  mx1[u][0] = vals[sz - 1];
  if (sz >= 2)mx1[u][1] = vals[sz - 2];
}
void dfs1(int u, int fa, int up_mx) {
  mx2[u] = up_mx;
  vector<int> vals;
  for (auto [v, w] : e[u]) {
    if (v == fa)continue;
    vals.push_back(mx1[v][0]);
  }
  sort(vals.begin(), vals.end());
  int sz = vals.size();
  for (auto [v, w] : e[u]) {
    if (v == fa)continue;
    int mx = up_mx + w;
    for (int i = sz - 1; i >= max(sz - 2, 0); i--) {
      if (vals[i] != mx1[v][0]) {
        mx = max(mx, vals[i] + w);
      }
    }
    dfs1(v, u, mx);
  }
}
signed main () {
  cin >> n >> m;
  for (int i = 1; i <= n - 1; i++) {
    int u, v, c;
    cin >> u >> v >> c;
    e[u].push_back({v, c});
    e[v].push_back({u, c});
  }
  dfs(1, 0);
  dfs0(1, 0);
  dfs1(1, 0, 0);
  build_lca(n);
  for (int i = 0; i < m; i++) {
    int x, y;
    cin >> x >> y;
    int fa = lca(x, y);
    int dis = dist[x] + dist[y] - 2 * dist[fa];
    int ans = 0;
    if (x == y) {
      if (mx2[x] == 0) {
        ans = mx1[x][0] + (mx1[x][1] == -1 ? 0 : mx1[x][1]);
      }
      else {
        ans = max(mx2[x] + mx1[x][0], mx1[x][0] + mx1[x][1]);
      }
    }
    else if (fa == x || fa == y) {
      if (depth[x] > depth[y])swap(x, y);
      ans = dis + mx2[x] + mx1[y][0];
      for (auto [v, w] : e[x]) {
        if (v == dp[x][0])continue;
        if (lca(v, y) == x) {
          ans = max(ans, dis + mx1[v][0] + w + mx1[y][0]);
        }
      }
    }
    else {
      ans = dis + mx1[x][0] + mx1[y][0];
    }
    cout << ans << endl;
  }
  return 0;
}
Java代码
import java.util.*;
public class Main {
    static int n, m;
    static final int N = (int) 1e5 + 5;
    static final int MAXLOG = 18;
    static List<int[]>[] e = new ArrayList[N];
    static int[] mx2 = new int[N], dist = new int[N];
    static int[][] mx1 = new int[N][2];
    static int[][] dp = new int[N][20];
    static int[] depth = new int[N];
    // 深度优先搜索 (DFS) 计算祖先数组
    static void dfs(int u, int parent) {
        dp[u][0] = parent;
        depth[u] = depth[parent] + 1;
        for (int[] edge : e[u]) {
            int v = edge[0];
            if (v == parent) continue; // 不往回走
            dfs(v, u);
        }
    }
    static void build_lca(int n) {
        for (int k = 1; k <= MAXLOG; k++) {
            for (int u = 1; u <= n; u++) {
                int ancestor = dp[u][k - 1];      
                dp[u][k] = dp[ancestor][k - 1];   
            }
        }
    }
    // 计算最近公共祖先 (LCA)
    static int lca(int u, int v) {
        if (depth[u] < depth[v]){
            int tmp = u; u = v; v = tmp;
        }
        int diff = depth[u] - depth[v];
        for (int k = 0; k <= MAXLOG; k++) {
            if (((diff >> k) & 1) == 1) {
                u = dp[u][k];
            }
        }
        if (u == v) return u;
        for (int k = MAXLOG; k >= 0; k--) {
            if (dp[u][k] != dp[v][k]) {
                u = dp[u][k];
                v = dp[v][k];
            }
        }
        return dp[u][0];
    }
    // 计算每个结点向下的最大长度和次大长度
    static void dfs0(int u, int fa) {
        List<Integer> vals = new ArrayList<Integer>();
        mx1[u][0] = 0;
        mx1[u][1] = -1; // 最大值和次大值
        for (int[] edge : e[u]) {
            int v = edge[0], w = edge[1];
            if (v == fa) continue;
            dist[v] = dist[u] + w;
            dfs0(v, u);
            vals.add(mx1[v][0] + w);
        }
        int sz = vals.size();
        if (sz == 0) return;
        Collections.sort(vals);
        mx1[u][0] = vals.get(sz - 1);
        if (sz >= 2) mx1[u][1] = vals.get(sz - 2);
    }
    // 计算每个结点向上的最大长度
    static void dfs1(int u, int fa, int up_mx) {
        mx2[u] = up_mx;
        List<Integer> vals = new ArrayList<Integer>();
        for (int[] edge : e[u]) {
            int v = edge[0], w = edge[1];
            if (v == fa) continue;
            vals.add(mx1[v][0]);
        }
        Collections.sort(vals);
        int sz = vals.size();
        for (int[] edge : e[u]) {
            int v = edge[0], w = edge[1];
            if (v == fa) continue;
            int mx = up_mx + w;
            for (int i = sz - 1; i >= Math.max(sz - 2, 0); i--) {
                if (vals.get(i) != mx1[v][0]) {
                    mx = Math.max(mx, vals.get(i) + w);
                }
            }
            dfs1(v, u, mx);
        }
    }
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        n = scanner.nextInt();
        m = scanner.nextInt();
        // 初始化邻接表
        for (int i = 1; i <= n; i++) {
            e[i] = new ArrayList<int[]>();
        }
        // 输入树的边
        for (int i = 1; i <= n - 1; i++) {
            int u = scanner.nextInt();
            int v = scanner.nextInt();
            int c = scanner.nextInt();
            e[u].add(new int[]{v, c});
            e[v].add(new int[]{u, c});
        }
        dfs(1,0);
        dfs0(1, 0);
        dfs1(1, 0, 0);
        build_lca(n);
        // 处理查询
        for (int i = 0; i < m; i++) {
            int x = scanner.nextInt();
            int y = scanner.nextInt();
            int fa = lca(x, y);
            int dis = dist[x] + dist[y] - 2 * dist[fa];
            int ans = 0;
            if (x == y) {
                if (mx2[x] == 0) {
                    ans = mx1[x][0] + (mx1[x][1] == -1 ? 0 : mx1[x][1]);
                } else {
                    ans = Math.max(mx2[x] + mx1[x][0], mx1[x][0] + mx1[x][1]);
                }
            }
            else if (fa == x || fa == y) {
                if(depth[x] > depth[y]) {
                    int temp = x;
                    x = y;
                    y = temp;
                }
                ans = dis + mx2[x] + mx1[y][0];
                for (int[] edge : e[x]) {
                    int v = edge[0], w = edge[1];
                    if (v == dp[x][0])continue;
                    if (lca(v, y) == x) {
                        ans = Math.max(ans, dis + mx1[v][0] + w + mx1[y][0]);
                    }
                }
            }
            else {
                ans = dis + mx1[x][0] + mx1[y][0];
            }
            System.out.println(ans);
        }
        scanner.close();
    }
}
        题目内容
游游很喜欢树,这一天他在研究树上的路径,他遇到了一个难题,现在邀请你帮助他解决该问题。
在一棵树上,两个点并不一定能确定一条链,但是可以找到一条经过这两个点最长的一条链。
你有一棵n个点的树,树上每条边都有一个权值,定义一条简单路径的长度为这条简单路径上的边权和,对于给定的两个点x,y,你需要回答在树上经过这
两个点的最长简单路径是多少。
树上的路径从节点u到节点v的简单路径定义为从节点u出发,以节点v为终点,随意在树上走,不经过重复的点和边走出来的序列。可以证明,在树上,任意两个节点间有且仅有一条简单路径。
输入描述
第一行两个数n,m(1≤n,m≤105)。
接下来n−1行,每行3个数ui,vi,di(1≤ui,vi≤n,1≤di≤10),表示树的第i条边。,
接下来m行,每行2个数x,y,表示一次询问。
输出描述
共m行,每行一个整数ans,表示你的答案。
样例1
输入
4 4
1 2 1
1 3 2
1 4 1
2 1
4 3
1 4
2 4
输出
3
3 
3 
2