#P4572. 第3题-动态区间的多项式岭回归
-
1000ms
Tried: 8
Accepted: 2
Difficulty: 9
所属公司 :
华为
时间 :2026年3月4日-AI方向
-
算法标签>数学
第3题-动态区间的多项式岭回归
解题思路
题目要求对每个缺失点 pos,只使用它左右两侧、被“最近缺失点”截断后的真实数据区间来训练模型并预测。
-
定位缺失位置 读入 N 行数据,数值表示真实流量;字符串
Missing_i表示第i个缺失值,记录其所在天序号pos[i](1-based)。 -
按题意构造训练区间(动态变化)
- 前向区间
[left_start, pos-1]:从pos-1向前找,遇到第一个任意缺失值位置k就停止,则left_start = k+1;若没遇到缺失则left_start = 1。 - 后向区间
[pos+1, right_end]:从pos+1向后找,遇到第一个任意缺失值位置k就停止,则right_end = k-1;若没遇到缺失则right_end = N。
- 训练二阶多项式岭回归(算法:Ridge Regression)
- 训练样本来自上述两段区间内的所有真实记录
(x, y),其中x为天序号,y为流量值。 - 模型:
y_hat = β2*x^2 + β1*x + β0 - 设计矩阵每行:
[x^2, x, 1] - 岭回归闭式解(题面给定):
β = (X^T X + λI)^(-1) X^T y,其中λ = 0.1,I为 3×3 单位阵。 - 因为只有 3 个系数,直接累加得到
A = X^T X + λI(3×3)和b = X^T y(3×1),再用高斯消元解线性方程Aβ=b。
- 输出
按
Missing_1 ... Missing_M顺序输出:Missing_i: xx.xx(保留两位小数)。
复杂度分析
设对某个缺失值可用训练点数为 K(最多约 N)。
- 时间复杂度:对每个缺失值扫描左右边界 O(N),累加矩阵 O(K),解 3×3 方程 O(1),总计
O(M*N)(N=200,M≤30 很小)。 - 空间复杂度:存储 N 个数据与缺失位置,
O(N)。
代码实现
import sys
LAMBDA = 0.1
def solve_3x3(A, b):
# 高斯消元解 3x3 线性方程 A*x=b(A 会被改写)
n = 3
# 增广矩阵
M = [A[i][:] + [b[i]] for i in range(n)]
for col in range(n):
# 选主元(简单起见找绝对值最大行)
pivot = col
for r in range(col, n):
if abs(M[r][col]) > abs(M[pivot][col]):
pivot = r
M[col], M[pivot] = M[pivot], M[col]
# 归一化
div = M[col][col]
for j in range(col, n + 1):
M[col][j] /= div
# 消元
for r in range(n):
if r == col:
continue
factor = M[r][col]
if factor == 0:
continue
for j in range(col, n + 1):
M[r][j] -= factor * M[col][j]
return [M[i][n] for i in range(n)] # 解向量
def predict_for_pos(pos, is_missing, values, N):
# 1) 找 left_start
j = pos - 1
while j >= 1 and (not is_missing[j]):
j -= 1
left_start = 1 if j < 1 else j + 1
# 2) 找 right_end
j = pos + 1
while j <= N and (not is_missing[j]):
j += 1
right_end = N if j > N else j - 1
# 3) 累加 A=X^T X + λI, b=X^T y
A = [[0.0] * 3 for _ in range(3)]
b = [0.0] * 3
def add_point(x, y):
x2 = x * x
feat = [x2, x, 1.0]
# A += feat^T * feat
for r in range(3):
for c in range(3):
A[r][c] += feat[r] * feat[c]
# b += feat^T * y
for r in range(3):
b[r] += feat[r] * y
# 左区间
for x in range(left_start, pos):
add_point(float(x), values[x])
# 右区间
for x in range(pos + 1, right_end + 1):
add_point(float(x), values[x])
# 加 λI
for i in range(3):
A[i][i] += LAMBDA
beta2, beta1, beta0 = solve_3x3(A, b)
x = float(pos)
return beta2 * x * x + beta1 * x + beta0
def main():
data = sys.stdin.read().strip().split()
if not data:
return
M = int(data[0])
N = int(data[1])
arr = data[2:2 + N]
is_missing = [False] * (N + 1)
values = [0.0] * (N + 1)
pos_of = [0] * (M + 1) # pos_of[i] = 位置
for i in range(1, N + 1):
s = arr[i - 1]
if s.startswith("Missing_"):
is_missing[i] = True
idx = int(s.split("_")[1])
pos_of[idx] = i
else:
values[i] = float(s)
out_lines = []
for i in range(1, M + 1):
pos = pos_of[i]
yhat = predict_for_pos(pos, is_missing, values, N)
out_lines.append(f"Missing_{i}: {yhat:.2f}")
sys.stdout.write("\n".join(out_lines))
if __name__ == "__main__":
main()
#include <bits/stdc++.h>
using namespace std;
static const double LAMBDA = 0.1;
// 高斯消元解 3x3 线性方程 A*x=b
static array<double,3> solve3x3(double A[3][3], double b[3]) {
double M[3][4];
for(int i=0;i<3;i++){
for(int j=0;j<3;j++) M[i][j]=A[i][j];
M[i][3]=b[i];
}
for(int col=0; col<3; col++){
// 选主元
int pivot = col;
for(int r=col; r<3; r++){
if (fabs(M[r][col]) > fabs(M[pivot][col])) pivot = r;
}
for(int j=col; j<4; j++) swap(M[col][j], M[pivot][j]);
// 归一化
double div = M[col][col];
for(int j=col; j<4; j++) M[col][j] /= div;
// 消元
for(int r=0; r<3; r++){
if(r==col) continue;
double factor = M[r][col];
if(factor==0) continue;
for(int j=col; j<4; j++){
M[r][j] -= factor * M[col][j];
}
}
}
return {M[0][3], M[1][3], M[2][3]};
}
static void addPoint(double A[3][3], double b[3], double x, double y){
double x2 = x*x;
double feat[3] = {x2, x, 1.0};
// A += feat^T * feat
for(int r=0;r<3;r++){
for(int c=0;c<3;c++){
A[r][c] += feat[r]*feat[c];
}
}
// b += feat^T * y
for(int r=0;r<3;r++){
b[r] += feat[r]*y;
}
}
static double predictForPos(int pos, const vector<bool>& isMissing,
const vector<double>& values, int N){
// 1) 找 left_start
int j = pos - 1;
while(j>=1 && !isMissing[j]) j--;
int leftStart = (j<1) ? 1 : (j+1);
// 2) 找 right_end
j = pos + 1;
while(j<=N && !isMissing[j]) j++;
int rightEnd = (j>N) ? N : (j-1);
// 3) 累加 A=X^T X + λI, b=X^T y
double A[3][3] = {{0,0,0},{0,0,0},{0,0,0}};
double b[3] = {0,0,0};
for(int x=leftStart; x<=pos-1; x++){
addPoint(A, b, (double)x, values[x]);
}
for(int x=pos+1; x<=rightEnd; x++){
addPoint(A, b, (double)x, values[x]);
}
// 加 λI
for(int i=0;i<3;i++) A[i][i] += LAMBDA;
auto beta = solve3x3(A, b);
double xx = (double)pos;
return beta[0]*xx*xx + beta[1]*xx + beta[2];
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int M, N;
cin >> M >> N;
vector<bool> isMissing(N+1,false);
vector<double> values(N+1,0.0);
vector<int> posOf(M+1,0);
for(int i=1;i<=N;i++){
string s;
cin >> s;
if(s.rfind("Missing_", 0) == 0){
isMissing[i] = true;
int idx = stoi(s.substr(s.find('_')+1));
posOf[idx] = i;
}else{
values[i] = stod(s);
}
}
cout.setf(std::ios::fixed);
cout << setprecision(2);
for(int i=1;i<=M;i++){
int pos = posOf[i];
double ans = predictForPos(pos, isMissing, values, N);
cout << "Missing_" << i << ": " << ans << "\n";
}
return 0;
}
import java.io.*;
import java.util.*;
public class Main {
static final double LAMBDA = 0.1;
// 高斯消元解 3x3 线性方程 A*x=b
static double[] solve3x3(double[][] A, double[] b) {
int n = 3;
double[][] M = new double[n][n + 1];
for (int i = 0; i < n; i++) {
System.arraycopy(A[i], 0, M[i], 0, n);
M[i][n] = b[i];
}
for (int col = 0; col < n; col++) {
// 选主元
int pivot = col;
for (int r = col; r < n; r++) {
if (Math.abs(M[r][col]) > Math.abs(M[pivot][col])) pivot = r;
}
double[] tmp = M[col];
M[col] = M[pivot];
M[pivot] = tmp;
// 归一化
double div = M[col][col];
for (int j = col; j <= n; j++) M[col][j] /= div;
// 消元
for (int r = 0; r < n; r++) {
if (r == col) continue;
double factor = M[r][col];
if (factor == 0) continue;
for (int j = col; j <= n; j++) {
M[r][j] -= factor * M[col][j];
}
}
}
return new double[]{M[0][n], M[1][n], M[2][n]};
}
static double predictForPos(int pos, boolean[] isMissing, double[] values, int N) {
// 1) 找 left_start
int j = pos - 1;
while (j >= 1 && !isMissing[j]) j--;
int leftStart = (j < 1) ? 1 : (j + 1);
// 2) 找 right_end
j = pos + 1;
while (j <= N && !isMissing[j]) j++;
int rightEnd = (j > N) ? N : (j - 1);
// 3) 累加 A=X^T X + λI, b=X^T y
double[][] A = new double[3][3];
double[] b = new double[3];
// 加一个训练点
for (int x = leftStart; x <= pos - 1; x++) {
addPoint(A, b, x, values[x]);
}
for (int x = pos + 1; x <= rightEnd; x++) {
addPoint(A, b, x, values[x]);
}
// 加 λI
for (int i = 0; i < 3; i++) A[i][i] += LAMBDA;
double[] beta = solve3x3(A, b);
double xx = pos;
return beta[0] * xx * xx + beta[1] * xx + beta[2];
}
static void addPoint(double[][] A, double[] b, double x, double y) {
double x2 = x * x;
double[] feat = new double[]{x2, x, 1.0};
// A += feat^T * feat
for (int r = 0; r < 3; r++) {
for (int c = 0; c < 3; c++) {
A[r][c] += feat[r] * feat[c];
}
}
// b += feat^T * y
for (int r = 0; r < 3; r++) {
b[r] += feat[r] * y;
}
}
public static void main(String[] args) throws Exception {
Scanner sc = new Scanner(System.in);
int M = sc.nextInt();
int N = sc.nextInt();
boolean[] isMissing = new boolean[N + 1];
double[] values = new double[N + 1];
int[] posOf = new int[M + 1];
for (int i = 1; i <= N; i++) {
String s = sc.next();
if (s.startsWith("Missing_")) {
isMissing[i] = true;
int idx = Integer.parseInt(s.substring(s.indexOf('_') + 1));
posOf[idx] = i;
} else {
values[i] = Double.parseDouble(s);
}
}
StringBuilder sb = new StringBuilder();
for (int i = 1; i <= M; i++) {
int pos = posOf[i];
double ans = predictForPos(pos, isMissing, values, N);
sb.append("Missing_").append(i).append(": ")
.append(String.format(Locale.US, "%.2f", ans));
if (i != M) sb.append('\n');
}
System.out.print(sb.toString());
}
}
题目内容
某大型互联网公司的数据中心,记录了其核心服务在连续 N=200 天内的出口总流量(单位 GB,取值范围通常在 100.00 到 500.00 之间)已按时间顺序给出。由于监控系统维护,数据中有 M 个( M 的范围为 20 到 30 )缺失值,按顺序标记为 Missing_1,Missing_2,...,Missing_M。
已知这些缺失值保证不会出现在第 1 天和最后 1 天(即首尾两条记录一定存在)
任务描述
你需要为每一个缺失值,通过其邻近的、动态变化的真实数据区间,建立一个 二阶多项式岭回归(2nd Ridge Regression)模型进行预测。
区间定义
对位于全局序号 pos 的缺失值:
前向区间[left_start,pos−1]:从 pos−1 开始向前(向着第 1 天)寻找,遇到的第一个原始缺失值(Missing_1...Missing_M 中的任意一个)的后一天。如果到第 1 天仍未碰到任何原始缺失值,则前向区间的起始点 left_start 为第 1 天。
后向区间[pos+1,right_end]:从 pos+1 开始向后(向着第 N 天)寻找,遇到的第一个原始缺失值的前一天。如果到第 N 天仍未碰到任何原始缺失值则后向区间的结束点 right_end 为第 N 天。
算法与公式
取上述两个区间内的所有真实记录作为训练集 (x,y),其中:
x 为日期序号 (1,2,...,N)
y 为对应的流量值
你需要使用岭回归求解二阶多项式模型求岭回归模型
Latex: \hat{y} = \beta_2 x^2 + \beta_1 x + \beta_0
岭回归的解可以通过以下矩阵公式计算:
β=(XTX+λI)−1XTy ,
Latex:\beta=(X^T X + \lambda)^-1 X^T y
其中:
-
β 是一个 3×1 的列向量,包含需要求解的系数 [β_2 β_1 β_0]。
-
X 是一个 n×3 的设计矩阵,其中 n 是训练集的数据点数量。对于训练集中的每一个日期序号,矩阵 X 中对应的一行为 [xi2,xi,1]。
-
y 是一个 n×1 的列向量,包含训练集中 n 个观测点的流量值。
-
X^T 是 X 的转置矩阵。
-
(lambda) 是正则化参数,在本题中统一设为 lambda=0.1 。
-
1 是一个 3×3 的单位矩阵。
-
(.)^{-1} 表示矩阵求逆。
输入描述
-
第 1 行:两个整数 M 和 N,由空格分隔。第一个参数 M 指的是缺失值的总数(范围为 20 到 30 );第二个参数 N 是指后续的数据行数(固定为 200)
-
第 2 行到第 N+1 行:每行包含一个值,该值可以是
- 一个浮点数,代表当日的真实流量值。
- 一个字符串,格式为 Missing_i (其中 i 从 1 到 M),代表当日的流量数据缺失。
输出描述
共 M 行,严格按照 Missing_1,Missing_2,...,Missing_M 的顺序输出。
每行格式为 Missing_i:xxx.xx,即标签、冒号、空格、预测值。预测值要求保留两位小数。
样例1
输入
20 200
140.36
146.38
167.91
162.64
181.99
166.79
Missing_1
156.46
175.24
165.52
157.71
Missing_2
158.26
169.09
142.55
151.18
148.18
Missing_3
140.23
146.42
135.47
Missing_4
130.90
138.79
133.65
129.18
151.72
142.50
133.01
157.68
Missing_5
157.02
169.40
168.70
178.77
160.13
174.77
174.48
162.20
167.09
181.81
160.76
172.85
167.83
167.38
164.35
140.30
160.63
143.56
142.56
133.02
133.61
Missing_6
130.73
143.76
146.32
136.02
Missing_7
151.45
143.21
147.88
164.99
176.53
177.58
163.27
Missing_8
163.40
167.27
182.03
189.90
175.84
181.42
171.06
160.66
161.04
159.17
156.67
Missing_9
140.52
153.13
135.72
153.66
136.88
143.00
Missing_10
147.52
136.38
152.19
Missing_11
140.37
151.19
155.24
Missing_12
176.08
166.01
174.35
186.10
189.84
Missing_13
167.38
180.46
184.17
167.70
158.32
170.87
159.46
152.25
164.62
159.22
160.63
155.92
132.63
146.97
128.47
133.05
134.12
145.20
161.01
153.34
152.31
160.25
157.89
162.57
159.33
188.02
188.42
Missing_14
190.36
172.49
179.07
186.54
174.78
189.76
179.46
169.32
Missing_15
166.40
174.29
147.45
140.39
166.35
150.74
133.56
158.77
140.73
153.93
136.37
143.02
168.03
162.22
173.28
176.61
159.22
173.93
179.96
169.60
178.89
190.53
202.52
200.04
187.90
Missing_16
184.13
193.93
170.60
183.11
178.36
170.28
174.84
160.06
169.08
159.11
Missing_17
140.99
148.42
156.97
144.91
144.61
169.12
152.68
176.46
Missing_18
165.14
170.70
171.10
182.38
181.63
196.53
Missing_19
180.73
182.49
192.30
184.48
178.30
192.26
193.45
188.63
Missing_20
176.12
173.62
输出
Missing_1: 175.81
Missing_2: 168.12
Missing_3: 150.08
Missing_4: 138.62
Missing_5: 158.61
Missing_6: 141.85
Missing_7: 146.87
Missing_8: 166.18
Missing_9: 155.56
Missing_10: 144.75
Missing_11: 147.16
Missing_12: 160.65
Missing_13: 166.72
Missing_14: 169.88
Missing_15: 166.19
Missing_16: 174.53
Missing_17: 164.11
Missing_18: 167.63
Missing_19: 181.34
Missing_20: 182.15
说明
Missing_1:175.81
Missing_2:168.12
Missing_3:150.08
Missing_4:138.62
Missing_5:158.61
Missing_6:141.85
Missing_7:146.87
Missing_8:166.18
Missing_9:155.56
Missing_10:144.75
Missing_11:147.16
Missing_12:160.65
Missing_13:166.72
Missing_14:169.88
Missing_15:166.19
Missing_16:174.53
Missing_17:164.11
Missing_18:167.63
Missing_19:181.34
Missing_20:182.15