#P4548. 第3题-模型INT8量化计算
-
1000ms
Tried: 33
Accepted: 10
Difficulty: 7
所属公司 :
华为
时间 :2026年1月21日-AI方向
-
算法标签>机器学习算法
第3题-模型INT8量化计算
解题思路
本题实现一种常见的线性层 INT8 量化推理流程:对输入矩阵 X 做 per-token(按行)量化,对权重矩阵 W 做 per-channel(按列)量化,然后在 INT32 累加域中完成矩阵乘法,最后用缩放因子反量化回 FP32 输出。
核心算法分为 4 步:
-
计算缩放因子(scale)
-
对 X 的每一行 i(一个 token):
- 取该行绝对值最大值 maxAbsX(i)
- scale:sX[i] = maxAbsX(i) / 127
-
对 W 的每一列 j(一个 channel):
- 取该列绝对值最大值 maxAbsW(j)
- scale:sW[j] = maxAbsW(j) / 127
-
特殊情况:若 maxAbs 为 0,则该行/列全为 0,为避免除 0,可令 scale=1(量化值也全为 0,最终输出仍为 0,不影响正确性)。
-
-
按 scale 做 INT8 量化(对称量化、零点为 0)
- 量化:Q = clip(round(val / scale), -127, 127)
- round 必须与 Python round 一致(奇进偶弃,ties-to-even)。
- clip 将结果限制在 [-127, 127],确保可用 int8 表示(题面要求是 -127~127)。
-
INT8×INT8 → INT32 的矩阵乘法
-
对每个输出元素 (i, j):
- Y_int32[i][j] = Σk QX[i][k] * QW[k][j]
-
用 INT32 累加避免溢出(本题 K≤128,范围足够安全)。
-
-
反量化得到 FP32 输出
- Y_fp32[i][j] = Y_int32[i][j] * sX[i] * sW[j]
- 最后按要求输出并保留两位小数(format(num, ".2f") 风格),每行用单个空格分隔。
实现要点
- per-token:对 X 按行计算 scale,量化时每行使用自己的 sX[i]。
- per-channel:对 W 按列计算 scale,量化时每列使用自己的 sW[j]。
- rounding:Python 直接用 round;Java 用 Math.rint;C++ 用 std::nearbyint(默认 FE_TONEAREST 即 ties-to-even)。
代码实现
import sys
def quantize_int8(value, scale):
# Python round() 默认就是“奇进偶弃”(ties-to-even)
q = int(round(value / scale))
if q > 127:
return 127
if q < -127:
return -127
return q
def int8_quant_matmul(X, W, M, K, N):
# 1) 计算 per-token scale:sX[i]
sX = [0.0] * M
QX = [[0] * K for _ in range(M)]
for i in range(M):
max_abs = 0.0
for k in range(K):
v = X[i][k]
av = v if v >= 0 else -v
if av > max_abs:
max_abs = av
scale = max_abs / 127.0
if scale == 0.0:
scale = 1.0 # 全 0 行,避免除 0
sX[i] = scale
for k in range(K):
QX[i][k] = quantize_int8(X[i][k], scale)
# 2) 计算 per-channel scale:sW[j]
sW = [0.0] * N
QW = [[0] * N for _ in range(K)]
for j in range(N):
max_abs = 0.0
for k in range(K):
v = W[k][j]
av = v if v >= 0 else -v
if av > max_abs:
max_abs = av
scale = max_abs / 127.0
if scale == 0.0:
scale = 1.0 # 全 0 列,避免除 0
sW[j] = scale
for k in range(K):
QW[k][j] = quantize_int8(W[k][j], scale)
# 3) INT32 矩阵乘法 + 4) 反量化
Y = [[0.0] * N for _ in range(M)]
for i in range(M):
for j in range(N):
acc = 0
for k in range(K):
acc += QX[i][k] * QW[k][j] # INT32 累加
Y[i][j] = acc * sX[i] * sW[j]
return Y
def main():
data = sys.stdin.buffer.read().split()
idx = 0
M = int(data[idx]); idx += 1
K = int(data[idx]); idx += 1
X = [[0.0] * K for _ in range(M)]
for i in range(M):
for k in range(K):
X[i][k] = float(data[idx]); idx += 1
K2 = int(data[idx]); idx += 1
N = int(data[idx]); idx += 1
# 默认输入合法,K2 应等于 K
W = [[0.0] * N for _ in range(K)]
for k in range(K):
for j in range(N):
W[k][j] = float(data[idx]); idx += 1
Y = int8_quant_matmul(X, W, M, K, N)
out_lines = []
for i in range(M):
row = []
for j in range(N):
row.append(format(Y[i][j], ".2f"))
out_lines.append(" ".join(row))
sys.stdout.write("\n".join(out_lines))
if __name__ == "__main__":
main()
#include <iostream>
#include <vector>
#include <cmath> // nearbyint
using namespace std;
// 按 Python round 逻辑:ties-to-even,可用 nearbyint 实现(默认舍入模式为最近偶数)
static int quantize_int8(double value, double scale) {
double qd = value / scale;
long long q = (long long) nearbyint(qd); // 0.5 情况取偶数
if (q > 127) return 127;
if (q < -127) return -127;
return (int)q;
}
static vector<vector<double>> int8_quant_matmul(
const vector<vector<double>>& X,
const vector<vector<double>>& W,
int M, int K, int N
) {
// 1) per-token scale
vector<double> sX(M, 0.0);
vector<vector<int>> QX(M, vector<int>(K, 0));
for (int i = 0; i < M; i++) {
double max_abs = 0.0;
for (int k = 0; k < K; k++) {
double v = X[i][k];
double av = v >= 0 ? v : -v;
if (av > max_abs) max_abs = av;
}
double scale = max_abs / 127.0;
if (scale == 0.0) scale = 1.0; // 全 0 行避免除 0
sX[i] = scale;
for (int k = 0; k < K; k++) {
QX[i][k] = quantize_int8(X[i][k], scale);
}
}
// 2) per-channel scale
vector<double> sW(N, 0.0);
vector<vector<int>> QW(K, vector<int>(N, 0));
for (int j = 0; j < N; j++) {
double max_abs = 0.0;
for (int k = 0; k < K; k++) {
double v = W[k][j];
double av = v >= 0 ? v : -v;
if (av > max_abs) max_abs = av;
}
double scale = max_abs / 127.0;
if (scale == 0.0) scale = 1.0; // 全 0 列避免除 0
sW[j] = scale;
for (int k = 0; k < K; k++) {
QW[k][j] = quantize_int8(W[k][j], scale);
}
}
// 3) INT32 乘法累加 + 4) 反量化
vector<vector<double>> Y(M, vector<double>(N, 0.0));
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
int acc = 0; // INT32 累加
for (int k = 0; k < K; k++) {
acc += QX[i][k] * QW[k][j];
}
Y[i][j] = (double)acc * sX[i] * sW[j];
}
}
return Y;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int M, K;
cin >> M >> K;
vector<vector<double>> X(M, vector<double>(K, 0.0));
for (int i = 0; i < M; i++) {
for (int k = 0; k < K; k++) {
cin >> X[i][k];
}
}
int K2, N;
cin >> K2 >> N;
// 默认输入合法,K2 应等于 K
vector<vector<double>> W(K, vector<double>(N, 0.0));
for (int k = 0; k < K; k++) {
for (int j = 0; j < N; j++) {
cin >> W[k][j];
}
}
vector<vector<double>> Y = int8_quant_matmul(X, W, M, K, N);
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
if (j) cout << ' ';
cout.setf(std::ios::fixed);
cout.precision(2);
cout << Y[i][j];
}
if (i + 1 < M) cout << '\n';
}
return 0;
}
import java.util.Scanner;
public class Main {
// 按 Python round 逻辑:ties-to-even,可用 Math.rint 实现
private static int quantizeInt8(double value, double scale) {
double qd = value / scale;
int q = (int) Math.rint(qd); // 四舍五入到最近整数,0.5 时取偶数
if (q > 127) return 127;
if (q < -127) return -127;
return q;
}
private static double[][] int8QuantMatmul(double[][] X, double[][] W, int M, int K, int N) {
// 1) per-token scale
double[] sX = new double[M];
int[][] QX = new int[M][K];
for (int i = 0; i < M; i++) {
double maxAbs = 0.0;
for (int k = 0; k < K; k++) {
double v = X[i][k];
double av = v >= 0 ? v : -v;
if (av > maxAbs) maxAbs = av;
}
double scale = maxAbs / 127.0;
if (scale == 0.0) scale = 1.0; // 全 0 行避免除 0
sX[i] = scale;
for (int k = 0; k < K; k++) {
QX[i][k] = quantizeInt8(X[i][k], scale);
}
}
// 2) per-channel scale
double[] sW = new double[N];
int[][] QW = new int[K][N];
for (int j = 0; j < N; j++) {
double maxAbs = 0.0;
for (int k = 0; k < K; k++) {
double v = W[k][j];
double av = v >= 0 ? v : -v;
if (av > maxAbs) maxAbs = av;
}
double scale = maxAbs / 127.0;
if (scale == 0.0) scale = 1.0; // 全 0 列避免除 0
sW[j] = scale;
for (int k = 0; k < K; k++) {
QW[k][j] = quantizeInt8(W[k][j], scale);
}
}
// 3) INT32 乘法累加 + 4) 反量化
double[][] Y = new double[M][N];
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
int acc = 0; // INT32 累加
for (int k = 0; k < K; k++) {
acc += QX[i][k] * QW[k][j];
}
Y[i][j] = acc * sX[i] * sW[j];
}
}
return Y;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int M = sc.nextInt();
int K = sc.nextInt();
double[][] X = new double[M][K];
for (int i = 0; i < M; i++) {
for (int k = 0; k < K; k++) {
X[i][k] = sc.nextDouble();
}
}
int K2 = sc.nextInt();
int N = sc.nextInt();
// 默认输入合法,K2 应等于 K
double[][] W = new double[K][N];
for (int k = 0; k < K; k++) {
for (int j = 0; j < N; j++) {
W[k][j] = sc.nextDouble();
}
}
double[][] Y = int8QuantMatmul(X, W, M, K, N);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
if (j > 0) sb.append(' ');
sb.append(String.format("%.2f", Y[i][j]));
}
if (i + 1 < M) sb.append('\n');
}
System.out.print(sb.toString());
sc.close();
}
}
题目内容
在深度学习模型部署中,为降低内存占用与计算延迟,常采用量化技术将FP32参数转换为低精度格式(如INT8)。本题要求实现一种INT8量化策略,对输入矩阵X与权重矩阵W分别进行per−token(按行)和per−channel(按列)量化,计算量化后矩阵相乘(带scales)的结果并输出。
原始矩阵:
-
输入矩阵X∈RM×K(按行视为M个Token)
-
权重矩阵W∈RK×N(按列视为N个Channel)
量化参数:
- sX(i):第i个Token的缩放因子(per−token)
- sW(j):第j个Channel的缩放因子(per−channel)
2. 量化公式
(1) Per-Token 量化(输入矩阵 X)
-
缩放因子计算:
sX(i)=127maxk∈[1,K]∣Xi,k∣,∀i∈[1,M] -
量化公式:
QX(i,k)=clip(round(sX(i)Xi,k),−127,127) -
反量化恢复:
X^i,k=QX(i,k)⋅sX(i)
注意:
- round:四舍五入取整函数【按Python round()函数处理逻辑,奇进偶弃舍入法。即非0.5的情况:遵循常规的四舍五入规则。0.5情况:舍入到最接近的偶数。】
- clip(x,−127,127):将超过区间的值截断到[−127,127]范围内,超过127则为127,小于−127同理。
(2) Per−Channel 量化(权重矩阵 W)
- 缩放因子计算:
- 量化公式:
- 反量化恢复:
3.量化矩阵乘法
(1) INT8 矩阵乘法(量化后计算)
Yint32(i,j)=∑k=1KQX(i,k)⋅QW(k,j)
- 输出为 INT32 矩阵(避免溢出)。
(2) 反量化到 FP32
Yfp32(i,j)=Yint32(i,j)⋅sX(i)⋅sW(j)
- 组合缩放因子:每个元素的缩放因子为 sX(i)⋅sW(j)。
输入描述
输入为矩阵
X∈RM×K 输入矩阵
W∈RK×N 权重矩阵
其中
0<M,K,N<=128
−1000000.0<= 矩阵元素值 <=1000000.0
标准输入方式读入,按照X、W顺序读入,前两个数字为矩阵大小。
输出描述
输出量化后矩阵相乘结果(带scale计算),并四舍五入到小数点后两位小数【建议使用Python format(num, '.2f')处理】。注意:每行结果以单个空格间隔,头尾不要有多余空格。
样例1
输入
2 2
1.0 0.5
-0.5 2.0
2 2
0.8 -0.3
1.2 1.5
输出
1.41 0.46
2.00 3.15
说明
样例2
输入
2 3
1.2 -2.3 3.4
-4.5 5.6 -6.7
3 2
0.8 -1.9
2.0 -2.1
3.2 3.3
输出
7.27 13.78
-13.92 -25.37
说明
以上输入对应如下矩阵内容:
输入矩阵 X(2∗3,M=2 Tokens)
X=[1.2−4.5−2.35.63.4−6.7]权重矩阵 W(3∗2,N=2 Channels)
W=0.82.03.2−1.9−2.13.3对应如下矩阵结果:
outfp32=[7.27−13.9213.78−25.37]