#P4906. 第3题-多头注意力掩码计算
-
2000ms
Tried: 822
Accepted: 232
Difficulty: 6
所属公司 :
华为
时间 :2026年5月9日-AI方向
第3题-多头注意力掩码计算
解题思路
- 对每个 Batch 和每个 Head 独立处理一个 S×S 的得分矩阵,输出顺序与输入顺序一致。
- 对于第 b 个 Batch,设有效长度为 Lb。若行索引 i≥Lb,说明该查询位置本身是 Padding,整行输出全 0。
- 否则该行的有效列必须同时满足因果掩码和 Padding 掩码,即 0≤j≤i 且 j<Lb。
- 对有效位置计算位置惩罚后的值:vj=Scorei,j−(i−j)×(s+1)
问题背景
在大模型的训练和推理中,为了提高 GPU 的效率,我们通常将多个不同长度的句子打包成一个 Batch:
Batch(批处理):不同句子长度不一,短句子后面会补零(Padding);
Multi-Head(多头):向量维度会被拆分成多个“头”,每个头独立计算注意力,以捕捉不同层面的语义;
双重掩码:在计算时,必须同时满足因果掩码(不看未来)和 Padding 掩码(不看填充的废话);
位置惩罚:对未被掩码的位置进行偏置惩罚。如果所有的“头”都使用相同的惩罚力度,那么它们关注的范围就会高度重合,造成计算资源的浪费。因此惩罚因子随着“头”的索引而变化。本题中惩罚因子 slope=s+1,其中 s 为头索引(从0开始)。
下面将给定一个 Batch 的注意力得分张量,默认该张量已经过缩放(Score=QKT/dk),请针对每批次和头,实现其掩码注入和归一化过程。
主要步骤
-
因果掩码:我们需要对计算出的得分矩阵进行掩码处理(对于一个给定的 n×n 得分矩阵 Score,其上三角位置(不包括对角)的值要替换为
-1e9,其他位置信息不变),防止模型看到“未来”的信息; -
Padding 掩码注入:给定每个序列的有效长度
L,若列索引 j≥L(列索引从0开始),则 Scorei,j=−1e9(即该 Token 是填充的,不应被关注); -
位置惩罚:注入距离偏置,对未被掩码的位置(即下三角位置),在其原始得分的基础上减去距离偏置:
-
偏置计算:Biasi,j=(i−j)×slope
-
更新公式:Scorei,j′=Scorei,j−Biasi,j
-
-
归一化:对处理后的矩阵每一行进行 Safe Softmax 归一化:
-
对于每一行 i,先找到该行中非掩码位置集合 validi;
-
找到 i 行的最大值 Mi;
-
对于 i 行 j 列的元素 Scorei,j′,如果 j∈/validi,则 Softmax(Scorei,j′)=0;如果 j∈validi,则进行
Softmax 计算:Softmax(Scorei,j′)=∑j∈validieScorei,j′−MieScorei,j′−Mi
-
注意:当 Scorei,j′=−1e9 时,视为 ex=0。
解答要求
输入格式
第一行:三个整数 B(批数量),H(“头”数量),S(得分矩阵维度)。
第二行:B个整数,代表每个 Batch 的有效序列长度 Lb(最小值为0)。
接下来 B×H×S 行,每行 S 个浮点数,数据按 B→H→S 顺序排列,是经过缩放之后的得分矩阵 Score。
输出
输出处理后的张量,每行 S 个浮点数,保留 4位小数。
样例1
输入:
2 2 3
3 2
1.0 1.0 1.0
1.0 1.0 1.0
1.0 1.0 1.0
1.0 1.0 1.0
1.0 1.0 1.0
1.0 1.0 1.0
2.0 2.0 2.0
2.0 2.0 2.0
2.0 2.0 2.0
2.0 2.0 2.0
2.0 2.0 2.0
2.0 2.0 2.0
输出:
1.0000 0.0000 0.0000
0.2689 0.7311 0.0000
0.0900 0.2447 0.6652
1.0000 0.0000 0.0000
0.1192 0.8808 0.0000
0.0159 0.1173 0.8668
1.0000 0.0000 0.0000
0.2689 0.7311 0.0000
0.0000 0.0000 0.0000
1.0000 0.0000 0.0000
0.1192 0.8808 0.0000
0.0000 0.0000 0.0000
解释:第一行输入三个整数2,2,3,分别说明:
当前注意力被分为2个batch,拆分成了2个“头”,最后数字3说明得分矩阵为 3 × 3,即需要输入4个 3 × 3 的矩阵;
第二行输入的数字为每个batch的有效长度组成的向量,根据第一行的输入可知,当前batch数量为2,因此为一个2维向量,表示第一个batch的有效长度为3,第二个batch的有效长度为2。
后面8行为之前交代的4个 3 × 3 的得分矩阵。
同样的,输出也是经过掩码变化后的4个 3 × 3 得分矩阵。
样例2
输入:
2 2 2
2 1
10.0 10.0
10.0 10.0
10.0 10.0
10.0 10.0
5.0 5.0
5.0 5.0
5.0 5.0
5.0 5.0
输出:
1.0000 0.0000
0.2689 0.7311
1.0000 0.0000
0.1192 0.8808
1.0000 0.0000
0.0000 0.0000
1.0000 0.0000
0.0000 0.0000
解释:第一行输入三个整数2,2,2,分别说明:
当前注意力被分为2个batch,拆分成了2个“头”,最后数字2说明得分矩阵为 2 × 2,即需要输入4个 2 × 2 的矩阵;
第二行输入的数字为每个batch的有效长度组成的向量,根据第一行的输入可知,当前batch数量为2,因此为一个2维向量,表示第一个batch的有效长度为2,第二个batch的有效长度为1。
后面8行为之前交代的4个 2 × 2 的得分矩阵。
同样的,输出也是经过掩码变化后的4个 2 × 2 得分矩阵。