#P4275. 第3题-基于空间连续块的稀疏注意力机制
-
1000ms
Tried: 119
Accepted: 27
Difficulty: 6
所属公司 :
华为
时间 :2025年10月22日-AI方向
-
算法标签>稀疏注意力机制
第3题-基于空间连续块的稀疏注意力机制
解题思路
长序列下,将历史 token 划分为固定大小的空间连续块,每块做均值池化后经一个两层 MLP 压缩成向量,再用固定查询向量 q=1 与压缩结果做打分。得到的压缩分数序列 A=(a1,…,am) 之后,需要将其划分为恰好两个连续非空子数组,使两段和的最小值最大。整体可分为两部分:
- 数值构造:分块 + 池化 + MLP + 打分
-
设序列长度为 n、维度为 d、块大小为 b,块数 m=⌈n/b⌉。
-
第 k 个块的均值池化
$$h_k=\frac{1}{|B_k|}\sum_{x\in B_k} x\in\mathbb{R}^d $$ -
两层 MLP(隐藏维度为 1):
$$t_k=W_1\cdot h_k+b_1,\quad r_k=\sigma(t_k)=\max(0,t_k) $$ck=W2⋅rk+b2∈Rd其中 $W_1\in\mathbb{R}^{1\times d},\, W_2\in\mathbb{R}^{d\times 1},\, b_1=2,\, b_2=1$(标量,按维度广播)。
-
因为 q=1,注意力得分
$$a_k=\frac{q\cdot c_k}{\sqrt d}=\frac{\sum_{i=1}^{d} c_k^{(i)}}{\sqrt d} $$顺序得到 A。
- 最优划分:前缀和 + 贪心
- 目标是 $\max_{1\le s\le m-1} \min\Big(\sum_{i=1}^{s}a_i,\sum_{i=s+1}^{m}a_i\Big)$。
- 记总和 T=∑i=1mai,前缀和 Ps=∑i=1sai。显然最优 s 使得两段尽量“均衡”,即 Ps 最接近 T/2。
- 实现上只需一次线性扫描:维护前缀和,逐个计算 min(Ps,T−Ps) 的最大值即可。 这是典型的前缀和 + 单遍扫描贪心,时间 O(m),优于通用的“二分答案 + 可行性判断”。
最终答案为 S=maxsmin(⋅),题目要求输出 round(100⋅S) 的整数(四舍五入,保留两位小数的整数化)。
复杂度分析
-
时间复杂度:
- 计算所有块均值与 MLP:遍历每个 token 各维度,O(n⋅d)。
- 计算打分并寻找最优切分点:O(m),其中 m=⌈n/b⌉≤n。
- 总计 O(n⋅d)。
-
空间复杂度:
- 保存一块的中间向量与常量参数,外加得分序列(或仅累计总和与前缀),为 O(d+m),可降到 O(d)(边算边累计,不必存整列)。
代码实现
Python
import sys
import math
import numpy as np
# 核心功能:根据题意计算最终整数化得分
def solve(n: int, d: int, b: int, X: np.ndarray, W1: np.ndarray, W2: np.ndarray) -> int:
m = (n + b - 1) // b # 块数
A = [] # 压缩注意力得分序列
sqrt_d = math.sqrt(d)
# 逐块计算 a_k
for k in range(m):
start = k * b
end = min((k + 1) * b, n)
block = X[start:end] # 该块的所有 token,形状 (len, d)
# 平均池化 h_k
h_k = block.mean(axis=0)
# 两层 MLP:t = W1·h + b1,r = ReLU(t),c = W2*r + b2(逐维加1)
t = float(W1.dot(h_k)) + 2.0
r = max(0.0, t)
c = W2 * r + 1.0 # 广播加 1
a_k = float(c.sum()) / sqrt_d
A.append(a_k)
# 线性扫描寻找最优切分点,使 min(左和, 右和) 最大
T = sum(A)
best = -1e100
pref = 0.0
for s in range(1, m): # 必须切成两个非空段
pref += A[s - 1]
best = max(best, min(pref, T - pref))
S = best
return int(round(S * 100.0))
def main():
data = sys.stdin.read().strip().split()
it = iter(data)
# 读入 n d b
n = int(next(it)); d = int(next(it)); b = int(next(it))
# 读入 n 行,每行 d 个浮点
xs = [ [float(next(it)) for _ in range(d)] for _ in range(n) ]
X = np.array(xs, dtype=float)
# 读入 W1, W2(各 d 个数)
W1 = np.array([float(next(it)) for _ in range(d)], dtype=float)
W2 = np.array([float(next(it)) for _ in range(d)], dtype=float)
ans = solve(n, d, b, X, W1, W2)
print(ans)
if __name__ == "__main__":
main()
Java
import java.util.*;
public class Main {
static long roundHalfToEven(double x) {
final double EPS = 1e-12;
double r = Math.floor(x); // 向下取整
double frac = x - r; // [0,1) 的小数部分
if (frac < 0.5 - EPS) return (long) r;
if (frac > 0.5 + EPS) return (long) (r + 1.0);
// frac 约等于 0.5:五取偶
long ri = (long) r; // floor(x)
return (ri % 2 == 0) ? ri : (ri + 1);
}
// 核心功能:计算最终整数化得分
static long solve(int n, int d, int b, double[][] X, double[] W1, double[] W2) {
int m = (n + b - 1) / b;
double[] A = new double[m];
double sqrt_d = Math.sqrt(d);
// 逐块计算 a_k
for (int k = 0; k < m; k++) {
int start = k * b;
int end = Math.min((k + 1) * b, n);
int len = end - start;
// 平均池化 h_k
double[] hk = new double[d];
Arrays.fill(hk, 0.0);
for (int i = start; i < end; i++) {
for (int j = 0; j < d; j++) {
hk[j] += X[i][j];
}
}
for (int j = 0; j < d; j++) hk[j] /= len;
// 两层 MLP
double t = 0.0;
for (int j = 0; j < d; j++) t += W1[j] * hk[j];
t += 2.0;
double r = Math.max(0.0, t);
double sumc = 0.0;
for (int j = 0; j < d; j++) sumc += (W2[j] * r + 1.0);
double ak = sumc / sqrt_d;
A[k] = ak;
}
// 单遍扫描找最优切分
double T = 0.0;
for (double v : A) T += v;
double best = -1e100;
double pref = 0.0;
for (int s = 1; s <= m - 1; s++) {
pref += A[s - 1];
double cur = Math.min(pref, T - pref);
if (cur > best) best = cur;
}
double S = best;
return roundHalfToEven(S * 100.0);
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
// 读入 n d b
int n = sc.nextInt();
int d = sc.nextInt();
int b = sc.nextInt();
// 读入 X
double[][] X = new double[n][d];
for (int i = 0; i < n; i++) {
for (int j = 0; j < d; j++) {
X[i][j] = sc.nextDouble();
}
}
// 读入 W1, W2
double[] W1 = new double[d];
double[] W2 = new double[d];
for (int j = 0; j < d; j++) W1[j] = sc.nextDouble();
for (int j = 0; j < d; j++) W2[j] = sc.nextDouble();
long ans = solve(n, d, b, X, W1, W2);
System.out.println(ans);
}
}
C++
#include <bits/stdc++.h>
using namespace std;
long long round_half_to_even(double x) {
const double EPS = 1e-12;
double r = floor(x); // 向下取整
double frac = x - r; // [0,1) 的小数部分
if (frac < 0.5 - EPS) return (long long)r;
if (frac > 0.5 + EPS) return (long long)(r + 1.0);
// frac 约等于 0.5:五取偶
long long ri = (long long)r; // floor(x) 的整数部分
return (ri % 2 == 0) ? ri : (ri + 1);
}
// 核心功能:计算最终整数化得分
long long solve(int n, int d, int b,
const vector<vector<double>>& X,
const vector<double>& W1,
const vector<double>& W2) {
int m = (n + b - 1) / b;
vector<double> A(m, 0.0);
double sqrt_d = sqrt((double)d);
// 逐块计算 a_k
for (int k = 0; k < m; ++k) {
int start = k * b;
int end = min((k + 1) * b, n);
int len = end - start;
// 平均池化 h_k
vector<double> hk(d, 0.0);
for (int i = start; i < end; ++i) {
for (int j = 0; j < d; ++j) {
hk[j] += X[i][j];
}
}
for (int j = 0; j < d; ++j) hk[j] /= (double)len;
// 两层 MLP:t = W1·h + 2,r = ReLU(t),c = W2*r + 1
double t = 0.0;
for (int j = 0; j < d; ++j) t += W1[j] * hk[j];
t += 2.0;
double r = max(0.0, t);
double sumc = 0.0;
for (int j = 0; j < d; ++j) sumc += (W2[j] * r + 1.0);
double ak = sumc / sqrt_d;
A[k] = ak;
}
// 单遍扫描找到最优切分
double T = 0.0;
for (double v : A) T += v;
double best = -1e100;
double pref = 0.0;
for (int s = 1; s <= m - 1; ++s) {
pref += A[s - 1];
double cur = min(pref, T - pref);
if (cur > best) best = cur;
}
double S = best;
return round_half_to_even(S * 100.0);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, d, b;
if (!(cin >> n >> d >> b)) return 0;
vector<vector<double>> X(n, vector<double>(d));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < d; ++j) {
cin >> X[i][j];
}
}
vector<double> W1(d), W2(d);
for (int j = 0; j < d; ++j) cin >> W1[j];
for (int j = 0; j < d; ++j) cin >> W2[j];
long long ans = solve(n, d, b, X, W1, W2);
cout << ans << "\n";
return 0;
}
题目内容
在大语言模型推理过程中,随着上下文长度增加,标准 Attention 的计算开销以 O(n2) 增长,成为性能瓶颈。为提升长序列处理效率,提出一种基于空间连续块的稀疏注意力机制。
具体流程如下:
-
一个长度为 n 的历史 token 序列,每个 token 表示为 1 个 d 维特征向量 xj∈Rd
。按固定块大小b,将序列划分为 m=ceil(n/b)个空间连续块(最后一个块可不满)
B1,B2,...,Bm,其中:Bk=[x(k−1)b,...,xmin(kn)−1]
-
对每个块 Bk:
(1) 计算平均池化向量:$\mathbf{h}_k = \frac{1}{B_k} \sum_{x \in B_k} \mathbf{x}$
(2) 使用一个两层多层感知机(MLP)进行非线性压缩(隐藏维度dl=1):$\mathbf{c}_k = W_2 \cdot \sigma(W_1 \cdot \mathbf{h}_k + b_1) + b_2$
其中:
① W1∈R1×d,W2∈Rd×1,输出 ck∈Rd
②b1=2,b2=1
③σ(x)=max(0,x)(即 ReLU 激活函数)
-
给定查询向量q∈Rd(题目中固定为全 1 向量:qi=1),计算每个压缩块的注意力得分:
$a_k = \frac{\mathbf{q} \cdot \mathbf{c}_k}{\sqrt{d}}$
得到压缩块注意力得分序列 A=(a1,a2,...,am)
-
将序列 A 划分为恰好 2 个连续非空子数组,目标是最大化这两个子数组和中的最小值 S 。
-
最终输出该最大化的最小值 S 的整数化得分,该子数组对应的 token 块将跳过细粒度 attention 计算,实现稀疏推理。
其中,整数化得分即 S 乘以 100 后四舍五入得到的整数,以实现保留两位小数精度的整数化表示:
round(100⋅S)
输入描述
第 1 行:n d b,以空格分隔,分别为序列长度、token 向量维度、块大小
接下来 n 行:每行 d 个数,以空格分隔,表示 xi
倒数第 2 行:d 个数,以空格分隔,表示 W1
最后 1 行:d 个数,以空格分隔,表示 W2
约束条件:
1≤n≤1000
1≤b≤n
1≤d≤100
所有向量非零
输出描述
返回一个整数,即上述步骤 5 的整数化得分
样例1
输入
3 1 1
2.0
4.0
6.0
1.0
2.0
输出
1700
说明
①分块:B1=[2.0],B2=[4.0],B3=[6.0]
②平均池化:h1=[2.0],h2=[4.0],h3=[6.0]
③MLP 压缩:c1=[9.0],c2=[13.0],c3=[17.0]
④注意力得分:A=[9,13,17]
⑤划分为 2 个连续非空子数组,最大化min(sum):
[9]∣[13,17] → 和:9,30→min=9
[9,13]∣[17] → 和:22,17→min=17→S=17,输出 1700
样例2
输入
3 2 1
2.0 1.0
3.0 2.0
4.0 3.0
1.0 0.5
2.0 1.0
输出
1732
说明
①分块:B1=[2.0,1.0],B2=[3.0,2.0],B3=[4.0,3.0]
②平均池化:h1=[2.0,1.0],h2=[3.0,2.0],h3=[4.0,3.0]
③MLP 压缩:c1=[10.0,5.5],c2=[13.0,7.0],c3=[16.0,8.5]
④注意力得分:$A = [\frac{15.5}{\sqrt{2}}, \frac{20.0}{\sqrt{2}}, \frac{24.5}{\sqrt{2}}]$
⑤划分为 2 个连续非空子数组,最大化(min(sum)):
$[\frac{15.5}{\sqrt{2}}] , [\frac{20.0}{\sqrt{2}}, \frac{24.5}{\sqrt{2}}] → $和:215.5,244.5 →min= 215.5
$[\frac{15.5}{\sqrt{2}}, \frac{20.0}{\sqrt{2}}] , [\frac{24.5}{\sqrt{2}}] →$ 和:235.5,224.5 $→ min = \frac{24.5}{\sqrt{2}}→ S = \frac{24.5}{\sqrt{2}}$,输出 1732