#P4497. 分组查询注意力
-
1000ms
Tried: 36
Accepted: 17
Difficulty: 5
分组查询注意力
步骤 1:线性投影(注意 K/V 输出维度已压缩)
用矩阵乘法计算:
- Q=XWQ+bQ
- K=XWK+bK
- V=XWV+bV
其中:
- Q 形状为 (T,dmodel)
- K,V 形状为 (T,hkdk),其中 dk=dmodel/hq
步骤 2:拆分 Query 为 hq 个 head
把 Q 的最后一维按 hq 均分成 hq 段,每段长度 dk:
- reshape 为 (T,hq,dk)
- 转置为 (hq,T,dk)
步骤 3:拆分 Key/Value 为 hk 个 head(不再组内切子块)
把 K,V reshape 为:
- K→(T,hk,dk),再转置为 (hk,T,dk)
- V→(T,hk,dk),再转置为 (hk,T,dk)
步骤 4:逐 head 计算注意力(每组共享同一个 K/V head)
令 groupsize=hq/hk,对每个 Query head i:
-
计算共享索引:
g(i)=⌊groupsizei⌋ -
取 Qi∈(T,dk)、Kg(i)∈(T,dk)、Vg(i)∈(T,dk)
-
打分矩阵:
scores=dkQiKg(i)⊤ -
稳定 softmax(对每行):
scores -= scores.max(axis=1, keepdims=True)attn = exp(scores) / exp(scores).sum(axis=1, keepdims=True)
-
输出:
Hi=attn⋅Vg(i)
步骤 5:拼接所有 head
把 Hi 按 head 维拼回 (T,dmodel)。
步骤 6:输出层
O=HWO+bOPython 参考实现
import numpy as np
from typing import List
class Solution:
def grouped_query_attention(
self,
X: List[List[float]],
W_Q: List[List[float]],
b_Q: List[float],
W_K: List[List[float]],
b_K: List[float],
W_V: List[List[float]],
b_V: List[float],
h_q: int,
h_k: int,
W_O: List[List[float]],
b_O: List[float],
) -> List[List[float]]:
"""
标准分组查询注意力(GQA)。
关键要点:
- d_k = d_model / h_q
- group_size = h_q / h_k
- K/V head 数为 h_k(压缩),每个 head 维度仍为 d_k
- 每 group_size 个 Query heads 共享同一个 K/V head:
g(i) = floor(i / group_size)
"""
Xn = np.asarray(X, dtype=np.float64)
WQ = np.asarray(W_Q, dtype=np.float64)
WK = np.asarray(W_K, dtype=np.float64)
WV = np.asarray(W_V, dtype=np.float64)
WO = np.asarray(W_O, dtype=np.float64)
bQ = np.asarray(b_Q, dtype=np.float64)
bK = np.asarray(b_K, dtype=np.float64)
bV = np.asarray(b_V, dtype=np.float64)
bO = np.asarray(b_O, dtype=np.float64)
T, d_model = Xn.shape
if d_model % h_q != 0:
raise ValueError("d_model must be divisible by h_q")
if not (1 <= h_k < h_q):
raise ValueError("require 1 <= h_k < h_q")
if h_q % h_k != 0:
raise ValueError("require h_q divisible by h_k")
d_k = d_model // h_q
group_size = h_q // h_k
inv_sqrt = 1.0 / np.sqrt(d_k)
# WK/WV 输出维度为 h_k*d_k
hkdk = h_k * d_k
if WK.shape != (d_model, hkdk) or WV.shape != (d_model, hkdk):
raise ValueError("W_K/W_V must have shape (d_model, h_k*d_k)")
if bK.shape != (hkdk,) or bV.shape != (hkdk,):
raise ValueError("b_K/b_V must have length h_k*d_k")
# 1) 线性投影
Q = Xn @ WQ + bQ # (T, d_model)
K = Xn @ WK + bK # (T, h_k*d_k)
V = Xn @ WV + bV # (T, h_k*d_k)
# 2) Q 拆成 h_q 个 head: (h_q, T, d_k)
Qh = Q.reshape(T, h_q, d_k).transpose(1, 0, 2)
# 3) K/V 拆成 h_k 个 head: (h_k, T, d_k)
Kh = K.reshape(T, h_k, d_k).transpose(1, 0, 2)
Vh = V.reshape(T, h_k, d_k).transpose(1, 0, 2)
# 4) 逐 head 注意力
Hh = np.empty((h_q, T, d_k), dtype=np.float64)
for i in range(h_q):
g = i // group_size
Qi = Qh[i] # (T, d_k)
Ki = Kh[g] # (T, d_k)
Vi = Vh[g] # (T, d_k)
scores = (Qi @ Ki.T) * inv_sqrt
scores = scores - scores.max(axis=1, keepdims=True)
attn = np.exp(scores)
attn = attn / attn.sum(axis=1, keepdims=True)
Hh[i] = attn @ Vi
# 5) 拼接 head: (T, d_model)
H = Hh.transpose(1, 0, 2).reshape(T, d_model)
# 6) 输出层
O = H @ WO + bO
return O.tolist()
题目描述
设输入序列为 X∈RT×dmodel。
给定:
- Query head 数:hq
- Key/Value head 数(分组数):hk,满足 1≤hk<hq
- 每个 Query head 的维度:dk=hqdmodel
并要求:
- dmodelmodhq=0
- hqmodhk=0
你需要实现分组查询注意力(Grouped Query Attention, GQA)。其核心特点是:Q 有 hq 个 head,而 K/V 只有 hk 个 head,因此多个 Query head 共享同一组 K/V。
1. 线性投影(K/V 的投影维度被压缩)
Q=XWQ+bQ K=XWK+bK V=XWV+bV其中参数形状为:
- WQ∈Rdmodel×dmodel,bQ∈Rdmodel
- WK∈Rdmodel×(hkdk),bK∈Rhkdk
- WV∈Rdmodel×(hkdk),bV∈Rhkdk
因此:
- Q∈RT×dmodel
- K,V∈RT×(hkdk)
2. 多头拆分
将 Q 拆分为 hq 个 Query heads:
Q→(Q0,…,Qhq−1),Qi∈RT×dk将 K,V 拆分为 hk 个 heads:
K→(K0,…,Khk−1),Kj∈RT×dk V→(V0,…,Vhk−1),Vj∈RT×dk3. Query head 到 K/V head 的映射
令每组包含的 Query head 数为:
s=hkhq对第 i 个 Query head,其对应的 K/V head 索引为:
g(i)=⌊si⌋,g(i)∈0,…,hk−14. 分组注意力计算
对每个 Query head i,使用共享的 Kg(i),Vg(i):
Hi=softmax(dkQiKg(i)⊤)Vg(i)其中 softmax 沿最后一维归一化。
5. 拼接所有 Query heads 的输出
H=concat(H0,…,Hhq−1)∈RT×(hqdk)=RT×dmodel6. 输出层线性变换
O=HWO+bO其中 WO∈Rdmodel×dmodel,bO∈Rdmodel,输出 O∈RT×dmodel。
输入参数
- X:形状为 T×dmodel
- WQ:形状为 dmodel×dmodel,bQ:长度为 dmodel
- WK,WV:形状为 dmodel×(hkdk)
- bK,bV:长度为 hkdk
- hq:Query head 数
- hk:Key/Value head 数(分组数),满足 1≤hk<hq 且 hqmodhk=0
- WO:形状为 dmodel×dmodel,bO:长度为 dmodel
返回值
- O:形状为 T×dmodel
示例
输入:
X =
[[1, 0, 1, 2],
[0, 1, 1, 0]]
W_Q =
[[1, 0, 0, 1],
[0, 1, 1, 0],
[1, 0, 1, 0],
[0, 1, 0, 1]]
b_Q = [0, 0, 0, 0]
W_K =
[[1, 1],
[0, 1],
[1, 0],
[0, 1]]
b_K = [0, 0]
W_V =
[[1, 0],
[1, 1],
[0, 1],
[0, 0]]
b_V = [0, 0]
h_q = 2
h_k = 1
W_O =
[[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]]
b_O = [0, 0, 0, 0]
输出:
O =
[[2.00, 2.02, 2.01, 2.01],
[2.00, 2.30, 2.20, 2.11]]
提示
- 输入范围: −1000≤X[i,j]≤1000
- 权重范围: −10≤WQ, WK, WV, WO≤10
- 偏置范围: −5≤bQ, bK, bV, bO≤5