#P3552. 第3题-云存储设备故障预测
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 1269
            Accepted: 133
            Difficulty: 6
            
          
          
          
                       所属公司 : 
                              华为
                                
            
                        
              时间 :2025年9月3日-国内-AI
                              
                      
          
 
- 
                        算法标签>机器学习算法          
 
第3题-云存储设备故障预测
解题思路
1) 数据清洗(按列统计 ➜ 按行替换)
- 
将每条日志按逗号切分:第 1 列为设备ID;最后一列为标签
y∈{0,1};中间前 5 列依次为特征: 写入次数、读取次数、平均写入延迟(ms)、平均读取延迟(ms)、设备使用年限(年)。 (若行里意外多出字段,取前 5 个数值作特征、最后一个作标签,兼容样例1) - 
缺失值填充:若特征中出现字符串
"NaN",视为缺失,用该列有效值的均值填充(仅用训练集估计)。 - 
异常值矫正:按规则判定异常并用该列有效值(在合法区间内)中位数替换:
- 写入/读取次数:< 0
 - 平均写入/读取延迟:< 0 或 > 1000
 - 设备使用年限:< 0 或 > 20
 
 - 
测试集用训练集的均值/中位数进行同样处理,保证一致性。
 
2) 模型与训练
- 
使用带偏置项的逻辑回归(Logistic Regression),损失为对数损失。
 - 
优化:批量梯度下降(Batch GD)
- 学习率 α = 0.01,迭代 100 次;
 - 参数初始化为 0;
 - 梯度:对每次迭代,用全量样本累加梯度再更新。
 
 - 
预测:
sigmoid(z) ≥ 0.5判为 1,否则 0。 
3) 复杂度分析
- 
设训练样本数
N (≤100),特征数d=5,迭代T=100:- 统计均值/中位数:
O(N*d log N)(中位数排序或用选择算法可降到线性) - 训练:
O(T * N * d) - 预测:
O(M * d)(M ≤ 10) 
 - 统计均值/中位数:
 - 
在本题数据范围内,时间与内存都非常充裕。
 
4) 边界与实现细节
- 标签永远取最后一列;特征只取ID 后的前 5 个数值(与样例1兼容)。
 - 若某列“有效值”为空(极端情况),中位数回退为该列均值,再不行则 0。
 - 数值转换时忽略空格;
"NaN"(大小写敏感)按缺失处理。 sigmoid计算做简单溢出保护(例如截断z)。
参考实现
Python 实现
import sys, math
def parse_line(line):
    parts = [p.strip() for p in line.strip().split(',')]
    if not parts: return None
    id_ = parts[0]
    if len(parts) < 7:
        # 不足字段,直接跳过(题面不会出现)
        return None
    # 特征:紧跟在ID后的前5个数值;标签:最后一个
    feats_raw = parts[1:6]
    y_raw = parts[-1]
    def to_num(s):
        if s == "NaN": return None
        try:
            return float(s)
        except:
            return None
    x = [to_num(v) for v in feats_raw]
    # 标签按最后一列,容忍浮点写法
    y = 0
    try:
        y = int(float(y_raw))
    except:
        y = 0
    return id_, x, y
# 合法区间判断
def valid(col, v):
    if v is None: return False
    if col in (0,1):  # 写/读次数
        return v >= 0
    if col in (2,3):  # 延迟
        return 0 <= v <= 1000
    if col == 4:      # 年限
        return 0 <= v <= 20
    return True
def median(vals):
    n = len(vals)
    if n == 0: return 0.0
    vals2 = sorted(vals)
    mid = n // 2
    if n % 2 == 1:
        return vals2[mid]
    else:
        return 0.5 * (vals2[mid - 1] + vals2[mid])
def sigmoid(z):
    # 简单数值稳定
    if z > 30: z = 30
    if z < -30: z = -30
    return 1.0 / (1.0 + math.exp(-z))
def clean_matrix(X, means, meds):
    # 替换缺失 -> 均值;异常 -> 中位数
    n = len(X)
    d = len(X[0]) if n else 5
    out = []
    for i in range(n):
        row = []
        for j in range(d):
            v = X[i][j]
            if v is None:
                v = means[j]
            # 异常替换
            if not valid(j, v):
                v = meds[j]
            row.append(v)
        out.append(row)
    return out
def main():
    data = sys.stdin.read().strip().splitlines()
    if not data:
        return
    it = 0
    # 读 N
    while it < len(data) and data[it].strip() == "":
        it += 1
    N = int(data[it].strip()); it += 1
    # 读训练集
    trainX_raw, trainY = [], []
    for _ in range(N):
        while it < len(data) and data[it].strip() == "":
            it += 1
        id_, x, y = parse_line(data[it]); it += 1
        trainX_raw.append(x)
        trainY.append(y)
    # —— 统计每列“有效值”的均值与中位数(仅用训练集的有效值)——
    d = 5
    means = [0.0] * d
    meds  = [0.0] * d
    
    for j in range(d):
        valid_vals = [row[j] for row in trainX_raw if valid(j, row[j])]
        if valid_vals:
            means[j] = sum(valid_vals) / len(valid_vals)
            # 中位数
            s = sorted(valid_vals)
            n = len(s)
            meds[j] = s[n//2] if n % 2 == 1 else 0.5 * (s[n//2 - 1] + s[n//2])
        else:
            # 没有任何有效值时的回退
            means[j] = 0.0
            meds[j]  = 0.0
    # 清洗训练集
    trainX = clean_matrix(trainX_raw, means, meds)
    # 读 M
    while it < len(data) and data[it].strip() == "":
        it += 1
    M = int(data[it].strip()); it += 1
    # 读测试集(忽略其提供的状态列,仅用于输入格式)
    testX_raw = []
    for _ in range(M):
        while it < len(data) and data[it].strip() == "":
            it += 1
        parts = [p.strip() for p in data[it].strip().split(',')]
        it += 1
        feats_raw = parts[1:6]  # 取前5个特征
        def to_num(s):
            if s == "NaN": return None
            try:
                return float(s)
            except:
                return None
        testX_raw.append([to_num(v) for v in feats_raw])
    testX = clean_matrix(testX_raw, means, meds)
    # 训练逻辑回归(批量GD)
    n = len(trainX)
    w = [0.0]*(d+1)  # w[0] 为偏置
    alpha = 0.01
    T = 100
    for _ in range(T):
        g = [0.0]*(d+1)
        for i in range(n):
            z = w[0]
            for j in range(d):
                z += w[j+1] * trainX[i][j]
            p = sigmoid(z)
            diff = p - trainY[i]
            g[0] += diff
            for j in range(d):
                g[j+1] += diff * trainX[i][j]
        # 参数更新(平均梯度)
        for k in range(d+1):
            w[k] -= alpha * g[k] / n
    # 预测
    out_lines = []
    for i in range(M):
        z = w[0]
        for j in range(d):
            z += w[j+1] * testX[i][j]
        p = sigmoid(z)
        pred = 1 if p >= 0.5 else 0
        out_lines.append(str(pred))
    print("\n".join(out_lines))
if __name__ == "__main__":
    main()
Java 实现
import java.io.*;
import java.util.*;
public class Main {
    static boolean valid(int c, double v) {
        if (Double.isNaN(v)) return false;
        if (c == 0 || c == 1) return v >= 0;            // 次数
        if (c == 2 || c == 3) return v >= 0 && v <= 1000; // 延迟
        if (c == 4) return v >= 0 && v <= 20;           // 年限
        return true;
    }
    static double median(List<Double> a) {
        if (a.isEmpty()) return 0.0;
        Collections.sort(a);
        int n = a.size();
        if (n % 2 == 1) return a.get(n/2);
        return (a.get(n/2 - 1) + a.get(n/2)) / 2.0;
    }
    static double sigmoid(double z) {
        if (z > 30) z = 30;
        if (z < -30) z = -30;
        return 1.0 / (1.0 + Math.exp(-z));
    }
    static Double toNum(String s) {
        s = s.trim();
        if (s.equals("NaN")) return Double.NaN;
        try { return Double.parseDouble(s); }
        catch (Exception e) { return Double.NaN; }
    }
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in, "UTF-8"));
        List<String> lines = new ArrayList<>();
        for (String ln; (ln = br.readLine()) != null; ) lines.add(ln);
        int it = 0;
        while (it < lines.size() && lines.get(it).trim().isEmpty()) it++;
        int N = Integer.parseInt(lines.get(it).trim()); it++;
        int d = 5;
        List<double[]> trainXraw = new ArrayList<>();
        List<Integer> trainY = new ArrayList<>();
        // 读取训练数据
        for (int i = 0; i < N; i++) {
            while (it < lines.size() && lines.get(it).trim().isEmpty()) it++;
            String[] parts = lines.get(it).split(",");
            it++;
            // 特征:ID后前5项;标签:最后一项
            Double[] x = new Double[d];
            for (int j = 0; j < d; j++) x[j] = toNum(parts[1 + j]);
            int y = 0;
            try { y = (int)Math.floor(Double.parseDouble(parts[parts.length - 1].trim()) + 1e-9); }
            catch (Exception e) { y = 0; }
            double[] row = new double[d];
            for (int j = 0; j < d; j++) row[j] = x[j] == null ? Double.NaN : x[j];
            trainXraw.add(row);
            trainY.add(y);
        }
        double[] mean = new double[d];
        double[] med  = new double[d];
        
        for (int j = 0; j < d; j++) {
            List<Double> validVals = new ArrayList<>();
            for (double[] r : trainXraw) {
                double v = r[j];
                if (!Double.isNaN(v) && valid(j, v)) validVals.add(v);
            }
            if (!validVals.isEmpty()) {
                double s = 0;
                for (double v : validVals) s += v;
                mean[j] = s / validVals.size();
                Collections.sort(validVals);
                int n = validVals.size();
                med[j] = (n % 2 == 1) ? validVals.get(n/2)
                                      : (validVals.get(n/2 - 1) + validVals.get(n/2)) / 2.0;
            } else {
                mean[j] = 0.0;
                med[j]  = 0.0;
            }
        }
        // 清洗训练集
        List<double[]> trainX = new ArrayList<>();
        for (double[] r : trainXraw) {
            double[] t = new double[d];
            for (int j = 0; j < d; j++) {
                double v = Double.isNaN(r[j]) ? mean[j] : r[j];
                if (!valid(j, v)) v = med[j];
                t[j] = v;
            }
            trainX.add(t);
        }
        // 读取 M
        while (it < lines.size() && lines.get(it).trim().isEmpty()) it++;
        int M = Integer.parseInt(lines.get(it).trim()); it++;
        // 测试集
        List<double[]> testXraw = new ArrayList<>();
        for (int i = 0; i < M; i++) {
            while (it < lines.size() && lines.get(it).trim().isEmpty()) it++;
            String[] parts = lines.get(it).split(",");
            it++;
            double[] r = new double[d];
            for (int j = 0; j < d; j++) {
                Double v = toNum(parts[1 + j]);
                r[j] = (v == null) ? Double.NaN : v;
            }
            testXraw.add(r);
        }
        List<double[]> testX = new ArrayList<>();
        for (double[] r : testXraw) {
            double[] t = new double[d];
            for (int j = 0; j < d; j++) {
                double v = Double.isNaN(r[j]) ? mean[j] : r[j];
                if (!valid(j, v)) v = med[j];
                t[j] = v;
            }
            testX.add(t);
        }
        // 训练 Logistic 回归(批量GD)
        double[] w = new double[d + 1]; // w[0] 偏置
        Arrays.fill(w, 0.0);
        double alpha = 0.01;
        int T = 100;
        int n = trainX.size();
        for (int t = 0; t < T; t++) {
            double[] g = new double[d + 1];
            Arrays.fill(g, 0.0);
            for (int i = 0; i < n; i++) {
                double[] x = trainX.get(i);
                double z = w[0];
                for (int j = 0; j < d; j++) z += w[j + 1] * x[j];
                double p = sigmoid(z);
                double diff = p - trainY.get(i);
                g[0] += diff;
                for (int j = 0; j < d; j++) g[j + 1] += diff * x[j];
            }
            for (int k = 0; k < d + 1; k++) w[k] -= alpha * g[k] / n;
        }
        // 预测并输出
        StringBuilder sb = new StringBuilder();
        for (double[] x : testX) {
            double z = w[0];
            for (int j = 0; j < d; j++) z += w[j + 1] * x[j];
            double p = sigmoid(z);
            int pred = p >= 0.5 ? 1 : 0;
            sb.append(pred).append('\n');
        }
        System.out.print(sb.toString());
    }
}
C++
#include <bits/stdc++.h>
using namespace std;
bool valid(int c, double v){
    if (isnan(v)) return false;
    if (c==0 || c==1) return v>=0;              // 次数
    if (c==2 || c==3) return v>=0 && v<=1000;   // 延迟
    if (c==4) return v>=0 && v<=20;             // 年限
    return true;
}
double sigm(double z){
    if (z > 30) z = 30;
    if (z < -30) z = -30;
    return 1.0 / (1.0 + exp(-z));
}
double toNum(const string& s){
    if (s=="NaN") return numeric_limits<double>::quiet_NaN();
    try { return stod(s); }
    catch(...) { return numeric_limits<double>::quiet_NaN(); }
}
vector<string> split(const string& s, char d=','){
    vector<string> r; string cur;
    for(char c: s){
        if(c==d){ r.push_back(string(cur.begin(), find_if(cur.rbegin(), cur.rend(), [](char ch){return !isspace((unsigned char)ch);} ).base()));
                   size_t l=0; while(l<r.back().size() && isspace((unsigned char)r.back()[l])) l++; r.back()=r.back().substr(l);
                   cur.clear(); }
        else cur.push_back(c);
    }
    r.push_back(string(cur.begin(), find_if(cur.rbegin(), cur.rend(), [](char ch){return !isspace((unsigned char)ch);} ).base()));
    size_t l=0; while(l<r.back().size() && isspace((unsigned char)r.back()[l])) l++; r.back()=r.back().substr(l);
    return r;
}
double median(vector<double> v){
    if(v.empty()) return 0.0;
    sort(v.begin(), v.end());
    int n=v.size();
    if(n&1) return v[n/2];
    return 0.5*(v[n/2-1]+v[n/2]);
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    string line;
    auto getline_nonempty = [&](string &out){
        while (getline(cin, out)){
            // 保留空行?题目输入无空白,保险处理
            if (!out.empty()) return true;
        }
        return false;
    };
    // 读 N
    if(!getline_nonempty(line)) return 0;
    int N = stoi(line);
    int d = 5;
    vector<array<double,5>> trainXraw;
    vector<int> trainY;
    for(int i=0;i<N;i++){
        getline_nonempty(line);
        auto parts = split(line, ',');
        array<double,5> x{};
        for(int j=0;j<d;j++) x[j] = toNum(parts[1+j]);
        int y=0;
        try { y = (int)floor(stod(parts.back())+1e-9); } catch(...) { y=0; }
        trainXraw.push_back(x);
        trainY.push_back(y);
    }
    array<double,5> mean{}, med{};
    for (int j = 0; j < d; ++j) {
        vector<double> vals;
        for (auto &r : trainXraw) {
            double v = r[j];
            if (!std::isnan(v) && valid(j, v)) vals.push_back(v);
        }
        if (!vals.empty()) {
            double s = accumulate(vals.begin(), vals.end(), 0.0);
            mean[j] = s / vals.size();
            sort(vals.begin(), vals.end());
            int n = (int)vals.size();
            med[j] = (n & 1) ? vals[n/2] : 0.5 * (vals[n/2 - 1] + vals[n/2]);
        } else {
            mean[j] = 0.0;
            med[j]  = 0.0;
        }
    }
    // 清洗训练
    vector<array<double,5>> trainX;
    for(auto &r: trainXraw){
        array<double,5> t{};
        for(int j=0;j<d;j++){
            double v = isnan(r[j]) ? mean[j] : r[j];
            if(!valid(j, v)) v = med[j];
            t[j]=v;
        }
        trainX.push_back(t);
    }
    // 读 M
    getline_nonempty(line);
    int M = stoi(line);
    vector<array<double,5>> testX;
    for(int i=0;i<M;i++){
        getline_nonempty(line);
        auto parts = split(line, ',');
        array<double,5> r{};
        for(int j=0;j<d;j++){
            double v = toNum(parts[1+j]);
            if(isnan(v)) v = mean[j];
            if(!valid(j, v)) v = med[j];
            r[j]=v;
        }
        testX.push_back(r);
    }
    // 训练 Logistic 回归(批量GD)
    vector<double> w(d+1, 0.0); // w[0] 偏置
    double alpha = 0.01;
    int T = 100;
    int n = (int)trainX.size();
    for(int it=0; it<T; ++it){
        vector<double> g(d+1, 0.0);
        for(int i=0;i<n;i++){
            double z = w[0];
            for(int j=0;j<d;j++) z += w[j+1]*trainX[i][j];
            double p = sigm(z);
            double diff = p - trainY[i];
            g[0] += diff;
            for(int j=0;j<d;j++) g[j+1] += diff * trainX[i][j];
        }
        for(int k=0;k<=d;k++) w[k] -= alpha * g[k] / n;
    }
    // 预测
    for(auto &x: testX){
        double z = w[0];
        for(int j=0;j<d;j++) z += w[j+1]*x[j];
        double p = sigm(z);
        int pred = (p>=0.5)? 1:0;
        cout << pred << "\n";
    }
    return 0;
}
        题目内容
在云存储系统中,需要预测存储设备故障以提前迁移数据。每条设备日志包含:
设备 ID ,写入次数,读取次数,平均写入延迟 (ms) ,平均读取延迟 (ms) ,设备使用年限(年),设备状态(0正常/1故障)
你需要实现一个设备故障预测系统。包含以下功能:
1、数据清洗:
- 
缺失值标记为"NaN",用该字段有效值的均值填充
 - 
异常值范围:
写入/读取次数:<0
平均写入/读取延迟:<0或>1000
使用年限:<0或>20
异常值用该字段有效值的中位数替换
 
2、逻辑回归模型:
- 
使用批量梯度下降法 (Batch GD) 训练,每次迭代使用全部样本
 - 
特征:[写入次数,读取次数,平均写入延迟,平均读取延迟,设备使用年限]
 - 
标签:设备状态
 - 
参数:迭代 100 次,学习率 α=0.01,初始权重全 0
 
3、预测输出:
预测结果: 0 (正常)或 1 (故障)
输入描述
第一行为训练总个数 N,(2<=N<=100)
第二行起连续 N 行训练数据,每个训练数据包含:设备ID,写入次数,读取次数,平均写入延迟,平均读取延迟,设备使用年限,状态
第 N+2 行为预测数据总个数 M,(1<=M<=10)
第 N+3 行起连续 M 行预测数据,每个预测数据包含:设备 ID ,写入次数,读取次数,平均写入延迟,平均读取延迟,设备使用年限,状态
输出描述
M 行预测结果
样例1
输入
5
dev1,NaN,-50,NaN,-2.0,25,0
dev2,180,90,18.0,9.0,4,0
dev3,NaN,80,1500.0,800.0,NaN,0
dev4,-100,-50,-5.0,-2.0,-1,0
dev5,200,NaN,20.0,NaN,5,1
2
dev_predict1,80,40,NaN,2.0,2,0
dev_predict2,210,105,18.0,9.8,4,0
输出
0
0
说明
1、预测数据包含缺失值"NaN",需要数据清洗
2、M 值为 2 ,输出分为 2 行,第一行表示“dev_predict1"设备的预测结果为 0 ,第二行表示 “dev_predict2” 设备的预期结果为 0
样例2
输入
3
dev1,100,50,20.1,10.2,2,0
dev2,150,80,25.3,NaN,3,1
dev3,120,60,22.4,15.0,1,0
1
dev_predict1,130,70,21.0,12.0,2,0
输出
1
说明
输出"dev_predict1“设备的预测结果为 1
提示
线性组合 z:
z=w0+∑i=15wixi
概率函数 P(y=1) :
P(y=1)=1+e−z1
预测规则:
