#P3479. 第2题-标签样本数量
-
1000ms
Tried: 2070
Accepted: 526
Difficulty: 5
所属公司 :
华为
时间 :2025年8月27日-国内-AI
-
算法标签>机器学习算法
第2题-标签样本数量
解题思路
核心步骤
-
读入参数:k,m,n,s;读入待分类样本向量 q(维度 n);读入 m 条样本(前 n 列为特征,最后一列为标签)。
-
计算距离:对每个样本 x,计算与 q 的欧氏距离
d(q,x)=i=1∑n(qi−xi)2为了效率与不影响排序,可直接用平方距离(省去开方,单调性一致)。
-
排序取前 k:按距离从小到大排序,取前 k 个邻居。
-
投票与并列规则:统计前 k 个邻居的标签频次,找出最高频数。若有多个标签并列第一,则在这几个标签中,选择距离最近的那个邻居的标签(即在已排序的前 k 邻居中,从前往后找到第一个其标签属于“并列集合”的样本)。
-
输出:输出最终预测标签与在前 k 中该标签出现的次数,格式:“label count”。
正确性说明
- 归一化保证各维度量纲一致,欧氏距离可直接比较。
- 使用平方距离与开方距离等价于排序目的。
- 并列处理遵循题意“序列第一(最近邻)优先”。
复杂度分析
- 距离计算:O(m⋅n)
- 排序:O(mlogm)
- 统计投票:O(k)
- 总复杂度:O(mlogm+m⋅n),在 m≤200,n≤5 的限制下完全可行。
- 额外空间:存距离与样本索引 O(m)。
Python
import sys
from collections import Counter
def main():
# 读入所有标记,适配行内/换行混排
tokens = sys.stdin.read().strip().split()
it = iter(tokens)
# 基本参数
k = int(next(it)); m = int(next(it)); n = int(next(it)); s = int(next(it)) # s未直接使用
# 待分类样本 q
q = [float(next(it)) for _ in range(n)]
# 读入 m 个样本(n 个特征 + 1 个标签)
X = []
y = []
for _ in range(m):
row = [float(next(it)) for __ in range(n + 1)]
X.append(row[:n])
# 标签以 float 给出,输出需要整数格式
y.append(int(row[-1]))
# 计算平方欧氏距离,保存 (dist2, idx)
dists = []
for i in range(m):
xi = X[i]
# 平方距离即可用于排序
dist2 = 0.0
for j in range(n):
diff = q[j] - xi[j]
dist2 += diff * diff
dists.append((dist2, i))
# 按距离升序排序
dists.sort(key=lambda t: t[0])
# 取前 k 个邻居的索引与标签
top_idx = [dists[i][1] for i in range(min(k, m))]
top_labels = [y[i] for i in top_idx]
# 统计频次
cnt = Counter(top_labels)
max_freq = max(cnt.values())
# 找出并列第一的标签集合
tie_labels = {lab for lab, c in cnt.items() if c == max_freq}
# 若并列,按距离顺序选择第一个属于并列集合的邻居的标签
# dists 已整体排序,这里只需在前 k 中寻找
chosen = None
for i in range(min(k, m)):
lab = y[dists[i][1]]
if lab in tie_labels:
chosen = lab
break
# 输出:标签 与 在前 k 中该标签出现次数
print(chosen, cnt[chosen])
if __name__ == '__main__':
main()
Java
import java.io.*;
import java.util.*;
public class Main {
static class Pair {
double d2; int idx;
Pair(double d2, int idx){ this.d2 = d2; this.idx = idx; }
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
List<String> toks = new ArrayList<>();
for (String line; (line = br.readLine()) != null; ) {
line = line.trim();
if (line.isEmpty()) continue;
String[] a = line.split("\\s+");
Collections.addAll(toks, a);
}
int p = 0;
int k = Integer.parseInt(toks.get(p++));
int m = Integer.parseInt(toks.get(p++));
int n = Integer.parseInt(toks.get(p++));
int s = Integer.parseInt(toks.get(p++)); // 未直接使用
double[] q = new double[n];
for (int i = 0; i < n; i++) q[i] = Double.parseDouble(toks.get(p++));
double[][] X = new double[m][n];
int[] y = new int[m];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) X[i][j] = Double.parseDouble(toks.get(p++));
y[i] = (int)Math.round(Double.parseDouble(toks.get(p++)));
}
List<Pair> ds = new ArrayList<>(m);
for (int i = 0; i < m; i++) {
double d2 = 0;
for (int j = 0; j < n; j++) {
double diff = q[j] - X[i][j];
d2 += diff * diff;
}
ds.add(new Pair(d2, i));
}
ds.sort(Comparator.comparingDouble(o -> o.d2));
int kk = Math.min(k, m);
Map<Integer, Integer> freq = new HashMap<>();
for (int i = 0; i < kk; i++) {
int lab = y[ds.get(i).idx];
freq.put(lab, freq.getOrDefault(lab, 0) + 1);
}
int maxFreq = 0;
for (int c : freq.values()) maxFreq = Math.max(maxFreq, c);
Set<Integer> tie = new HashSet<>();
for (Map.Entry<Integer,Integer> e : freq.entrySet())
if (e.getValue() == maxFreq) tie.add(e.getKey());
int ansLab = -1;
for (int i = 0; i < kk; i++) {
int lab = y[ds.get(i).idx];
if (tie.contains(lab)) { ansLab = lab; break; }
}
System.out.println(ansLab + " " + freq.get(ansLab));
}
}
C++
#include <bits/stdc++.h>
using namespace std;
struct PairD {
double d2; int idx;
bool operator<(const PairD& o) const { return d2 < o.d2; }
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
// 读到 EOF,按空白分隔
vector<string> tok;
string s;
while (cin >> s) tok.push_back(s);
if (tok.empty()) return 0;
size_t p = 0;
int k = stoi(tok[p++]);
int m = stoi(tok[p++]);
int n = stoi(tok[p++]);
int sc = stoi(tok[p++]); // 未直接使用
vector<double> q(n);
for (int i = 0; i < n; ++i) q[i] = stod(tok[p++]);
vector<vector<double>> X(m, vector<double>(n));
vector<int> y(m);
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) X[i][j] = stod(tok[p++]);
y[i] = (int)llround(stod(tok[p++])); // 标签以浮点给出
}
vector<PairD> ds; ds.reserve(m);
for (int i = 0; i < m; ++i) {
double d2 = 0.0;
for (int j = 0; j < n; ++j) {
double diff = q[j] - X[i][j];
d2 += diff * diff;
}
ds.push_back({d2, i});
}
sort(ds.begin(), ds.end());
int kk = min(k, m);
unordered_map<int,int> freq;
freq.reserve(kk * 2 + 1);
for (int i = 0; i < kk; ++i) {
int lab = y[ds[i].idx];
++freq[lab];
}
int maxFreq = 0;
for (auto &e : freq) maxFreq = max(maxFreq, e.second);
// 并列集合
unordered_set<int> tie;
for (auto &e : freq) if (e.second == maxFreq) tie.insert(e.first);
// 最近邻优先打破并列
int ansLab = -1;
for (int i = 0; i < kk; ++i) {
int lab = y[ds[i].idx];
if (tie.count(lab)) { ansLab = lab; break; }
}
cout << ansLab << " " << freq[ansLab] << "\n";
return 0;
}
题目内容
KNN 算法的核心思想是,如果一个样本在特征空间中的 K 个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。请按照下面的步理,实现 KNN 算法。
KNN 算法说明:
计算待分类点到其他样本点的距离;
通过距离进行排序,选择距离最小的 K 个点;提取这 K 个临近点的类别,根据少数服从多数的原则,将占比最多的那个标签赋值给待分类样本点的 label 。
本题说明:
1、给定数据集中,默认每一类标签都存在数据,不存在某类型数量为 0 的场景;
2、为消除不同特征权重问题,给出数据均已做好归一化处理,并保留两位小数;
3、出现并列第一的情形时,取并列第一的样本中,最近邻居的标签返回;
4、距离函数定义为: dx,y=∑i=1n(xi−yi)2。
输入描述
第 1 行:k m n s :k 代表每次计算时选取的最近邻居个数(不大于 20 ),m 代表样本数量(不大于 200 ),n 代表样本维度(不包括标签,不大于 5 ),s 代表类别个数(不于 5 );
第 2 行:待分类样本
第 3 行~第 m+2 行:m 个样本,每一行 n+1 列,最后一列为类别标签 label
输出描述
输出待分类样本的类别标签及距离最小的 K 个点中的该标签样本数量
样例1
输入
3 10 2 3
0.81 0.64
0.19 0.2 1.0
0.18 0.14 0.0
0.76 0.58 1.0
0.4 0.16 1.0
0.98 0.85 0.0
0.42 0.97 1.0
0.75 0.26 1.0
0.24 0.06 1.0
0.97 0.8 0.0
0.21 0.1 2.0
输出
0 2
说明
第 1 行输入说明输入了 m=10 个样本,每个样本有 n=2 个维度的数据(去除最后一列标签),共有 s=3 种类别
第 2 行输入待分类样本的 n 维数据
从第 3 行到第 12 行的前两列数据为输入的 m=10 个样本,每个样本有 n=2 个维度的数据+最后一列的标签数据
待分类样本 [0.81 0.64] 最近的前 k=3 个邻居分别为:[0.76 0.58],[0.98 0.85],[0.97 0.8] ,分别有 2 个 0 号标签和 1 个 1 号标签 0 号标签占多,返回 0 以及标签 0 的样本数量 2
样例2
输入
6 10 2 4
0.78 0.63
0.57 0.07 1.0
0.5 0.13 1.0
0.83 0.07 3.0
0.27 0.87 3.0
0.81 0.44 2.0
0.21 0.73 3.0
0.45 0.91 1.0
0.12 0.22 2.0
0.25 0.48 0.0
0.54 0.87 1.0
输出
1 2
说明
本样例的距离最小的 6 个样本中,标签 1 和标签 3 出现次数都是 2 次,并列第一;虽然 [0.8 0.44] 距离样本最近,但其标签 2 不是出现最多的,排除在下一轮统计样本中此时需要从标签 1 和标签 3 中的样本中,选取距离最近的 [0.54 0.87] 的标签 1 作为返回值,并同时返回标签 1 的样本数量 2 。