#P3552. 第3题-云存储设备故障预测
-
1000ms
Tried: 1307
Accepted: 139
Difficulty: 6
所属公司 :
华为
时间 :2025年9月3日-国内-AI
-
算法标签>机器学习算法
第3题-云存储设备故障预测
解题思路
1) 数据清洗(按列统计 ➜ 按行替换)
-
将每条日志按逗号切分:第 1 列为设备ID;最后一列为标签
y∈{0,1};中间前 5 列依次为特征: 写入次数、读取次数、平均写入延迟(ms)、平均读取延迟(ms)、设备使用年限(年)。 (若行里意外多出字段,取前 5 个数值作特征、最后一个作标签,兼容样例1) -
缺失值填充:若特征中出现字符串
"NaN",视为缺失,用该列有效值的均值填充(仅用训练集估计)。 -
异常值矫正:按规则判定异常并用该列有效值(在合法区间内)中位数替换:
- 写入/读取次数:< 0
- 平均写入/读取延迟:< 0 或 > 1000
- 设备使用年限:< 0 或 > 20
-
测试集用训练集的均值/中位数进行同样处理,保证一致性。
2) 模型与训练
-
使用带偏置项的逻辑回归(Logistic Regression),损失为对数损失。
-
优化:批量梯度下降(Batch GD)
- 学习率 α = 0.01,迭代 100 次;
- 参数初始化为 0;
- 梯度:对每次迭代,用全量样本累加梯度再更新。
-
预测:
sigmoid(z) ≥ 0.5判为 1,否则 0。
3) 复杂度分析
-
设训练样本数
N (≤100),特征数d=5,迭代T=100:- 统计均值/中位数:
O(N*d log N)(中位数排序或用选择算法可降到线性) - 训练:
O(T * N * d) - 预测:
O(M * d)(M ≤ 10)
- 统计均值/中位数:
-
在本题数据范围内,时间与内存都非常充裕。
4) 边界与实现细节
- 标签永远取最后一列;特征只取ID 后的前 5 个数值(与样例1兼容)。
- 若某列“有效值”为空(极端情况),中位数回退为该列均值,再不行则 0。
- 数值转换时忽略空格;
"NaN"(大小写敏感)按缺失处理。 sigmoid计算做简单溢出保护(例如截断z)。
参考实现
Python 实现
import sys, math
def parse_line(line):
parts = [p.strip() for p in line.strip().split(',')]
if not parts: return None
id_ = parts[0]
if len(parts) < 7:
# 不足字段,直接跳过(题面不会出现)
return None
# 特征:紧跟在ID后的前5个数值;标签:最后一个
feats_raw = parts[1:6]
y_raw = parts[-1]
def to_num(s):
if s == "NaN": return None
try:
return float(s)
except:
return None
x = [to_num(v) for v in feats_raw]
# 标签按最后一列,容忍浮点写法
y = 0
try:
y = int(float(y_raw))
except:
y = 0
return id_, x, y
# 合法区间判断
def valid(col, v):
if v is None: return False
if col in (0,1): # 写/读次数
return v >= 0
if col in (2,3): # 延迟
return 0 <= v <= 1000
if col == 4: # 年限
return 0 <= v <= 20
return True
def median(vals):
n = len(vals)
if n == 0: return 0.0
vals2 = sorted(vals)
mid = n // 2
if n % 2 == 1:
return vals2[mid]
else:
return 0.5 * (vals2[mid - 1] + vals2[mid])
def sigmoid(z):
# 简单数值稳定
if z > 30: z = 30
if z < -30: z = -30
return 1.0 / (1.0 + math.exp(-z))
def clean_matrix(X, means, meds):
# 替换缺失 -> 均值;异常 -> 中位数
n = len(X)
d = len(X[0]) if n else 5
out = []
for i in range(n):
row = []
for j in range(d):
v = X[i][j]
if v is None:
v = means[j]
# 异常替换
if not valid(j, v):
v = meds[j]
row.append(v)
out.append(row)
return out
def main():
data = sys.stdin.read().strip().splitlines()
if not data:
return
it = 0
# 读 N
while it < len(data) and data[it].strip() == "":
it += 1
N = int(data[it].strip()); it += 1
# 读训练集
trainX_raw, trainY = [], []
for _ in range(N):
while it < len(data) and data[it].strip() == "":
it += 1
id_, x, y = parse_line(data[it]); it += 1
trainX_raw.append(x)
trainY.append(y)
# —— 统计每列“有效值”的均值与中位数(仅用训练集的有效值)——
d = 5
means = [0.0] * d
meds = [0.0] * d
for j in range(d):
valid_vals = [row[j] for row in trainX_raw if valid(j, row[j])]
if valid_vals:
means[j] = sum(valid_vals) / len(valid_vals)
# 中位数
s = sorted(valid_vals)
n = len(s)
meds[j] = s[n//2] if n % 2 == 1 else 0.5 * (s[n//2 - 1] + s[n//2])
else:
# 没有任何有效值时的回退
means[j] = 0.0
meds[j] = 0.0
# 清洗训练集
trainX = clean_matrix(trainX_raw, means, meds)
# 读 M
while it < len(data) and data[it].strip() == "":
it += 1
M = int(data[it].strip()); it += 1
# 读测试集(忽略其提供的状态列,仅用于输入格式)
testX_raw = []
for _ in range(M):
while it < len(data) and data[it].strip() == "":
it += 1
parts = [p.strip() for p in data[it].strip().split(',')]
it += 1
feats_raw = parts[1:6] # 取前5个特征
def to_num(s):
if s == "NaN": return None
try:
return float(s)
except:
return None
testX_raw.append([to_num(v) for v in feats_raw])
testX = clean_matrix(testX_raw, means, meds)
# 训练逻辑回归(批量GD)
n = len(trainX)
w = [0.0]*(d+1) # w[0] 为偏置
alpha = 0.01
T = 100
for _ in range(T):
g = [0.0]*(d+1)
for i in range(n):
z = w[0]
for j in range(d):
z += w[j+1] * trainX[i][j]
p = sigmoid(z)
diff = p - trainY[i]
g[0] += diff
for j in range(d):
g[j+1] += diff * trainX[i][j]
# 参数更新(平均梯度)
for k in range(d+1):
w[k] -= alpha * g[k] / n
# 预测
out_lines = []
for i in range(M):
z = w[0]
for j in range(d):
z += w[j+1] * testX[i][j]
p = sigmoid(z)
pred = 1 if p >= 0.5 else 0
out_lines.append(str(pred))
print("\n".join(out_lines))
if __name__ == "__main__":
main()
Java 实现
import java.io.*;
import java.util.*;
public class Main {
static boolean valid(int c, double v) {
if (Double.isNaN(v)) return false;
if (c == 0 || c == 1) return v >= 0; // 次数
if (c == 2 || c == 3) return v >= 0 && v <= 1000; // 延迟
if (c == 4) return v >= 0 && v <= 20; // 年限
return true;
}
static double median(List<Double> a) {
if (a.isEmpty()) return 0.0;
Collections.sort(a);
int n = a.size();
if (n % 2 == 1) return a.get(n/2);
return (a.get(n/2 - 1) + a.get(n/2)) / 2.0;
}
static double sigmoid(double z) {
if (z > 30) z = 30;
if (z < -30) z = -30;
return 1.0 / (1.0 + Math.exp(-z));
}
static Double toNum(String s) {
s = s.trim();
if (s.equals("NaN")) return Double.NaN;
try { return Double.parseDouble(s); }
catch (Exception e) { return Double.NaN; }
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in, "UTF-8"));
List<String> lines = new ArrayList<>();
for (String ln; (ln = br.readLine()) != null; ) lines.add(ln);
int it = 0;
while (it < lines.size() && lines.get(it).trim().isEmpty()) it++;
int N = Integer.parseInt(lines.get(it).trim()); it++;
int d = 5;
List<double[]> trainXraw = new ArrayList<>();
List<Integer> trainY = new ArrayList<>();
// 读取训练数据
for (int i = 0; i < N; i++) {
while (it < lines.size() && lines.get(it).trim().isEmpty()) it++;
String[] parts = lines.get(it).split(",");
it++;
// 特征:ID后前5项;标签:最后一项
Double[] x = new Double[d];
for (int j = 0; j < d; j++) x[j] = toNum(parts[1 + j]);
int y = 0;
try { y = (int)Math.floor(Double.parseDouble(parts[parts.length - 1].trim()) + 1e-9); }
catch (Exception e) { y = 0; }
double[] row = new double[d];
for (int j = 0; j < d; j++) row[j] = x[j] == null ? Double.NaN : x[j];
trainXraw.add(row);
trainY.add(y);
}
double[] mean = new double[d];
double[] med = new double[d];
for (int j = 0; j < d; j++) {
List<Double> validVals = new ArrayList<>();
for (double[] r : trainXraw) {
double v = r[j];
if (!Double.isNaN(v) && valid(j, v)) validVals.add(v);
}
if (!validVals.isEmpty()) {
double s = 0;
for (double v : validVals) s += v;
mean[j] = s / validVals.size();
Collections.sort(validVals);
int n = validVals.size();
med[j] = (n % 2 == 1) ? validVals.get(n/2)
: (validVals.get(n/2 - 1) + validVals.get(n/2)) / 2.0;
} else {
mean[j] = 0.0;
med[j] = 0.0;
}
}
// 清洗训练集
List<double[]> trainX = new ArrayList<>();
for (double[] r : trainXraw) {
double[] t = new double[d];
for (int j = 0; j < d; j++) {
double v = Double.isNaN(r[j]) ? mean[j] : r[j];
if (!valid(j, v)) v = med[j];
t[j] = v;
}
trainX.add(t);
}
// 读取 M
while (it < lines.size() && lines.get(it).trim().isEmpty()) it++;
int M = Integer.parseInt(lines.get(it).trim()); it++;
// 测试集
List<double[]> testXraw = new ArrayList<>();
for (int i = 0; i < M; i++) {
while (it < lines.size() && lines.get(it).trim().isEmpty()) it++;
String[] parts = lines.get(it).split(",");
it++;
double[] r = new double[d];
for (int j = 0; j < d; j++) {
Double v = toNum(parts[1 + j]);
r[j] = (v == null) ? Double.NaN : v;
}
testXraw.add(r);
}
List<double[]> testX = new ArrayList<>();
for (double[] r : testXraw) {
double[] t = new double[d];
for (int j = 0; j < d; j++) {
double v = Double.isNaN(r[j]) ? mean[j] : r[j];
if (!valid(j, v)) v = med[j];
t[j] = v;
}
testX.add(t);
}
// 训练 Logistic 回归(批量GD)
double[] w = new double[d + 1]; // w[0] 偏置
Arrays.fill(w, 0.0);
double alpha = 0.01;
int T = 100;
int n = trainX.size();
for (int t = 0; t < T; t++) {
double[] g = new double[d + 1];
Arrays.fill(g, 0.0);
for (int i = 0; i < n; i++) {
double[] x = trainX.get(i);
double z = w[0];
for (int j = 0; j < d; j++) z += w[j + 1] * x[j];
double p = sigmoid(z);
double diff = p - trainY.get(i);
g[0] += diff;
for (int j = 0; j < d; j++) g[j + 1] += diff * x[j];
}
for (int k = 0; k < d + 1; k++) w[k] -= alpha * g[k] / n;
}
// 预测并输出
StringBuilder sb = new StringBuilder();
for (double[] x : testX) {
double z = w[0];
for (int j = 0; j < d; j++) z += w[j + 1] * x[j];
double p = sigmoid(z);
int pred = p >= 0.5 ? 1 : 0;
sb.append(pred).append('\n');
}
System.out.print(sb.toString());
}
}
C++
#include <bits/stdc++.h>
using namespace std;
bool valid(int c, double v){
if (isnan(v)) return false;
if (c==0 || c==1) return v>=0; // 次数
if (c==2 || c==3) return v>=0 && v<=1000; // 延迟
if (c==4) return v>=0 && v<=20; // 年限
return true;
}
double sigm(double z){
if (z > 30) z = 30;
if (z < -30) z = -30;
return 1.0 / (1.0 + exp(-z));
}
double toNum(const string& s){
if (s=="NaN") return numeric_limits<double>::quiet_NaN();
try { return stod(s); }
catch(...) { return numeric_limits<double>::quiet_NaN(); }
}
vector<string> split(const string& s, char d=','){
vector<string> r; string cur;
for(char c: s){
if(c==d){ r.push_back(string(cur.begin(), find_if(cur.rbegin(), cur.rend(), [](char ch){return !isspace((unsigned char)ch);} ).base()));
size_t l=0; while(l<r.back().size() && isspace((unsigned char)r.back()[l])) l++; r.back()=r.back().substr(l);
cur.clear(); }
else cur.push_back(c);
}
r.push_back(string(cur.begin(), find_if(cur.rbegin(), cur.rend(), [](char ch){return !isspace((unsigned char)ch);} ).base()));
size_t l=0; while(l<r.back().size() && isspace((unsigned char)r.back()[l])) l++; r.back()=r.back().substr(l);
return r;
}
double median(vector<double> v){
if(v.empty()) return 0.0;
sort(v.begin(), v.end());
int n=v.size();
if(n&1) return v[n/2];
return 0.5*(v[n/2-1]+v[n/2]);
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
string line;
auto getline_nonempty = [&](string &out){
while (getline(cin, out)){
// 保留空行?题目输入无空白,保险处理
if (!out.empty()) return true;
}
return false;
};
// 读 N
if(!getline_nonempty(line)) return 0;
int N = stoi(line);
int d = 5;
vector<array<double,5>> trainXraw;
vector<int> trainY;
for(int i=0;i<N;i++){
getline_nonempty(line);
auto parts = split(line, ',');
array<double,5> x{};
for(int j=0;j<d;j++) x[j] = toNum(parts[1+j]);
int y=0;
try { y = (int)floor(stod(parts.back())+1e-9); } catch(...) { y=0; }
trainXraw.push_back(x);
trainY.push_back(y);
}
array<double,5> mean{}, med{};
for (int j = 0; j < d; ++j) {
vector<double> vals;
for (auto &r : trainXraw) {
double v = r[j];
if (!std::isnan(v) && valid(j, v)) vals.push_back(v);
}
if (!vals.empty()) {
double s = accumulate(vals.begin(), vals.end(), 0.0);
mean[j] = s / vals.size();
sort(vals.begin(), vals.end());
int n = (int)vals.size();
med[j] = (n & 1) ? vals[n/2] : 0.5 * (vals[n/2 - 1] + vals[n/2]);
} else {
mean[j] = 0.0;
med[j] = 0.0;
}
}
// 清洗训练
vector<array<double,5>> trainX;
for(auto &r: trainXraw){
array<double,5> t{};
for(int j=0;j<d;j++){
double v = isnan(r[j]) ? mean[j] : r[j];
if(!valid(j, v)) v = med[j];
t[j]=v;
}
trainX.push_back(t);
}
// 读 M
getline_nonempty(line);
int M = stoi(line);
vector<array<double,5>> testX;
for(int i=0;i<M;i++){
getline_nonempty(line);
auto parts = split(line, ',');
array<double,5> r{};
for(int j=0;j<d;j++){
double v = toNum(parts[1+j]);
if(isnan(v)) v = mean[j];
if(!valid(j, v)) v = med[j];
r[j]=v;
}
testX.push_back(r);
}
// 训练 Logistic 回归(批量GD)
vector<double> w(d+1, 0.0); // w[0] 偏置
double alpha = 0.01;
int T = 100;
int n = (int)trainX.size();
for(int it=0; it<T; ++it){
vector<double> g(d+1, 0.0);
for(int i=0;i<n;i++){
double z = w[0];
for(int j=0;j<d;j++) z += w[j+1]*trainX[i][j];
double p = sigm(z);
double diff = p - trainY[i];
g[0] += diff;
for(int j=0;j<d;j++) g[j+1] += diff * trainX[i][j];
}
for(int k=0;k<=d;k++) w[k] -= alpha * g[k] / n;
}
// 预测
for(auto &x: testX){
double z = w[0];
for(int j=0;j<d;j++) z += w[j+1]*x[j];
double p = sigm(z);
int pred = (p>=0.5)? 1:0;
cout << pred << "\n";
}
return 0;
}
题目内容
在云存储系统中,需要预测存储设备故障以提前迁移数据。每条设备日志包含:
设备 ID ,写入次数,读取次数,平均写入延迟 (ms) ,平均读取延迟 (ms) ,设备使用年限(年),设备状态(0正常/1故障)
你需要实现一个设备故障预测系统。包含以下功能:
1、数据清洗:
-
缺失值标记为"NaN",用该字段有效值的均值填充
-
异常值范围:
写入/读取次数:<0
平均写入/读取延迟:<0或>1000
使用年限:<0或>20
异常值用该字段有效值的中位数替换
2、逻辑回归模型:
-
使用批量梯度下降法 (Batch GD) 训练,每次迭代使用全部样本
-
特征:[写入次数,读取次数,平均写入延迟,平均读取延迟,设备使用年限]
-
标签:设备状态
-
参数:迭代 100 次,学习率 α=0.01,初始权重全 0
3、预测输出:
预测结果: 0 (正常)或 1 (故障)
输入描述
第一行为训练总个数 N,(2<=N<=100)
第二行起连续 N 行训练数据,每个训练数据包含:设备ID,写入次数,读取次数,平均写入延迟,平均读取延迟,设备使用年限,状态
第 N+2 行为预测数据总个数 M,(1<=M<=10)
第 N+3 行起连续 M 行预测数据,每个预测数据包含:设备 ID ,写入次数,读取次数,平均写入延迟,平均读取延迟,设备使用年限,状态
输出描述
M 行预测结果
样例1
输入
5
dev1,NaN,-50,NaN,-2.0,25,0
dev2,180,90,18.0,9.0,4,0
dev3,NaN,80,1500.0,800.0,NaN,0
dev4,-100,-50,-5.0,-2.0,-1,0
dev5,200,NaN,20.0,NaN,5,1
2
dev_predict1,80,40,NaN,2.0,2,0
dev_predict2,210,105,18.0,9.8,4,0
输出
0
0
说明
1、预测数据包含缺失值"NaN",需要数据清洗
2、M 值为 2 ,输出分为 2 行,第一行表示“dev_predict1"设备的预测结果为 0 ,第二行表示 “dev_predict2” 设备的预期结果为 0
样例2
输入
3
dev1,100,50,20.1,10.2,2,0
dev2,150,80,25.3,NaN,3,1
dev3,120,60,22.4,15.0,1,0
1
dev_predict1,130,70,21.0,12.0,2,0
输出
1
说明
输出"dev_predict1“设备的预测结果为 1
提示
线性组合 z:
z=w0+∑i=15wixi
概率函数 P(y=1) :
P(y=1)=1+e−z1
预测规则:
