#P2700. 第3题-小苯数组
-
1000ms
Tried: 33
Accepted: 4
Difficulty: 6
所属公司 :
阿里
时间 :2025年3月15日-阿里淘天(开发岗)
-
算法标签>双指针
第3题-小苯数组
题解
题面描述
给定一个由n个整数组成的数组{a1,a2,…,an}。小苯对数组中各个区间的乘积很感兴趣,并提出了q个询问:对于每个询问给定一个长度len,要求计算数组中所有长度为len的区间的乘积之和,定义区间(l,r)的乘积为 f(l,r)=al×al+1×⋯×ar 但有个特殊要求:如果某个区间的乘积f(l,r)>109,则直接将其视为0(不计入总和)。
思路分析
由于n和q的总和均可达到3×105,故必须设计一个预处理算法,将所有可能的区间长度len的答案预先求出,然后对每个询问直接输出对应结果。
注意到如果一个区间的乘积超过109就不计入和,因此只有乘积不大的区间需要累加。
我们可以将原数组按照值为0的元素划分为若干个区间,因为任何包含0的区间乘积为0(且不会超界),不必单独处理。对于每个非0的区间,我们分两种情况讨论:
-
全部为1的区间
由于所有元素都是1,任意区间乘积都为1,且不会超过109。设该区间长度为L,则该区间中长度为len的子区间个数为L−len+1,它们对答案的贡献即为L−len+1。 -
非全部为1的区间
此时区间中至少存在一个大于1的数,乘积会随区间长度指数增长,故对于每个起点我们可以用二分查找或双指针确定从该起点开始最大的区间长度L′使得乘积不超过109。
具体做法为:- 令局部数组记为seg,长度为L,预先构造前缀乘积数组P(令P[0]=1,对1≤k≤L有P[k]=min(P[k−1]×seg[k−1],109+1),若超过109则记为109+1,表示已经超界)。
- 对于每个起点i(0≤i<L),利用二分查找确定最大的r(i+1≤r≤L),满足P[r]≤109×P[i]. 这样起点i合法的区间个数为r−i。而对于长度为l(1≤l≤r−i)的区间,其乘积为P[i]P[i+l]. 将这些乘积累加到对应的答案ans[l]中即可。
对于整个数组,我们对每个非0区间分别累加答案,最后对每个询问直接输出预处理好的ans[len]。
注意:包含0的区间其乘积为0,不用累加。
由于对于非全部为1的区间,由于数组中存在大于1的元素,其合法的子区间长度往往很短(乘积会很快超过109),故内层循环不会过慢;而对于全1的区间,我们采用公式直接计算,从而避免O(L2)的遍历。
C++
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll LIMIT = 1000000000;
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;
cin >> T;
while(T--){
int n, q;
cin >> n >> q;
vector<ll> a(n);
for (int i = 0; i < n; i++) {
cin >> a[i];
}
// ans[len] 表示所有区间长度为 len 的答案(1-indexed)
vector<ll> ans(n+1, 0);
int i = 0;
while(i < n){
if(a[i] == 0){
// 包含 0 的区间乘积为 0,直接跳过
i++;
continue;
}
// 处理非0段
int j = i;
while(j < n && a[j] != 0) j++;
int lenSeg = j - i;
vector<ll> seg(a.begin()+i, a.begin()+j);
// 判断是否全为 1
bool allOnes = true;
for(auto &num : seg){
if(num != 1){
allOnes = false;
break;
}
}
if(allOnes){
// 对于全 1 的段,长度为 L 的段中,长度为 len 的子区间个数为 (L - len + 1),乘积均为 1
for (int L = 1; L <= lenSeg; L++){
ans[L] += (ll)(lenSeg - L + 1);
}
} else {
// 对于非全 1 的段,对每个起点单独累乘,直到乘积超过 LIMIT
for (int start = 0; start < lenSeg; start++){
ll prod = 1;
for (int end = start; end < lenSeg; end++){
// 计算子区间 [start, end]
// 注意:乘积可能超界,超过 LIMIT 则直接跳出循环
if(prod > LIMIT / seg[end]){
// 若乘积超过 LIMIT,则后续子区间均无效
break;
}
prod *= seg[end];
if(prod > LIMIT) break; // 虽然乘积仍在计算中,但超过 LIMIT 后视为 0
int len = end - start + 1;
ans[len] += prod;
}
}
}
i = j;
}
// 输出询问答案
for (int k = 0; k < q; k++){
int len;
cin >> len;
cout << ans[len] << "\n";
}
}
return 0;
}
Python
# -*- coding: utf-8 -*-
import sys
LIMIT = 10**9
def main():
input = sys.stdin.readline
T = int(input())
for _ in range(T):
n, q = map(int, input().split())
a = list(map(int, input().split()))
ans = [0]*(n+1) # ans[len] 表示所有长度为 len 的区间答案,1-indexed
i = 0
while i < n:
if a[i] == 0:
# 包含 0 的区间乘积为 0,跳过
i += 1
continue
j = i
while j < n and a[j] != 0:
j += 1
lenSeg = j - i
seg = a[i:j]
# 判断是否全为 1
if all(x == 1 for x in seg):
# 全 1 段,贡献直接为 (L - len + 1)
for L in range(1, lenSeg+1):
ans[L] += (lenSeg - L + 1)
else:
# 非全 1 段,对每个起点单独累乘,直到乘积超过 LIMIT
for start in range(lenSeg):
prod = 1
for end in range(start, lenSeg):
# 判断乘积是否会超过 LIMIT
if prod > LIMIT // seg[end]:
break
prod *= seg[end]
if prod > LIMIT:
break
length = end - start + 1
ans[length] += prod
i = j
out = []
for _ in range(q):
L = int(input())
out.append(str(ans[L]))
sys.stdout.write("\n".join(out) + "\n")
if __name__ == '__main__':
main()
Java
import java.io.*;
import java.util.*;
public class Main {
static final long LIMIT = 1000000000L;
public static void main(String[] args) throws IOException{
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
int T = Integer.parseInt(br.readLine());
while(T-- > 0){
String[] parts = br.readLine().split(" ");
int n = Integer.parseInt(parts[0]);
int q = Integer.parseInt(parts[1]);
String[] arrStr = br.readLine().split(" ");
long[] a = new long[n];
for(int i = 0; i < n; i++){
a[i] = Long.parseLong(arrStr[i]);
}
// ans[len] 表示所有长度为 len 的区间答案,1-indexed
long[] ans = new long[n+1];
int i = 0;
while(i < n){
if(a[i] == 0){
// 包含 0 的区间乘积为 0,直接跳过
i++;
continue;
}
int j = i;
while(j < n && a[j] != 0) j++;
int lenSeg = j - i;
long[] seg = new long[lenSeg];
for(int k = 0; k < lenSeg; k++){
seg[k] = a[i+k];
}
boolean allOnes = true;
for(long num : seg){
if(num != 1){
allOnes = false;
break;
}
}
if(allOnes){
// 全 1 段,直接计算贡献
for(int L = 1; L <= lenSeg; L++){
ans[L] += (lenSeg - L + 1);
}
} else {
// 非全 1 段,从每个起点单独累乘,直到乘积超过 LIMIT
for(int start = 0; start < lenSeg; start++){
long prod = 1;
for(int end = start; end < lenSeg; end++){
if(prod > LIMIT / seg[end]){
break;
}
prod *= seg[end];
if(prod > LIMIT) break;
int len = end - start + 1;
ans[len] += prod;
}
}
}
i = j;
}
for(int k = 0; k < q; k++){
int L = Integer.parseInt(br.readLine());
out.println(ans[L]);
}
}
out.flush();
out.close();
}
}
题目内容
小苯有一个由n个整数组成的数组{a1,a2,..,an},他对数组的乘积非常感兴趣,因此他提出了q次询问,具体地,每次询问:
- 小苯会询问一个数字len,他想知道数组中所有长度为len的区间的乘积之和(形式化的:定义f(l,r)=al×al+1×...×ar,则对于所有1≤l≤r≤n,r−l+1=len的(l,r),求 f(l,r) 之和。)
但特别的,懒惰的小苯不希望处理太大的乘积,因此如果某个f(l,r)>109,则小苯会直接舍弃此值,即将其视为0
输入描述
每个测试文件均包含多组测试数据。第一行输入一个整数T(1≦T≦103) 代表数据组数,每组测试数据描述如下:
第一行输入两个正整数n,q(1≦n,q≦3×105),分别表示数组a的长度和小苯的询问次数。
第二行输入n个整数a1,a2,…,an(0≦a≦109),表示小苯拥有的a数组。
接下来q行,每行一个整数len(1≦len≦n),表示当前次的询问。
除此之外,保证单个测试文件的n之和不超过3×105,q之和不超过3×105。
输出描述
对于每组测试数据,输出q行,每行一个整数,表示对应询问的答案。
样例1
输入
1
5 2
1 2 2 1 2
1
3
输出
8
12
说明
对于第一组测试数据的第一次询问,长度为1的区间有5个,乘积之和为 f(1,1)+f(2,2)+f(3,3)+f(4,4)+f(5,5)=1+2+2+1+2=8
对于第一组测试数据的第二次询问,长度为3的区间有3个,乘积之和为 f(1,3)+f(2,4)+f(3,5)=1×2×2+2×1×2+2×2×1=12