#P3518. 第3题-连续数列
-
1000ms
Tried: 69
Accepted: 15
Difficulty: 6
所属公司 :
美团
时间 :2025年8月30日-算法岗
-
算法标签>单调栈
第3题-连续数列
解题思路
问题转化
对于任意子数组,它的权值等于“需要补充多少数字才能变成一个连续区间”。 这个数量可以写成: 权值 = (子数组最大值 − 子数组最小值) − (子数组长度 − 1)。
所以总答案就是:
- 所有子数组的
(最大值 − 最小值)之和, - 减去所有子数组
(长度 − 1)的和。
第二部分和数组内容无关,只和 n 有关,可以直接算出来:
所有子数组长度从 1 到 n,每种长度出现的次数是 (n−k+1),因此整体和是
n*(n+1)*(n-1)/6。
如何快速求所有子数组的最大值/最小值之和
这一步是关键。 我们可以用单调栈 + 贡献法:
- 对于最大值:计算每个元素在多少个子数组中是最大值。
- 对于最小值:计算每个元素在多少个子数组中是最小值。
具体做法:
- 找到当前元素左边第一个比它大(或小)的下标,记为 left。
- 找到当前元素右边第一个比它大(或小)的下标,记为 right。
- 当前位置能作为极值的子数组个数就是
(i - left) * (right - i)。
这样每个元素的贡献可以在 O(1) 算出,总复杂度 O(n)。
最终公式
答案 = 所有子数组最大值之和 − 所有子数组最小值之和 − n*(n+1)*(n-1)/6。
复杂度分析
- 时间复杂度:O(n),因为四个方向用单调栈各遍历一次。
- 空间复杂度:O(n)。
代码
Python
import sys
def sum_max(a):
n = len(a)
pg = [-1] * n # 前一个更大元素下标
ng = [n] * n # 后一个更大元素下标
st = []
for i in range(n):
while st and a[st[-1]] <= a[i]:
st.pop()
pg[i] = st[-1] if st else -1
st.append(i)
st.clear()
for i in range(n - 1, -1, -1):
while st and a[st[-1]] <= a[i]:
st.pop()
ng[i] = st[-1] if st else n
st.append(i)
res = 0
for i in range(n):
L = i - pg[i]
R = ng[i] - i
res += a[i] * L * R
return res
def sum_min(a):
n = len(a)
ps = [-1] * n # 前一个更小元素下标
ns = [n] * n # 后一个更小元素下标
st = []
for i in range(n):
while st and a[st[-1]] >= a[i]:
st.pop()
ps[i] = st[-1] if st else -1
st.append(i)
st.clear()
for i in range(n - 1, -1, -1):
while st and a[st[-1]] >= a[i]:
st.pop()
ns[i] = st[-1] if st else n
st.append(i)
res = 0
for i in range(n):
L = i - ps[i]
R = ns[i] - i
res += a[i] * L * R
return res
def main():
data = sys.stdin.read().strip().split()
n = int(data[0])
a = list(map(int, data[1:1+n]))
smax = sum_max(a)
smin = sum_min(a)
gap = n * (n + 1) * (n - 1) // 6 # ∑(r-l)
ans = smax - smin - gap
print(ans)
if __name__ == "__main__":
main()
C++
#include <bits/stdc++.h>
using namespace std;
long long sum_max(const vector<int>& a){
int n = (int)a.size();
vector<int> pg(n, -1), ng(n, n);
vector<int> st;
st.reserve(n);
// 前更大
for(int i=0;i<n;i++){
while(!st.empty() && a[st.back()] <= a[i]) st.pop_back();
pg[i] = st.empty() ? -1 : st.back();
st.push_back(i);
}
st.clear();
// 后更大
for(int i=n-1;i>=0;i--){
while(!st.empty() && a[st.back()] <= a[i]) st.pop_back();
ng[i] = st.empty() ? n : st.back();
st.push_back(i);
}
long long res = 0;
for(int i=0;i<n;i++){
long long L = i - pg[i];
long long R = ng[i] - i;
res += 1LL * a[i] * L * R;
}
return res;
}
long long sum_min(const vector<int>& a){
int n = (int)a.size();
vector<int> ps(n, -1), ns(n, n);
vector<int> st;
st.reserve(n);
// 前更小
for(int i=0;i<n;i++){
while(!st.empty() && a[st.back()] >= a[i]) st.pop_back();
ps[i] = st.empty() ? -1 : st.back();
st.push_back(i);
}
st.clear();
// 后更小
for(int i=n-1;i>=0;i--){
while(!st.empty() && a[st.back()] >= a[i]) st.pop_back();
ns[i] = st.empty() ? n : st.back();
st.push_back(i);
}
long long res = 0;
for(int i=0;i<n;i++){
long long L = i - ps[i];
long long R = ns[i] - i;
res += 1LL * a[i] * L * R;
}
return res;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
if(!(cin >> n)) return 0;
vector<int> a(n);
for(int i=0;i<n;i++) cin >> a[i];
long long smax = sum_max(a);
long long smin = sum_min(a);
long long gap = 1LL * n * (n + 1) * (n - 1) / 6; // ∑(r-l)
long long ans = smax - smin - gap;
cout << ans << "\n";
return 0;
}
Java
import java.io.*;
import java.util.*;
public class Main {
static long sumMax(int[] a) {
int n = a.length;
int[] pg = new int[n];
int[] ng = new int[n];
Arrays.fill(pg, -1);
Arrays.fill(ng, n);
int[] st = new int[n];
int top = -1;
// 前更大
for (int i = 0; i < n; i++) {
while (top >= 0 && a[st[top]] <= a[i]) top--;
pg[i] = (top == -1) ? -1 : st[top];
st[++top] = i;
}
// 清空栈
top = -1;
// 后更大
for (int i = n - 1; i >= 0; i--) {
while (top >= 0 && a[st[top]] <= a[i]) top--;
ng[i] = (top == -1) ? n : st[top];
st[++top] = i;
}
long res = 0;
for (int i = 0; i < n; i++) {
long L = i - pg[i];
long R = ng[i] - i;
res += 1L * a[i] * L * R;
}
return res;
}
static long sumMin(int[] a) {
int n = a.length;
int[] ps = new int[n];
int[] ns = new int[n];
Arrays.fill(ps, -1);
Arrays.fill(ns, n);
int[] st = new int[n];
int top = -1;
// 前更小
for (int i = 0; i < n; i++) {
while (top >= 0 && a[st[top]] >= a[i]) top--;
ps[i] = (top == -1) ? -1 : st[top];
st[++top] = i;
}
// 清空栈
top = -1;
// 后更小
for (int i = n - 1; i >= 0; i--) {
while (top >= 0 && a[st[top]] >= a[i]) top--;
ns[i] = (top == -1) ? n : st[top];
st[++top] = i;
}
long res = 0;
for (int i = 0; i < n; i++) {
long L = i - ps[i];
long R = ns[i] - i;
res += 1L * a[i] * L * R;
}
return res;
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String s = br.readLine();
if (s == null || s.isEmpty()) return;
int n = Integer.parseInt(s.trim());
int[] a = new int[n];
StringTokenizer st = new StringTokenizer(br.readLine());
for (int i = 0; i < n; i++) a[i] = Integer.parseInt(st.nextToken());
long smax = sumMax(a);
long smin = sumMin(a);
long gap = 1L * n * (n + 1) * (n - 1) / 6; // ∑(r-l)
long ans = smax - smin - gap;
System.out.println(ans);
}
}
题目内容
小美从一个原始的连续数列(各元素互不相同,且恰好是某个区间内的所有整数)中丢失了一些数字,剩余元素按 原顺序形成了一个长度为n 的数组{a1,a2,...,an}。
选定区间[l,r]后,一定可以通过插入若干元素,使得子数组中的元素恰好构成从 min(al,...,ar)到
max(al,...,ar)的连续整数序列。我们称这个子数组的权值为:插入的最少元素数量。
请你计算数组中所有子数组的权值之和。
输入描述
第一行输入一个整数n(1≦n≦2×105)表示数组长度。
第二行输入n个互不相同的整数{a1,a2,...,an}(1≦ai≦106)表示数组元素.
输出描述
输出一个整数,表示数组中所有子数组的权值之和。
样例1
输入
3
3 1 2
输出
1
说明
在第一个样例中,所有子数组及其所需插入数如下:
-
[1,1]={3}, min=3,max=3,需插入0个元素
-
[1,2]={3,1}, min=1,max=3需入1个元素;
-
[1,3]={3,1,2},需插入0个元素;
-
[2,2]={1},需插入0个元素;
-
[2,3]={1,2},需插入0个元素;
-
[3,3]= {2},需插入0个元素;
权值之和为0+1+0+0+0+0=1
样例2
输入
4
2 5 3 8
输出
14