#P4497. 分组查询注意力
-
1000ms
Tried: 3
Accepted: 2
Difficulty: 5
分组查询注意力
算法步骤
步骤 1:线性投影
用矩阵乘法计算:
- Q=XWQ+bQ
- K=XWK+bK
- V=XWV+bV
此时 Q,K,V 形状都是 (T,dmodel)。
步骤 2:拆分 Query 为 hq 个 head
把 Q 的最后一维按 hq 均分成 hq 段,每段长度 dk:
- 先 reshape 为 (T,hq,dk)
- 再转置为 (hq,T,dk),方便按 head 索引
步骤 3:拆分 Key/Value 为 hk 个 group,再切子块
每个 group 覆盖 group_size=hq/hk 个 Query head。
- K reshape 成 (T,hk,dgroup) 再转成 (hk,T,dgroup)
- 再把最后一维拆成 (group_size,dk),得到 (hk,T,group_size,dk)
- 再转置为 (hk,group_size,T,dk)
V 同理得到 (hk,group_size,T,dk)
这样我们就可以用 (g, r) 直接取到与某 head 对齐的 Kg,r,Vg,r。
步骤 4:逐 head 计算注意力
对每个 head i:
-
计算 g=i//group_size, r=i%group_size
-
取 Qi∈(T,dk)、Kg,r∈(T,dk)、Vg,r∈(T,dk)
-
打分矩阵:
scores=(QiKg,r⊤)/dk -
稳定 softmax(对每行):
scores -= scores.max(axis=1, keepdims=True)attn = exp(scores) / exp(scores).sum(axis=1, keepdims=True)
-
输出:
Hi=attn⋅Vg,r
步骤 5:拼接所有 head
把 Hi 按 head 维拼回 (T,dmodel)。
步骤 6:输出层
O=HWO+bOPython 参考实现
import numpy as np
from typing import List
class Solution:
def grouped_query_attention(
self,
X: List[List[float]],
W_Q: List[List[float]],
b_Q: List[float],
W_K: List[List[float]],
b_K: List[float],
W_V: List[List[float]],
b_V: List[float],
h_q: int,
h_k: int,
W_O: List[List[float]],
b_O: List[float],
) -> List[List[float]]:
"""
分组查询注意力(GQA)。
关键要点:
- d_k = d_model / h_q
- group_size = h_q / h_k
- 每个 group 的 K/V 宽度是 d_group = d_k * group_size
- 为了与每个 Q head 对齐计算注意力,将每个 group 的 K/V 再切成 group_size 个子块,每个子块宽度 d_k
"""
Xn = np.asarray(X, dtype=np.float64)
WQ = np.asarray(W_Q, dtype=np.float64)
WK = np.asarray(W_K, dtype=np.float64)
WV = np.asarray(W_V, dtype=np.float64)
WO = np.asarray(W_O, dtype=np.float64)
bQ = np.asarray(b_Q, dtype=np.float64)
bK = np.asarray(b_K, dtype=np.float64)
bV = np.asarray(b_V, dtype=np.float64)
bO = np.asarray(b_O, dtype=np.float64)
T, d_model = Xn.shape
d_k = d_model // h_q
group_size = h_q // h_k
d_group = d_k * group_size
inv_sqrt = 1.0 / np.sqrt(d_k)
# 1) 线性投影
Q = Xn @ WQ + bQ
K = Xn @ WK + bK
V = Xn @ WV + bV
# 2) Q 拆成 h_q 个 head: (h_q, T, d_k)
Qh = Q.reshape(T, h_q, d_k).transpose(1, 0, 2)
# 3) K/V 拆成 h_k 个 group,再在组内切成 group_size 个子块: (h_k, group_size, T, d_k)
Kg = (
K.reshape(T, h_k, d_group)
.transpose(1, 0, 2)
.reshape(h_k, T, group_size, d_k)
.transpose(0, 2, 1, 3)
)
Vg = (
V.reshape(T, h_k, d_group)
.transpose(1, 0, 2)
.reshape(h_k, T, group_size, d_k)
.transpose(0, 2, 1, 3)
)
# 4) 逐 head 注意力
Hh = np.empty((h_q, T, d_k), dtype=np.float64)
for i in range(h_q):
g = i // group_size
r = i % group_size
Qi = Qh[i] # (T, d_k)
Ki = Kg[g, r] # (T, d_k)
Vi = Vg[g, r] # (T, d_k)
scores = (Qi @ Ki.T) * inv_sqrt
scores = scores - scores.max(axis=1, keepdims=True)
attn = np.exp(scores)
attn = attn / attn.sum(axis=1, keepdims=True)
Hh[i] = attn @ Vi
# 5) 拼接 head: (T, d_model)
H = Hh.transpose(1, 0, 2).reshape(T, d_model)
# 6) 输出层
O = H @ WO + bO
return O.tolist()
题目描述
设输入序列:
X∈RT×dmodel给定:
-
Query head 数:
hq -
Key/Value 分组数:
hk,hk<hq -
每个 Query head 的维度:
dk=hqdmodel
你需要实现 分组查询注意力(Grouped Query Attention, GQA),流程如下。
1. 线性投影
$$Q = X W_Q + b_Q,\quad K = X W_K + b_K,\quad V = X W_V + b_V$$2. 拆分 Q 为多个 Query heads
将 Q 均匀拆分成 hq 个 Query heads:
Qi∈RT×dk3. 将 K 和 V 拆分为 h_k 组(group)
每个组的维度为:
dgroup=dk⋅hkhq拆分得到:
$$K_j,\ V_j \in \mathbb{R}^{T \times d_{\text{group}}},\quad j=1,\dots,h_k$$4. 每个 Query head 选择其对应的 K/V 组
令组索引:
$$g(i) = \left\lfloor \frac{i}{h_q/h_k} \right\rfloor$$对每个 Query head 计算注意力:
$$H_i = \text{softmax}\left(\frac{Q_i K_{g(i)}^\top}{\sqrt{d_k}}\right)V_{g(i)}$$5. 拼接所有 Query head 的输出
H=concat(H1,…,Hhq)6. 输出层线性变换
O=HWO+bO输入参数
X:形状为 T×dmodel 的输入序列W_Q, W_K, W_V:形状为 dmodel×dmodel 的投影矩阵b_Q, b_K, b_V:长度为 dmodel 的偏置项h_q:Query head 数h_k:Key/Value 分组数,需满足 hk<hqW_O:形状为 dmodel×dmodel 的输出层矩阵b_O:长度为 dmodel 的输出偏置
返回值
O:形状为 T×dmodel 的最终输出序列
示例
输入:
X =
[[1, 0, 1, 2],
[0, 1, 1, 0]]
W_Q =
[[1, 0, 0, 1],
[0, 1, 1, 0],
[1, 0, 1, 0],
[0, 1, 0, 1]]
b_Q = [0, 0, 0, 0]
W_K =
[[1, 1, 0, 0],
[0, 1, 1, 0],
[1, 0, 0, 1],
[0, 0, 1, 1]]
b_K = [0, 0, 0, 0]
W_V =
[[1, 0, 0, 1],
[1, 1, 0, 0],
[0, 1, 1, 0],
[0, 0, 1, 1]]
b_V = [0, 0, 0, 0]
W_O =
[[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]]
b_O = [0, 0, 0, 0]
输出:
O =
[[3.46, 4.49, 3.99, 3.96],
[3.76, 3.0, 2.5, 4.26]]
提示
- 输入范围: −1000≤X[i,j]≤1000
- 权重范围: −10≤WQ, WK, WV, WO≤10
- 偏置范围: −5≤bQ, bK, bV, bO≤5
- head 数要求: 1≤hk<hq≤dmodel 且 dmodelmodhq=0
- softmax 输出满足: 0≤softmax(zi)≤1,且 ∑isoftmax(zi)=1