题目描述
给定 Query 序列
XQ∈RTQ×dmodel
和 Key/Value 序列
XK∈RTK×dmodel
请你实现交叉注意力(Cross Attention)的完整计算流程。
1. 线性变换
Q=XQWQ+bQ,K=XKWK+bK,V=XKWV+bV
2. 计算注意力分数
S=dkQK⊤
3. 对每个 Query 行进行 softmax
A=softmax(S)
4. 得到加权输出
H=AV
5. 输出层线性变换
O=HWO+bO
输入参数
X_Q:形状为 TQ×dmodel 的 Query 序列
X_K:形状为 TK×dmodel 的 Key/Value 序列
W_Q, W_K, W_V:形状为 dmodel×dmodel 的线性投影矩阵
b_Q, b_K, b_V:长度为 dmodel 的偏置向量
W_O:形状为 dmodel×dmodel 的输出层权重
b_O:长度为 dmodel 的输出偏置
返回值
O:形状为 TQ×dmodel 的最终交叉注意力输出矩阵
示例
输入:
X_Q =
[[1, 0, 1, 0],
[0, 1, 0, 1]]
X_K =
[[1, 1, 0, 0],
[0, 1, 1, 0]]
W_Q =
[[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]]
b_Q = [0, 0, 0, 0]
W_K =
[[1, 1, 0, 0],
[0, 1, 1, 0],
[1, 0, 0, 1],
[0, 0, 1, 1]]
b_K = [0, 0, 0, 0]
W_V =
[[1, 0, 0, 1],
[1, 1, 0, 0],
[0, 1, 1, 0],
[0, 0, 1, 1]]
b_V = [0, 0, 0, 0]
W_O =
[[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]]
b_O = [0, 0, 0, 0]
输出:
O =
[[2.00, 2.00, 1.76, 2.24],
[2.00, 2.00, 2.24, 1.76]]
提示
- 输入序列范围:
−1000≤XQ[i,j], XK[i,j]≤1000
- 权重矩阵范围:
−10≤WQ, WK, WV, WO≤10
- 偏置项范围:
−5≤bQ, bK, bV, bO≤5
- softmax 输出满足:
0≤A[i,j]≤1
∑jA[i,j]=1
- 输出 O 可为任意实数矩阵