#P3712. 第2题-大模型Attention模块开发
-
1000ms
Tried: 1340
Accepted: 390
Difficulty: 4
所属公司 :
华为
时间 :2025年9月17日-AI岗
-
算法标签>模拟
第2题-大模型Attention模块开发
解题思路
-
按题意用“暴力模拟”完整走一遍计算图:
- 构造 X 为 n×m 的全 1;构造 W1、W2、W3 为 m×h 的上三角全 1。
- 计算 Q=X·W1,K=X·W2,V=X·W3(普通三重循环矩阵乘法)。
- 计算 M=(Q·K^T)/sqrt(h)。
- 按“简化 softmax”把 M 的每一行做归一化:A[i][j]=M[i][j]/(该行元素和)。
- 计算 Y=A·V。
- 将 Y 全部元素求和,四舍五入输出整数。
-
算法类型:暴力模拟/矩阵运算。
-
由于 n、m、h < 100,直接模拟即可在时空限制内通过。
复杂度分析
-
矩阵乘法开销:
- 计算 Q、K、V:O(n·m·h)
- 计算 M=Q·K^T:O(n²·h)
- 行归一化:O(n²)
- 计算 Y=A·V:O(n²·h)
-
总时间复杂度:O(n·m·h + n²·h)。
-
空间复杂度:O(n·h + n²)(保存 Q、K、V、A 或 M 等中间结果)。
代码实现
Python
import sys
import ast
import numpy as np
def solve(n, m, h):
# 1) 构造 X 全 1,W 上三角全 1
X = np.ones((n, m), dtype=float)
W = np.triu(np.ones((m, h), dtype=float)) # W1=W2=W3 相同
# 2) 计算 Q, K, V(矩阵乘法)
Q = X @ W
K = X @ W
V = X @ W
# 3) 计算 M=(Q·K^T)/sqrt(h)
M = (Q @ K.T) / np.sqrt(float(h))
# 4) “简化 softmax”:按行除以行和
row_sum = M.sum(axis=1, keepdims=True)
A = M / (row_sum + 1e-12)
# 5) 计算 Y=A·V 并求和
Y = A @ V
total = float(Y.sum())
# 6) 四舍五入输出整数
return int(np.rint(total))
def main():
s = sys.stdin.read().strip()
try:
val = ast.literal_eval(s)
if isinstance(val, (list, tuple)) and len(val) == 3:
n, m, h = map(int, val)
else:
n, m, h = map(int, s.split())
except Exception:
n, m, h = map(int, s.split())
print(solve(n, m, h))
if __name__ == "__main__":
main()
Java
import java.util.*;
public class Main {
static double[][] matmul(double[][] A, double[][] B){
int n=A.length, m=B[0].length, k=B.length;
double[][] C=new double[n][m];
for(int i=0;i<n;i++)
for(int t=0;t<k;t++){
double v=A[i][t]; if(v==0) continue;
for(int j=0;j<m;j++) C[i][j]+=v*B[t][j];
}
return C;
}
static double[][] trans(double[][] M){
int n=M.length,m=M[0].length;
double[][] T=new double[m][n];
for(int i=0;i<n;i++) for(int j=0;j<m;j++) T[j][i]=M[i][j];
return T;
}
static void rowNorm(double[][] M){
for(double[] r:M){
double s=0; for(double x:r) s+=x;
if(s==0) continue; double inv=1.0/s;
for(int j=0;j<r.length;j++) r[j]*=inv;
}
}
public static void main(String[] args){
Scanner sc=new Scanner(System.in);
int n=sc.nextInt(), m=sc.nextInt(), h=sc.nextInt();
double[][] X=new double[n][m];
for(int i=0;i<n;i++) Arrays.fill(X[i],1.0);
double[][] W=new double[m][h];
for(int i=0;i<m;i++) for(int j=i;j<h;j++) W[i][j]=1.0;
double[][] Q=matmul(X,W), K=Q, V=Q;
double[][] M=matmul(Q, trans(K));
double s=Math.sqrt(h);
if(s!=0) for(int i=0;i<M.length;i++) for(int j=0;j<M[0].length;j++) M[i][j]/=s;
rowNorm(M);
double[][] Y=matmul(M,V);
double total=0;
for(double[] r:Y) for(double x:r) total+=x;
System.out.println(Math.round(total));
sc.close();
}
}
C++
#include <bits/stdc++.h>
using namespace std;
using Mat = vector<vector<double>>;
Mat matmul(const Mat& A, const Mat& B) {
int n=A.size(), m=B[0].size(), k=B.size();
Mat C(n, vector<double>(m, 0.0));
for (int i=0;i<n;++i)
for (int t=0;t<k;++t) if (A[i][t]!=0.0) {
double v=A[i][t];
for (int j=0;j<m;++j) C[i][j]+=v*B[t][j];
}
return C;
}
Mat trans(const Mat& M){
int n=M.size(), m=M[0].size();
Mat T(m, vector<double>(n));
for(int i=0;i<n;++i) for(int j=0;j<m;++j) T[j][i]=M[i][j];
return T;
}
void row_norm(Mat& M){ // 简化 softmax:每行除以行和
for (auto& row : M){
double s=0; for(double x:row) s+=x;
if (s==0) continue; double inv=1.0/s;
for(double& x:row) x*=inv;
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n,m,h; if(!(cin>>n>>m>>h)) return 0;
// X: n×m 全1;W: m×h 上三角全1
Mat X(n, vector<double>(m,1.0));
Mat W(m, vector<double>(h,0.0));
for(int i=0;i<m;++i) for(int j=i;j<h;++j) W[i][j]=1.0;
Mat Q = matmul(X,W); // W1=W2=W3,相同
Mat K = Q, V = Q;
Mat M = matmul(Q, trans(K));
double s = sqrt((double)h);
if (s!=0) for(auto& r:M) for(double& x:r) x/=s;
row_norm(M); // A
Mat Y = matmul(M, V);
double total=0; for(auto& r:Y) for(double x:r) total+=x;
cout << llround(total) << "\n";
return 0;
}
题目内容
已知大模型常用的 Attention 模块定义如下:
$Y = \text{softmax}\left(\frac{QK^T}{\sqrt{h}}\right)V$
此处考虑二维情况,其中
$Q, K, V = XW_1, XW_2, XW_3 \in \mathbb{R}^{n \times h}, \quad X \in \mathbb{R}^{n \times m}, \quad W_1, W_2, W_3 \in \mathbb{R}^{m \times h}$
注意:
-
为简便起见,所有输入初始化为全1矩阵,所有权重矩阵初始化为上三角全 1 矩阵。
-
对任意矩阵 ( M ) 的 softmax 计算简化为:
$\text{softmax}(M)_{ij} = \frac{M_{ij}}{M_i}, \quad M_i = \sum_j M_{ij}$
输入描述
输入为维度参数 n,m和h,参数间使用空格隔开,均为小于 100 的正整数
输出描述
输出为结果矩阵 Y∈Rn×h的所有元素之和,例如 15,输出在四舍五入后保留整数
样例1
输入
3 3 3
输出
18
说明
$X = \begin{pmatrix} 1 & 1 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{pmatrix}, \quad W_1, W_2, W_3 = \begin{pmatrix} 1 & 1 & 1 \\ 0 & 1 & 1 \\ 0 & 0 & 1 \end{pmatrix}$
$Q, K, V = \begin{pmatrix} 1 & 2 & 3 \\ 1 & 2 & 3 \\ 1 & 2 & 3 \end{pmatrix}, \quad Y = \begin{pmatrix} 1 & 2 & 3 \\ 1 & 2 & 3 \\ 1 & 2 & 3 \end{pmatrix}$
输出为:18
样例2
输入
2 3 1
输出
2
说明
$X = \begin{pmatrix} 1 & 1 & 1 \\ 1 & 1 & 1 \end{pmatrix}, \quad W_1, W_2, W_3 = \begin{pmatrix} 1 \\ 0 \\ 0 \end{pmatrix}$
$Q, K, V = \begin{pmatrix} 1 \\ 1 \end{pmatrix}, \quad Y = \begin{pmatrix} 1 \\ 1 \end{pmatrix}$
输出为:2
提示
输入参数不包含 0,为正整数