#P4447. 第2题-医疗诊断模型的训练与更新
-
1000ms
Tried: 15
Accepted: 9
Difficulty: 5
所属公司 :
华为
时间 :2025年11月6日-留学生AI方向
-
算法标签>机器学习算法
第2题-医疗诊断模型的训练与更新
解题思路
该题要求实现一个极简的序列分类器:对长度为 L 的症状序列(每步维度 D)进行两层线性映射(均为无偏置的 MLP 层),得到每步的 K 维预测,再对“序列维度”取平均作为最终 K 类输出;损失函数为 MSE。随后基于单样本进行一次 SGD 更新,输出更新后的两层权重。
- 模型与前向传播
-
记第 t 步输入为向量 xt∈RD。
-
第一层 MLP(无偏置):
$$h_t \;=\; x_t W_{\text{mlp}},\quad W_{\text{mlp}}\in\mathbb{R}^{D\times D} $$ -
分类层(无偏置):
$$p_t \;=\; h_t W_{\text{cls}},\quad W_{\text{cls}}\in\mathbb{R}^{D\times K} $$ -
序列平均作为最终输出(无需 softmax):
$$\hat{y} \;=\; \frac{1}{L}\sum_{t=1}^{L} p_t \;\in\; \mathbb{R}^{K} $$
- 损失函数 采用均方误差:
- 反向传播与梯度
-
令
$$g \;=\; \frac{\partial L}{\partial \hat{y}} \;=\; \frac{2}{K}(\hat{y}-y) \;\in\; \mathbb{R}^{K} $$ -
因 y^=L1∑tpt,有
$$\frac{\partial L}{\partial p_t} \;=\; \frac{1}{L}g \quad(\forall t) $$ -
对分类层:
$$\frac{\partial L}{\partial W_{\text{cls}}} \;=\; \sum_{t} h_t^\top \left(\frac{1}{L}g\right) \;=\; \left(\frac{1}{L}\sum_t h_t\right)^\top g \;\;\in\mathbb{R}^{D\times K} $$记 hˉ=L1∑tht,则 ∂L/∂Wcls=hˉ⊤g(外积)。
-
对第一层:
$$\frac{\partial L}{\partial h_t} \;=\; \left(\frac{1}{L}g\right) W_{\text{cls}}^\top \;=\; \frac{1}{L}\, (W_{\text{cls}}g) \;\in\; \mathbb{R}^{D} $$$$\frac{\partial L}{\partial W_{\text{mlp}}} \;=\; \sum_t x_t^\top \left(\frac{1}{L}\,W_{\text{cls}}g\right) \;=\; \left(\frac{1}{L}\sum_t x_t\right)^\top (W_{\text{cls}}g) \;\;\in\mathbb{R}^{D\times D} $$记 xˉ=L1∑txt,v=Wclsg,则 ∂L/∂Wmlp=xˉ⊤v(外积)。
- SGD 更新
- 输出格式
- 第 1 行:y^(K 个,保留 2 位小数,逗号分隔)
- 第 2 行:MSE(1 个,保留 2 位小数)
- 第 3 行:更新后的 Wmlp(按行展开 D×D,保留 2 位)
- 第 4 行:更新后的 Wcls(按行展开 D×K,保留 2 位)
复杂度分析
- 前向:计算所有 ht 与 pt: O(L⋅D2)(第一层) + O(L⋅D⋅K)(分类层)。
- 反向: 计算 g:O(K); 计算 hˉ:O(L⋅D); 计算 ∂L/∂Wcls:O(D⋅K); 计算 v=Wclsg:O(D⋅K); 计算 xˉ:O(L⋅D); 计算 ∂L/∂Wmlp:O(D2)。
- 总体时间复杂度:O(L⋅D2+L⋅D⋅K),在题目约束下(均 ≤10)非常合适。
- 空间复杂度:存储输入与权重及中间量,主要为 O(L⋅D+D2+D⋅K)。
代码实现
Python
import sys
import ast
# 将一行形如 "1,2,3" 的输入安全解析为列表
def parse_line(line: str):
return list(ast.literal_eval("[" + line.strip() + "]"))
# 前向计算与一次SGD更新,返回(y_hat, loss, Wmlp_new, Wcls_new)
def solve_once(L, D, K, eta, y_true, seq_flat, Wmlp_flat, Wcls_flat):
# 还原形状
X = [seq_flat[i*D:(i+1)*D] for i in range(L)] # L x D
Wmlp = [Wmlp_flat[i*D:(i+1)*D] for i in range(D)] # D x D(按行)
Wcls = [Wcls_flat[i*K:(i+1)*K] for i in range(D)] # D x K(按行)
# 前向:h_t = x_t @ Wmlp,p_t = h_t @ Wcls
H_sum = [0.0]*D
P_avg = [0.0]*K
for t in range(L):
x = X[t] # 长度D
# h = x @ Wmlp
h = [0.0]*D
for j in range(D):
s = 0.0
for d in range(D):
s += x[d] * Wmlp[d][j]
h[j] = s
# 累加 H_sum
for j in range(D):
H_sum[j] += h[j]
# p = h @ Wcls
p = [0.0]*K
for k in range(K):
s = 0.0
for j in range(D):
s += h[j] * Wcls[j][k]
p[k] = s
# 累加到平均(最后再除以L)
for k in range(K):
P_avg[k] += p[k]
P_avg = [v / L for v in P_avg] # 预测 \hat{y}
# 计算 MSE
loss = 0.0
for k in range(K):
diff = (P_avg[k] - y_true[k])
loss += diff * diff
loss /= K
# 反向:g = (2/K) * (y_hat - y_true)
g = [(2.0 / K) * (P_avg[k] - y_true[k]) for k in range(K)]
# dL/dWcls = ( (sum_t h_t)/L )^T * g (外积)
H_bar = [v / L for v in H_sum] # D
dWcls = [[H_bar[j] * g[k] for k in range(K)] for j in range(D)]
# v = Wcls @ g (D)
v = [0.0]*D
for j in range(D):
s = 0.0
for k in range(K):
s += Wcls[j][k] * g[k]
v[j] = s
# dL/dWmlp = ( (sum_t x_t)/L )^T * v (外积)
X_sum = [0.0]*D
for t in range(L):
for d in range(D):
X_sum[d] += X[t][d]
X_bar = [v_ / L for v_ in X_sum] # D
dWmlp = [[X_bar[i] * v[j] for j in range(D)] for i in range(D)]
# SGD 更新
for i in range(D):
for j in range(D):
Wmlp[i][j] -= eta * dWmlp[i][j]
for j in range(D):
for k in range(K):
Wcls[j][k] -= eta * dWcls[j][k]
# 展平权重并返回
Wmlp_new = [Wmlp[i][j] for i in range(D) for j in range(D)]
Wcls_new = [Wcls[j][k] for j in range(D) for k in range(K)]
return P_avg, loss, Wmlp_new, Wcls_new
def fmt_line(arr):
return ",".join(f"{x:.2f}" for x in arr)
def main():
lines = [line for line in sys.stdin if line.strip() != ""]
# 读取五行
L, D, K, eta = parse_line(lines[0])
y_true = parse_line(lines[1])
seq_flat = parse_line(lines[2])
Wmlp_flat = parse_line(lines[3])
Wcls_flat = parse_line(lines[4])
y_hat, loss, Wmlp_new, Wcls_new = solve_once(L, D, K, eta, y_true, seq_flat, Wmlp_flat, Wcls_flat)
print(fmt_line(y_hat))
print(f"{loss:.2f}")
print(fmt_line(Wmlp_new))
print(fmt_line(Wcls_new))
if __name__ == "__main__":
main()
Java
import java.io.*;
import java.util.*;
/**
* ACM风格主类:读取输入,计算一次前向、MSE与SGD更新,输出结果
*/
public class Main {
// 将一行形如 "1,2,3" 的字符串解析为 double 列表
static double[] parseLine(String line) {
String s = line.trim().replace(",", " ");
String[] parts = s.split("\\s+");
double[] res = new double[parts.length];
for (int i = 0; i < parts.length; i++) res[i] = Double.parseDouble(parts[i]);
return res;
}
// 外积:a(D) ^T * b(M) -> D x M
static double[][] outer(double[] a, double[] b) {
int D = a.length, M = b.length;
double[][] out = new double[D][M];
for (int i = 0; i < D; i++) {
for (int j = 0; j < M; j++) {
out[i][j] = a[i] * b[j];
}
}
return out;
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
// 读取五行
double[] line1 = parseLine(br.readLine());
int L = (int) line1[0];
int D = (int) line1[1];
int K = (int) line1[2];
double eta = line1[3];
double[] yTrue = parseLine(br.readLine());
double[] seqFlat = parseLine(br.readLine());
double[] WmlpFlat = parseLine(br.readLine());
double[] WclsFlat = parseLine(br.readLine());
// 还原矩阵
double[][] X = new double[L][D];
for (int i = 0; i < L; i++) {
System.arraycopy(seqFlat, i * D, X[i], 0, D);
}
double[][] Wmlp = new double[D][D];
for (int i = 0; i < D; i++) {
System.arraycopy(WmlpFlat, i * D, Wmlp[i], 0, D);
}
double[][] Wcls = new double[D][K];
for (int i = 0; i < D; i++) {
System.arraycopy(WclsFlat, i * K, Wcls[i], 0, K);
}
// 前向
double[] Hsum = new double[D];
double[] Pavg = new double[K];
for (int t = 0; t < L; t++) {
// h = x @ Wmlp
double[] h = new double[D];
for (int j = 0; j < D; j++) {
double s = 0.0;
for (int d0 = 0; d0 < D; d0++) s += X[t][d0] * Wmlp[d0][j];
h[j] = s;
}
for (int j = 0; j < D; j++) Hsum[j] += h[j];
// p = h @ Wcls
double[] p = new double[K];
for (int k = 0; k < K; k++) {
double s = 0.0;
for (int j = 0; j < D; j++) s += h[j] * Wcls[j][k];
p[k] = s;
}
for (int k = 0; k < K; k++) Pavg[k] += p[k];
}
for (int k = 0; k < K; k++) Pavg[k] /= L;
// MSE
double loss = 0.0;
for (int k = 0; k < K; k++) {
double diff = Pavg[k] - yTrue[k];
loss += diff * diff;
}
loss /= K;
// g = (2/K)*(y_hat - y_true)
double[] g = new double[K];
for (int k = 0; k < K; k++) g[k] = (2.0 / K) * (Pavg[k] - yTrue[k]);
// dWcls = (Hsum/L)^T * g
double[] Hbar = new double[D];
for (int j = 0; j < D; j++) Hbar[j] = Hsum[j] / L;
double[][] dWcls = outer(Hbar, g);
// v = Wcls @ g (D)
double[] v = new double[D];
for (int j = 0; j < D; j++) {
double s = 0.0;
for (int k = 0; k < K; k++) s += Wcls[j][k] * g[k];
v[j] = s;
}
// dWmlp = (Xsum/L)^T * v
double[] Xsum = new double[D];
for (int t = 0; t < L; t++) for (int d0 = 0; d0 < D; d0++) Xsum[d0] += X[t][d0];
double[] Xbar = new double[D];
for (int d0 = 0; d0 < D; d0++) Xbar[d0] = Xsum[d0] / L;
double[][] dWmlp = outer(Xbar, v);
// SGD更新
for (int i = 0; i < D; i++) for (int j = 0; j < D; j++) Wmlp[i][j] -= eta * dWmlp[i][j];
for (int j = 0; j < D; j++) for (int k = 0; k < K; k++) Wcls[j][k] -= eta * dWcls[j][k];
// 输出
StringBuilder sb = new StringBuilder();
// 第1行:预测
for (int k = 0; k < K; k++) {
sb.append(String.format(java.util.Locale.US, "%.2f", Pavg[k]));
if (k + 1 < K) sb.append(",");
}
System.out.println(sb.toString());
// 第2行:MSE
System.out.println(String.format(java.util.Locale.US, "%.2f", loss));
// 第3行:Wmlp(行优先)
sb.setLength(0);
for (int i = 0; i < D; i++) {
for (int j = 0; j < D; j++) {
sb.append(String.format(java.util.Locale.US, "%.2f", Wmlp[i][j]));
if (!(i == D - 1 && j == D - 1)) sb.append(",");
}
}
System.out.println(sb.toString());
// 第4行:Wcls(行优先,D×K)
sb.setLength(0);
for (int j = 0; j < D; j++) {
for (int k = 0; k < K; k++) {
sb.append(String.format(java.util.Locale.US, "%.2f", Wcls[j][k]));
if (!(j == D - 1 && k == K - 1)) sb.append(",");
}
}
System.out.println(sb.toString());
}
}
C++
#include <bits/stdc++.h>
using namespace std;
// 解析一行形如 "1,2,3" 为 double 数组
static vector<double> parseLine(const string& line) {
string s;
s.reserve(line.size());
for (char c : line) s.push_back(c == ',' ? ' ' : c);
stringstream ss(s);
vector<double> res; double x;
while (ss >> x) res.push_back(x);
return res;
}
// 外积:a(D) ^T * b(M) -> D x M
static vector<vector<double>> outer(const vector<double>& a, const vector<double>& b) {
int D = (int)a.size(), M = (int)b.size();
vector<vector<double>> out(D, vector<double>(M, 0.0));
for (int i = 0; i < D; ++i)
for (int j = 0; j < M; ++j)
out[i][j] = a[i] * b[j];
return out;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
string line;
// 读取五行
getline(cin, line);
vector<double> v1 = parseLine(line);
int L = (int)v1[0], D = (int)v1[1], K = (int)v1[2];
double eta = v1[3];
getline(cin, line);
vector<double> yTrue = parseLine(line);
getline(cin, line);
vector<double> seqFlat = parseLine(line);
getline(cin, line);
vector<double> WmlpFlat = parseLine(line);
getline(cin, line);
vector<double> WclsFlat = parseLine(line);
// 还原矩阵
vector<vector<double>> X(L, vector<double>(D, 0.0));
for (int i = 0; i < L; ++i)
for (int j = 0; j < D; ++j)
X[i][j] = seqFlat[i*D + j];
vector<vector<double>> Wmlp(D, vector<double>(D, 0.0));
for (int i = 0; i < D; ++i)
for (int j = 0; j < D; ++j)
Wmlp[i][j] = WmlpFlat[i*D + j];
vector<vector<double>> Wcls(D, vector<double>(K, 0.0));
for (int i = 0; i < D; ++i)
for (int k = 0; k < K; ++k)
Wcls[i][k] = WclsFlat[i*K + k];
// 前向
vector<double> Hsum(D, 0.0), Pavg(K, 0.0);
for (int t = 0; t < L; ++t) {
// h = x @ Wmlp
vector<double> h(D, 0.0);
for (int j = 0; j < D; ++j) {
double s = 0.0;
for (int d0 = 0; d0 < D; ++d0) s += X[t][d0] * Wmlp[d0][j];
h[j] = s;
}
for (int j = 0; j < D; ++j) Hsum[j] += h[j];
// p = h @ Wcls
vector<double> p(K, 0.0);
for (int k = 0; k < K; ++k) {
double s = 0.0;
for (int j = 0; j < D; ++j) s += h[j] * Wcls[j][k];
p[k] = s;
}
for (int k = 0; k < K; ++k) Pavg[k] += p[k];
}
for (int k = 0; k < K; ++k) Pavg[k] /= L;
// MSE
double loss = 0.0;
for (int k = 0; k < K; ++k) {
double diff = Pavg[k] - yTrue[k];
loss += diff * diff;
}
loss /= K;
// g = (2/K)*(y_hat - y_true)
vector<double> g(K, 0.0);
for (int k = 0; k < K; ++k) g[k] = (2.0 / K) * (Pavg[k] - yTrue[k]);
// dWcls = (Hsum/L)^T * g
vector<double> Hbar(D, 0.0);
for (int j = 0; j < D; ++j) Hbar[j] = Hsum[j] / L;
vector<vector<double>> dWcls = outer(Hbar, g);
// v = Wcls @ g
vector<double> v(D, 0.0);
for (int j = 0; j < D; ++j) {
double s = 0.0;
for (int k = 0; k < K; ++k) s += Wcls[j][k] * g[k];
v[j] = s;
}
// dWmlp = (Xsum/L)^T * v
vector<double> Xsum(D, 0.0), Xbar(D, 0.0);
for (int t = 0; t < L; ++t)
for (int d0 = 0; d0 < D; ++d0)
Xsum[d0] += X[t][d0];
for (int d0 = 0; d0 < D; ++d0) Xbar[d0] = Xsum[d0] / L;
vector<vector<double>> dWmlp = outer(Xbar, v);
// SGD 更新
for (int i = 0; i < D; ++i)
for (int j = 0; j < D; ++j)
Wmlp[i][j] -= eta * dWmlp[i][j];
for (int j = 0; j < D; ++j)
for (int k = 0; k < K; ++k)
Wcls[j][k] -= eta * dWcls[j][k];
// 输出:四行
cout.setf(std::ios::fixed); cout<<setprecision(2);
// 第1行:预测
for (int k = 0; k < K; ++k) {
if (k) cout << ",";
cout << Pavg[k];
}
cout << "\n";
// 第2行:MSE
cout << loss << "\n";
// 第3行:Wmlp(行优先)
for (int i = 0; i < D; ++i) {
for (int j = 0; j < D; ++j) {
if (i || j) cout << ",";
cout << Wmlp[i][j];
}
}
cout << "\n";
// 第4行:Wcls(行优先)
for (int j = 0; j < D; ++j) {
for (int k = 0; k < K; ++k) {
if (j || k) cout << ",";
cout << Wcls[j][k];
}
}
cout << "\n";
return 0;
}
题目内容
某智能医疗平台正在研发一套基于人工智能的自动疾病辅助诊断系统。例如,该系统通过对患者多次填写的症状问卷数据进行分析,帮助医生快速判断患者属子健康、感冒还是肺炎三类之一、每位患者在就诊前需填写一个包含多个症状的问题序列(如咳歌、发热、咽痛等);每条问卷的症状项被嵌入为特征向量,形成一个长度为 L 的症状序列,每个症状的特征维度为 D 。这些离散症状特征经过预处理后输入到诊断系统中。系统采用一层 MLP 进行特征映射,再使用一层 MLP 作为分类器输出各症状的预测概率,为简化考虑,输出率无需进行 sottmax 归一化。同时,MLP 层也无偏置项。请实现以下输出:
-
前向推理:输出预测概率( K 个,例如 K=3 时表示分类为健康/感冒/肺炎的概率),并取症状维度的平均值作为输出
-
LOSS 计算:输出 MSE 损失 LOSS ; 定义为,$L_{m s c}=\frac{1}{K} \sum_{i=1}^{K}\left(y_{i}-\hat{y}_{i}\right)^{2}$
其中 K 为类别数, y1 为真实概率, y^i2 表示预测械率。
-
权重更新:输出单次反向传播后的权重。更新采用 SGD 优化器,定义为: $W_{\text {new }}=W_{\text {old }}-\eta \nabla_{w} L$
其中 η 为学习率。
输入描述
第 1 行:序列长度 L∈[1,10]、特征维度 D∈[1,10]、分类数 K∈[2,5]、学习率 η∈[0,1]
第 2 行:真实概率,K 个数
第 3 行:输入序列, L×D 个数
第 4 行:MLP 参数 Wmlp,D×D 个数
第 5 行:分类层参数 Wcls,D×K 个数
输出描述
第 1 行: K 个类别的预测概率
第 2 行:MSE LOSS,1 个数
第 3 行:MLP 更新后的参数 Wmlp,D×D 个数
第 4 行:分类层更新后的参数 Wcls,D×K 个数
注:数据间用运号隔开,输出结果均保留 2 位小数
样例1
输入
4,2,5,1.0
0.10,0.20,0.30,0.25,0.15
0.0,1.0,-1.5,2.5,3.0,-0.5,0.7,0.3
0.6,-0.4,0.2,0.9
0.5,0.1,-0.3,0.8,0.0,-0.2,0.4,0.6,-0.5,1.0
输出
0.14,0.26,0.16,0.13,0.52
0.04
0.61,-0.48,0.21,0.78
0.49,0.09,-0.27,0.82,-0.07,-0.21,0.39,0.63,-0.48,0.92
说明
输入:
第 1 行:序列长度 L=4、特征维度 D=2 、分类数 K=5、学习率 η=1.0
第 2 行:表示真实标签三分类的概率分别为 0.10,0.20,0.30,0.25,0.15
第 3 行:输入序列数据内容,4×2=8 个数
第 4 行: MLP 参数 Wmlp,2×2 个数
第 5 行:分类层参数 Wcls,2×5 个数
输出:
第 1 行:五分类的预测概率分别为 0.14,0.26,0.16,0.13,0.52
第 2 行: MSELOSS 为 0.04
第 3 行: MLP 更新后的参数 Wmlp
第 4 行:分类层更新后的参数 Wcls
样例2
输入
2,2,3,0,1
1.0,0.0,0.0
1.0,2.0,3.0,4.0
1.0,1.0,1.0,1.0
1.0,0.0,0.0,1.0,0.0,0.0
输出
5.00,0.00,0.00
5.33
0.47,-0.53,-0.80,0.20
0.47,0.00,0.00,0.20,0.00,0.00
说明
输入:
第 1 行:序列长度 L=2、特征维度 D=2 、分类数 K=3、学习率 η=0.1
第 2 行:表示真实标签三分类的概率分别为 1.0、0.0、0.0
第 3 行:输入序列数据内容,2×2=4 个数
第 4 行: MLP 参数 Wmlp,2×2 个数
第 5 行:分类层参数 Wcls,2×3 个数
输出:
第 1 行:三分类的预测概率分别为 5.00、0.00、0.00
第 2 行: MSELOSS 为 5.33
第 3 行: MLP 更新后的参数 Wmlp
第 4 行:分类层更新后的参数 Wcls