#P2698. 第1题-连续非空子数组
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 28
            Accepted: 6
            Difficulty: 7
            
          
          
          
                       所属公司 : 
                              阿里
                                
            
                        
              时间 :2025年3月15日-阿里淘天(开发岗)
                              
                      
          
 
- 
                        算法标签>思维          
 
第1题-连续非空子数组
题解
题目描述
给定一个由n个整数构成的数组{a1,a2,…,an},其中每个ai满足0≤ai≤2。定义一个数组的mex为未出现在该数组中的最小非负整数。例如
- mex{1,2,3}=0
 - mex{0,2,5}=1
 
要求取出数组中的所有连续非空子数组,并求每个子数组的mex值之和。
连续非空子数组指从原数组中取出一段连续的元素(可以取全数组,也可以取部分),且该子数组至少包含一个元素。
思路
本题的核心在于利用数组中仅有的三个数0、1和2的特性,将所有连续子数组的mex值分为四类:不含0(mex=0)、含0但不含1(mex=1)、同时含0和1但不含2(mex=2),以及同时含0、1和2(mex=3);通过扫描数组统计不含某个或某些数字的连续区间,并利用公式2L(L+1)计算区间内子数组的数量,再利用容斥原理求出每一类子数组的个数,最后加权求和得到所有子数组的mex之和。
由于数组中元素取值仅为0、1和2,因此对于任一子数组,其可能的mex值只有以下几种情况:
- mex=0:子数组中没有出现0。
 - mex=1:子数组中出现了0但没有1。
 - mex=2:子数组中同时出现了0和1但没有2。
 - mex=3:子数组中同时出现了0、1和2。
 
我们可以统计满足上述条件的子数组个数,然后利用mex的权值计算答案。记子数组总数为
total=2n(n+1)
如何统计子数组个数
利用“缺失某个数”的思路,设F(x)为不含x的子数组个数,可以利用扫描数组,找出连续不含x的段,其长度为L时,其子数组个数为
2L(L+1)
同理,设F(x,y)表示不含x和y的子数组个数,即该子数组中的所有元素只能为剩下的那个数。
接下来分类讨论:
- 
mex=0:子数组中没有0
cnt0=F(0) - 
mex=1:子数组中出现0但没有1
统计F(1)(即不含1的子数组个数),再减去其中同时不含0和1的,即F(0,1),得到
cnt1=F(1)−F(0,1) - 
mex=2:子数组中含有0和1但没有2
先统计F(2)(不含2的子数组个数),再减去其中不含0的和不含1的部分,即
cnt2=F(2)−F(0,2)−F(1,2) (这里不必加回F(0,1,2),因为非空子数组不可能同时缺少0、1和2) - 
mex=3:子数组中同时含有0、1和2利用容斥原理:cnt_3=total−(F(0)+F(1)+F(2)) + (F(0,1)+F(0,2)+F(1,2))
 
最后答案为
ans = 0×cnt0 + 1×cnt1 + 2×cnt2 + 3×cnt3
cpp
#include <iostream>
#include <vector>
using namespace std;
typedef long long ll;
// 统计不包含数字 x 的子数组数
ll countNo(const vector<int>& arr, int x) {
    ll res = 0;
    ll cnt = 0;
    for (int v : arr) {
        if (v == x) {
            res += cnt * (cnt + 1LL) / 2;
            cnt = 0;
        } else {
            cnt++;
        }
    }
    res += cnt * (cnt + 1LL) / 2;
    return res;
}
// 统计不包含数字 x 和 y 的子数组数,即子数组中只能出现剩下的那个数
ll countNoPair(const vector<int>& arr, int x, int y) {
    ll res = 0;
    ll cnt = 0;
    for (int v : arr) {
        if (v == x || v == y) {
            res += cnt * (cnt + 1LL) / 2;
            cnt = 0;
        } else {
            cnt++;
        }
    }
    res += cnt * (cnt + 1LL) / 2;
    return res;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n;
    cin >> n;
    vector<int> arr(n);
    for (int i = 0; i < n; i++){
        cin >> arr[i];
    }
    
    ll total = (ll)n * (n + 1LL) / 2;
    
    // 统计各个缺失情况
    ll cnt0 = countNo(arr, 0);      // 不含0的子数组数
    ll cnt1_all = countNo(arr, 1);    // 不含1的子数组数
    ll cnt2_all = countNo(arr, 2);    // 不含2的子数组数
    ll cnt01 = countNoPair(arr, 0, 1);  // 不含0和1的子数组数
    ll cnt02 = countNoPair(arr, 0, 2);  // 不含0和2的子数组数
    ll cnt12 = countNoPair(arr, 1, 2);  // 不含1和2的子数组数
    // 计算各个类别的子数组个数
    ll cnt_mex0 = cnt0;                        //mex=0的子数组
    ll cnt_mex1 = cnt1_all - cnt01;             //mex=1的子数组:包含0但不含1
    ll cnt_mex2 = cnt2_all - cnt02 - cnt12;     // mex=2的子数组:包含0和1但不含2
    ll cnt_mex3 = total - (cnt0 + cnt1_all + cnt2_all) + (cnt01 + cnt02 + cnt12); //mex=3的子数组
    // 最终答案
    ll ans = 1LL * cnt_mex1 + 2LL * cnt_mex2 + 3LL * cnt_mex3;
    cout << ans << "\n";
    return 0;
}
python
def count_no(arr, x):
    # 统计不包含数字 x 的子数组数
    res = 0
    cnt = 0
    for v in arr:
        if v == x:
            res += cnt * (cnt + 1) // 2
            cnt = 0
        else:
            cnt += 1
    res += cnt * (cnt + 1) // 2
    return res
def count_no_pair(arr, x, y):
    # 统计不包含数字 x 和 y 的子数组数(子数组中只能出现剩下的那个数)
    res = 0
    cnt = 0
    for v in arr:
        if v == x or v == y:
            res += cnt * (cnt + 1) // 2
            cnt = 0
        else:
            cnt += 1
    res += cnt * (cnt + 1) // 2
    return res
def main():
    import sys
    input_data = sys.stdin.read().split()
    n = int(input_data[0])
    arr = list(map(int, input_data[1:]))
    
    total = n * (n + 1) // 2
    
    cnt0 = count_no(arr, 0)       # 不含0的子数组数
    cnt1_all = count_no(arr, 1)     # 不含1的子数组数
    cnt2_all = count_no(arr, 2)     # 不含2的子数组数
    
    cnt01 = count_no_pair(arr, 0, 1)  # 不含0和1的子数组数
    cnt02 = count_no_pair(arr, 0, 2)  # 不含0和2的子数组数
    cnt12 = count_no_pair(arr, 1, 2)  # 不含1和2的子数组数
    # 计算各个类别的子数组个数
    cnt_mex0 = cnt0                                #mex=0的子数组
    cnt_mex1 = cnt1_all - cnt01                    #mex=1的子数组:包含0但不含1
    cnt_mex2 = cnt2_all - cnt02 - cnt12            # mex=2的子数组:包含0和1但不含2
    cnt_mex3 = total - (cnt0 + cnt1_all + cnt2_all) + (cnt01 + cnt02 + cnt12)  # mex=3的子数组
    # 最终答案
    ans = 1 * cnt_mex1 + 2 * cnt_mex2 + 3 * cnt_mex3
    print(ans)
if __name__ == '__main__':
    main()
java
import java.io.*;
import java.util.*;
public class Main {
    // 统计不包含数字 x 的子数组数
    public static long countNo(int[] arr, int x) {
        long res = 0;
        long cnt = 0;
        for (int v : arr) {
            if (v == x) {
                res += cnt * (cnt + 1) / 2;
                cnt = 0;
            } else {
                cnt++;
            }
        }
        res += cnt * (cnt + 1) / 2;
        return res;
    }
    
    // 统计不包含数字 x 和 y 的子数组数(子数组中只能出现剩下的那个数)
    public static long countNoPair(int[] arr, int x, int y) {
        long res = 0;
        long cnt = 0;
        for (int v : arr) {
            if (v == x || v == y) {
                res += cnt * (cnt + 1) / 2;
                cnt = 0;
            } else {
                cnt++;
            }
        }
        res += cnt * (cnt + 1) / 2;
        return res;
    }
    
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());
        String[] tokens = br.readLine().split("\\s+");
        int[] arr = new int[n];
        for (int i = 0; i < n; i++){
            arr[i] = Integer.parseInt(tokens[i]);
        }
        
        // 子数组总数 $$total = \frac{n(n+1)}{2}$$
        long total = (long) n * (n + 1) / 2;
        
        long cnt0 = countNo(arr, 0);       // 不含0的子数组数
        long cnt1_all = countNo(arr, 1);     // 不含1的子数组数
        long cnt2_all = countNo(arr, 2);     // 不含2的子数组数
        
        long cnt01 = countNoPair(arr, 0, 1);  // 不含0和1的子数组数
        long cnt02 = countNoPair(arr, 0, 2);  // 不含0和2的子数组数
        long cnt12 = countNoPair(arr, 1, 2);  // 不含1和2的子数组数
        
        // 计算各个类别的子数组个数
        long cnt_mex0 = cnt0;  //mex=0的子数组
        long cnt_mex1 = cnt1_all - cnt01;  //mex=1的子数组:包含0但不含1
        long cnt_mex2 = cnt2_all - cnt02 - cnt12;  // mex=2的子数组:包含0和1但不含2
        long cnt_mex3 = total - (cnt0 + cnt1_all + cnt2_all) + (cnt01 + cnt02 + cnt12);  //mex=3的子数组
        
        // 最终答案
        long ans = 1 * cnt_mex1 + 2 * cnt_mex2 + 3 * cnt_mex3;
        System.out.println(ans);
    }
}
        题目内容
整数数组的mex定义为没有出现在数组中的最小非负整数。例如mex(1,2,3)=0,mex(0,2,5)=1。 现在,对于给定的由n个整数组成的数组{a1,a2,...,an},取出全部连续非空子数组,并计算每个子数组的mex之和。
连续非空子数组为从原数组中,连续的选择一段元素(可以全选、可以不选)得到的新数组,且新数组中至少有一个元素。
输入描述
第一行输入一个整数n(1<=n<=2×105)代表数组中的元素数量。
第二行输入n个整数a1,a2,...,an(0<=ai<=2)代表数组元素。
输出描述
输出一个整数,代表所有子数组的mex之和。
样例1
输入
3
1 1 0
输出
3
说明
在这个样例中,答案由以下三部分构成:
长度为1的连续子数组:(1)。(1)、(0),mex之和为0+0+1=1
长度为2的连续子数组:(1,1)、(1,0),mex之和为0+2=2
长度为的固挂子数据:(1,1,0),mex之和为2。
因此,答案为1+2+2=5。