#P3712. 第2题-大模型Attention模块开发
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 1235
            Accepted: 371
            Difficulty: 4
            
          
          
          
                       所属公司 : 
                              华为
                                
            
                        
              时间 :2025年9月17日-AI岗
                              
                      
          
 
- 
                        算法标签>模拟          
 
第2题-大模型Attention模块开发
解题思路
- 
按题意用“暴力模拟”完整走一遍计算图:
- 构造 X 为 n×m 的全 1;构造 W1、W2、W3 为 m×h 的上三角全 1。
 - 计算 Q=X·W1,K=X·W2,V=X·W3(普通三重循环矩阵乘法)。
 - 计算 M=(Q·K^T)/sqrt(h)。
 - 按“简化 softmax”把 M 的每一行做归一化:A[i][j]=M[i][j]/(该行元素和)。
 - 计算 Y=A·V。
 - 将 Y 全部元素求和,四舍五入输出整数。
 
 - 
算法类型:暴力模拟/矩阵运算。
 - 
由于 n、m、h < 100,直接模拟即可在时空限制内通过。
 
复杂度分析
- 
矩阵乘法开销:
- 计算 Q、K、V:O(n·m·h)
 - 计算 M=Q·K^T:O(n²·h)
 - 行归一化:O(n²)
 - 计算 Y=A·V:O(n²·h)
 
 - 
总时间复杂度:O(n·m·h + n²·h)。
 - 
空间复杂度:O(n·h + n²)(保存 Q、K、V、A 或 M 等中间结果)。
 
代码实现
Python
import sys
import ast
import numpy as np
def solve(n, m, h):
    # 1) 构造 X 全 1,W 上三角全 1
    X = np.ones((n, m), dtype=float)
    W = np.triu(np.ones((m, h), dtype=float))  # W1=W2=W3 相同
    # 2) 计算 Q, K, V(矩阵乘法)
    Q = X @ W
    K = X @ W
    V = X @ W
    # 3) 计算 M=(Q·K^T)/sqrt(h)
    M = (Q @ K.T) / np.sqrt(float(h))
    # 4) “简化 softmax”:按行除以行和
    row_sum = M.sum(axis=1, keepdims=True)
    A = M / (row_sum + 1e-12)
    # 5) 计算 Y=A·V 并求和
    Y = A @ V
    total = float(Y.sum())
    # 6) 四舍五入输出整数
    return int(np.rint(total))
def main():
    s = sys.stdin.read().strip()
    try:
        val = ast.literal_eval(s)
        if isinstance(val, (list, tuple)) and len(val) == 3:
            n, m, h = map(int, val)
        else:
            n, m, h = map(int, s.split())
    except Exception:
        n, m, h = map(int, s.split())
    print(solve(n, m, h))
if __name__ == "__main__":
    main()
Java
import java.util.*;
public class Main {
    static double[][] matmul(double[][] A, double[][] B){
        int n=A.length, m=B[0].length, k=B.length;
        double[][] C=new double[n][m];
        for(int i=0;i<n;i++)
            for(int t=0;t<k;t++){
                double v=A[i][t]; if(v==0) continue;
                for(int j=0;j<m;j++) C[i][j]+=v*B[t][j];
            }
        return C;
    }
    static double[][] trans(double[][] M){
        int n=M.length,m=M[0].length;
        double[][] T=new double[m][n];
        for(int i=0;i<n;i++) for(int j=0;j<m;j++) T[j][i]=M[i][j];
        return T;
    }
    static void rowNorm(double[][] M){
        for(double[] r:M){
            double s=0; for(double x:r) s+=x;
            if(s==0) continue; double inv=1.0/s;
            for(int j=0;j<r.length;j++) r[j]*=inv;
        }
    }
    public static void main(String[] args){
        Scanner sc=new Scanner(System.in);
        int n=sc.nextInt(), m=sc.nextInt(), h=sc.nextInt();
        double[][] X=new double[n][m];
        for(int i=0;i<n;i++) Arrays.fill(X[i],1.0);
        double[][] W=new double[m][h];
        for(int i=0;i<m;i++) for(int j=i;j<h;j++) W[i][j]=1.0;
        double[][] Q=matmul(X,W), K=Q, V=Q;
        double[][] M=matmul(Q, trans(K));
        double s=Math.sqrt(h);
        if(s!=0) for(int i=0;i<M.length;i++) for(int j=0;j<M[0].length;j++) M[i][j]/=s;
        rowNorm(M);
        double[][] Y=matmul(M,V);
        double total=0;
        for(double[] r:Y) for(double x:r) total+=x;
        System.out.println(Math.round(total));
        sc.close();
    }
}
C++
#include <bits/stdc++.h>
using namespace std;
using Mat = vector<vector<double>>;
Mat matmul(const Mat& A, const Mat& B) {
    int n=A.size(), m=B[0].size(), k=B.size();
    Mat C(n, vector<double>(m, 0.0));
    for (int i=0;i<n;++i)
        for (int t=0;t<k;++t) if (A[i][t]!=0.0) {
            double v=A[i][t];
            for (int j=0;j<m;++j) C[i][j]+=v*B[t][j];
        }
    return C;
}
Mat trans(const Mat& M){
    int n=M.size(), m=M[0].size();
    Mat T(m, vector<double>(n));
    for(int i=0;i<n;++i) for(int j=0;j<m;++j) T[j][i]=M[i][j];
    return T;
}
void row_norm(Mat& M){ // 简化 softmax:每行除以行和
    for (auto& row : M){
        double s=0; for(double x:row) s+=x;
        if (s==0) continue; double inv=1.0/s;
        for(double& x:row) x*=inv;
    }
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n,m,h; if(!(cin>>n>>m>>h)) return 0;
    // X: n×m 全1;W: m×h 上三角全1
    Mat X(n, vector<double>(m,1.0));
    Mat W(m, vector<double>(h,0.0));
    for(int i=0;i<m;++i) for(int j=i;j<h;++j) W[i][j]=1.0;
    Mat Q = matmul(X,W);       // W1=W2=W3,相同
    Mat K = Q, V = Q;
    Mat M = matmul(Q, trans(K));
    double s = sqrt((double)h);
    if (s!=0) for(auto& r:M) for(double& x:r) x/=s;
    row_norm(M);               // A
    Mat Y = matmul(M, V);
    double total=0; for(auto& r:Y) for(double x:r) total+=x;
    cout << llround(total) << "\n";
    return 0;
}
        题目内容
已知大模型常用的 Attention 模块定义如下:
$Y = \text{softmax}\left(\frac{QK^T}{\sqrt{h}}\right)V$
此处考虑二维情况,其中
$Q, K, V = XW_1, XW_2, XW_3 \in \mathbb{R}^{n \times h}, \quad X \in \mathbb{R}^{n \times m}, \quad W_1, W_2, W_3 \in \mathbb{R}^{m \times h}$
注意:
- 
为简便起见,所有输入初始化为全1矩阵,所有权重矩阵初始化为上三角全 1 矩阵。
 - 
对任意矩阵 ( M ) 的 softmax 计算简化为:
 
$\text{softmax}(M)_{ij} = \frac{M_{ij}}{M_i}, \quad M_i = \sum_j M_{ij}$
输入描述
输入为维度参数 n,m和h,参数间使用空格隔开,均为小于 100 的正整数
输出描述
输出为结果矩阵 Y∈Rn×h的所有元素之和,例如 15,输出在四舍五入后保留整数
样例1
输入
3 3 3
输出
18 
说明
$X = \begin{pmatrix} 1 & 1 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{pmatrix}, \quad W_1, W_2, W_3 = \begin{pmatrix} 1 & 1 & 1 \\ 0 & 1 & 1 \\ 0 & 0 & 1 \end{pmatrix}$
$Q, K, V = \begin{pmatrix} 1 & 2 & 3 \\ 1 & 2 & 3 \\ 1 & 2 & 3 \end{pmatrix}, \quad Y = \begin{pmatrix} 1 & 2 & 3 \\ 1 & 2 & 3 \\ 1 & 2 & 3 \end{pmatrix}$
输出为:18
样例2
输入
2 3 1
输出
2
说明
$X = \begin{pmatrix} 1 & 1 & 1 \\ 1 & 1 & 1 \end{pmatrix}, \quad W_1, W_2, W_3 = \begin{pmatrix} 1 \\ 0 \\ 0 \end{pmatrix}$
$Q, K, V = \begin{pmatrix} 1 \\ 1 \end{pmatrix}, \quad Y = \begin{pmatrix} 1 \\ 1 \end{pmatrix}$
输出为:2
提示
输入参数不包含 0,为正整数