#P2695. 第2题-SVM核矩阵
-
1000ms
Tried: 100
Accepted: 15
Difficulty: 7
所属公司 :
美团
时间 :2025年3月15日-算法岗
-
算法标签>模拟
第2题-SVM核矩阵
题解
题面描述
给定一个数据集,其输入格式为带中括号的二维列表,形如
[[1,2],
[3,4],
[5,6]]
每个内部列表代表一个样本的特征向量;随后输入一行字符串表示核函数类型,可能的取值为 ′linear′、′polynomial′、′rbf′。
若核函数类型为 ′polynomial′ 或 ′rbf′,则还需要额外输入一行参数:
- 对于 ′polynomial′,输入一个整数参数 d(多项式次数);
- 对于 ′rbf′,输入一个浮点数参数 γ(高斯核参数)。
要求根据以下公式计算核矩阵:
- 线性核:K(x,y)=xTy
- 多项式核:K(x,y)=(γ⋅xTy+r)d (默认 γ=1,r=0,d=2)
- 高斯核(RBF):K(x,y)=exp(−γ∣x−y∣2) (默认γ=0.5)
计算得到的核矩阵为一个 N×N 的二维列表
思路
一段话总结:本题要求从带中括号的多行字符串中解析出二维数据集,并根据输入的核函数类型(线性、多项式或 RBF)以及相应参数,通过手动实现向量点积和欧氏距离平方的计算,构造出支持向量机核矩阵
-
输入解析
- 使用类似 Python 中的 ast.literaleval(或手动解析)将字符串转换为二维列表。
- 随后读取核函数类型,若输入时带有引号则去除;若为 ′polynomial′ 或 ′rbf′,再读取额外参数。
-
核矩阵计算
- 对于数据集中每个样本向量 xi 和 xj,根据核函数类型采用不同公式计算核值。
- 线性核直接计算点积:xiTxj。
- 多项式核计算:先计算点积,再执行 K(xi,xj)=(γ⋅xiTxj+r)d;(这里默认 γ=1,r=0)。
- RBF 核计算:先计算向量之间的欧氏距离平方,再计算指数函数:K(xi,xj)=exp(−γ∣xi−xj∣2)。
-
输出
- 将计算好的核矩阵以二维列表形式输出
cpp
#include <iostream>
#include <sstream>
#include <vector>
#include <string>
#include <algorithm> // 包含 count 函数所需头文件
#include <cmath>
#include <iomanip>
using namespace std;
// 函数:读取整块输入并解析出矩阵字符串和剩余输入行
void extractInput(string &matrixStr, vector<string> &restLines) {
string line;
vector<string> lines;
while(getline(cin, line)) {
lines.push_back(line);
}
// 利用计数法提取矩阵部分(从第一个 '[' 开始,到所有 '[' 与 ']' 平衡)
int bracketCount = 0;
bool matrixStarted = false;
stringstream matrixSS;
size_t i = 0;
for(; i < lines.size(); i++){
string s = lines[i];
if(!matrixStarted && s.find("[") != string::npos) {
matrixStarted = true;
}
if(matrixStarted) {
// 统计该行中 '[' 与 ']' 的数量
bracketCount += count(s.begin(), s.end(), '[');
bracketCount -= count(s.begin(), s.end(), ']');
matrixSS << s;
if(bracketCount == 0) {
i++;
break;
}
}
}
matrixStr = matrixSS.str();
// 剩余的行保存到 restLines 中
for(; i < lines.size(); i++){
if(!lines[i].empty())
restLines.push_back(lines[i]);
}
}
vector<vector<double>> parseMatrix(const string &matrixStr) {
// 去掉最外层的 "[[" 与 "]]"
size_t start = matrixStr.find("[[");
size_t end = matrixStr.rfind("]]");
string inner = matrixStr.substr(start + 2, end - start - 2);
// inner 形如 "1,2],[3,4],[5,6"
vector<vector<double>> data;
stringstream ss(inner);
string rowStr;
while(getline(ss, rowStr, ']')) {
// 去除可能出现的 "[" 和逗号前后的空格
size_t pos = rowStr.find("[");
if(pos != string::npos) {
rowStr = rowStr.substr(pos + 1);
}
if(rowStr.empty()) continue;
vector<double> row;
stringstream rowStream(rowStr);
string numStr;
while(getline(rowStream, numStr, ',')) {
if(numStr.empty()) continue;
row.push_back(stod(numStr));
}
if(!row.empty()) data.push_back(row);
}
return data;
}
double dotProduct(const vector<double>& a, const vector<double>& b) {
double result = 0.0;
for (size_t i = 0; i < a.size(); i++) {
result += a[i] * b[i];
}
return result;
}
double euclideanDistanceSquared(const vector<double>& a, const vector<double>& b) {
double sum = 0.0;
for (size_t i = 0; i < a.size(); i++) {
double diff = a[i] - b[i];
sum += diff * diff;
}
return sum;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
string matrixStr;
vector<string> restLines;
extractInput(matrixStr, restLines);
// 解析数据集
vector<vector<double>> data = parseMatrix(matrixStr);
int N = data.size();
// 读取核函数类型
string kernelType;
if(!restLines.empty()) {
kernelType = restLines[0];
// 如果带引号则去掉
if(kernelType.front()=='\'' && kernelType.back()=='\'')
kernelType = kernelType.substr(1, kernelType.size()-2);
} else {
cerr << "未提供核函数类型" << endl;
return 1;
}
// 默认参数:多项式核默认 $d=2$,RBF 核默认 $\gamma=0.5$,多项式核中的 $r=0$
int d = 2;
double gamma = 0.5;
double r = 0.0;
if(kernelType == "polynomial" && restLines.size() >= 2) {
d = stoi(restLines[1]);
} else if(kernelType == "rbf" && restLines.size() >= 2) {
gamma = stod(restLines[1]);
}
// 构造核矩阵
vector<vector<double>> K(N, vector<double>(N, 0.0));
for (int i = 0; i < N; i++){
for (int j = 0; j < N; j++){
double val = 0.0;
if(kernelType == "linear"){
// 线性核:$K(x,y)=x^T y$
val = dotProduct(data[i], data[j]);
} else if(kernelType == "polynomial"){
// 多项式核:$K(x,y)=(\gamma\cdot x^T y + r)^d$
double dp = dotProduct(data[i], data[j]);
double tmp = 1.0 * dp + r; // 此处默认 $\gamma=1$,$r=0$
val = pow(tmp, d);
} else if(kernelType == "rbf"){
// RBF核:$K(x,y)=\exp(-\gamma\|x-y\|^2)$
double distSq = euclideanDistanceSquared(data[i], data[j]);
val = exp(-gamma * distSq);
}
// 根据核函数类型选择保留位数:
// 高斯核保留两位小数,其它核保留整数
if(kernelType == "rbf") {
val = round(val * 100.0) / 100.0;
} else {
val = round(val);
}
K[i][j] = val;
}
}
// 输出核矩阵,格式为 [[...],[...],...]
cout << "[";
for (int i = 0; i < N; i++){
cout << "[";
for (int j = 0; j < N; j++){
if(kernelType == "rbf") {
cout << fixed << setprecision(2) << K[i][j];
} else {
// 输出整数
cout << static_cast<int>(K[i][j]);
}
if(j < N - 1) cout << ",";
}
cout << "]";
if(i < N - 1) cout << ",";
}
cout << "]";
return 0;
}
python
import sys
import ast
import math
def extract_input():
"""
读取所有输入行,并提取矩阵字符串和剩余行。
利用计数法从第一个 '[' 开始,到所有 '[' 与 ']' 平衡为止。
返回 (matrix_str, rest_lines)
"""
lines = sys.stdin.read().splitlines()
matrix_lines = []
bracket_count = 0
matrix_started = False
i = 0
while i < len(lines):
line = lines[i].strip()
if not matrix_started and '[' in line:
matrix_started = True
if matrix_started:
bracket_count += line.count('[')
bracket_count -= line.count(']')
matrix_lines.append(line)
if bracket_count == 0:
i += 1
break
i += 1
matrix_str = " ".join(matrix_lines)
rest_lines = [l.strip() for l in lines[i:] if l.strip() != ""]
return matrix_str, rest_lines
def parse_matrix(matrix_str):
"""
将矩阵字符串转换为二维列表。
这里直接使用 ast.literal_eval 解析字符串。
"""
try:
data = ast.literal_eval(matrix_str)
except Exception as e:
print("数据集格式错误:", e)
sys.exit(1)
return data
def dot_product(x, y):
"""计算两个向量的点积 $x^T y$"""
return sum(a * b for a, b in zip(x, y))
def euclidean_distance_squared(x, y):
"""计算两个向量的欧氏距离平方 $\|x-y\|^2$"""
return sum((a - b)**2 for a, b in zip(x, y))
def main():
matrix_str, rest_lines = extract_input()
data = parse_matrix(matrix_str)
N = len(data)
if not rest_lines:
print("未提供核函数类型")
sys.exit(1)
kernel_type = rest_lines[0]
# 如果核函数类型带引号,则去除
if kernel_type.startswith("'") and kernel_type.endswith("'"):
kernel_type = kernel_type[1:-1]
# 默认参数:多项式核默认 $d=2$,RBF核默认 $\gamma=0.5$,多项式核中的 $r=0$
d = 2
gamma = 0.5
r = 0.0
if kernel_type == "polynomial" and len(rest_lines) >= 2:
try:
d = int(rest_lines[1])
except Exception as e:
print("多项式核次数参数错误:", e)
sys.exit(1)
elif kernel_type == "rbf" and len(rest_lines) >= 2:
try:
gamma = float(rest_lines[1])
except Exception as e:
print("RBF核参数 $\gamma$ 错误:", e)
sys.exit(1)
# 构造核矩阵
K = []
for i in range(N):
row = []
for j in range(N):
if kernel_type == "linear":
# 线性核: $K(x,y)=x^T y$
val = dot_product(data[i], data[j])
elif kernel_type == "polynomial":
# 多项式核: $K(x,y)=(\gamma\cdot x^T y + r)^d$
dp = dot_product(data[i], data[j])
tmp = 1.0 * dp + r # 此处默认 $\gamma=1$,$r=0$
val = tmp ** d
elif kernel_type == "rbf":
# RBF核: $K(x,y)=\exp(-\gamma\|x-y\|^2)$
dist_sq = euclidean_distance_squared(data[i], data[j])
val = math.exp(-gamma * dist_sq)
else:
print("未知的核函数类型:", kernel_type)
sys.exit(1)
# 根据核函数类型选择保留位数:
# 高斯核保留两位小数,其它核保留整数
if kernel_type == "rbf":
val = round(val, 2)
else:
val = round(val)
row.append(val)
K.append(row)
# 输出核矩阵,格式为 [[...],[...],...]
output = "["
for i in range(N):
output += "["
for j in range(N):
if kernel_type == "rbf":
output += f"{K[i][j]:.2f}"
else:
output += f"{int(K[i][j])}"
if j < N - 1:
output += ","
output += "]"
if i < N - 1:
output += ","
output += "]"
print(output)
if __name__ == "__main__":
main()
java
import java.util.*;
import java.io.*;
import java.text.*;
public class Main {
// 读取所有输入行
public static List<String> readInput() throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
List<String> lines = new ArrayList<>();
String line;
while ((line = br.readLine()) != null && !line.isEmpty()) {
lines.add(line);
}
return lines;
}
// 从输入行中提取矩阵字符串和剩余行(利用计数法判断 '[' 与 ']' 的平衡)
public static String extractMatrix(List<String> lines, List<String> restLines) {
StringBuilder matrixSB = new StringBuilder();
int bracketCount = 0;
boolean matrixStarted = false;
int i = 0;
for(; i < lines.size(); i++){
String s = lines.get(i).trim();
if(!matrixStarted && s.contains("[")) {
matrixStarted = true;
}
if(matrixStarted){
bracketCount += countChar(s, '[');
bracketCount -= countChar(s, ']');
matrixSB.append(s);
if(bracketCount == 0){
i++;
break;
}
}
}
for(; i < lines.size(); i++){
if(!lines.get(i).trim().isEmpty())
restLines.add(lines.get(i).trim());
}
return matrixSB.toString();
}
// 辅助方法:统计字符串中某字符的个数
public static int countChar(String s, char c) {
int count = 0;
for(char ch : s.toCharArray()){
if(ch == c) count++;
}
return count;
}
// 解析矩阵字符串,返回二维列表(List<List<Double>>)
public static List<List<Double>> parseMatrix(String matrixStr) {
// 去掉最外层的 "[[" 与 "]]"
int start = matrixStr.indexOf("[[");
int end = matrixStr.lastIndexOf("]]");
if(start == -1 || end == -1) {
return new ArrayList<>();
}
String inner = matrixStr.substring(start + 2, end);
// inner 形如 "1,2],[3,4],[5,6"
List<List<Double>> data = new ArrayList<>();
String[] rows = inner.split("\\],\\[");
for(String rowStr : rows){
String[] nums = rowStr.split(",");
List<Double> row = new ArrayList<>();
for(String num : nums){
if(!num.trim().isEmpty()){
row.add(Double.parseDouble(num.trim()));
}
}
if(!row.isEmpty()) data.add(row);
}
return data;
}
// 计算两个向量的点积 $x^T y$
public static double dotProduct(List<Double> a, List<Double> b) {
double sum = 0.0;
for (int i = 0; i < a.size(); i++){
sum += a.get(i) * b.get(i);
}
return sum;
}
// 计算两个向量的欧氏距离平方 $\|x-y\|^2$
public static double euclideanDistanceSquared(List<Double> a, List<Double> b) {
double sum = 0.0;
for (int i = 0; i < a.size(); i++){
double diff = a.get(i) - b.get(i);
sum += diff * diff;
}
return sum;
}
public static void main(String[] args) throws IOException {
List<String> lines = readInput();
if(lines.isEmpty()){
System.out.println("未输入数据");
return;
}
List<String> restLines = new ArrayList<>();
String matrixStr = extractMatrix(lines, restLines);
List<List<Double>> data = parseMatrix(matrixStr);
int N = data.size();
// 读取核函数类型
String kernelType = "";
if(!restLines.isEmpty()){
kernelType = restLines.get(0);
// 去除引号
if(kernelType.startsWith("'") && kernelType.endsWith("'")){
kernelType = kernelType.substring(1, kernelType.length()-1);
}
} else {
System.out.println("未提供核函数类型");
return;
}
// 默认参数:多项式核默认 $d=2$,RBF核默认 $\gamma=0.5$,多项式核中的 $r=0$
int d = 2;
double gamma = 0.5;
double r = 0.0;
if(kernelType.equals("polynomial") && restLines.size() >= 2){
d = Integer.parseInt(restLines.get(1));
} else if(kernelType.equals("rbf") && restLines.size() >= 2){
gamma = Double.parseDouble(restLines.get(1));
}
double[][] K = new double[N][N];
for(int i = 0; i < N; i++){
for(int j = 0; j < N; j++){
double val = 0.0;
if(kernelType.equals("linear")){
// 线性核:$K(x,y)=x^T y$
val = dotProduct(data.get(i), data.get(j));
} else if(kernelType.equals("polynomial")){
// 多项式核:$K(x,y)=(\gamma\cdot x^T y + r)^d$
double dp = dotProduct(data.get(i), data.get(j));
double tmp = 1.0 * dp + r; // 此处默认 $\gamma=1$,$r=0$
val = Math.pow(tmp, d);
} else if(kernelType.equals("rbf")){
// RBF核:$K(x,y)=\exp(-\gamma\|x-y\|^2)$
double distSq = euclideanDistanceSquared(data.get(i), data.get(j));
val = Math.exp(-gamma * distSq);
} else {
System.out.println("未知的核函数类型: " + kernelType);
return;
}
// 根据核函数类型选择保留位数:
// 高斯核保留两位小数,其它核保留整数
if(kernelType.equals("rbf")){
val = Math.round(val * 100.0) / 100.0;
} else {
val = Math.round(val);
}
K[i][j] = val;
}
}
// 输出核矩阵,格式为 [[...],[...],...]
StringBuilder output = new StringBuilder();
output.append("[");
for(int i = 0; i < N; i++){
output.append("[");
for(int j = 0; j < N; j++){
if(kernelType.equals("rbf")){
output.append(String.format("%.2f", K[i][j]));
} else {
output.append(String.format("%d", (int)K[i][j]));
}
if(j < N - 1)
output.append(",");
}
output.append("]");
if(i < N - 1)
output.append(",");
}
output.append("]");
System.out.println(output.toString());
}
}
题目内容
实现一个计算支持向量机(SVM)核矩阵的系统,具体要求如下:
-
读取输入数据集,由数据点组成,每个数据点是一个特征向量。
-
读取核函数类型,可以是'linear’(线性核)、'polynomial'(多项式核)或'rbf'(高斯核)。
-
根据核函数类型计算核矩阵:
-
线性核:计算两个向量的内积。
-
多项式核:计算形式为:
K(x,y)=(γxTy+r)d 的核函数,其中默认参数为γ=1、r=0、d=2。
-
高斯核(RBF核):计算形式为K(x,y)=exp(−r∣∣x−y∣∣2)的核函数,其中 默认参数为γ=0.5。
4.输出核矩阵,对于高斯核矩阵每个元素保留两位小数,使用round(x,2)
(不允许使用numpy库、sklearn库)。
输入描述
输入包括:
-
一个数据集,表示为一个二维列表,每个内部列表代表一个数据点的特征向量。
-
一个字符串,表示核函数类型'linear'、'polynomial'或'rbf'。
-
如果核函数为'polynomial'或'rbf',可能需要额外的参数:
-
对于'polynomial',需要整数参数d,表示多项式的次数。
-
对于'rbf',需要浮点数参数gamma,表示高斯核的参数
输出描述
输出为一个二维列表,表示计算得到的核矩阵,对于高斯核矩阵每个元素保留两位小数,使用round(x,2)
补充说明
- 线性核的计算公式:K(x,y)=xTy
- 多项式核的计算公式(默认参数γ=1,r=0,d=2): K(x,y)=(γxTy+r)d
- 高斯核(RBF核)的计算公式(默认参数γ=0.5) : K(x,y)=exp(−γ∣∣x−y∣∣2)
- 在实现中,必要时可以自行设置默认参数。
样例1
输入
[[1,2],
[3,4],
[5,6]]
'linear'
输出
[[5,11,17],[11,25,39],[17,39,61]]