#P3843. 第3题-Masked Multi-Head Self-Attention 实现
-
1000ms
Tried: 691
Accepted: 100
Difficulty: 7
所属公司 :
华为
时间 :2025年9月28日-AI方向
-
算法标签>缩放点积注意力多头拆分/拼接因果遮罩
第3题-Masked Multi-Head Self-Attention 实现
解题思路
本题要求手写「带因果掩码」的多头自注意力(Decoder常用),输入为:
num_headsX:形状[batch_size, seq_len, d_model]Q, K, V, W_O:均为形状[d_model, d_model]的线性投影矩阵(对应 $W_Q, W_K, W_V, W_O$)
整体流程(Scaled Dot-Product Attention + 多头并行):
-
线性映射生成 Q/K/V
Q=XWQ,K=XWK,V=XWV维度:
[B, S, d_model]。 -
分头 将最后一维
d_model均分为num_heads个头,每头维度d_k = d_model / num_heads,并重排为Q_h, K_h, V_h∈[B, H, S, d_k]。 -
每头计算注意力分数
$$\text{scores} = \frac{Q_h K_h^\top}{\sqrt{d_k}}\quad\in [B,H,S,S] $$ -
因果掩码(防未来信息泄露) 构造 下三角 Mask(
$$\text{masked\_scores} = \text{where}(mask=0,\; -\infty,\; \text{scores}) $$[S,S],下三角为1,上三角为0),广播到[B,H,S,S]。 将上三角(不允许关注的)位置置为 −∞: -
Softmax 得注意力权重(按最后一维
α=softmax(masked_scores)S做归一化,数值稳定:减去行最大值) -
聚合得到每头输出
head=αVh∈[B,H,S,dk] -
拼接各头并做输出投影 先将各头在
$$\text{output} = \text{concat(heads)}\; W_O \quad\in [B,S,d_{model}] $$d_k维拼接回d_model:[B,S,H\cdot d_k] = [B,S,d_model], 再乘以W_O: -
输出格式 题目要求保留两位小数,并以 List 形式输出(即常见的嵌套列表),需要把
ndarray/数组转换为列表。
变长序列:常见做法是先对批次内对齐(padding),然后结合 因果掩码 与 padding mask(本题未给出padding mask输入),本实现提供标准因果掩码;若有padding,可在同维度位置再叠加一个padding掩码(将padding位置设为 −∞)。
复杂度分析
设批次 B、序列长度 S、模型维度 D=d_model、头数 H、每头维度 d_k=D/H。
-
时间复杂度
- 线性映射:
X * W_Q/W_K/W_V各为O(B*S*D^2)(若使用分块/并行可等价为O(B*S*D*D))。 - 注意力
QK^T:每头O(S^2*d_k),总计O(B*H*S^2*d_k) = O(B*S^2*D)。 - 乘
V聚合:同阶O(B*S^2*D)。 - 输出投影
* W_O:O(B*S*D^2)。 综合为O(B*S*D^2 + B*S^2*D),与标准Transformer一致。
- 线性映射:
-
空间复杂度 主要存储
Q/K/V、注意力分数与权重:O(B*S*D + B*H*S*S),即O(B*S*D + B*S^2*H)。
代码实现
Python
# 题意:读入 "num_heads;X;Q;K;V;W_O"(用分号分隔),实现因果掩码多头自注意力
# 要求:输出为 List(嵌套列表),保留两位小数
import sys
import numpy as np
from ast import literal_eval
def to_str(arr):
"""递归把嵌套 list 转成字符串,数值固定两位小数且无引号;把 -0.00 规整为 0.00"""
if isinstance(arr, list):
return "[" + ", ".join(to_str(x) for x in arr) + "]"
else:
# 数值分支
v = float(arr)
s = f"{v:.2f}"
# 规整 -0.00 -> 0.00
if s == "-0.00":
s = "0.00"
return s
def softmax_stable(x, axis=-1):
# 数值稳定 softmax
m = np.max(x, axis=axis, keepdims=True)
ex = np.exp(x - m)
return ex / np.sum(ex, axis=axis, keepdims=True)
def multi_head_self_attention(X, WQ, WK, WV, WO, num_heads):
B, S, D = X.shape
assert D % num_heads == 0, "d_model 必须能被 num_heads 整除"
d_k = D // num_heads
# 1) 线性映射
Q = X @ WQ # [B,S,D]
K = X @ WK
V = X @ WV
# 2) 分头 -> [B,H,S,d_k]
def split_heads(t):
t = t.reshape(B, S, num_heads, d_k) # [B,S,H,d_k]
return np.transpose(t, (0, 2, 1, 3)) # [B,H,S,d_k]
Qh, Kh, Vh = split_heads(Q), split_heads(K), split_heads(V)
# 3) 注意力分数 [B,H,S,S]
# scores[b,h,i,j] = Qh[b,h,i,:] dot Kh[b,h,j,:] / sqrt(d_k)
# 利用矩阵乘法: (B,H,S,d_k) x (B,H,d_k,S) -> (B,H,S,S)
scores = (Qh @ np.transpose(Kh, (0,1,3,2))) / np.sqrt(d_k)
# 4) 因果掩码:允许关注自己及之前位置 => 下三角为1,其余为0
mask = np.tril(np.ones((S, S), dtype=np.float32)) # [S,S]
mask = mask[None, None, :, :] # [1,1,S,S] 广播到 [B,H,S,S]
scores = np.where(mask == 1, scores, -np.inf)
# 5) softmax
attn = softmax_stable(scores, axis=-1) # [B,H,S,S]
# 6) 加权求和
heads = attn @ Vh # [B,H,S,d_k]
# 7) 拼回 + 输出投影
heads = np.transpose(heads, (0, 2, 1, 3)) # [B,S,H,d_k]
concat = heads.reshape(B, S, D) # [B,S,D]
out = concat @ WO # [B,S,D]
return out
def main():
raw = sys.stdin.read().strip()
# 按分号分割:num_heads;X;Q;K;V;W_O
parts = [p.strip() for p in raw.split(';')]
if len(parts) != 6:
raise ValueError("输入应包含6段参数:num_heads;X;Q;K;V;W_O")
num_heads = int(parts[0])
X = np.array(literal_eval(parts[1]), dtype=float)
WQ = np.array(literal_eval(parts[2]), dtype=float)
WK = np.array(literal_eval(parts[3]), dtype=float)
WV = np.array(literal_eval(parts[4]), dtype=float)
WO = np.array(literal_eval(parts[5]), dtype=float)
out = multi_head_self_attention(X, WQ, WK, WV, WO, num_heads)
out = np.around(out, 2) # 保留两位小数
# 转为嵌套列表输出
out = np.around(out, 2)
print(to_str(out.tolist()))
if __name__ == "__main__":
main()
C++
#include <bits/stdc++.h>
using namespace std;
// 解析二维数组,如 [[1,2],[3,4]]
static vector<vector<double>> parse2D(const string &s) {
vector<vector<double>> res;
vector<double> row;
int depth = 0;
int n = s.size();
for (int i = 0; i < n; ++i) {
char c = s[i];
if (c == '[') {
depth++;
if (depth == 2) row.clear();
} else if (c == ']') {
if (depth == 2) {
if (!row.empty()) res.push_back(row);
row.clear();
}
depth--;
} else {
// 读取数字
if (isdigit(c) || c=='-' || c=='+' || c=='.' || c=='e' || c=='E') {
int j = i;
while (j < n && (isdigit(s[j]) || s[j]=='-' || s[j]=='+' || s[j]=='.' || s[j]=='e' || s[j]=='E')) j++;
double val = stod(s.substr(i, j - i));
row.push_back(val);
i = j - 1;
}
}
}
return res;
}
// 解析三维数组,如 [[[...],[...]], [[...],[...]]]
static vector<vector<vector<double>>> parse3D(const string &s) {
vector<vector<vector<double>>> res;
vector<vector<double>> mat;
vector<double> row;
int depth = 0;
int n = s.size();
for (int i = 0; i < n; ++i) {
char c = s[i];
if (c == '[') {
depth++;
if (depth == 2) mat.clear();
if (depth == 3) row.clear();
} else if (c == ']') {
if (depth == 3) {
if (!row.empty()) mat.push_back(row);
row.clear();
} else if (depth == 2) {
if (!mat.empty()) res.push_back(mat);
mat.clear();
}
depth--;
} else {
// 读取数字
if (isdigit(c) || c=='-' || c=='+' || c=='.' || c=='e' || c=='E') {
int j = i;
while (j < n && (isdigit(s[j]) || s[j]=='-' || s[j]=='+' || s[j]=='.' || s[j]=='e' || s[j]=='E')) j++;
double val = stod(s.substr(i, j - i));
row.push_back(val);
i = j - 1;
}
}
}
return res;
}
// X[b][t][d] * W[d][d] -> Y[b][t][d]
static vector<vector<vector<double>>> matmul3D2D(
const vector<vector<vector<double>>> &X,
const vector<vector<double>> &W) {
int B = (int)X.size();
int T = (int)X[0].size();
int D = (int)X[0][0].size();
vector<vector<vector<double>>> Y(B, vector<vector<double>>(T, vector<double>(D, 0.0)));
for (int b = 0; b < B; ++b) {
for (int t = 0; t < T; ++t) {
for (int j = 0; j < D; ++j) {
double sum = 0.0;
for (int k = 0; k < D; ++k) sum += X[b][t][k] * W[k][j];
Y[b][t][j] = sum;
}
}
}
return Y;
}
// 打印三维数组,保留两位小数
static void print3D(const vector<vector<vector<double>>> &A) {
cout << "[";
for (int b = 0; b < (int)A.size(); ++b) {
if (b) cout << ", ";
cout << "[";
for (int i = 0; i < (int)A[b].size(); ++i) {
if (i) cout << ", ";
cout << "[";
for (int j = 0; j < (int)A[b][i].size(); ++j) {
if (j) cout << ", ";
double v = A[b][i][j];
if (fabs(v) < 0.005) v = 0.0; // 避免-0.00
cout.setf(std::ios::fixed); cout<<setprecision(2)<<v;
}
cout << "]";
}
cout << "]";
}
cout << "]\n";
}
int main() {
// 读取整份输入(可能包含空格与换行)
std::ostringstream oss;
string line;
while (std::getline(cin, line)) {
oss << line;
}
string all = oss.str();
// 以分号拆分为6段:num_heads ; X ; W_Q ; W_K ; W_V ; W_O
vector<string> parts;
{
string cur;
for (char c : all) {
if (c == ';') {
parts.push_back(cur);
cur.clear();
} else {
cur.push_back(c);
}
}
if (!cur.empty()) parts.push_back(cur);
}
if (parts.size() != 6) return 0;
// 解析num_heads
int num_heads = 0;
{
string s = parts[0];
// 用替换字符+输入流:将非数字转空格
for (char &c : s) if (!(isdigit(c) || c=='-' || c=='+')) c = ' ';
istringstream iss(s);
iss >> num_heads;
}
// 解析X与四个权重
auto X = parse3D(parts[1]);
auto WQ = parse2D(parts[2]);
auto WK = parse2D(parts[3]);
auto WV = parse2D(parts[4]);
auto WO = parse2D(parts[5]);
int B = (int)X.size();
int T = (int)X[0].size();
int D = (int)X[0][0].size();
int H = num_heads;
int d_k = D / H;
// 1) 生成Q/K/V
auto Q = matmul3D2D(X, WQ);
auto K = matmul3D2D(X, WK);
auto V = matmul3D2D(X, WV);
// 2) 重排为多头 [B,H,T,d_k]
auto to_heads = [&](const vector<vector<vector<double>>> &A){
vector<vector<vector<vector<double>>>> Ah(
B, vector<vector<vector<double>>>(H, vector<vector<double>>(T, vector<double>(d_k, 0.0)))
);
for (int b = 0; b < B; ++b)
for (int t = 0; t < T; ++t)
for (int d = 0; d < D; ++d) {
int h = d / d_k, r = d % d_k;
Ah[b][h][t][r] = A[b][t][d];
}
return Ah;
};
auto Qh = to_heads(Q);
auto Kh = to_heads(K);
auto Vh = to_heads(V);
// 3-6) 注意力(带因果mask)
double inv_sqrt = 1.0 / sqrt((double)d_k);
vector<vector<vector<vector<double>>>> Ah( // attention后的值 [B,H,T,d_k]
B, vector<vector<vector<double>>>(H, vector<vector<double>>(T, vector<double>(d_k, 0.0)))
);
for (int b = 0; b < B; ++b) {
for (int h = 0; h < H; ++h) {
// 预计算 scores [T][T]
vector<vector<double>> scores(T, vector<double>(T, 0.0));
for (int i = 0; i < T; ++i) {
for (int j = 0; j < T; ++j) {
double dot = 0.0;
for (int r = 0; r < d_k; ++r) dot += Qh[b][h][i][r] * Kh[b][h][j][r];
scores[i][j] = dot * inv_sqrt;
}
}
// softmax with causal mask
for (int i = 0; i < T; ++i) {
double maxv = -1e100;
for (int j = 0; j < T; ++j) {
if (j > i) scores[i][j] = -1e9; // mask未来位置
if (scores[i][j] > maxv) maxv = scores[i][j];
}
double sumexp = 0.0;
vector<double> p(T, 0.0);
for (int j = 0; j < T; ++j) {
double e = exp(scores[i][j] - maxv);
p[j] = e;
sumexp += e;
}
// attention * V
for (int r = 0; r < d_k; ++r) {
double acc = 0.0;
for (int j = 0; j < T; ++j) {
double w = (sumexp == 0.0 ? 0.0 : p[j] / sumexp);
acc += w * Vh[b][h][j][r];
}
Ah[b][h][i][r] = acc;
}
}
}
}
// 7) 拼接头 -> [B,T,D]
vector<vector<vector<double>>> concat(B, vector<vector<double>>(T, vector<double>(D, 0.0)));
for (int b = 0; b < B; ++b)
for (int t = 0; t < T; ++t)
for (int h = 0; h < H; ++h)
for (int r = 0; r < d_k; ++r)
concat[b][t][h*d_k + r] = Ah[b][h][t][r];
// 线性投影 WO
auto Y = matmul3D2D(concat, WO);
// 输出
print3D(Y);
return 0;
}
Java
import java.io.*;
import java.util.*;
public class Main {
// 解析二维数组 [[...],[...]]
static List<List<Double>> parse2D(String s) {
List<List<Double>> res = new ArrayList<>();
List<Double> row = new ArrayList<>();
int depth = 0, n = s.length();
for (int i = 0; i < n; i++) {
char c = s.charAt(i);
if (c == '[') {
depth++;
if (depth == 2) row = new ArrayList<>();
} else if (c == ']') {
if (depth == 2) {
if (!row.isEmpty()) res.add(row);
row = new ArrayList<>();
}
depth--;
} else {
if (Character.isDigit(c) || c=='-' || c=='+' || c=='.' || c=='e' || c=='E') {
int j = i;
while (j < n) {
char cj = s.charAt(j);
if (Character.isDigit(cj) || cj=='-' || cj=='+' || cj=='.' || cj=='e' || cj=='E') j++;
else break;
}
double val = Double.parseDouble(s.substring(i, j));
row.add(val);
i = j - 1;
}
}
}
return res;
}
// 解析三维数组 [[[...],[...]], [[...],[...]]]
static List<List<List<Double>>> parse3D(String s) {
List<List<List<Double>>> res = new ArrayList<>();
List<List<Double>> mat = new ArrayList<>();
List<Double> row = new ArrayList<>();
int depth = 0, n = s.length();
for (int i = 0; i < n; i++) {
char c = s.charAt(i);
if (c == '[') {
depth++;
if (depth == 2) mat = new ArrayList<>();
if (depth == 3) row = new ArrayList<>();
} else if (c == ']') {
if (depth == 3) {
if (!row.isEmpty()) mat.add(row);
row = new ArrayList<>();
} else if (depth == 2) {
if (!mat.isEmpty()) res.add(mat);
mat = new ArrayList<>();
}
depth--;
} else {
if (Character.isDigit(c) || c=='-' || c=='+' || c=='.' || c=='e' || c=='E') {
int j = i;
while (j < n) {
char cj = s.charAt(j);
if (Character.isDigit(cj) || cj=='-' || cj=='+' || cj=='.' || cj=='e' || cj=='E') j++;
else break;
}
double val = Double.parseDouble(s.substring(i, j));
row.add(val);
i = j - 1;
}
}
}
return res;
}
// X[b][t][d] * W[d][d]
static double[][][] matmul3D2D(double[][][] X, double[][] W) {
int B = X.length, T = X[0].length, D = X[0][0].length;
double[][][] Y = new double[B][T][D];
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
for (int j = 0; j < D; j++) {
double sum = 0.0;
for (int k = 0; k < D; k++) sum += X[b][t][k] * W[k][j];
Y[b][t][j] = sum;
}
}
}
return Y;
}
static String fmt2(double v) {
if (Math.abs(v) < 0.005) v = 0.0; // 避免-0.00
return String.format(java.util.Locale.US, "%.2f", v);
}
public static void main(String[] args) throws Exception {
// 读取整份输入
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringBuilder sbAll = new StringBuilder();
String line;
while ((line = br.readLine()) != null) sbAll.append(line);
String all = sbAll.toString();
// 拆分为6段:num_heads ; X ; WQ ; WK ; WV ; WO
List<String> parts = new ArrayList<>();
{
StringBuilder cur = new StringBuilder();
for (int i = 0; i < all.length(); i++) {
char c = all.charAt(i);
if (c == ';') {
parts.add(cur.toString());
cur.setLength(0);
} else {
cur.append(c);
}
}
if (cur.length() > 0) parts.add(cur.toString());
}
if (parts.size() != 6) return;
// 解析num_heads:替换非数字为空格+输入流风格
int numHeads = 0;
{
String s = parts.get(0);
StringBuilder t = new StringBuilder();
for (int i = 0; i < s.length(); i++) {
char c = s.charAt(i);
if (Character.isDigit(c) || c=='-' || c=='+') t.append(c);
else t.append(' ');
}
Scanner sc = new Scanner(t.toString());
if (sc.hasNextInt()) numHeads = sc.nextInt();
}
// 解析X与权重
List<List<List<Double>>> Xlist = parse3D(parts.get(1));
List<List<Double>> WQl = parse2D(parts.get(2));
List<List<Double>> WKl = parse2D(parts.get(3));
List<List<Double>> WVl = parse2D(parts.get(4));
List<List<Double>> WOl = parse2D(parts.get(5));
int B = Xlist.size();
int T = Xlist.get(0).size();
int D = Xlist.get(0).get(0).size();
int H = numHeads;
int dk = D / H;
// 转为原生数组
double[][][] X = new double[B][T][D];
for (int b = 0; b < B; b++)
for (int t2 = 0; t2 < T; t2++)
for (int d = 0; d < D; d++)
X[b][t2][d] = Xlist.get(b).get(t2).get(d);
double[][] WQ = new double[D][D], WK = new double[D][D], WV = new double[D][D], WO = new double[D][D];
for (int i = 0; i < D; i++)
for (int j = 0; j < D; j++) {
WQ[i][j] = WQl.get(i).get(j);
WK[i][j] = WKl.get(i).get(j);
WV[i][j] = WVl.get(i).get(j);
WO[i][j] = WOl.get(i).get(j);
}
// 1) Q/K/V
double[][][] Q = matmul3D2D(X, WQ);
double[][][] K = matmul3D2D(X, WK);
double[][][] V = matmul3D2D(X, WV);
// 2) 重排到多头 [B][H][T][dk]
double[][][][] Qh = new double[B][H][T][dk];
double[][][][] Kh = new double[B][H][T][dk];
double[][][][] Vh = new double[B][H][T][dk];
for (int b = 0; b < B; b++)
for (int t2 = 0; t2 < T; t2++)
for (int d = 0; d < D; d++) {
int h = d / dk, r = d % dk;
Qh[b][h][t2][r] = Q[b][t2][d];
Kh[b][h][t2][r] = K[b][t2][d];
Vh[b][h][t2][r] = V[b][t2][d];
}
// 3-6) 注意力(带因果mask)
double invSqrt = 1.0 / Math.sqrt((double)dk);
double[][][][] Ah = new double[B][H][T][dk];
for (int b = 0; b < B; b++) {
for (int h = 0; h < H; h++) {
// scores [T][T]
double[][] scores = new double[T][T];
for (int i = 0; i < T; i++) {
for (int j = 0; j < T; j++) {
double dot = 0.0;
for (int r = 0; r < dk; r++) dot += Qh[b][h][i][r] * Kh[b][h][j][r];
scores[i][j] = dot * invSqrt;
}
}
// softmax per i with causal mask
for (int i = 0; i < T; i++) {
double maxv = -1e100;
for (int j = 0; j < T; j++) {
if (j > i) scores[i][j] = -1e9; // mask
if (scores[i][j] > maxv) maxv = scores[i][j];
}
double[] p = new double[T];
double sumexp = 0.0;
for (int j = 0; j < T; j++) {
double e = Math.exp(scores[i][j] - maxv);
p[j] = e;
sumexp += e;
}
for (int r = 0; r < dk; r++) {
double acc = 0.0;
for (int j = 0; j < T; j++) {
double w = (sumexp == 0.0 ? 0.0 : p[j] / sumexp);
acc += w * Vh[b][h][j][r];
}
Ah[b][h][i][r] = acc;
}
}
}
}
// 7) 拼接多头 -> [B][T][D]
double[][][] concat = new double[B][T][D];
for (int b2 = 0; b2 < B; b2++)
for (int t2 = 0; t2 < T; t2++)
for (int h = 0; h < H; h++)
for (int r = 0; r < dk; r++)
concat[b2][t2][h*dk + r] = Ah[b2][h][t2][r];
// 线性投影 WO
double[][][] Y = matmul3D2D(concat, WO);
// 输出 List 形式
StringBuilder out = new StringBuilder();
out.append("[");
for (int b2 = 0; b2 < B; b2++) {
if (b2 > 0) out.append(", ");
out.append("[");
for (int t2 = 0; t2 < T; t2++) {
if (t2 > 0) out.append(", ");
out.append("[");
for (int d2 = 0; d2 < D; d2++) {
if (d2 > 0) out.append(", ");
out.append(fmt2(Y[b2][t2][d2]));
}
out.append("]");
}
out.append("]");
}
out.append("]\n");
System.out.print(out.toString());
}
}
题目内容
在Transformer模型中,Multi-Head Self-Attention是核心组件,用于捕捉序列中的依赖关系。你需要从头实现一个Masked Multi-Head Self-Attention函数,支持自注意力(即queries、keys和values来自同一输入序列),并处理编码(mask)以防止未来位置的信息泄露(常见于Decoder中)。
具体要求:
- 支持多头注意力:将注意力机制并行分成多个"头",每个头学习不同的注意力模式,增强模型对多维度特征的捕捉能力。
- 计算过程:
- 生成Q、K、V矩阵 对输入序列X(维度:[batch_size, seq_len, d_model])通过3个线性层分别生成查询(Query, Q)、键(Key, K)、值(Value, V)矩阵:(Q=X⋅WQ,K=X⋅WK,V=X⋅WV),其中 $W_Q, W_K, W_V \in \mathbb{R}^{d_{model} \times d_{model}}$。
- 将Q、K、V拆分为多个头 将Q、K、V分割为num_heads个并行的子矩阵(每个头的维度为d_k = d_model / num_heads)。 分割后维度为[batch_size, num_heads, seq_len, d_k]。
- 对于每个头,计算注意力分数:attention_scores = ( Q⋅KT ) / sqrt(d_k)。
- 提供mask(一个(batch_size, seq_len, seq_len)的布尔数组,其中True表示需要掩码的位置),则将masked位置的注意力分数设置为负无穷(-inf),以确保softmax后为0。掩码后的分数为masked_scores。
- 对掩码后的分数应用softmax得到注意力权重。 softmax_scores=softmax(masked_scores)。
- 计算注意力输出:attention=softmax_scores · V。
- 拼接多头输出,并通过一个线性投影得到最终结果。 $output = concat(attention_1, ..., attention_{num_heads}) · W_O$ ,其中 WO∈Rdmodel×dmodel 是可学习参数,输出维度为 [batch_size, seq_len, d_model].
注意: 1、需处理批次(batch_size > 1)和变长序列。
2、输入参数以分号分隔。第一个参数为多头数量num_heads; 第二个参数为Q矩阵;第三个参数为K矩阵;第四个参数为V矩阵;第五个参数为 WO。
3、输出为List,需要将np.ndarray转为List
输入描述
以";"分隔,分别为 num_heads, X, Q、K、V,WO
输出描述
输出为最终结果 output,输出保留两位有效小数,并且为 List。
样例1
输入
2;[[[ 1.92, 1.48], [0.67, -1.23], [0.35, -0.68]], [[-1.11, 0.09], [-0.3, -0.39], [-0.59, -0.06]]];[[1.0, 2.0], [2.0, 2.0]];[[1.0, 1.0], [2.0, 2.0]];[[1.0, 1.0], [2.0, 2.0]];[[1.0, 1.0], [2.0, 2.0]]
输出
[[[14.64, 14.64], [-5.36, -5.36], [-4.44, -4.44]], [[-2.79, -2.79], [-3.04, -3.04], [-2.79, -2.79]]]
样例2
输入
2;[[[ 1.92, 1.48], [0.67, -1.23], [0.35, -0.68]], [[-1.11, 0.09], [-0.3, -0.39], [-0.59, -0.06]]];[[1.0,1.0], [2.0, 2.0]];[[1.0, 1.0], [2.0, 2.0]];[[1.0, 1.0], [2.0, 2.0]];[[1.0, 1.0], [2.0, 2.0]]
输出
[[[14.64, 14.64], [-5.37, -5.37], [-4.62, -4.62]], [[-2.79, -2.79], [-3.03, -3.03], [-2.77, -2.77]]]
提示
- 手动实现softmax:exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True));softmax = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)。确保数值稳定性,减去每行最大值
- 使用np.around(np.ndarray, 2)将输出保留2位小数
- 通过下三角矩阵实现序列掩码mask,确保每个位置只能关注自身及之前的位置。下三角为1,上三角为0。
- 处理-inf:可以使用np.where(mask == 0, -np.inf, attention_scores)