题目要求实现单头 Self-Attention(自注意力)机制的前向计算。给定序列长度 n 和特征维度 d,输入矩阵 X∈Rn×d 以及三个权重矩阵 Wq,Wk,Wv∈Rd×d。
计算过程分为四个步骤:
线性映射: 通过矩阵乘法计算 Query, Key, Value 矩阵: Q=XWq,K=XWk,V=XWv 结果矩阵 Q,K,V 的形状均为 n×d。
缩放点积注意力得分: 计算 Q 和 KT 的乘积,并除以缩放因子 d: S=dQKT 其中 S 的形状为 n×n。Sij 表示第 i 个 token 对第 j 个 token 的关注度得分。
归一化 (Softmax): 对 S 的每一行进行 Softmax 操作,得到注意力权重矩阵 A: $A_{ij} = \text{softmax}(S_{i \cdot})_j = \frac{e^{S_{ij}}}{\sum_{k=1}^n e^{S_{ik}}}$ 为了数值稳定性,通常在计算指数前,将每行减去该行的最大值。
加权求和: 利用权重矩阵 A 对 V 进行加权求和,得到最终输出 O: O=AV O 的形状为 n×d。
class Solution:
def selfAttention(self, n: int, d: int, X: List[List[float]], Wq: List[List[float]], Wk: List[List[float]], Wv: List[List[float]]) -> List[List[str]]:
# 辅助函数:矩阵乘法
def matmul(A, B):
rows_A, cols_A = len(A), len(A[0])
rows_B, cols_B = len(B), len(B[0])
# 结果矩阵初始化
C = [[0.0] * cols_B for _ in range(rows_A)]
for i in range(rows_A):
for k in range(cols_A):
val = A[i][k]
if val == 0: continue
for j in range(cols_B):
C[i][j] += val * B[k][j]
return C
# 辅助函数:矩阵转置
def transpose(A):
return [[A[j][i] for j in range(len(A))] for i in range(len(A[0]))]
# 1. 线性映射 Q, K, V
Q = matmul(X, Wq)
K = matmul(X, Wk)
V = matmul(X, Wv)
# 2. 缩放点积 S = (Q @ K.T) / sqrt(d)
K_T = transpose(K)
S = matmul(Q, K_T)
scale = 1.0 / math.sqrt(d)
for i in range(n):
for j in range(n):
S[i][j] *= scale
# 3. Softmax
A = []
for i in range(n):
row = S[i]
max_val = max(row) # 数值稳定性处理
exps = [math.exp(x - max_val) for x in row]
sum_exps = sum(exps)
A.append([e / sum_exps for e in exps])
# 4. 输出 O = A @ V
O = matmul(A, V)
# 格式化输出:保留4位小数,不足补0
res = []
for row in O:
res.append([f"{x:.4f}" for x in row])
return res
给定序列长度 n、特征维度 d 的输入矩阵 X(大小为 n × d),以及三组权重矩阵 Wq, Wk, Wv(大小均为 d × d),请你实现单头 Self-Attention 的前向计算,并返回 Attention 输出矩阵 O。
Attention 计算定义如下:
线性映射:
Q = X · WqK = X · WkV = X · Wv缩放点积得分:
S = Q · K^T / sqrt(d)
其中 K^T 为 K 的转置。行级 softmax 得到注意力权重:
A = softmax(S)
对每一行独立做 softmax。输出:
O = A · V请返回 O(形状为 n × d)。
你的答案与标准答案误差不超过 1e-4 即视为正确。
n: 序列长度d: 特征维度X: n × d 浮点矩阵Wq, Wk, Wv: d × d 浮点矩阵O,大小为 n × d1 <= n <= 501 <= d <= 50X[i][j], Wq[i][j], Wk[i][j], Wv[i][j] 为浮点数-3 <= X[i][j], Wq[i][j], Wk[i][j], Wv[i][j] <=3 输入:
n = 2 , d = 2
X = [[1,0],[0,1]]
Wq = [[1,0],[0,1]]
Wk = [[1,0],[0,1]]
Wv = [[1,2],[3,2]]
输出:
O = [[1.6605,2.0000],[2.3395,2.0000]]
解释:
Q=K=X,先算 Q·K^T/sqrt(2) 得到注意力分数,softmax 后加权 V 得到输出。