题目描述
设输入序列为 X∈RT×dmodel。
给定:
- Query head 数:hq
- Key/Value head 数(分组数):hk,满足 1≤hk<hq
- 每个 Query head 的维度:dk=hqdmodel
并要求:
- dmodelmodhq=0
- hqmodhk=0
你需要实现分组查询注意力(Grouped Query Attention, GQA)。其核心特点是:Q 有 hq 个 head,而 K/V 只有 hk 个 head,因此多个 Query head 共享同一组 K/V。
1. 线性投影(K/V 的投影维度被压缩)
Q=XWQ+bQ
K=XWK+bK
V=XWV+bV
其中参数形状为:
- WQ∈Rdmodel×dmodel,bQ∈Rdmodel
- WK∈Rdmodel×(hkdk),bK∈Rhkdk
- WV∈Rdmodel×(hkdk),bV∈Rhkdk
因此:
- Q∈RT×dmodel
- K,V∈RT×(hkdk)
2. 多头拆分
将 Q 拆分为 hq 个 Query heads:
Q→(Q0,…,Qhq−1),Qi∈RT×dk
将 K,V 拆分为 hk 个 heads:
K→(K0,…,Khk−1),Kj∈RT×dk
V→(V0,…,Vhk−1),Vj∈RT×dk
3. Query head 到 K/V head 的映射
令每组包含的 Query head 数为:
s=hkhq
对第 i 个 Query head,其对应的 K/V head 索引为:
g(i)=⌊si⌋,g(i)∈0,…,hk−1
4. 分组注意力计算
对每个 Query head i,使用共享的 Kg(i),Vg(i):
Hi=softmax(dkQiKg(i)⊤)Vg(i)
其中 softmax 沿最后一维归一化。
5. 拼接所有 Query heads 的输出
H=concat(H0,…,Hhq−1)∈RT×(hqdk)=RT×dmodel
6. 输出层线性变换
O=HWO+bO
其中 WO∈Rdmodel×dmodel,bO∈Rdmodel,输出 O∈RT×dmodel。
输入参数
- X:形状为 T×dmodel
- WQ:形状为 dmodel×dmodel,bQ:长度为 dmodel
- WK,WV:形状为 dmodel×(hkdk)
- bK,bV:长度为 hkdk
- hq:Query head 数
- hk:Key/Value head 数(分组数),满足 1≤hk<hq 且 hqmodhk=0
- WO:形状为 dmodel×dmodel,bO:长度为 dmodel
返回值
- O:形状为 T×dmodel
示例
输入:
X =
[[1, 0, 1, 2],
[0, 1, 1, 0]]
W_Q =
[[1, 0, 0, 1],
[0, 1, 1, 0],
[1, 0, 1, 0],
[0, 1, 0, 1]]
b_Q = [0, 0, 0, 0]
W_K =
[[1, 1],
[0, 1],
[1, 0],
[0, 1]]
b_K = [0, 0]
W_V =
[[1, 0],
[1, 1],
[0, 1],
[0, 0]]
b_V = [0, 0]
h_q = 2
h_k = 1
W_O =
[[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]]
b_O = [0, 0, 0, 0]
输出:
O =
[[2.00, 2.02, 2.01, 2.01],
[2.00, 2.30, 2.20, 2.11]]
提示
- 输入范围:
−1000≤X[i,j]≤1000
- 权重范围:
−10≤WQ, WK, WV, WO≤10
- 偏置范围:
−5≤bQ, bK, bV, bO≤5