#P2735. 第2题-小红的结点树
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 12
            Accepted: 4
            Difficulty: 6
            
          
          
          
                       所属公司 : 
                              阿里
                                
            
                        
              时间 :2025年3月23日-阿里云(算法岗)
                              
                      
          
 
- 
                        算法标签>动态规划          
 
第2题-小红的结点树
题解思路
本题是树形动态规划(Tree DP)的应用,目标是统计树上所有连通块中,节点权值为奇数的结点数量之和。树的结构保证了任何连通块都是一个子树的“子集”。我们可以通过对每个子树进行动态规划来解决这个问题。
1. 定义状态
设 C[u] 表示以节点 u 为根的子树的连通块数量(包括节点 u 自身以及所有包含该节点的连通块)。我们可以通过子树的组合来递归地计算这个值。
- C[u] 的计算公式为:
其中 children(u) 表示节点 u 的所有子节点。这个式子表示所有子树的连通块可以通过选择或不选择每个子树中的某些节点来构成。 
此外,设 D[u] 表示以节点 u 为根的子树中,所有连通块内部的奇数权值节点数量之和。
- D[u] 的计算公式为:
其中 f(u) 是一个指示函数,表示节点 u 的权值是否为奇数,定义为:

 
2. 计算方法
我们采用深度优先搜索(DFS)来遍历整棵树,并在每次递归时计算每个节点的 C[u] 和 D[u]。对于每个节点 u,我们需要计算所有子节点的 C[v] 和 D[v],并通过前缀积和后缀积来高效计算去掉某个子节点后的乘积。
3. 最终结果
最后,我们将所有节点的 D[u] 累加起来,得到最终结果。即所有连通块中奇数权值节点数量之和: 答案=∑u=1nD[u]
并对 109+7 取模。
复杂度分析
- 每个节点会被 DFS 访问一次,计算每个节点的子树信息的时间复杂度为 O(n)。
 - 因此,整体时间复杂度为 O(n),适合 n≤105 的规模。
 
代码实现
C++代码
#include <iostream>
#include <vector>
using namespace std;
 
const long long MOD = 1000000007;
 
int n;
vector<int> a;
vector<vector<int>> adj;
vector<long long> C, D; // C[u]: 以 u 为根且必包含 u 的连通块个数
                       // D[u]: 以 u 为根且必包含 u 的连通块中奇数权值节点数量之和
 
// DFS:递归计算每个节点的 C[u] 和 D[u]
void dfs(int u, int parent) {
    // 收集 u 的所有子节点(排除父节点)
    vector<int> children;
    for (int v : adj[u]) {
        if (v == parent) continue;
        dfs(v, u);
        children.push_back(v);
    }
    int m = children.size();
    long long prod = 1;
    // 计算 prod = ∏ (1 + C[v]),表示 u 子树所有子连通块的选择方式
    for (int i = 0; i < m; i++) {
        int v = children[i];
        prod = (prod * (C[v] + 1)) % MOD;
    }
    C[u] = prod; // u 为根的连通块总数
 
    // 计算前缀积和后缀积,方便求去掉某个子节点因子的乘积
    vector<long long> prefix(m + 1, 1), suffix(m + 1, 1);
    for (int i = 0; i < m; i++) {
        int v = children[i];
        prefix[i + 1] = (prefix[i] * (C[v] + 1)) % MOD;
    }
    for (int i = m - 1; i >= 0; i--) {
        int v = children[i];
        suffix[i] = (suffix[i + 1] * (C[v] + 1)) % MOD;
    }
    // f(u)=1 当 a[u] 为奇数,否则为0
    long long f = (a[u] % 2);
    // 节点 u 自身的贡献:f(u) * (∏ (1 + C[v]))
    long long sum = (f * prod) % MOD;
    // 累加各子树贡献:对于每个子节点 v,其贡献为 D[v] 乘以除去 v 的其他子节点的选择方式
    for (int i = 0; i < m; i++) {
        int v = children[i];
        long long rest = (prefix[i] * suffix[i + 1]) % MOD;
        sum = (sum + D[v] * rest) % MOD;
    }
    D[u] = sum;
}
 
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
 
    cin >> n;
    a.resize(n + 1);
    adj.resize(n + 1);
    C.resize(n + 1, 0);
    D.resize(n + 1, 0);
    for (int i = 1; i <= n; i++){
        cin >> a[i];
    }
    for (int i = 1; i <= n - 1; i++){
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(1, 0);
    long long ans = 0;
    // 将所有节点作为连通块的最低点时统计的 D[u] 累加即为答案
    for (int i = 1; i <= n; i++){
        ans = (ans + D[i]) % MOD;
    }
    cout << ans << "\n";
    return 0;
}
Python代码
MOD = 10**9 + 7
import sys
sys.setrecursionlimit(10**6)
 
# 输入节点数
n = int(input())
# 注意:将下标从1开始,故在最前面插入一个占位值
a = [0] + list(map(int, input().split()))
# 构造邻接表
adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
    u, v = map(int, input().split())
    adj[u].append(v)
    adj[v].append(u)
 
# C[u]: 以 u 为根且必包含 u 的连通块个数
# D[u]: 以 u 为根且必包含 u 的连通块中奇数权值节点数量之和
C = [0] * (n + 1)
D = [0] * (n + 1)
 
def dfs(u, parent):
    children = []
    for v in adj[u]:
        if v == parent:
            continue
        dfs(v, u)
        children.append(v)
    m = len(children)
    prod = 1
    # 计算 ∏ (1 + C[v])
    for v in children:
        prod = (prod * (C[v] + 1)) % MOD
    C[u] = prod
    # 预处理前缀积和后缀积
    prefix = [1] * (m + 1)
    suffix = [1] * (m + 1)
    for i in range(m):
        v = children[i]
        prefix[i + 1] = (prefix[i] * (C[v] + 1)) % MOD
    for i in range(m - 1, -1, -1):
        v = children[i]
        suffix[i] = (suffix[i + 1] * (C[v] + 1)) % MOD
    # f(u) = 1 当 a[u] 为奇数,否则为 0
    f = a[u] % 2
    total = (f * prod) % MOD
    for i in range(m):
        v = children[i]
        rest = (prefix[i] * suffix[i + 1]) % MOD
        total = (total + D[v] * rest) % MOD
    D[u] = total
 
dfs(1, 0)
# 累加所有节点贡献即为答案
ans = sum(D[1:]) % MOD
print(ans)
Java代码
import java.io.*;
import java.util.*;
 
public class Main {
    static final long MOD = 1000000007;
    static int n;
    static int[] a;
    static ArrayList<Integer>[] adj;
    // C[u]: 以 u 为根且必包含 u 的连通块个数
    // D[u]: 以 u 为根且必包含 u 的连通块中奇数权值节点数量之和
    static long[] C, D;
 
    // DFS递归函数,parent表示父节点
    public static void dfs(int u, int parent) {
        ArrayList<Integer> children = new ArrayList<>();
        for (int v : adj[u]) {
            if (v == parent) continue;
            dfs(v, u);
            children.add(v);
        }
        int m = children.size();
        long prod = 1;
        // 计算 prod = ∏ (1 + C[v])
        for (int v : children) {
            prod = (prod * (C[v] + 1)) % MOD;
        }
        C[u] = prod;
 
        // 预处理前缀积和后缀积,方便去除某个子节点的因子
        long[] prefix = new long[m + 1];
        long[] suffix = new long[m + 1];
        prefix[0] = 1;
        for (int i = 0; i < m; i++) {
            int v = children.get(i);
            prefix[i + 1] = (prefix[i] * (C[v] + 1)) % MOD;
        }
        suffix[m] = 1;
        for (int i = m - 1; i >= 0; i--) {
            int v = children.get(i);
            suffix[i] = (suffix[i + 1] * (C[v] + 1)) % MOD;
        }
 
        // f(u) = 1 当 a[u] 为奇数,否则为 0
        long f = (a[u] % 2);
        long total = (f * prod) % MOD;
        for (int i = 0; i < m; i++) {
            int v = children.get(i);
            long rest = (prefix[i] * suffix[i + 1]) % MOD;
            total = (total + (D[v] * rest) % MOD) % MOD;
        }
        D[u] = total;
    }
 
    public static void main(String[] args) throws IOException {
        // 使用 BufferedReader 提高输入效率
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(br.readLine());
        a = new int[n + 1];
        C = new long[n + 1];
        D = new long[n + 1];
 
        String[] parts = br.readLine().split(" ");
        for (int i = 1; i <= n; i++){
            a[i] = Integer.parseInt(parts[i - 1]);
        }
 
        // 初始化邻接表
        adj = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++){
            adj[i] = new ArrayList<>();
        }
        for (int i = 1; i < n; i++){
            String[] edge = br.readLine().split(" ");
            int u = Integer.parseInt(edge[0]);
            int v = Integer.parseInt(edge[1]);
            adj[u].add(v);
            adj[v].add(u);
        }
 
        dfs(1, 0);
 
        long ans = 0;
        // 累加所有节点的 D[u]
        for (int i = 1; i <= n; i++){
            ans = (ans + D[i]) % MOD;
        }
        System.out.println(ans);
    }
}
        题目内容
小红拿到一棵n个结点的树,第i个点的权值为ai。 现在,你需要求解,对于全部的连通块,它们内部中结点权值为奇数的结点数量之和是多少。由于答案可能很大,请将答案对(109+7)取模后输出。
她定义对于树上的两个点,如果它们相连,则称他们位于同一个连通块里。特别地,一个单独的点也可以构成一个连通块。连通块的大小即为连通块中节点的数量。
输入描述
第一行一个整数n(1≤n≤105),表示结点个数。
第二行n个整数,第i个数为ai(1≤ai≤109),表示一次权值。
接下来n−1行,每行两个整数u,v(1≤u,v≤n,u=v),表示u,v 之间存在一条无向边。
输出描述
一个整数,表示所有连通块内部中结点权值为奇数的结点数量之和,结果对109+7取模。
样例1
输入
4
1 2 3 1
1 2
1 3
2 4
输出
14
说明
在这个样例中,树的形状如下图所示。一共可以分割得到10个不同的连通块:
- 
{3}号点构成的连通块,包含一个权值为奇数的点;
 - 
{1}号点构成的连通块,包含一个权值为奇数的点;
 - 
{2}号点构成的连通块,包含零个权值为奇数的点;
 - 
{4}号点构成的连通块,包含一个权值为奇数的点;
 - 
{1,3}号点构成的连通块,包含两个权值为奇数的点;
 - 
{1,2}号点构成的连通块,包含一个权值为奇数的点;
 - 
{2,4}号点构成的连通块,包含一个权值为奇数的点;
 - 
{3,1,2}号点构成的连通块,包含两个权值为奇数的点;
 - 
{1,2,4}号点构成的连通块,包含两个权值为奇数的点;
 - 
{3,1,2,4}号点构成的连通块,包含三个权值为奇数的点;
综上,答案为1+1+0+1+2+1+1+2+2+3=14。