#P3518. 第3题-连续数列
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 69
            Accepted: 15
            Difficulty: 6
            
          
          
          
                       所属公司 : 
                              美团
                                
            
                        
              时间 :2025年8月30日-算法岗
                              
                      
          
 
- 
                        算法标签>单调栈          
 
第3题-连续数列
解题思路
问题转化
对于任意子数组,它的权值等于“需要补充多少数字才能变成一个连续区间”。 这个数量可以写成: 权值 = (子数组最大值 − 子数组最小值) − (子数组长度 − 1)。
所以总答案就是:
- 所有子数组的 
(最大值 − 最小值)之和, - 减去所有子数组 
(长度 − 1)的和。 
第二部分和数组内容无关,只和 n 有关,可以直接算出来:
所有子数组长度从 1 到 n,每种长度出现的次数是 (n−k+1),因此整体和是
n*(n+1)*(n-1)/6。
如何快速求所有子数组的最大值/最小值之和
这一步是关键。 我们可以用单调栈 + 贡献法:
- 对于最大值:计算每个元素在多少个子数组中是最大值。
 - 对于最小值:计算每个元素在多少个子数组中是最小值。
 
具体做法:
- 找到当前元素左边第一个比它大(或小)的下标,记为 left。
 - 找到当前元素右边第一个比它大(或小)的下标,记为 right。
 - 当前位置能作为极值的子数组个数就是 
(i - left) * (right - i)。 
这样每个元素的贡献可以在 O(1) 算出,总复杂度 O(n)。
最终公式
答案 = 所有子数组最大值之和 − 所有子数组最小值之和 − n*(n+1)*(n-1)/6。
复杂度分析
- 时间复杂度:O(n),因为四个方向用单调栈各遍历一次。
 - 空间复杂度:O(n)。
 
代码
Python
import sys
def sum_max(a):
    n = len(a)
    pg = [-1] * n  # 前一个更大元素下标
    ng = [n] * n   # 后一个更大元素下标
    st = []
    for i in range(n):
        while st and a[st[-1]] <= a[i]:
            st.pop()
        pg[i] = st[-1] if st else -1
        st.append(i)
    st.clear()
    for i in range(n - 1, -1, -1):
        while st and a[st[-1]] <= a[i]:
            st.pop()
        ng[i] = st[-1] if st else n
        st.append(i)
    res = 0
    for i in range(n):
        L = i - pg[i]
        R = ng[i] - i
        res += a[i] * L * R
    return res
def sum_min(a):
    n = len(a)
    ps = [-1] * n  # 前一个更小元素下标
    ns = [n] * n   # 后一个更小元素下标
    st = []
    for i in range(n):
        while st and a[st[-1]] >= a[i]:
            st.pop()
        ps[i] = st[-1] if st else -1
        st.append(i)
    st.clear()
    for i in range(n - 1, -1, -1):
        while st and a[st[-1]] >= a[i]:
            st.pop()
        ns[i] = st[-1] if st else n
        st.append(i)
    res = 0
    for i in range(n):
        L = i - ps[i]
        R = ns[i] - i
        res += a[i] * L * R
    return res
def main():
    data = sys.stdin.read().strip().split()
    n = int(data[0])
    a = list(map(int, data[1:1+n]))
    smax = sum_max(a)
    smin = sum_min(a)
    gap = n * (n + 1) * (n - 1) // 6  # ∑(r-l)
    ans = smax - smin - gap
    print(ans)
if __name__ == "__main__":
    main()
C++
#include <bits/stdc++.h>
using namespace std;
long long sum_max(const vector<int>& a){
    int n = (int)a.size();
    vector<int> pg(n, -1), ng(n, n);
    vector<int> st;
    st.reserve(n);
    // 前更大
    for(int i=0;i<n;i++){
        while(!st.empty() && a[st.back()] <= a[i]) st.pop_back();
        pg[i] = st.empty() ? -1 : st.back();
        st.push_back(i);
    }
    st.clear();
    // 后更大
    for(int i=n-1;i>=0;i--){
        while(!st.empty() && a[st.back()] <= a[i]) st.pop_back();
        ng[i] = st.empty() ? n : st.back();
        st.push_back(i);
    }
    long long res = 0;
    for(int i=0;i<n;i++){
        long long L = i - pg[i];
        long long R = ng[i] - i;
        res += 1LL * a[i] * L * R;
    }
    return res;
}
long long sum_min(const vector<int>& a){
    int n = (int)a.size();
    vector<int> ps(n, -1), ns(n, n);
    vector<int> st;
    st.reserve(n);
    // 前更小
    for(int i=0;i<n;i++){
        while(!st.empty() && a[st.back()] >= a[i]) st.pop_back();
        ps[i] = st.empty() ? -1 : st.back();
        st.push_back(i);
    }
    st.clear();
    // 后更小
    for(int i=n-1;i>=0;i--){
        while(!st.empty() && a[st.back()] >= a[i]) st.pop_back();
        ns[i] = st.empty() ? n : st.back();
        st.push_back(i);
    }
    long long res = 0;
    for(int i=0;i<n;i++){
        long long L = i - ps[i];
        long long R = ns[i] - i;
        res += 1LL * a[i] * L * R;
    }
    return res;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n;
    if(!(cin >> n)) return 0;
    vector<int> a(n);
    for(int i=0;i<n;i++) cin >> a[i];
    long long smax = sum_max(a);
    long long smin = sum_min(a);
    long long gap = 1LL * n * (n + 1) * (n - 1) / 6; // ∑(r-l)
    long long ans = smax - smin - gap;
    cout << ans << "\n";
    return 0;
}
Java
import java.io.*;
import java.util.*;
public class Main {
    static long sumMax(int[] a) {
        int n = a.length;
        int[] pg = new int[n];
        int[] ng = new int[n];
        Arrays.fill(pg, -1);
        Arrays.fill(ng, n);
        int[] st = new int[n];
        int top = -1;
        // 前更大
        for (int i = 0; i < n; i++) {
            while (top >= 0 && a[st[top]] <= a[i]) top--;
            pg[i] = (top == -1) ? -1 : st[top];
            st[++top] = i;
        }
        // 清空栈
        top = -1;
        // 后更大
        for (int i = n - 1; i >= 0; i--) {
            while (top >= 0 && a[st[top]] <= a[i]) top--;
            ng[i] = (top == -1) ? n : st[top];
            st[++top] = i;
        }
        long res = 0;
        for (int i = 0; i < n; i++) {
            long L = i - pg[i];
            long R = ng[i] - i;
            res += 1L * a[i] * L * R;
        }
        return res;
    }
    static long sumMin(int[] a) {
        int n = a.length;
        int[] ps = new int[n];
        int[] ns = new int[n];
        Arrays.fill(ps, -1);
        Arrays.fill(ns, n);
        int[] st = new int[n];
        int top = -1;
        // 前更小
        for (int i = 0; i < n; i++) {
            while (top >= 0 && a[st[top]] >= a[i]) top--;
            ps[i] = (top == -1) ? -1 : st[top];
            st[++top] = i;
        }
        // 清空栈
        top = -1;
        // 后更小
        for (int i = n - 1; i >= 0; i--) {
            while (top >= 0 && a[st[top]] >= a[i]) top--;
            ns[i] = (top == -1) ? n : st[top];
            st[++top] = i;
        }
        long res = 0;
        for (int i = 0; i < n; i++) {
            long L = i - ps[i];
            long R = ns[i] - i;
            res += 1L * a[i] * L * R;
        }
        return res;
    }
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String s = br.readLine();
        if (s == null || s.isEmpty()) return;
        int n = Integer.parseInt(s.trim());
        int[] a = new int[n];
        StringTokenizer st = new StringTokenizer(br.readLine());
        for (int i = 0; i < n; i++) a[i] = Integer.parseInt(st.nextToken());
        long smax = sumMax(a);
        long smin = sumMin(a);
        long gap = 1L * n * (n + 1) * (n - 1) / 6; // ∑(r-l)
        long ans = smax - smin - gap;
        System.out.println(ans);
    }
}
        题目内容
小美从一个原始的连续数列(各元素互不相同,且恰好是某个区间内的所有整数)中丢失了一些数字,剩余元素按 原顺序形成了一个长度为n 的数组{a1,a2,...,an}。
选定区间[l,r]后,一定可以通过插入若干元素,使得子数组中的元素恰好构成从 min(al,...,ar)到
max(al,...,ar)的连续整数序列。我们称这个子数组的权值为:插入的最少元素数量。
请你计算数组中所有子数组的权值之和。
输入描述
第一行输入一个整数n(1≦n≦2×105)表示数组长度。
第二行输入n个互不相同的整数{a1,a2,...,an}(1≦ai≦106)表示数组元素.
输出描述
输出一个整数,表示数组中所有子数组的权值之和。
样例1
输入
3
3 1 2
输出
1 
说明
在第一个样例中,所有子数组及其所需插入数如下:
- 
[1,1]={3}, min=3,max=3,需插入0个元素
 - 
[1,2]={3,1}, min=1,max=3需入1个元素;
 - 
[1,3]={3,1,2},需插入0个元素;
 - 
[2,2]={1},需插入0个元素;
 - 
[2,3]={1,2},需插入0个元素;
 - 
[3,3]= {2},需插入0个元素;
 
权值之和为0+1+0+0+0+0=1
样例2
输入
4
2 5 3 8
输出
14