#P4538. 第2题-基于混淆矩阵,推导分类模型的核心评估指标
-
1000ms
Tried: 59
Accepted: 19
Difficulty: 5
所属公司 :
华为
时间 :2025年1月7日-AI方向
-
算法标签>机器学习算法
第2题-基于混淆矩阵,推导分类模型的核心评估指标
解题思路
本题要求在多分类场景下,基于混淆矩阵(Confusion Matrix)的四种计数(TP/FP/FN/TN)计算每个类别的 Precision、Recall、F1 Score,并按给定权重 weights 做加权平均后输出整体指标。
相关算法与核心公式(逐类 one-vs-rest 方式统计):
-
构建每个类别的 TP / FP / FN 对于类别
c(c从0到K-1,K = len(weights)):TP[c]:pred[i] == c且trueY[i] == c的样本数FP[c]:pred[i] == c且trueY[i] != c的样本数FN[c]:pred[i] != c且trueY[i] == c的样本数 (TN不影响这三个指标,可不算)
-
逐类计算指标(注意分母为 0 的情况按题意置 0)
Precision[c] = TP[c] / (TP[c] + FP[c]),若分母为 0,则为 0Recall[c] = TP[c] / (TP[c] + FN[c]),若分母为 0,则为 0F1[c] = 2 * Precision[c] * Recall[c] / (Precision[c] + Recall[c]),若分母为 0,则为 0
-
加权平均(Weighted Average)
Precision = Σ weights[c] * Precision[c]Recall = Σ weights[c] * Recall[c]F1Score = Σ weights[c] * F1[c]
实现方法:
-
先开长度为
K的数组TP/FP/FN全为 0 -
扫一遍样本
(pred[i], trueY[i]),只更新与这两个类别相关的统计:- 若
pred[i] == trueY[i]:TP[pred[i]]++ - 否则:
FP[pred[i]]++且FN[trueY[i]]++
- 若
-
再遍历
c=0..K-1计算逐类指标并累加权重得到最终结果 -
按要求保留 2 位小数输出
代码实现
Python
import sys
def compute_weighted_metrics(pred, true_y, weights):
k = len(weights)
tp = [0] * k
fp = [0] * k
fn = [0] * k
for p, t in zip(pred, true_y):
if p == t:
tp[p] += 1
else:
fp[p] += 1
fn[t] += 1
wp = wr = wf1 = 0.0
for c in range(k):
denom_p = tp[c] + fp[c]
denom_r = tp[c] + fn[c]
precision = tp[c] / denom_p if denom_p != 0 else 0.0
recall = tp[c] / denom_r if denom_r != 0 else 0.0
denom_f1 = precision + recall
f1 = (2.0 * precision * recall / denom_f1) if denom_f1 != 0 else 0.0
wp += weights[c] * precision
wr += weights[c] * recall
wf1 += weights[c] * f1
return wp, wr, wf1
def main():
data = sys.stdin.read().strip().splitlines()
pred = list(map(int, data[0].split()))
true_y = list(map(int, data[1].split()))
weights = list(map(float, data[2].split()))
p, r, f1 = compute_weighted_metrics(pred, true_y, weights)
print(f"{p:.2f} {r:.2f} {f1:.2f}")
if __name__ == "__main__":
main()
Java
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
public class Main {
private static int[] parseIntArray(String line) {
StringTokenizer st = new StringTokenizer(line);
int n = st.countTokens();
int[] a = new int[n];
for (int i = 0; i < n; i++) {
a[i] = Integer.parseInt(st.nextToken());
}
return a;
}
private static double[] parseDoubleArray(String line) {
StringTokenizer st = new StringTokenizer(line);
int n = st.countTokens();
double[] a = new double[n];
for (int i = 0; i < n; i++) {
a[i] = Double.parseDouble(st.nextToken());
}
return a;
}
// 题面功能:计算加权 Precision / Recall / F1
private static double[] computeWeightedMetrics(int[] pred, int[] trueY, double[] weights) {
int k = weights.length;
int[] tp = new int[k];
int[] fp = new int[k];
int[] fn = new int[k];
for (int i = 0; i < pred.length; i++) {
int p = pred[i];
int t = trueY[i];
if (p == t) {
tp[p]++;
} else {
fp[p]++;
fn[t]++;
}
}
double wp = 0.0, wr = 0.0, wf1 = 0.0;
for (int c = 0; c < k; c++) {
int denomP = tp[c] + fp[c];
int denomR = tp[c] + fn[c];
double precision = (denomP != 0) ? (tp[c] * 1.0 / denomP) : 0.0;
double recall = (denomR != 0) ? (tp[c] * 1.0 / denomR) : 0.0;
double denomF1 = precision + recall;
double f1 = (denomF1 != 0.0) ? (2.0 * precision * recall / denomF1) : 0.0;
wp += weights[c] * precision;
wr += weights[c] * recall;
wf1 += weights[c] * f1;
}
return new double[]{wp, wr, wf1};
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String line1 = br.readLine();
String line2 = br.readLine();
String line3 = br.readLine();
int[] pred = parseIntArray(line1);
int[] trueY = parseIntArray(line2);
double[] weights = parseDoubleArray(line3);
double[] ans = computeWeightedMetrics(pred, trueY, weights);
System.out.printf("%.2f %.2f %.2f%n", ans[0], ans[1], ans[2]);
}
}
C++
#include <bits/stdc++.h>
using namespace std;
static vector<int> parseIntList(const string& line) {
stringstream ss(line);
vector<int> a;
int x;
while (ss >> x) a.push_back(x);
return a;
}
static vector<double> parseDoubleList(const string& line) {
stringstream ss(line);
vector<double> a;
double x;
while (ss >> x) a.push_back(x);
return a;
}
// 题面功能:计算加权 Precision / Recall / F1
static tuple<double, double, double> computeWeightedMetrics(
const vector<int>& pred, const vector<int>& trueY, const vector<double>& weights
) {
int k = (int)weights.size();
vector<int> tp(k, 0), fp(k, 0), fn(k, 0);
for (int i = 0; i < (int)pred.size(); i++) {
int p = pred[i];
int t = trueY[i];
if (p == t) {
tp[p]++;
} else {
fp[p]++;
fn[t]++;
}
}
double wp = 0.0, wr = 0.0, wf1 = 0.0;
for (int c = 0; c < k; c++) {
int denomP = tp[c] + fp[c];
int denomR = tp[c] + fn[c];
double precision = (denomP != 0) ? (1.0 * tp[c] / denomP) : 0.0;
double recall = (denomR != 0) ? (1.0 * tp[c] / denomR) : 0.0;
double denomF1 = precision + recall;
double f1 = (denomF1 != 0.0) ? (2.0 * precision * recall / denomF1) : 0.0;
wp += weights[c] * precision;
wr += weights[c] * recall;
wf1 += weights[c] * f1;
}
return {wp, wr, wf1};
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
string line1, line2, line3;
getline(cin, line1);
getline(cin, line2);
getline(cin, line3);
vector<int> pred = parseIntList(line1);
vector<int> trueY = parseIntList(line2);
vector<double> weights = parseDoubleList(line3);
auto [p, r, f1] = computeWeightedMetrics(pred, trueY, weights);
cout << fixed << setprecision(2) << p << " " << r << " " << f1 << "\n";
return 0;
}
题目内容
混淆矩阵是分类模型评估中的核心工具,尤其适用于多分类问题。它以矩阵形式展示每个类别的真正例 (TP)、假正例 (FP)、真负例 (TN) 和假负例 (FN)。基于混淆矩阵,可以计算精确率 (Precision)、召回率 (Recall) 和 F1 分数 (F1 Score) 等指标,
详细说明:
1.混淆矩阵扩展:
- 在多分类问题中,混淆矩阵扩展为 N×N 矩阵(N 为类别数),每个类别对应一行和一列,分别记录 TP、FP、FN 和 TN 。
2.指标计算:
-
精确率(Precision):对于每个类别,计算 TP/(TP+FP) 。若某个类别的 TP+FP 为零,则该类别的精确率为 0 。
-
召回率(Recall):对于每个类别,计算 TP/(TP+FN) 。若某个类别的 TP+FN 为零,则该类别的召回率为 0 。
-
F1 分数:每个类别的 F1 分数为 2∗ (Precision∗Recall)/(Precision+Recall) 。若 Precision 和 Recall 均为零,则 F1 分数为零。
3.加权计算:
- 使用 weights 参数对每个类别的 Precision、Recall 和 F1 Score 进行加权平均,得到整体的评估结果。
推导过程:
1.构建混淆矩阵:
-
初始化一个 N×N 的零矩阵,N 为类别数。
-
遍历每个样本的预测结果和真实标签,更新混淆矩阵中的 TP、FP、FN 和 TN 。
2.计算每个类别的指标:
-
对于每个类别,计算 TP、FP、FN 和 TN 。
-
根据公式计算 Precision、Recal 和 F1 Score,处理分母为零的情况。
3.加权平均计算:
- 使用提供的权重,对每个类别的 Precision、Recall 和 F1 Score 进行加权平均,得到整体评估结果。
输入描述
-
输入数据分为三行,分别对应以下内容:
-
第一行是模型的预测结果 pred,表示每个样本的预测类别,用空格分隔。
-
第二行是真实标签 trueY,表示第一行中每个样本的真实类别,用空格分隔。
-
第三行是每个类别的权重 weights,用空格分隔,第一个数据表示类别 0 的权重,第二个数据表示类别 1 的权重,以此类推。
-
说明:pred 和 trueY 的长度相等;样本类别是从 0 开始的整数 weights 的长度等于样本类别个数;weights 的值为非负数且总和为 1 。
输出描述
-
输出结果为一行,包含三个评估指标,用空格分隔:
-
precision:精确率,计算方式为每个类别的 TP/(TP+FP),并根据权重计算加权平均。
-
recall:召回率,计算方式为每个类别的 TP/(TP+FN),并根据权重计算加权平均。
-
f1Score:F1 分数,每个类别的 F1 分数为精确率和召回率的调和平均,再根据权重计算加权平均。
-
输出结果保留 2 位小数,不足 2 位时补零。
样例1
输入
0 0 0 1 1 2 2 2 2
0 1 2 0 1 2 2 2 2
0.25 0.25 0.5
输出
0.71 0.65 0.67
说明
输入解释:
当前样本的预测类别 pred 为 0 0 0 1 1 2 2 2 2
上述样本的真实类别 trueY 为 0 1 2 0 1 2 2 2 2
类别 0、1、2 的权重 weights 分别为 0.25 0.25 0.5
输出解释:
precision:0.71
recall:0.65
f1Score:0.67
样例2
输入
0 0 1 1
0 0 0 1
0.5 0.5
输出
0.75 0.83 0.73
说明
输入解释:
当前输入样本的预测类别 pred 为 0 0 1 1
当前样本的真实类别 trueY 为 0 0 0 1
类别 0、1 的权重 weights 分别为 0.5 0.5
输出解释:
precision:0.75
recall:0.83
f1Score:0.73