#P4464. 第2题-全连接层INT8非对称量化实现
-
1000ms
Tried: 190
Accepted: 40
Difficulty: 6
所属公司 :
华为
时间 :2025年11月12日-AI方向
-
算法标签>机器学习算法
第2题-全连接层INT8非对称量化实现
解题思路
本题要求将输入向量与权重矩阵分别做 INT8 非对称量化(per-tensor),用量化后的整数直接做全连接(矩阵–向量乘),并用反量化后的结果评估与原始浮点计算之间的误差。
核心要点如下:
-
量化(asymmetric, INT8) 对张量 v(可为向量 x 或矩阵 W)做 per-tensor 量化:
-
尺度(scale)
$$\text{scale}_v = \frac{\max(v)-\min(v)}{255},\quad \text{若}\ \max(v)=\min(v)\ \text{则}\ \text{scale}_v=0 $$ -
量化到 [−128,127]
$$v_{\text{quant}}=\text{clamp}\Big(\text{round}\Big(\frac{v-\min(v)}{\text{scale}_v}\Big)-128,\,-128,\,127\Big) $$其中 round 采用就近取偶(Banker’s Rounding)。当 scalev=0 时,直接令 vquant=−128。
-
-
反量化(dequant)
$$v_{\text{dequant}} = (v_{\text{quant}}+128)\cdot \text{scale}_v + \min(v) $$当 scalev=0 时,令 vdequant=min(v)。
-
全连接层计算 设输入为 x∈Rn,权重 W∈Rm×n,输出
Y=x⋅W⊤∈Rm-
整数路径输出(第一行):用 xquant 与 Wquant 直接做整数点积,得到 m 个整数。不添加偏置。
-
误差评估路径:分别将 xquant,Wquant 反量化为 xdequant,Wdequant,再做浮点全连接得到 Ydequant。与原始浮点 Yfloat 做均方误差
$$\text{MSE}=\frac{1}{m}\sum_{i=0}^{m-1}\big(Y_{\text{float},i}-Y_{\text{dequant},i}\big)^2 $$输出时取 round(MSE×100000),此处按“四舍五入”(half-up)。
-
-
实现细节
- x 与 W 分别独立计算 min,max,scale(per-tensor 量化)。
- 量化时采用就近取偶;MSE 放大后采用四舍五入(half-up)。
- 矩阵乘法按行做点积;整数输出建议用较大整型累加避免溢出。
代码实现
Python
import sys
import math
def quantize_tensor(values):
"""对一维列表进行INT8非对称量化,返回(q, scale, vmin)"""
vmin = min(values)
vmax = max(values)
if vmax == vmin:
# scale为0:全量化为-128
return [-128] * len(values), 0.0, vmin
scale = (vmax - vmin) / 255.0
q = []
for v in values:
t = (v - vmin) / scale # 落在[0,255]
rq = round(t) # 就近取偶
iv = int(rq) - 128
if iv < -128:
iv = -128
elif iv > 127:
iv = 127
q.append(iv)
return q, scale, vmin
def dequantize_tensor(q, scale, vmin):
"""反量化一维列表"""
if scale == 0.0:
return [vmin] * len(q)
return [(qi + 128) * scale + vmin for qi in q]
def fc_int_output(xq, Wq, n):
"""使用量化后的整数做全连接:返回长度m的整数输出"""
m = len(Wq) // n
y = []
for i in range(m):
s = 0
base = i * n
for j in range(n):
s += xq[j] * Wq[base + j]
y.append(s)
return y
def fc_float_output(x, W, n):
"""浮点全连接:返回长度m的浮点输出"""
m = len(W) // n
y = []
for i in range(m):
s = 0.0
base = i * n
for j in range(n):
s += x[j] * W[base + j]
y.append(s)
return y
def round_half_up(x):
"""四舍五入到最近整数(正数half-up,MSE>=0安全)"""
return int(math.floor(x + 0.5))
def main():
data = sys.stdin.read().strip().split()
it = iter(data)
n = int(next(it))
x = [float(next(it)) for _ in range(n)]
m = int(next(it)); n2 = int(next(it)) # 题面保证维度合法
W = []
for _ in range(m):
for _ in range(n):
W.append(float(next(it)))
# 量化(x 与 W 分别 per-tensor)
xq, sx, xmin = quantize_tensor(x)
Wq, sw, wmin = quantize_tensor(W)
# 整数路径输出
y_int = fc_int_output(xq, Wq, n)
# 误差评估:反量化 -> 浮点全连接
x_d = dequantize_tensor(xq, sx, xmin)
W_d = dequantize_tensor(Wq, sw, wmin)
y_float = fc_float_output(x, W, n)
y_deq = fc_float_output(x_d, W_d, n)
# MSE × 100000 四舍五入
msz = len(y_float)
mse = sum((y_float[i] - y_deq[i]) ** 2 for i in range(msz)) / msz
mse_scaled = round_half_up(mse * 100000.0)
# 输出
print(' '.join(str(v) for v in y_int))
print(mse_scaled)
if __name__ == "__main__":
main()
Java
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
// 要求:ACM风格;类名Main;不用快读库,使用常规BufferedReader+StringTokenizer
public class Main {
// 量化结果结构体
static class QuantRes {
int[] q;
double scale;
double vmin;
QuantRes(int[] q, double scale, double vmin) {
this.q = q; this.scale = scale; this.vmin = vmin;
}
}
// INT8 非对称量化(就近取偶:Math.rint)
static QuantRes quantize(double[] v) {
double vmin = v[0], vmax = v[0];
for (double val : v) { if (val < vmin) vmin = val; if (val > vmax) vmax = val; }
if (vmax == vmin) {
int[] q = new int[v.length];
for (int i = 0; i < v.length; i++) q[i] = -128;
return new QuantRes(q, 0.0, vmin);
}
double scale = (vmax - vmin) / 255.0;
int[] q = new int[v.length];
for (int i = 0; i < v.length; i++) {
double t = (v[i] - vmin) / scale;
long rq = Math.round(Math.rint(t)); // rint为就近取偶,round包一层得到long
int iv = (int) rq - 128;
if (iv < -128) iv = -128;
else if (iv > 127) iv = 127;
q[i] = iv;
}
return new QuantRes(q, scale, vmin);
}
// 反量化
static double[] dequantize(int[] q, double scale, double vmin) {
double[] r = new double[q.length];
if (scale == 0.0) {
for (int i = 0; i < q.length; i++) r[i] = vmin;
return r;
}
for (int i = 0; i < q.length; i++) r[i] = (q[i] + 128) * scale + vmin;
return r;
}
// 整数全连接:xq(1×n) 与 Wq(m×n扁平) -> m个整数
static long[] fcInt(int[] xq, int[] Wq, int n, int m) {
long[] y = new long[m];
for (int i = 0; i < m; i++) {
long s = 0L;
int base = i * n;
for (int j = 0; j < n; j++) s += (long) xq[j] * (long) Wq[base + j];
y[i] = s;
}
return y;
}
// 浮点全连接:x(1×n) 与 W(m×n扁平) -> m个浮点
static double[] fcFloat(double[] x, double[] W, int n, int m) {
double[] y = new double[m];
for (int i = 0; i < m; i++) {
double s = 0.0;
int base = i * n;
for (int j = 0; j < n; j++) s += x[j] * W[base + j];
y[i] = s;
}
return y;
}
// 四舍五入(half-up)到最近整数,MSE非负可直接使用
static long roundHalfUp(double v) {
return (long) Math.floor(v + 0.5);
}
public static void main(String[] args) throws Exception {
FastIn in = new FastIn();
int n = in.nextInt();
double[] x = new double[n];
for (int i = 0; i < n; i++) x[i] = in.nextDouble();
int m = in.nextInt();
int n2 = in.nextInt(); // 题面保证合法
double[] W = new double[m * n];
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++)
W[i * n + j] = in.nextDouble();
// 量化
QuantRes qx = quantize(x);
QuantRes qW = quantize(W);
// 整数输出
long[] yInt = fcInt(qx.q, qW.q, n, m);
// 反量化 -> 浮点全连接
double[] xDeq = dequantize(qx.q, qx.scale, qx.vmin);
double[] WDeq = dequantize(qW.q, qW.scale, qW.vmin);
double[] yFloat = fcFloat(x, W, n, m);
double[] yDeq = fcFloat(xDeq, WDeq, n, m);
// MSE × 100000 四舍五入
double mse = 0.0;
for (int i = 0; i < m; i++) {
double d = yFloat[i] - yDeq[i];
mse += d * d;
}
mse /= m;
long mseScaled = roundHalfUp(mse * 100000.0);
// 输出
StringBuilder sb = new StringBuilder();
for (int i = 0; i < m; i++) {
if (i > 0) sb.append(' ');
sb.append(yInt[i]);
}
System.out.println(sb.toString());
System.out.println(mseScaled);
}
// 简洁输入工具:BufferedReader + StringTokenizer
static class FastIn {
BufferedReader br;
StringTokenizer st;
FastIn() { br = new BufferedReader(new InputStreamReader(System.in)); }
String next() throws IOException {
while (st == null || !st.hasMoreElements()) {
String line = br.readLine();
if (line == null) return null;
st = new StringTokenizer(line);
}
return st.nextToken();
}
int nextInt() throws IOException { return Integer.parseInt(next()); }
double nextDouble() throws IOException { return Double.parseDouble(next()); }
}
}
C++
#include <bits/stdc++.h>
using namespace std;
// 就近取偶(Banker's rounding)
static long long round_half_even(double x) {
double f = floor(x);
double frac = x - f;
if (frac < 0.5) return (long long)f;
if (frac > 0.5) return (long long)f + 1;
// frac == 0.5,取偶
long long fl = (long long)f;
if ((fl % 2LL) == 0LL) return fl;
return fl + 1LL;
}
// 四舍五入(half-up),MSE非负
static long long round_half_up(double v) {
return (long long)floor(v + 0.5);
}
// 量化(返回 q, scale, vmin)
static void quantize(const vector<double>& v, vector<int>& q, double& scale, double& vmin) {
double vmax = v[0];
vmin = v[0];
for (double x : v) { if (x < vmin) vmin = x; if (x > vmax) vmax = x; }
if (vmax == vmin) {
scale = 0.0;
q.assign(v.size(), -128);
return;
}
scale = (vmax - vmin) / 255.0;
q.resize(v.size());
for (size_t i = 0; i < v.size(); ++i) {
double t = (v[i] - vmin) / scale; // in [0,255]
long long rq = round_half_even(t);
long long iv = rq - 128;
if (iv < -128) iv = -128;
else if (iv > 127) iv = 127;
q[i] = (int)iv;
}
}
// 反量化
static void dequantize(const vector<int>& q, double scale, double vmin, vector<double>& out) {
out.resize(q.size());
if (scale == 0.0) {
for (size_t i = 0; i < q.size(); ++i) out[i] = vmin;
return;
}
for (size_t i = 0; i < q.size(); ++i) out[i] = (q[i] + 128) * scale + vmin;
}
// 整数全连接:xq(1×n) 与 Wq(m×n扁平) -> m个整数(long long)
static void fc_int(const vector<int>& xq, const vector<int>& Wq, int n, int m, vector<long long>& y) {
y.assign(m, 0LL);
for (int i = 0; i < m; ++i) {
long long s = 0;
int base = i * n;
for (int j = 0; j < n; ++j) s += 1LL * xq[j] * Wq[base + j];
y[i] = s;
}
}
// 浮点全连接:x 与 W(m×n扁平)
static void fc_float(const vector<double>& x, const vector<double>& W, int n, int m, vector<double>& y) {
y.assign(m, 0.0);
for (int i = 0; i < m; ++i) {
double s = 0.0;
int base = i * n;
for (int j = 0; j < n; ++j) s += x[j] * W[base + j];
y[i] = s;
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
if (!(cin >> n)) return 0;
vector<double> x(n);
for (int i = 0; i < n; ++i) cin >> x[i];
int m, n2;
cin >> m >> n2;
vector<double> W(m * n);
for (int i = 0; i < m; ++i)
for (int j = 0; j < n; ++j)
cin >> W[i * n + j];
// 量化
vector<int> xq, Wq;
double sx, xmin, sw, wmin;
quantize(x, xq, sx, xmin);
quantize(W, Wq, sw, wmin);
// 整数输出
vector<long long> y_int;
fc_int(xq, Wq, n, m, y_int);
// 反量化并浮点全连接
vector<double> x_deq, W_deq, y_float, y_deq;
dequantize(xq, sx, xmin, x_deq);
dequantize(Wq, sw, wmin, W_deq);
fc_float(x, W, n, m, y_float);
fc_float(x_deq, W_deq, n, m, y_deq);
// MSE × 100000 四舍五入
double mse = 0.0;
for (int i = 0; i < m; ++i) {
double d = y_float[i] - y_deq[i];
mse += d * d;
}
mse /= m;
long long mse_scaled = round_half_up(mse * 100000.0);
// 输出
for (int i = 0; i < m; ++i) {
if (i) cout << ' ';
cout << y_int[i];
}
cout << "\n" << mse_scaled << "\n";
return 0;
}
题目内容
【背景】在移动设备部署深度学习模型时,浮点运算会消耗大量计算资源。通过 INT8 非对称量化,可将全连接层的浮点运算转化为整数运算,显著提高推理速度。实际应用中:
- 量化后模型大小缩小 4 倍(32bit→8bit)
- 整数运算指令比浮点指令快 2-4 倍
- 广泛应用于移动端 NLP 模型(如 BERT 最后一层分类头)
- 在物联网设备上可降低能耗并减少内存占用
【题目要求】请实现以下功能:
- 量化和全连接层计算:对输入向量 x 和权重矩阵 W 执行 INT8 非对称量化,使用量化后的整数值 xquant和Wquant 进行全连接层计算,输出计算结果。为简化起见,本题中全连接层不考虑偏置。
- 计算量化误差:对量化的整数进行反量化得到 xdequant和 Wdequant并进行全连接层计算,与原始浮点 x、W 的全连接层计算结果进行比较,计算两个全连接层输出之间的均方误差(MSE),并将 MSE × 100000 后四舍五入后输出。
【算法原理】
1、INT8 非对称量化:
1)尺度:scalev=(max(v)−min(v))/255,当max(v)==min(v),即张量 v 的所有值相等时,scalev=0。
2)量化,对张量 v(向量 x 或矩阵 W)进行量化得到vquant,量化后的整数区间为 [-128,127]:
$v_{quant} = clamp(round((v - min(v))/scale_v) - 128, -128,127)$ ,当scalev=0时量化结果为vquant=−128。
其中 round () 采用就近取偶。
$\text{round}(x)= \begin{cases} \lfloor x \rfloor, & \{x\} < \frac{1}{2}, \\ \lfloor x \rfloor + 1, & \{x\} > \frac{1}{2}, \\ 2 \cdot \lfloor \frac{x+1}{2} \rfloor, & \{x\} = \frac{1}{2}. \end{cases}$>
其中:{x}=x−⌊x⌋,⌊x⌋ 表示向下取整。
$\text{clamp}(t, lo, hi)= \begin{cases} lo, & t < lo \\ hi, & t > hi \\ t, & else \end{cases}$>
3)反量化,对 vquant 进行反量化后得到 vdequant:
$v_{\text{dequant}} = (v_{\text{quant}} + 128) \cdot \text{scale}_v + \min(v)$,当 scalev=0 时,反量化值 vdequant=min(v),即为原始输入的最小值。
2、全连接层计算,以输入向量x和权重矩阵W为例,全连接层输出Y。
Y=x⋅WT
3、量化误差,计算原始浮点输入的全连接层输出 Yfloat 和反量化数据的全连接层输出 Ydequant 之间的均方误差(MSE):
$MSE = \frac{1}{m} \sum_{i=0}^{m-1} (Y_{\text{float},i} - Y_{\text{dequant},i})^2$,m 为权重矩阵的行数。
输入描述
第一行: n (输入向量 x 的维度)第二行: n 个浮点数 (输入向量 x)第三行: m n (权重矩阵 W 的维度)接下来 m 行:每行 n 个浮点数 (权重矩阵 W)
输出描述
第一行: m 个整数 (使用量化数据 xquant和 Wquant计算的全连接层输出)
第二行: 1 个整数 (量化误差 MSE,注意是 MSE × 100000 后四舍五入输出整数)
样例1
输入
3
1.0 2.0 3.0
2 3
0.1 0.2 0.3
0.4 0.5 0.6
输出
13082 12929
0
说明
3 # n=3 (输入向量维度)
1.0 2.0 3.0 # x = [1.0, 2.0, 3.0]
2 3 # m=2, n=3 (权重矩阵 2×3)
0.1 0.2 0.3 # W 第 1 行 = [0.1, 0.2, 0.3]
0.4 0.5 0.6 # W 第 2 行 = [0.4, 0.5, 0.6]
量化输入向量 X: xquant= [-128, 0, 127]
量化权重矩阵 W: Wquant= [[-128, -77, -26], [25, 76, 127]]
量化域整数运算:输出第一行结果: 13082 12929
计算MSE
原始浮点输出:
Y_float [0] = 1.0×0.1 + 2.0×0.2 + 3.0×0.3 = 0.1 + 0.4 + 0.9 = 1.4
Y_float [1] = 1.0×0.4 + 2.0×0.5 + 3.0×0.6 = 0.4 + 1.0 + 1.8 = 3.2
反量化后: Y_dequant = Y_float
MSE: 输出第二行结果: 0
样例2
输入
7
0.3 -1.1 2.2 -3.3 4.4 -5.5 6.6
3 7
0.2 -0.3 0.4 -0.1 0.0 5 -0.6
-1.5 1.2 -0.9 0.6 -0.3 0.1 0
3 -2 1 -0.5 0.25 -0.125 0.0625
输出
-5476 -7406 8954
933
说明
7 # n=7 (输入向量维度)
0.3 -1.1 2.2 -3.3 4.4 -5.5 6.6 # x = [0.3, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6]
3 7 # m=3, n=7 (权重矩阵 3×7)
0.2 -0.3 0.4 -0.1 0.05 -0.6 # W 第 1 行
-1.5 1.2 -0.9 0.6 -0.3 0.1 0 # W 第 2 行
3 -2 1 -0.5 0.25 -0.125 0.0625 # W 第 3 行
输出:
量化域整数运算输出: -5476 -7406 8954
MSE 输出: 933