本题要求将输入向量与权重矩阵分别做 INT8 非对称量化(per-tensor),用量化后的整数直接做全连接(矩阵–向量乘),并用反量化后的结果评估与原始浮点计算之间的误差。
核心要点如下:
【背景】在移动设备部署深度学习模型时,浮点运算会消耗大量计算资源。通过 INT8 非对称量化,可将全连接层的浮点运算转化为整数运算,显著提高推理速度。实际应用中:
【题目要求】请实现以下功能:
【算法原理】
1、INT8 非对称量化:
1)尺度:scalev=(max(v)−min(v))/255,当max(v)==min(v),即张量 v 的所有值相等时,scalev=0。
2)量化,对张量 v(向量 x 或矩阵 W)进行量化得到vquant,量化后的整数区间为 [-128,127]:
vquant=clamp(round((v−min(v))/scalev)−128,−128,127) ,当scalev=0时量化结果为vquant=−128。
其中 round () 采用就近取偶。
$\text{round}(x)= \begin{cases} \lfloor x \rfloor, & \{x\} < \frac{1}{2}, \\ \lfloor x \rfloor + 1, & \{x\} > \frac{1}{2}, \\ 2 \cdot \lfloor \frac{x+1}{2} \rfloor, & \{x\} = \frac{1}{2}. \end{cases}$
其中:{x}=x−⌊x⌋,⌊x⌋ 表示向下取整。
clamp(t,lo,hi)=⎩⎨⎧lo,hi,t,t<lot>hielse
3)反量化,对 vquant 进行反量化后得到 vdequant:
vdequant=(vquant+128)⋅scalev+min(v),当 scalev=0 时,反量化值 vdequant=min(v),即为原始输入的最小值。
2、全连接层计算,以输入向量x和权重矩阵W为例,全连接层输出Y。
Y=x⋅WT
3、量化误差,计算原始浮点输入的全连接层输出 Yfloat 和反量化数据的全连接层输出 Ydequant 之间的均方误差(MSE):
MSE=m1∑i=0m−1(Yfloat,i−Ydequant,i)2,m 为权重矩阵的行数。
第一行: n (输入向量 x 的维度)第二行: n 个浮点数 (输入向量 x)第三行: m n (权重矩阵 W 的维度)接下来 m 行:每行 n 个浮点数 (权重矩阵 W)
第一行: m 个整数 (使用量化数据 xquant和 Wquant计算的全连接层输出)
第二行: 1 个整数 (量化误差 MSE,注意是 MSE × 100000 后四舍五入输出整数)
输入
3
1.0 2.0 3.0
2 3
0.1 0.2 0.3
0.4 0.5 0.6
输出
13082 12929
0
说明
3 # n=3 (输入向量维度)
1.0 2.0 3.0 # x = [1.0, 2.0, 3.0]
2 3 # m=2, n=3 (权重矩阵 2×3)
0.1 0.2 0.3 # W 第 1 行 = [0.1, 0.2, 0.3]
0.4 0.5 0.6 # W 第 2 行 = [0.4, 0.5, 0.6]
量化输入向量 X: xquant= [-128, 0, 127]
量化权重矩阵 W: Wquant= [[-128, -77, -26], [25, 76, 127]]
量化域整数运算:输出第一行结果: 13082 12929
计算MSE
原始浮点输出:
Y_float [0] = 1.0×0.1 + 2.0×0.2 + 3.0×0.3 = 0.1 + 0.4 + 0.9 = 1.4
Y_float [1] = 1.0×0.4 + 2.0×0.5 + 3.0×0.6 = 0.4 + 1.0 + 1.8 = 3.2
反量化后: Y_dequant = Y_float
MSE: 输出第二行结果: 0
输入
7
0.3 -1.1 2.2 -3.3 4.4 -5.5 6.6
3 7
0.2 -0.3 0.4 -0.1 0 0.5 -0.6
-1.5 1.2 -0.9 0.6 -0.3 0.1 0
3 -2 1 -0.5 0.25 -0.125 0.0625
输出
-5476 -7406 8954
933
说明
7 # n=7 (输入向量维度)
0.3 -1.1 2.2 -3.3 4.4 -5.5 6.6 # x = [0.3, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6]
3 7 # m=3, n=7 (权重矩阵 3×7)
0.2 -0.3 0.4 -0.1 0 0.5 -0.6 # W 第 1 行
-1.5 1.2 -0.9 0.6 -0.3 0.1 0 # W 第 2 行
3 -2 1 -0.5 0.25 -0.125 0.0625 # W 第 3 行
输出:
量化域整数运算输出: -5476 -7406 8954
MSE 输出: 933