#P3528. 第二题-阈值最优的决策树
-
1000ms
Tried: 1073
Accepted: 206
Difficulty: 3
所属公司 :
华为
时间 :2025年9月5日-模拟赛-AI
-
算法标签>排序
第二题-阈值最优的决策树
思路
-
关键:因为特征是一维且阈值只按大小划分,所以只需在相邻不同特征值之间、以及两端考虑切分点。重复特征值的样本不可被阈值拆分,需整组一起在同侧。
-
做法:
-
将样本按特征 x 升序排序。
-
预处理前缀计数 prefL(i):排序后前 i 个(含 i)中标签等于 L 的数量。
-
预处理后缀计数 sufR(i):排序后从 i 到末尾中标签等于 R 的数量。
-
只在每个“特征值分组的末尾”作为切分(记该位置为 i)计算
correct(i)=prefL(i)+sufR(i+1)
同时考虑两端:
correct(−1)=sufR(0),correct(n−1)=prefL(n−1)
-
取最大正确数,准确率为该最大值除以 M。
-
-
复杂度:排序 O(MlogM),线性扫描与前后缀 O(M)。
C++
#include <bits/stdc++.h>
using namespace std;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int M;
if (!(cin >> M)) return 0;
vector<pair<long long,int>> a(M);
for (int i = 0; i < M; ++i) {
long long x; int y;
cin >> x >> y;
a[i] = {x, y};
}
int L, R;
cin >> L >> R;
sort(a.begin(), a.end(), [](const auto& p1, const auto& p2){
if (p1.first != p2.first) return p1.first < p2.first;
return p1.second < p2.second;
});
vector<int> prefL(M, 0), sufR(M, 0);
for (int i = 0; i < M; ++i) {
prefL[i] = (i ? prefL[i-1] : 0) + (a[i].second == L);
}
for (int i = M - 1; i >= 0; --i) {
sufR[i] = (i + 1 < M ? sufR[i+1] : 0) + (a[i].second == R);
}
long long best = 0;
// 阈值小于最小特征:全部走右子树
if (M > 0) best = max<long long>(best, sufR[0]);
// 在每个“特征值分组的末尾”作为切分
for (int i = 0; i + 1 < M; ++i) {
if (a[i].first != a[i+1].first) {
long long correct = prefL[i] + sufR[i+1];
if (correct > best) best = correct;
}
}
// 阈值大于等于最大特征:全部走左子树
if (M > 0) best = max<long long>(best, prefL[M-1]);
double acc = (M == 0) ? 0.0 : (double)best / (double)M;
cout.setf(std::ios::fixed);
cout << setprecision(3) << acc << "\n";
return 0;
}
Python
import sys
def main():
data = sys.stdin.read().strip().split()
if not data:
return
it = iter(data)
M = int(next(it))
a = []
for _ in range(M):
x = int(next(it)); y = int(next(it))
a.append((x, y))
L = int(next(it)); R = int(next(it))
a.sort(key=lambda p: (p[0], p[1]))
prefL = [0]*M
sufR = [0]*M
for i in range(M):
prefL[i] = (prefL[i-1] if i > 0 else 0) + (1 if a[i][1] == L else 0)
for i in range(M-1, -1, -1):
sufR[i] = (sufR[i+1] if i+1 < M else 0) + (1 if a[i][1] == R else 0)
best = 0
if M > 0:
best = max(best, sufR[0]) # 阈值在最小特征左侧
for i in range(M-1):
if a[i][0] != a[i+1][0]:
best = max(best, prefL[i] + sufR[i+1])
best = max(best, prefL[M-1]) # 阈值在最大特征右侧
acc = 0.0 if M == 0 else best / M
print(f"{acc:.3f}")
if __name__ == "__main__":
main()
Java
import java.io.*;
import java.util.*;
// 读入、排序、前后缀统计、仅在不同特征值的分界处计算正确数
public class Main {
static class Pair {
long x;
int y;
Pair(long x, int y) { this.x = x; this.y = y; }
}
public static void main(String[] args) throws Exception {
BufferedInputStream bis = new BufferedInputStream(System.in);
FastScanner fs = new FastScanner(bis);
Integer MM = fs.nextInt();
if (MM == null) return;
int M = MM;
List<Pair> a = new ArrayList<>(M);
for (int i = 0; i < M; i++) {
long x = fs.nextLong();
int y = fs.nextInt();
a.add(new Pair(x, y));
}
int L = fs.nextInt();
int R = fs.nextInt();
a.sort((p1, p2) -> {
if (p1.x != p2.x) return Long.compare(p1.x, p2.x);
return Integer.compare(p1.y, p2.y);
});
int[] prefL = new int[M];
int[] sufR = new int[M];
for (int i = 0; i < M; i++) {
prefL[i] = (i > 0 ? prefL[i-1] : 0) + (a.get(i).y == L ? 1 : 0);
}
for (int i = M - 1; i >= 0; i--) {
sufR[i] = (i + 1 < M ? sufR[i+1] : 0) + (a.get(i).y == R ? 1 : 0);
}
long best = 0;
if (M > 0) {
best = Math.max(best, sufR[0]); // 阈值在最小特征左侧
for (int i = 0; i + 1 < M; i++) {
if (a.get(i).x != a.get(i+1).x) {
long correct = (long)prefL[i] + (long)sufR[i+1];
if (correct > best) best = correct;
}
}
best = Math.max(best, prefL[M-1]); // 阈值在最大特征右侧
}
double acc = (M == 0) ? 0.0 : (double)best / (double)M;
System.out.printf(java.util.Locale.ROOT, "%.3f%n", acc);
}
// 简易高速读入
static class FastScanner {
private final InputStream in;
private final byte[] buffer = new byte[1 << 16];
private int ptr = 0, len = 0;
FastScanner(InputStream is) { this.in = is; }
private int read() throws IOException {
if (ptr >= len) {
len = in.read(buffer);
ptr = 0;
if (len <= 0) return -1;
}
return buffer[ptr++];
}
String next() throws IOException {
StringBuilder sb = new StringBuilder();
int c;
while ((c = read()) <= ' ') {
if (c == -1) return null;
}
do {
sb.append((char)c);
c = read();
} while (c > ' ');
return sb.toString();
}
Integer nextInt() throws IOException {
String s = next();
return s == null ? null : Integer.parseInt(s);
}
Long nextLong() throws IOException {
String s = next();
return s == null ? null : Long.parseLong(s);
}
}
}
题目内容
决策树生成算法递归地产生决策树,直到不能继续下去为止,在这个过程中,最关键的是确定每个节点的阈值。一种传统方法是划分之后,需要使得数据集的熵减最大化。
而小明同学面对的问题是一个基座问题:只有一个特征的数据集的二分类问题。如果构建出二叉树,那么将形如:”一个根节点配两个儿子节点“的结构。
在这种问题下,小明希望通过一个更加简洁的策略来获得结果:枚举阈值,得到验证集上的最优的准确率。请问给定验证集,需要设定怎样的阈值使得准确率最大化。
请输出小A通过他提出的寻找阈值的策略,在验证集上可以达到的最优 准确率
输入描述
第一行一个整数M(1<=M<=105) 表示验证集条数
随后 M 行为验证集特征和 label,每行 2 个整数,第一个数为该条数据的特征,最后一个整数为该条数据的 label∈[0,1] 。
随后一行两个整数L,R 分别代表左子树和右子数的label , L=R且L,R∈[0,1]
输出描述
第一行,一个浮点数,为验证集可达到的最优准确率,四舍五入保留小数点后 3 位。
样例1
输入
5
1 0
2 0
3 1
4 0
5 0
0 1
输出
0.800
说明
设定阈值=5 , 那么所有样本进入左子树,被归类为0 , 准确率为54=0.8 , 注意保留三位小数
样例2
输入
5
1 1
2 1
3 1
4 0
5 0
1 0
输出
1.000
说明
设定阈值=3 , 那么样本1,2,3进入左子树,被归类为1 , 样本4,5,6 进入右子树,被归类为0 , 全部分类正确,准确率为1
提示
本题准确率的计算方法为:总样本个数预测正确的样本个数
Related
In following contests: