#P4518. 第2题-基于剪枝的神经网络模型压缩
-
1000ms
Tried: 107
Accepted: 21
Difficulty: 5
所属公司 :
华为
时间 :2025年12月3日-AI方向
-
算法标签>机器学习算法
第2题-基于剪枝的神经网络模型压缩
解题思路
题目本质是:
- 先对权重矩阵 (W) 做按行的结构化剪枝(整行去掉),根据每行的 L1 范数决定保留与删除哪些输入特征;
- 然后用剪枝后的矩阵 (X')、(W') 做线性变换并预测类别。
1. 剪枝部分
-
对于每一行 (i)(对应第 (i) 个输入特征)计算:
∣Wi,:∣∗1=∑∗j=0c−1∣Wij∣ -
剪枝行数:
k=⌊ratio×d⌋如果 (ratio > 0) 且 (k = 0),则令 (k = 1),保证至少剪掉 1 行。
-
将所有行按 L1 范数从小到大排序,取前 (k) 行作为“要删除”的行。
-
对应删除:
- 在权重矩阵中删除这些行,得到 (W')(维度 (d−k)×c)。
- 在输入矩阵中删除对应的列,得到 (X')(维度 n×(d−k))。
实现上可以这样做:
- 先计算
row_norms[i],长度为d。 - 建立一个下标数组
idx = [0, 1, ..., d-1],按照row_norms[idx[i]]升序排序。 - 前
k个下标加入一个“剪掉集合”,剩下的就是要保留的行(列)。
2. 前向计算与预测
剪枝后,线性部分为:
$$h = X' W' \quad (n \times (d-k)) \cdot ((d-k) \times c) = n \times c$$第 (i) 行第 (j) 列:
$$h_{ij} = \sum_{t=0}^{d-k-1} X'*{i,t} \cdot W'*{t,j}$$逻辑上,题目给出了 softmax:
$$y = \text{softmax}(h), \quad \text{label}*i = \arg\max_j y*{ij}$$但注意:
softmax 是对每个分量做单调递增变换,且不改变各维之间的大小关系。 因此:
$$\arg\max_j \text{softmax}(h_{i,:}) = \arg\max_j h_{ij}$$
也就是说,为了得到预测标签,我们完全可以跳过 softmax 的显式计算,直接对每行的 (h) 做 argmax 即可。 这样实现更简单,也避免不必要的指数运算和数值问题。
步骤总结:
- 用保留的特征索引,计算每个样本到每个类别的加权和
h[i][j]。 - 对每行
i找到最大值所在的列下标j,即为该样本的预测标签。 - 将所有 label 以空格分隔输出一行。
代码实现
Python
import sys
def structured_pruning_prediction(n, d, c, X, W, ratio):
# 1. 计算每行 L1 范数
row_norms = []
for i in range(d):
s = 0.0
for j in range(c):
s += abs(W[i][j])
row_norms.append(s)
# 2. 计算剪枝行数 k
k = int(ratio * d) # 向下取整
if ratio > 0 and k == 0:
k = 1 # 至少剪一行
# 3. 找到 L1 范数最小的 k 行(要剪掉的行)
indices = list(range(d))
indices.sort(key=lambda idx: row_norms[idx]) # 按范数从小到大排序
prune_set = set(indices[:k]) # 要剪掉的下标集合
# 4. 构造保留的特征索引(按原顺序)
keep_indices = [i for i in range(d) if i not in prune_set]
kept_d = len(keep_indices)
# 5. 使用保留特征计算 h = X' W'
# 为了节省空间,这里不显式构造 X'、W',直接根据 keep_indices 访问原矩阵
labels = []
for i in range(n):
# scores[j] 表示样本 i 对类别 j 的线性得分 h_ij
scores = [0.0] * c
for t in range(kept_d):
feat_idx = keep_indices[t]
x_val = X[i][feat_idx]
if x_val == 0.0:
continue
for j in range(c):
scores[j] += x_val * W[feat_idx][j]
# 6. 对每行 scores 求 argmax,得到预测标签
max_j = 0
max_val = scores[0]
for j in range(1, c):
if scores[j] > max_val:
max_val = scores[j]
max_j = j
labels.append(max_j)
return labels
def main():
data = sys.stdin.read().strip().split()
if not data:
return
ptr = 0
n = int(data[ptr]); ptr += 1
d = int(data[ptr]); ptr += 1
c = int(data[ptr]); ptr += 1
# 读取 X 矩阵
X = []
for _ in range(n):
row = []
for _ in range(d):
row.append(float(data[ptr]))
ptr += 1
X.append(row)
# 读取 W 矩阵
W = []
for _ in range(d):
row = []
for _ in range(c):
row.append(float(data[ptr]))
ptr += 1
W.append(row)
# 读取 ratio
ratio = float(data[ptr]); ptr += 1
labels = structured_pruning_prediction(n, d, c, X, W, ratio)
print(" ".join(str(x) for x in labels))
if __name__ == "__main__":
main()
Java
import java.util.*;
public class Main {
// 按行剪枝并进行预测的函数
public static int[] structuredPruningPrediction(int n, int d, int c,
double[][] X, double[][] W,
double ratio) {
// 1. 计算每行 L1 范数
double[] rowNorms = new double[d];
for (int i = 0; i < d; i++) {
double s = 0.0;
for (int j = 0; j < c; j++) {
s += Math.abs(W[i][j]);
}
rowNorms[i] = s;
}
// 2. 计算剪枝行数 k
int k = (int) Math.floor(ratio * d);
if (ratio > 0 && k == 0) {
k = 1; // 至少剪一行
}
// 3. 找到 L1 范数最小的 k 行
Integer[] idx = new Integer[d];
for (int i = 0; i < d; i++) {
idx[i] = i;
}
Arrays.sort(idx, new Comparator<Integer>() {
@Override
public int compare(Integer a, Integer b) {
if (rowNorms[a] < rowNorms[b]) return -1;
else if (rowNorms[a] > rowNorms[b]) return 1;
else return 0;
}
});
boolean[] prune = new boolean[d];
for (int i = 0; i < k; i++) {
prune[idx[i]] = true;
}
// 4. 构造保留特征的下标数组
int keptCount = d - k;
int[] keepIndices = new int[keptCount];
int pos = 0;
for (int i = 0; i < d; i++) {
if (!prune[i]) {
keepIndices[pos++] = i;
}
}
// 5. 计算 h = X' W' 并直接做 argmax
int[] labels = new int[n];
for (int i = 0; i < n; i++) {
double[] scores = new double[c];
// 线性计算:只用保留的特征
for (int t = 0; t < keptCount; t++) {
int featIdx = keepIndices[t];
double xVal = X[i][featIdx];
if (xVal == 0.0) continue;
for (int j = 0; j < c; j++) {
scores[j] += xVal * W[featIdx][j];
}
}
// 6. 对 scores 求 argmax
int bestIdx = 0;
double bestVal = scores[0];
for (int j = 1; j < c; j++) {
if (scores[j] > bestVal) {
bestVal = scores[j];
bestIdx = j;
}
}
labels[i] = bestIdx;
}
return labels;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
if (!sc.hasNext()) {
sc.close();
return;
}
int n = sc.nextInt();
int d = sc.nextInt();
int c = sc.nextInt();
// 读取 X 矩阵
double[][] X = new double[n][d];
for (int i = 0; i < n; i++) {
for (int j = 0; j < d; j++) {
X[i][j] = sc.nextDouble();
}
}
// 读取 W 矩阵
double[][] W = new double[d][c];
for (int i = 0; i < d; i++) {
for (int j = 0; j < c; j++) {
W[i][j] = sc.nextDouble();
}
}
// 读取 ratio
double ratio = sc.nextDouble();
sc.close();
int[] labels = structuredPruningPrediction(n, d, c, X, W, ratio);
// 输出结果,一行空格分隔
StringBuilder sb = new StringBuilder();
for (int i = 0; i < labels.length; i++) {
if (i > 0) sb.append(' ');
sb.append(labels[i]);
}
System.out.println(sb.toString());
}
}
C++
#include <bits/stdc++.h>
using namespace std;
// 按行结构化剪枝并预测的函数
vector<int> structuredPruningPrediction(int n, int d, int c,
const vector<vector<double>>& X,
const vector<vector<double>>& W,
double ratio) {
// 1. 计算每行 L1 范数
vector<double> rowNorms(d, 0.0);
for (int i = 0; i < d; ++i) {
double s = 0.0;
for (int j = 0; j < c; ++j) {
s += fabs(W[i][j]);
}
rowNorms[i] = s;
}
// 2. 计算剪枝行数 k
int k = static_cast<int>(floor(ratio * d));
if (ratio > 0 && k == 0) {
k = 1; // 至少剪一行
}
// 3. 找到 L1 范数最小的 k 行
vector<int> idx(d);
for (int i = 0; i < d; ++i) idx[i] = i;
sort(idx.begin(), idx.end(), [&](int a, int b) {
return rowNorms[a] < rowNorms[b];
});
vector<bool> prune(d, false);
for (int i = 0; i < k; ++i) {
prune[idx[i]] = true;
}
// 4. 构造保留特征的下标
vector<int> keepIndices;
keepIndices.reserve(d - k);
for (int i = 0; i < d; ++i) {
if (!prune[i]) keepIndices.push_back(i);
}
int keptCount = (int)keepIndices.size();
// 5. 计算 h = X' W' 并求 argmax
vector<int> labels(n);
for (int i = 0; i < n; ++i) {
vector<double> scores(c, 0.0);
// 线性计算
for (int t = 0; t < keptCount; ++t) {
int featIdx = keepIndices[t];
double xVal = X[i][featIdx];
if (xVal == 0.0) continue;
for (int j = 0; j < c; ++j) {
scores[j] += xVal * W[featIdx][j];
}
}
// 6. 求 argmax
int bestIdx = 0;
double bestVal = scores[0];
for (int j = 1; j < c; ++j) {
if (scores[j] > bestVal) {
bestVal = scores[j];
bestIdx = j;
}
}
labels[i] = bestIdx;
}
return labels;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, d, c;
if (!(cin >> n >> d >> c)) {
return 0;
}
// 读取 X 矩阵
vector<vector<double>> X(n, vector<double>(d));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < d; ++j) {
cin >> X[i][j];
}
}
// 读取 W 矩阵
vector<vector<double>> W(d, vector<double>(c));
for (int i = 0; i < d; ++i) {
for (int j = 0; j < c; ++j) {
cin >> W[i][j];
}
}
// 读取 ratio
double ratio;
cin >> ratio;
vector<int> labels = structuredPruningPrediction(n, d, c, X, W, ratio);
// 输出一行,空格分隔
for (int i = 0; i < (int)labels.size(); ++i) {
if (i > 0) cout << ' ';
cout << labels[i];
}
cout << '\n';
return 0;
}
题目内容
在端侧设备部署神经网络模型时,需解决模型参数量过大的问题。本题目要求实现神经网络模型的结构化剪枝,通过移除冗余输入通道降低模型复杂度,同时保持分类性能。给定输入矩阵 X 、模型权重 W 以及剪枝比例 ratio,对 W 进行结构化剪枝,并使用剪枝后的结果计算模型预测结果。以下是相关计算流程及指标定义说明。
-
输入矩阵: X (维度: n×d, n 为样本数,d 为输入特征数)
-
权重矩阵: W (维度:d×c,c 为输出类别数)
-
计算过程:
- 线性变换: h=XW(维度:n×c)
- Softmax 激活 : y=softmax(h)(输出概率分布)
- 预测标签: label=arg max(y)
剪枝目标 : 对权重矩阵 W 按行剪枝(移除整行权重),剪枝率为 ratio,剪枝指标为 L1 范数。
提示:
1、y=softmax(h) 其中 $y_{i j}=\frac{\exp \left(h_{i j}-\max \left(h_{i}\right)\right)}{\sum_{j} \exp \left(h_{i j}-\max \left(h_{i}\right)\right)}$ ,按行计算概率分布,每个元系减去最大值防止外溢,
2、labeli=argmax(yi) 按行计算,输出每行最大值对应的列下标,范围为 [0,c) 。
1.剪枝定义
-
按行剪枝 : 移除权重矩阵 W 中不重要的行(对应输入特征),保留重要行。
-
物理意义 : 移除对输出影响较小的输入特征,压缩模型输入维度。
-
剪枝后维度:
- 权重矩阵 W′ 维度:(d−k)×c ( k 为剪枝行数)。
- 输入矩阵 X′ 维度: n×(d−k) (需移除对应特征列)。
2.剪枝指标
-
L1 范数:对权重矩阵 W 的每一行计算绝对值之和。
- 第 i 行的 L1 范数: $\left\|W_{i,:}\right\|_{1}=\sum_{j=1}^{c}\left|W_{i j}\right|$
-
剪枝规则 : 保留 L1 范数较大的行(重要性高),移除 L1 范数较小的行(重要性低)。
3.剪枝步骤
1.计算每行 L1 范数:$row\_norms=[\left\|W_{0,:}\right\|_{1},\left\|W_{i,:}\right\|_{1},...\left\|W_{d-1,:}\right\|_{1}]$
2.确定剪枝行数: k=⌊ratio×d⌋ (需剪掉的行数)
3.选择 L1 范数最小的 k 行移除,得到剪枝后权重矩阵 W′ 。
4.调整输入矩阵 X :移除与剪枝行对应的列,得到 X’ 。
说明:
k=⌊ratio×d⌋
表示 k 是向下取整后的结果。如果 ratio>0 并且向下取整后 k 为 0 ,则取 k 为 1 (至少剪枝 1 行)
输入描述
输入内容如下:
第一行三个整数:n d c
接下来 n 行,每行 d 个浮点数:X 矩阵
接下来 d 行,每行 c 个浮点数:W 矩阵
最后一行:剪枝率 ratio
输入范围:
1、1<=n,d,c<=64
2、0<=ratio<=1.0
输出描述
输出为使用剪枝后矩阵计算得到的预测 label 结果。
样例1
输入
4 5 2
1.89 1.88 0.87 0.19 0.62
0.75 0.75 1.45 0.24 0.65
1.26 0.4 0.69 0.54 0.93
0.11 0.61 0.25 1.47 1.96
0.89 2.44
0.97 2.61
2.24 0.72
1.64 0.38
2.29 0.69
0.3
输出
1 0 1 0
说明
样例2
输入
2 2 2
1.0 2.0
3.0 4.0
0.1 0.2
0.3 0.4
0.5
输出
1 1
说明
表示 X 矩阵为:
1.0 2.0
3.0 4.0
W 矩阵为:
0.1 0.2
0.3 0.4
剪枝率 ratio 为 0.5