#P1445. 第5题-树上染色
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 937
            Accepted: 177
            Difficulty: 8
            
          
          
          
                       所属公司 : 
                              美团
                                
            
                        
              时间 :2023年8月12日
                              
                      
          
 
- 
                        算法标签>动态规划          
 
第5题-树上染色
思路:树形DP
题意类似的题:P1431-塔子哥的树
树形DP 问题。 考虑DP的状态定义:
- dp[i][0] 表示以 i 为子树,不选择 i 这个节点进行染色,i 这棵子树可以染色的结点最大数量
 - dp[i][1] 表示以 i 为子树,对 i 这个节点进行染色,i 这棵子树可以染色的结点最大数量
 
状态转移方程为:
- 
$dp[i][0] = \sum\limits_{j\in son(i)} \max(dp[j][0], dp[j][1]) $
即对于 i 的所有儿子节点 j ,取 dp[j][0] 和 dp[j][1] 的较大值。
 - 
$dp[i][1] = \max\limits_{j\in son(i)}(dp[i][0]-\max(dp[j][0], dp[j][1])+dp[j][0]+2)$
这里需要满足 a[i]×a[j] 是一个完全平方数 首先,由于我们只可以对一个节点染色一次,所以我们选择一个 a[i]×a[j] 为完全平方数的 j ,将这个 i 和 j 同时染为红色。
与 dp[i][0] 不同的是,其他的儿子都是取 max(dp[xxx][0],dp[xxx][1]) ,而 j 是取 dp[j][0]+2 。
转移到 dp[i][1] 就是 $dp[i][1] = \sum\limits_{xxx \in son(i),xxx\neq j}\max(dp[xxx][0],dp[xxx][1])+dp[j][0]+2$
即 $dp[i][1] = dp[i][0]-\max{(dp[j][0],dp[j][1])}+dp[j][0]+2$
 
时间复杂度:O(n)
代码
C++
#include <bits/stdc++.h>
using namespace std;
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n;
    cin >> n;
    vector<int> a(n);
    for (int i = 0; i < n; ++i) cin >> a[i];
    vector<vector<int>> g(n);
    for (int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        u--; v--;
        g[u].emplace_back(v);
        g[v].emplace_back(u);
    }
    // dp[i][0] 表示以 i 为子树,不选择 i 这个节点进行染色
    // dp[i][1] 表示以 i 为子树,选择 i 这个节点进行染色
    vector<vector<int>> dp(n, vector<int>(2, 0));
    function<void(int,int)> dfs = [&](int u, int fa) {
        for (int v: g[u]) {
            if (v == fa) continue;
            dfs(v, u);
            dp[u][0] += max(dp[v][0], dp[v][1]);
        }
        for (int v: g[u]) {
            if (v == fa) continue;
            long long val = 1ll * a[v] * a[u];
            long long sq = sqrt(val + 0.5);
            if (sq * sq != val) continue;
            dp[u][1] = max(dp[u][1], (dp[u][0] - max(dp[v][0], dp[v][1])) + dp[v][0] + 2);
        }
    };
    dfs(0, -1);
    cout << max(dp[0][0], dp[0][1]) << "\n";
    return 0;
}
python
import math
def dfs(u, fa):
    global dp, g, a
    for v in g[u]:
        if v == fa:
            continue
        dfs(v, u)
        dp[u][0] += max(dp[v][0], dp[v][1])
    for v in g[u]:
        if v == fa:
            continue
        val = a[v] * a[u]
        sq = int(math.sqrt(val + 0.5))
        if sq * sq != val:
            continue
        dp[u][1] = max(dp[u][1], (dp[u][0] - max(dp[v][0], dp[v][1])) + dp[v][0] + 2)
n = int(input())
a = list(map(int, input().split()))
g = [[] for _ in range(n)]
for _ in range(n - 1):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    g[u].append(v)
    g[v].append(u)
dp = [[0, 0] for _ in range(n)]
dfs(0, -1)
print(max(dp[0][0], dp[0][1]))
Java
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();
        List<Integer> a = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            a.add(scanner.nextInt());
        }
        List<List<Integer>> g = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            g.add(new ArrayList<>());
        }
        for (int i = 1; i < n; i++) {
            int u = scanner.nextInt() - 1;
            int v = scanner.nextInt() - 1;
            g.get(u).add(v);
            g.get(v).add(u);
        }
        int[][] dp = new int[n][2];
        dfs(0, -1, g, a, dp);
        System.out.println(Math.max(dp[0][0], dp[0][1]));
    }
    static void dfs(int u, int fa, List<List<Integer>> g, List<Integer> a, int[][] dp) {
        for (int v : g.get(u)) {
            if (v == fa) continue;
            dfs(v, u, g, a, dp);
            dp[u][0] += Math.max(dp[v][0], dp[v][1]);
        }
        for (int v : g.get(u)) {
            if (v == fa) continue;
            long val = 1L * a.get(v) * a.get(u);
            long sq = (long) Math.sqrt(val + 0.5);
            if (sq * sq != val) continue;
            dp[u][1] = Math.max(dp[u][1], (dp[u][0] - Math.max(dp[v][0], dp[v][1])) + dp[v][0] + 2);
        }
    }
}
        题目内容
给定一棵树,每个节点都有一个权值以及最开始是白色。
定义操作A:
选择两个有边直接相连的节点,可以将两个节点同时染红.当且仅当 他们都是白色
但是这样的题目太过简单,所以我们定义一个更复杂的操作B:
在满足操作A的条件下 两个节点的权值的乘积也需要是x∗x的形式 , x≥1
现在允许执行操作若干次操作B。问这棵树最多能够得到红色节点?
输入描述
第一行输入一个正整数n,代表节点的数量。
第二行输入n个正整数ai,代表每个节点的权值。
接下来的n−1行,每行输入两个正整教u,v,代表节点u和节点v有一条边连接
1≤n≤105
1≤ai≤109
1≤u,v≤n
输出描述
输出一个整数表示最多可以染红的节点数量。
样例1
输入
3
3 5 7
1 2
2 3
输出
0
样例2
输入
3
5 5 5
1 2
2 3
输出
2
说明
可以染红第二个和第三个节点。或者可以染红第一个和第二个节点。这样都是染红两个节点。
而根据规则,你无法同时染红1,2,3节点。