#P4498. 交叉注意力
-
1000ms
Tried: 2
Accepted: 2
Difficulty: 5
交叉注意力
算法步骤
-
输入数据:
-
接收两个张量:
X_Q:形状为 TQ×dmodel 的 Query 序列,TQ 表示查询序列的长度,dmodel 表示特征维度。X_K:形状为 TK×dmodel 的 Key/Value 序列,TK 表示键值序列的长度。
-
-
线性变换:
-
对输入的 Query 和 Key/Value 序列进行线性变换,使用权重矩阵和偏置向量:
-
计算 Query:
Q=XQWQ+bQ其中 WQ 的形状为 dmodel×dmodel,bQ 的形状为 dmodel。
-
计算 Key:
K=XKWK+bK其中 WK 和 bK 的形状同样为 dmodel×dmodel 和 dmodel。
-
计算 Value:
V=XKWV+bV其中 WV 和 bV 的形状也为 dmodel×dmodel 和 dmodel。
-
-
-
计算注意力分数:
-
通过计算 Query 和 Key 之间的点积来获取注意力分数:
-
计算方法为:
S=dkQK⊤其中 dk 是 Key 的特征维度(通常与 dmodel 相等)。这里的缩放因子 dk 是为了防止在进行 softmax 时注意力分数过大而导致的梯度消失。
-
-
-
softmax 归一化:
-
对计算得到的注意力分数进行 softmax 归一化,得到注意力权重:
A=softmax(S) -
计算 softmax 的流程为:
-
首先计算指数:
eij=eSij -
然后归一化:
Aij=∑keikeij -
这样得到的 A 的每一行和为 1,表示每个 Q 对所有 K 的权重分配。
-
-
-
计算加权结果:
-
使用注意力权重 A 对 Value 进行加权求和,得到加权后的输出:
H=AV -
这里的 H 是 TQ×dmodel 的张量,表示加权后的结果。
-
-
输出层线性变换:
-
最后通过线性变换将加权结果映射到输出,用线性变换的方式得到最终输出:
O=HWO+bO -
其中 WO 的形状为 dmodel×dmodel,而 bO 的形状为 dmodel。
-
Python 实现
import numpy as np
from typing import List
class Solution:
def cross_attention(self,
X_Q: np.ndarray,
X_K: np.ndarray,
W_Q: np.ndarray,
b_Q: np.ndarray,
W_K: np.ndarray,
b_K: np.ndarray,
W_V: np.ndarray,
b_V: np.ndarray,
W_O: np.ndarray,
b_O: np.ndarray) -> np.ndarray:
"""计算交叉注意力机制的输出"""
# 1. 线性变换
Q = X_Q @ W_Q + b_Q # 计算 Q
K = X_K @ W_K + b_K # 计算 K
V = X_K @ W_V + b_V # 计算 V
# 2. 计算注意力分数
d_k = K.shape[1] # d_k 是特征维度
S = Q @ K.T / np.sqrt(d_k) # S 为注意力分数
# 3. softmax 归一化
A = self.softmax(S) # 计算注意力权重
# 4. 计算加权结果
H = A @ V # H 为加权后的输出
# 5. 输出层线性变换
O = H @ W_O + b_O # 最终输出 O
return O
def softmax(self, x: np.ndarray) -> np.ndarray:
"""计算softmax"""
e_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) # 为稳定性减去最大值
return e_x / e_x.sum(axis=-1, keepdims=True) # 返回归一化后的概率分布
题目描述
给定 Query 序列
XQ∈RTQ×dmodel和 Key/Value 序列
XK∈RTK×dmodel请你实现交叉注意力(Cross Attention)的完整计算流程。
1. 线性变换
$$Q = X_Q W_Q + b_Q,\quad K = X_K W_K + b_K,\quad V = X_K W_V + b_V$$2. 计算注意力分数
S=dkQK⊤3. 对每个 Query 行进行 softmax
A=softmax(S)4. 得到加权输出
H=AV5. 输出层线性变换
O=HWO+bO输入参数
X_Q:形状为 TQ×dmodel 的 Query 序列X_K:形状为 TK×dmodel 的 Key/Value 序列W_Q, W_K, W_V:形状为 dmodel×dmodel 的线性投影矩阵b_Q, b_K, b_V:长度为 dmodel 的偏置向量W_O:形状为 dmodel×dmodel 的输出层权重b_O:长度为 dmodel 的输出偏置
返回值
O:形状为 TQ×dmodel 的最终交叉注意力输出矩阵
示例
输入:
X_Q =
[[1, 0, 1, 0],
[0, 1, 0, 1]]
X_K =
[[1, 1, 0, 0],
[0, 1, 1, 0]]
W_Q =
[[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]]
b_Q = [0, 0, 0, 0]
W_K =
[[1, 1, 0, 0],
[0, 1, 1, 0],
[1, 0, 0, 1],
[0, 0, 1, 1]]
b_K = [0, 0, 0, 0]
W_V =
[[1, 0, 0, 1],
[1, 1, 0, 0],
[0, 1, 1, 0],
[0, 0, 1, 1]]
b_V = [0, 0, 0, 0]
W_O =
[[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]]
b_O = [0, 0, 0, 0]
输出:
O =
[[2.00, 2.00, 1.76, 2.24],
[2.00, 2.00, 2.24, 1.76]]
提示
- 输入序列范围: −1000≤XQ[i,j], XK[i,j]≤1000
- 权重矩阵范围: −10≤WQ, WK, WV, WO≤10
- 偏置项范围: −5≤bQ, bK, bV, bO≤5
- softmax 输出满足: 0≤A[i,j]≤1 ∑jA[i,j]=1
- 输出 O 可为任意实数矩阵