#P2903. 第4题-小美的结点树
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 30
            Accepted: 5
            Difficulty: 9
            
          
          
          
                       所属公司 : 
                              美团
                                
            
                        
              时间 :2025年4月26日-算法岗
                              
                      
          
 
- 
                        算法标签>树          
 
第4题-小美的结点树
题目理解
给定一个树,树的节点上有权值。对于每次操作,指定两节点 u 和 v,我们需要从节点 u 到节点 v 的路径上,按路径的顺序给节点增加一定的权值。增加的权值是从 x 开始,依次递增。
思路与方法
1. 路径计算
在树上,任意两点之间有且仅有一条简单路径。因此,我们首先需要能够高效地计算出两点之间的路径。这个问题的关键是如何快速找到路径,并在路径上更新权值。
1.1 树的表示与深度优先搜索(DFS)
树可以通过邻接表来表示。为了能快速求出任意两点 u 和 v 之间的路径,我们可以先进行一次深度优先搜索(DFS),记录每个节点的父节点和深度(即从根节点到当前节点的边数)。
1.2 路径的求解
通过DFS,我们可以得到任意两个节点之间的路径。对于任意两个节点 u 和 v,其路径可以通过它们的最小公共祖先(LCA,Lowest Common Ancestor)来计算。具体来说,路径可以分为从 u 到 LCA 和从 LCA 到 v 两部分。为了快速找到LCA,可以利用二进制提升(Binary Lifting)技术,这样可以在 O(logn) 的时间内求解任意两个节点的LCA。
1.3 路径上的更新
对于每次操作,给定 u,v,x,我们可以先找到路径上的所有节点,然后根据题意依次更新这些节点的权值。在这里,我们考虑到更新路径时可能会有重复的操作,因此直接在路径上进行操作会比较慢。为了解决这个问题,可以采用区间更新技巧:利用差分数组的思想,对路径上每个节点进行增量操作。
2. 操作处理
在每次操作中,我们需要沿着路径更新节点的权值。具体来说,对于节点 u 到节点 v 之间的每一条边,我们希望在这条路径上给对应的节点增加特定的值。为了避免重复计算,可以采用树的路径更新技巧,利用差分数组进行区间的增量操作。
3. 复杂度分析
- DFS 的时间复杂度为 O(n),用来计算父节点和深度。
 - 预处理LCA的时间复杂度为O(nlogn)
 - LCA 查询 在每次操作中可以通过二进制提升方法在 O(logn) 时间内完成。
 - 对于每次操作,我们需要进行路径更新,由于路径上的节点数量至多为 O(n),因此每次操作的更新时间复杂度为 O(n)。
 
总体复杂度为 O((n+q)logn),其中 n 为节点数,q 为操作数。
代码实现
Python 代码
import sys
sys.setrecursionlimit(10**7)
input = sys.stdin.readline
MAX_N = 100005
MAX_LIFT = 20  # 2^20 ≈ 10^6
# 树的邻接表
adj = [[] for _ in range(MAX_N)]
# LCA 父表 / 深度
parent = [[-1] * MAX_LIFT for _ in range(MAX_N)]
depth = [-1] * MAX_N
# 初始权值
a = [0] * MAX_N
# 差分数组
diff0 = [0] * MAX_N   # 常数项差分
diff1 = [0] * MAX_N   # depth 系数差分
def dfs_build(u, p):
    """预处理 depth 和 parent[u][0]"""
    depth[u] = depth[p] + 1
    parent[u][0] = p
    for v in adj[u]:
        if v == p: continue
        dfs_build(v, u)
def init_lca(n):
    depth[0] = -1
    dfs_build(1, 0)
    for j in range(1, MAX_LIFT):
        for u in range(1, n+1):
            pu = parent[u][j-1]
            parent[u][j] = parent[pu][j-1] if pu != -1 else -1
def lca(u, v):
    if depth[u] < depth[v]:
        u, v = v, u
    # 把 u 抬到和 v 一样深
    d = depth[u] - depth[v]
    for j in range(MAX_LIFT):
        if d >> j & 1:
            u = parent[u][j]
    if u == v:
        return u
    # 一起跳
    for j in reversed(range(MAX_LIFT)):
        if parent[u][j] != parent[v][j]:
            u = parent[u][j]
            v = parent[v][j]
    return parent[u][0]
def apply_query(u, v, x):
    """把一次 u→v 的等差加 x,x+1,... 拆成两段差分更新"""
    w = lca(u, v)
    # 上行段 u→w (逆序,slope = -1)
    C_up = x + depth[u]
    diff0[u]    += C_up
    if parent[w][0] != 0:
        diff0[parent[w][0]] -= C_up
    diff1[u]    += -1
    if parent[w][0] != 0:
        diff1[parent[w][0]] -= -1
    # 下行段 w→v (正序,slope = +1),但要排除 w 本身
    if v != w:
        C_down = x + depth[u] - 2*depth[w]
        diff0[v]  += C_down
        diff0[w]  -= C_down
        diff1[v]  += 1
        diff1[w]  -= 1
def dfs_acc(u, p):
    """后序把差分从子节点累加上来"""
    for v in adj[u]:
        if v == p: continue
        dfs_acc(v, u)
        diff0[u] += diff0[v]
        diff1[u] += diff1[v]
def main():
    n, q = map(int, input().split())
    # 读初值
    vals = list(map(int, input().split()))
    for i, v in enumerate(vals, start=1):
        a[i] = v
    # 读边
    for _ in range(n-1):
        u, v = map(int, input().split())
        adj[u].append(v)
        adj[v].append(u)
    # LCA 预处理
    init_lca(n)
    # 处理 q 次差分标记
    for _ in range(q):
        u, v, x = map(int, input().split())
        apply_query(u, v, x)
    # 一次后序把 diff0,diff1 累加
    dfs_acc(1, 0)
    # 输出答案
    res = [0]*n
    for u in range(1, n+1):
        # 每个节点实际加 = diff0[u] + diff1[u] * depth[u]
        res[u-1] = a[u] + diff0[u] + diff1[u] * depth[u]
    print(" ".join(map(str,res)))
if __name__ == "__main__":
    main()
Java 代码
import java.io.*;
import java.util.*;
public class Main {
    static final int MAX_LIFT = 20;
    static int n, q;
    static List<Integer>[] adj;
    static int[][] parent;
    static int[] depth;
    static long[] a, diff0, diff1;
    public static void main(String[] args) throws IOException {
        BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(in.readLine());
        n = Integer.parseInt(st.nextToken());
        q = Integer.parseInt(st.nextToken());
        adj = new ArrayList[n+1];
        for (int i = 1; i <= n; i++) adj[i] = new ArrayList<>();
        parent = new int[n+1][MAX_LIFT];
        depth  = new int[n+1];
        a      = new long[n+1];
        diff0  = new long[n+1];
        diff1  = new long[n+1];
        st = new StringTokenizer(in.readLine());
        for (int i = 1; i <= n; i++) {
            a[i] = Long.parseLong(st.nextToken());
        }
        for (int i = 0; i < n-1; i++) {
            st = new StringTokenizer(in.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            adj[u].add(v);
            adj[v].add(u);
        }
        initLCA();
        for (int i = 0; i < q; i++) {
            st = new StringTokenizer(in.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            long x = Long.parseLong(st.nextToken());
            applyQuery(u, v, x);
        }
        dfsAcc(1, 0);
        StringBuilder sb = new StringBuilder();
        for (int i = 1; i <= n; i++) {
            long add = diff0[i] + diff1[i] * depth[i];
            sb.append(a[i] + add).append(i==n?'\n':' ');
        }
        System.out.print(sb);
    }
    static void dfsBuild(int u, int p) {
        parent[u][0] = p;
        depth[u] = (p == 0 ? 0 : depth[p] + 1);
        for (int v : adj[u]) {
            if (v == p) continue;
            dfsBuild(v, u);
        }
    }
    static void initLCA() {
        dfsBuild(1, 0);
        for (int j = 1; j < MAX_LIFT; j++) {
            for (int u = 1; u <= n; u++) {
                int pu = parent[u][j-1];
                parent[u][j] = (pu == 0 ? 0 : parent[pu][j-1]);
            }
        }
    }
    static int lca(int u, int v) {
        if (depth[u] < depth[v]) {
            int t = u; u = v; v = t;
        }
        int d = depth[u] - depth[v];
        for (int j = 0; j < MAX_LIFT; j++) {
            if ((d & (1<<j)) != 0) {
                u = parent[u][j];
            }
        }
        if (u == v) return u;
        for (int j = MAX_LIFT-1; j >= 0; j--) {
            if (parent[u][j] != parent[v][j]) {
                u = parent[u][j];
                v = parent[v][j];
            }
        }
        return parent[u][0];
    }
    static void applyQuery(int u, int v, long x) {
        int w = lca(u, v);
        // 上行段 u→w (slope = -1)
        long Cup = x + depth[u];
        diff0[u]   += Cup;
        if (parent[w][0] != 0) diff0[parent[w][0]] -= Cup;
        diff1[u]   += -1;
        if (parent[w][0] != 0) diff1[parent[w][0]] -= -1;
        // 下行段 w→v, 排除 w
        if (v != w) {
            long Cdown = x + depth[u] - 2L * depth[w];
            diff0[v]   += Cdown;
            diff0[w]   -= Cdown;
            diff1[v]   +=  1;
            diff1[w]   -=  1;
        }
    }
    static void dfsAcc(int u, int p) {
        for (int v : adj[u]) {
            if (v == p) continue;
            dfsAcc(v, u);
            diff0[u] += diff0[v];
            diff1[u] += diff1[v];
        }
    }
}
C++ 代码
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAX_LIFT = 20;
int n, q;
vector<vector<int>> adj;
vector<vector<int>> parent;
vector<int> depth;
vector<ll> a, diff0, diff1;
void dfs_build(int u, int p) {
    parent[u][0] = p;
    depth[u] = (p == 0 ? 0 : depth[p] + 1);
    for (int v : adj[u]) {
        if (v == p) continue;
        dfs_build(v, u);
    }
}
void init_lca() {
    dfs_build(1, 0);
    for (int j = 1; j < MAX_LIFT; ++j) {
        for (int u = 1; u <= n; ++u) {
            int pu = parent[u][j-1];
            parent[u][j] = (pu == 0 ? 0 : parent[pu][j-1]);
        }
    }
}
int lca(int u, int v) {
    if (depth[u] < depth[v]) swap(u,v);
    int d = depth[u] - depth[v];
    for (int j = 0; j < MAX_LIFT; ++j) {
        if (d & (1<<j)) u = parent[u][j];
    }
    if (u == v) return u;
    for (int j = MAX_LIFT-1; j >= 0; --j) {
        if (parent[u][j] != parent[v][j]) {
            u = parent[u][j];
            v = parent[v][j];
        }
    }
    return parent[u][0];
}
void apply_query(int u, int v, ll x) {
    int w = lca(u, v);
    // 上行段 u→w (slope = -1)
    ll Cup = x + depth[u];
    diff0[u]   += Cup;
    if (parent[w][0] != 0) diff0[parent[w][0]] -= Cup;
    diff1[u]   += -1;
    if (parent[w][0] != 0) diff1[parent[w][0]] -= -1;
    // 下行段 w→v, 排除 w
    if (v != w) {
        ll Cdown = x + depth[u] - 2LL*depth[w];
        diff0[v]   += Cdown;
        diff0[w]   -= Cdown;
        diff1[v]   +=  1;
        diff1[w]   -=  1;
    }
}
void dfs_acc(int u, int p) {
    for (int v : adj[u]) {
        if (v == p) continue;
        dfs_acc(v, u);
        diff0[u] += diff0[v];
        diff1[u] += diff1[v];
    }
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> q;
    adj.assign(n+1, {});
    parent.assign(n+1, vector<int>(MAX_LIFT, 0));
    depth.assign(n+1, 0);
    a.assign(n+1, 0);
    diff0.assign(n+1, 0);
    diff1.assign(n+1, 0);
    for (int i = 1; i <= n; ++i) {
        cin >> a[i];
    }
    for (int i = 0, u, v; i < n-1; ++i) {
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    init_lca();
    for (int i = 0, u, v; i < q; ++i) {
        ll x;
        cin >> u >> v >> x;
        apply_query(u, v, x);
    }
    dfs_acc(1, 0);
    // 输出结果
    for (int i = 1; i <= n; ++i) {
        ll add = diff0[i] + diff1[i] * (ll)depth[i];
        cout << (a[i] + add) << (i==n?'\n':' ');
    }
    return 0;
}
        题目内容
小美有一棵 n 个结点的树,树上第 i 个结点的权值为 ai 。
现在她定义树上任意两点 u,v 的距离为 dist(u,v) ,即树上两点间简单路径的边数。
现在她提出 9 次操作,每次操作给定三个整数 u,v,x ,她准备从 u 出发,把 u→v 简单路径上的结点权值,按节点在路径上出现的先后顺序,依次加上 x,x+1,x+2,...,x+dist(u,v) 。请你输出操作后所有结点的权值。
从节点 u 到节点 v 的简单路径定义为从节点 u 出发,以节点 v 为终点,随意在树上走,不经过重复的点和边走出来的序列。可以证明,在树上,任意两个节点间有且仅有一条简单路径。
输入描述
第一行两个整数 n,q(2≤n,q≤105) ,表示树的结点总数和操作次数。
第二行 n 个整数,第 i 个整数 ai(1≤ai≤106) 表示树上第 i 个结点的初始权值。
接下来 n−1 行,每行两个整数 u,v(1≤u,v≤n;u=v) 表示 u,v 之间存在一条无向边。
接下来 q 行,每行三个整数 u,v,x(1≤u,v≤n,1≤x≤106) 含义如题面所示。
输出描述
输出一行,共 n 个整数,以空格隔开,表示操作后每个结点的权值。
样例1
输入
3 9
1 1 1
1 2
1 3
1 1 1
2 2 1
3 3 1
2 3 1
3 2 1
1 3 1
3 1 1
1 2 1
2 1 1
输出
12 9 9