#P3480. 第3题-F1值最优的决策树剪枝
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 1098
            Accepted: 169
            Difficulty: 6
            
          
          
          
                       所属公司 : 
                              华为
                                
            
                        
              时间 :2025年8月27日-国内-AI
                              
                      
          
 
- 
                        算法标签>树          
 
第3题-F1值最优的决策树剪枝
题解思路
方法思路
- 问题分析:决策树可能过拟合训练数据,需要通过剪枝提高泛化能力。对于每个节点,考虑将其转换为叶节点后的F1值,选择最优方案。
 - 算法选择:使用深度优先搜索(DFS)后序遍历决策树。对于每个节点,计算:
- 保留子树时的F1值(递归处理左右子树)
 - 剪枝为叶节点时的F1值
 
 - 关键操作:比较两种方案的F1值,选择较大的一个,实现贪心剪枝。
 - 复杂度分析:每个节点处理一次,每次处理需要遍历验证数据子集。时间复杂度为O(N*M),其中N为节点数,M为验证集大小。
 
解题代码
Python代码
import sys
def main():
    data = sys.stdin.read().split()
    it = iter(data)
    
    n = int(next(it)); m = int(next(it)); k = int(next(it))
    
    # 读取节点信息
    nodes = {}
    for i in range(1, n + 1):
        left_id = int(next(it))
        right_id = int(next(it))
        feature = int(next(it))
        threshold = int(next(it))
        label = int(next(it))
        nodes[i] = {
            'left': left_id,
            'right': right_id,
            'feature': feature,
            'threshold': threshold,
            'label': label,
            'is_leaf': left_id == 0 and right_id == 0
        }
    
    # 读取验证数据
    validation_data = []
    for _ in range(m):
        features = [float(next(it)) for _ in range(k)]
        true_label = int(next(it))
        validation_data.append((features, true_label))
    
    def evaluate_with_label(pred_label, data_subset):
        tp = fp = fn = 0
        for _, true_label in data_subset:
            if pred_label == 1 and true_label == 1:
                tp += 1
            elif pred_label == 1 and true_label == 0:
                fp += 1
            elif pred_label == 0 and true_label == 1:
                fn += 1
        precision = tp / (tp + fp) if tp + fp > 0 else 0
        recall = tp / (tp + fn) if tp + fn > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
        return tp, fp, fn, f1
    def prune(node_id, data_subset):
        node = nodes[node_id]
        # 作为叶子评估(剪枝为叶)
        tp_leaf, fp_leaf, fn_leaf, f1_leaf = evaluate_with_label(node['label'], data_subset)
        if node['is_leaf'] or not data_subset:
            return tp_leaf, fp_leaf, fn_leaf, f1_leaf
        # 分割数据
        left_data, right_data = [], []
        for features, true_label in data_subset:
            if features[node['feature'] - 1] <= node['threshold']:
                left_data.append((features, true_label))
            else:
                right_data.append((features, true_label))
        left_stat = prune(node['left'], left_data)
        right_stat = prune(node['right'], right_data)
        left_tp, left_fp, left_fn, _ = left_stat
        right_tp, right_fp, right_fn, _ = right_stat
        tp_sub = left_tp + right_tp
        fp_sub = left_fp + right_fp
        fn_sub = left_fn + right_fn
        precision_sub = tp_sub / (tp_sub + fp_sub) if tp_sub + fp_sub > 0 else 0
        recall_sub = tp_sub / (tp_sub + fn_sub) if tp_sub + fn_sub > 0 else 0
        f1_sub = 2 * precision_sub * recall_sub / (precision_sub + recall_sub) if precision_sub + recall_sub > 0 else 0
        # 选更优方案
        if f1_leaf > f1_sub:
            return tp_leaf, fp_leaf, fn_leaf, f1_leaf
        else:
            return tp_sub, fp_sub, fn_sub, f1_sub
    
    # 从根节点开始剪枝
    _, _, _, best_f1 = prune(1, validation_data)
    print("{:.6f}".format(best_f1))
if __name__ == "__main__":
    main()
Java代码
import java.io.*;
import java.util.*;
public class Main {
    static class Node {
        int l, r, f, th, label;
        boolean isLeaf() { return l==0 && r==0; }
    }
    static class Stat {
        int tp=0, fp=0, fn=0;
        double f1() {
            double P = (tp+fp)>0 ? (double)tp/(tp+fp) : 0.0;
            double R = (tp+fn)>0 ? (double)tp/(tp+fn) : 0.0;
            return (P+R)>0 ? 2.0*P*R/(P+R) : 0.0;
        }
        void add(Stat o){ tp+=o.tp; fp+=o.fp; fn+=o.fn; }
    }
    static Node[] nodes;
    static int[][] X;
    static int[] Y;
    // 将一批样本用某个pred_label作为叶子预测
    static Stat evalWithLabel(int predLabel, List<Integer> idx){
        Stat s = new Stat();
        for(int t: idx){
            if(predLabel==1 && Y[t]==1) s.tp++;
            else if(predLabel==1 && Y[t]==0) s.fp++;
            else if(predLabel==0 && Y[t]==1) s.fn++;
        }
        return s;
    }
    static Stat prune(int id, List<Integer> idx){
        Node nd = nodes[id];
        // 方案A:把当前节点剪为叶子
        Stat leaf = evalWithLabel(nd.label, idx);
        if (nd.isLeaf() || idx.isEmpty()) return leaf;
        // 数据分流
        ArrayList<Integer> L = new ArrayList<>(idx.size());
        ArrayList<Integer> Rr = new ArrayList<>(idx.size());
        for(int t: idx){
            // nd.f在非叶必>=1
            if (X[t][nd.f-1] <= nd.th) L.add(t);
            else Rr.add(t);
        }
        // 方案B:保留子树;若子节点缺失(编号为0),用“当前节点label作叶子”的方式替代
        Stat keepLeft  = prune(nd.l, L);
        Stat keepRight = prune(nd.r, Rr);
        Stat keep = new Stat(); keep.add(keepLeft); keep.add(keepRight);
        return (leaf.f1() > keep.f1()) ? leaf : keep;
    }
    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        Integer N = fs.nextInt(); if (N==null) return;
        int M = fs.nextInt(), K = fs.nextInt();
        nodes = new Node[N+1];
        for(int i=1;i<=N;i++){
            Node nd = new Node();
            nd.l = fs.nextInt(); nd.r = fs.nextInt();
            nd.f = fs.nextInt(); nd.th = fs.nextInt(); nd.label = fs.nextInt();
            nodes[i] = nd;
        }
        X = new int[M][K]; Y = new int[M];
        for(int i=0;i<M;i++){
            for(int j=0;j<K;j++) X[i][j] = fs.nextInt();
            Y[i] = fs.nextInt();
        }
        ArrayList<Integer> all = new ArrayList<>(M);
        for(int i=0;i<M;i++) all.add(i);
        double best = prune(1, all).f1();
        System.out.printf(Locale.US, "%.6f%n", best);
    }
    // 轻量输入
    static class FastScanner {
        private final InputStream in;
        private final byte[] buffer = new byte[1<<16];
        private int ptr=0, len=0;
        FastScanner(InputStream is){ in=is; }
        private int read() throws IOException {
            if (ptr>=len){ len=in.read(buffer); ptr=0; if(len<=0) return -1; }
            return buffer[ptr++];
        }
        Integer nextInt() throws IOException {
            int c, sgn=1, x=0;
            do { c=read(); if(c==-1) return null; } while(c<=32);
            if(c=='-'){ sgn=-1; c=read(); }
            for(; c>32; c=read()) x = x*10 + (c-'0');
            return x*sgn;
        }
    }
}
C++代码
#include <bits/stdc++.h>
using namespace std;
struct Node {
    int l, r, f, th, label;
    bool is_leaf() const { return l == 0 && r == 0; }
};
struct Stat {
    int tp=0, fp=0, fn=0;
    double f1() const {
        double P = (tp+fp)? (double)tp/(tp+fp) : 0.0;
        double R = (tp+fn)? (double)tp/(tp+fn) : 0.0;
        return (P+R>0)? 2.0*P*R/(P+R) : 0.0;
    }
};
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int N,M,K; 
    if(!(cin>>N>>M>>K)) return 0;
    vector<Node> nodes(N+1);
    for(int i=1;i<=N;i++){
        cin>>nodes[i].l>>nodes[i].r>>nodes[i].f>>nodes[i].th>>nodes[i].label;
    }
    vector<vector<int>> X(M, vector<int>(K));
    vector<int> Y(M);
    for(int i=0;i<M;i++){
        for(int j=0;j<K;j++) cin>>X[i][j];
        cin>>Y[i];
    }
    function<Stat(int, const vector<int>&)> prune = [&](int id, const vector<int>& idx)->Stat{
        const auto &nd = nodes[id];
        // 评估剪为叶子
        Stat leaf;
        for(int t: idx){
            int pred = nd.label;
            if(pred==1 && Y[t]==1) leaf.tp++;
            else if(pred==1 && Y[t]==0) leaf.fp++;
            else if(pred==0 && Y[t]==1) leaf.fn++;
        }
        if(nd.is_leaf() || idx.empty()) return leaf;
        // 保留子树
        vector<int> L, Rr;
        L.reserve(idx.size()); Rr.reserve(idx.size());
        for(int t: idx){
            if (X[t][nd.f-1] <= nd.th) L.push_back(t);
            else Rr.push_back(t);
        }
        Stat ls = prune(nd.l, L), rs = prune(nd.r, Rr);
        Stat keep; keep.tp = ls.tp+rs.tp; keep.fp = ls.fp+rs.fp; keep.fn = ls.fn+rs.fn;
        // 选择更优
        return (leaf.f1() > keep.f1()) ? leaf : keep;
    };
    vector<int> all(M); iota(all.begin(), all.end(), 0);
    double best = prune(1, all).f1();
    cout.setf(std::ios::fixed); cout<<setprecision(6)<<best<<"\n";
    return 0;
}
        题目内容
决策树生成算法递归地产生决策树,直到不能继续下去为止,这样产生的树往往对训练数据的分类很准确,但对未知的测试数据的分类却没有那么准确,即出现过拟合现象。
在决策树学习中将已生成的树进行简化的过程称为剪枝。具体地,剪枝从已生成的树上裁掉一些子树或叶节点,并将其根节点或父节点作为新的叶节点,从而简化分类树模型。
小A希望通过决策树的方法解决一个二分类任务。在该二分类的任务中,标签 1 是正分类,标签 0 是负分类。现在小A已经训练了一个未剪枝的二分类的决策树。他希望对该决策树进行剪枝,能够在验证集上达到最优的 F1 值。
给定一个二叉树为待剪枝的二分类决策树,每个节点有 3 个参数 fi、thi、labeli 。当节点非叶节点时,fi、thi 表示该节点决策应用的特征编号和阈值。在数据的第 fi 个特征小于等于 thi 时决策走左节点,大于 thi 时走右节点。决策树的预测通过该规则推理到叶节点时,叶节点的 labeli 为该条数据的预测结果。
请输出小A通过剪枝在验证集上可以达到的最优 F1 值。
输入描述
第一行为一个 N、M、K 。其中,N(1<=N<=100) 表示决策树的节点个数。M(1<=M<=300) 表示验证集条数。K(1<=K<=100) 表示每条验证集特征个数。
随后 N 行,第 i 行表示第 i 个节点,根节点编号为 1 ,每行包括 5 个整数 li、ri、fi、thi、labeli 。其中 li、ri 分别表示节点的左右子节点编号 (0<=liri<=100) 。若 li=0、ri=0 则表示无子节点,不存在只有一个子节点的情况。当节点非叶节点时,fi、thi 表示该节点的特征编号和阔值,否则 fi、thi 为 0 。labeli 表示当该节点作为叶节点时的分类结果( labeli 取值为 0 或 1 )。
随后 M 行为验证集特征和 label,每行 K+1 个整数,前 K 个整数为该条数据的特征,最后一个整数位该条数据的 label 。
输出描述
请输出一个浮点数,为验证集可达到的最优 F1 值,四舍五入保留小数点后 6 位。
样例1
输入
7 3 2
2 3 1 50 0
4 5 2 50 0
6 7 2 50 1
0 0 0 0 0
0 0 0 0 1
0 0 0 0 0
0 0 0 0 1
30 60 1
30 30 1
60 30 1
输出
0.800000
说明
原始决策树为

第一条数据的终止节点为 5 ,节点 predictlabel 为 1 ,预测正确。
第二条数据的终止节点为 4 ,节点 predictlabel 为 0 ,预测错误。
第三条数据的终止节点为 6 ,节点 predictlabel 为 0 ,预测错误。
Precision 为 1 , Recall 为 1/3,F1 Score 为 1/2 。
决策树可将节点 6、7 裁剪掉,裁剪后的决策树为:

第一条数据的终止节点为 5 ,节点 predictlabel 为 1 ,预测正确。
第二条数据的终止节点为 4 ,节点 predictlabel 为 0 ,预测错误。
第三条数据的终止节点为 3 ,节点 predictlabel 为 1 ,预测正确。
Precision 为 1,Recall 为 2/3,F1 Score 为 4/5=0.800000 。
样例2
输入
7 3 3
2 3 3 87 1
0 0 1 3 0
4 5 1 38 1
0 0 2 8 1
6 7 2 94 1
0 0 2 44 1
0 0 2 9 0
30 78 73 0
73 99 99 1
72 3 2 0
输出
1.000000
提示
F1 值计算公式:
F1=2∗(Precision∗Recall)/(Precision+Recall)