#P2698. 第1题-连续非空子数组
-
1000ms
Tried: 29
Accepted: 7
Difficulty: 7
所属公司 :
阿里
时间 :2025年3月15日-阿里淘天(开发岗)
-
算法标签>思维
第1题-连续非空子数组
题解
题目描述
给定一个由n个整数构成的数组{a1,a2,…,an},其中每个ai满足0≤ai≤2。定义一个数组的mex为未出现在该数组中的最小非负整数。例如
- mex{1,2,3}=0
- mex{0,2,5}=1
要求取出数组中的所有连续非空子数组,并求每个子数组的mex值之和。
连续非空子数组指从原数组中取出一段连续的元素(可以取全数组,也可以取部分),且该子数组至少包含一个元素。
思路
本题的核心在于利用数组中仅有的三个数0、1和2的特性,将所有连续子数组的mex值分为四类:不含0(mex=0)、含0但不含1(mex=1)、同时含0和1但不含2(mex=2),以及同时含0、1和2(mex=3);通过扫描数组统计不含某个或某些数字的连续区间,并利用公式2L(L+1)计算区间内子数组的数量,再利用容斥原理求出每一类子数组的个数,最后加权求和得到所有子数组的mex之和。
由于数组中元素取值仅为0、1和2,因此对于任一子数组,其可能的mex值只有以下几种情况:
- mex=0:子数组中没有出现0。
- mex=1:子数组中出现了0但没有1。
- mex=2:子数组中同时出现了0和1但没有2。
- mex=3:子数组中同时出现了0、1和2。
我们可以统计满足上述条件的子数组个数,然后利用mex的权值计算答案。记子数组总数为
total=2n(n+1)
如何统计子数组个数
利用“缺失某个数”的思路,设F(x)为不含x的子数组个数,可以利用扫描数组,找出连续不含x的段,其长度为L时,其子数组个数为
2L(L+1)
同理,设F(x,y)表示不含x和y的子数组个数,即该子数组中的所有元素只能为剩下的那个数。
接下来分类讨论:
-
mex=0:子数组中没有0
cnt0=F(0) -
mex=1:子数组中出现0但没有1
统计F(1)(即不含1的子数组个数),再减去其中同时不含0和1的,即F(0,1),得到
cnt1=F(1)−F(0,1) -
mex=2:子数组中含有0和1但没有2
先统计F(2)(不含2的子数组个数),再减去其中不含0的和不含1的部分,即
cnt2=F(2)−F(0,2)−F(1,2) (这里不必加回F(0,1,2),因为非空子数组不可能同时缺少0、1和2) -
mex=3:子数组中同时含有0、1和2利用容斥原理:cnt_3=total−(F(0)+F(1)+F(2)) + (F(0,1)+F(0,2)+F(1,2))
最后答案为
ans = 0×cnt0 + 1×cnt1 + 2×cnt2 + 3×cnt3
cpp
#include <iostream>
#include <vector>
using namespace std;
typedef long long ll;
// 统计不包含数字 x 的子数组数
ll countNo(const vector<int>& arr, int x) {
ll res = 0;
ll cnt = 0;
for (int v : arr) {
if (v == x) {
res += cnt * (cnt + 1LL) / 2;
cnt = 0;
} else {
cnt++;
}
}
res += cnt * (cnt + 1LL) / 2;
return res;
}
// 统计不包含数字 x 和 y 的子数组数,即子数组中只能出现剩下的那个数
ll countNoPair(const vector<int>& arr, int x, int y) {
ll res = 0;
ll cnt = 0;
for (int v : arr) {
if (v == x || v == y) {
res += cnt * (cnt + 1LL) / 2;
cnt = 0;
} else {
cnt++;
}
}
res += cnt * (cnt + 1LL) / 2;
return res;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
vector<int> arr(n);
for (int i = 0; i < n; i++){
cin >> arr[i];
}
ll total = (ll)n * (n + 1LL) / 2;
// 统计各个缺失情况
ll cnt0 = countNo(arr, 0); // 不含0的子数组数
ll cnt1_all = countNo(arr, 1); // 不含1的子数组数
ll cnt2_all = countNo(arr, 2); // 不含2的子数组数
ll cnt01 = countNoPair(arr, 0, 1); // 不含0和1的子数组数
ll cnt02 = countNoPair(arr, 0, 2); // 不含0和2的子数组数
ll cnt12 = countNoPair(arr, 1, 2); // 不含1和2的子数组数
// 计算各个类别的子数组个数
ll cnt_mex0 = cnt0; //mex=0的子数组
ll cnt_mex1 = cnt1_all - cnt01; //mex=1的子数组:包含0但不含1
ll cnt_mex2 = cnt2_all - cnt02 - cnt12; // mex=2的子数组:包含0和1但不含2
ll cnt_mex3 = total - (cnt0 + cnt1_all + cnt2_all) + (cnt01 + cnt02 + cnt12); //mex=3的子数组
// 最终答案
ll ans = 1LL * cnt_mex1 + 2LL * cnt_mex2 + 3LL * cnt_mex3;
cout << ans << "\n";
return 0;
}
python
def count_no(arr, x):
# 统计不包含数字 x 的子数组数
res = 0
cnt = 0
for v in arr:
if v == x:
res += cnt * (cnt + 1) // 2
cnt = 0
else:
cnt += 1
res += cnt * (cnt + 1) // 2
return res
def count_no_pair(arr, x, y):
# 统计不包含数字 x 和 y 的子数组数(子数组中只能出现剩下的那个数)
res = 0
cnt = 0
for v in arr:
if v == x or v == y:
res += cnt * (cnt + 1) // 2
cnt = 0
else:
cnt += 1
res += cnt * (cnt + 1) // 2
return res
def main():
import sys
input_data = sys.stdin.read().split()
n = int(input_data[0])
arr = list(map(int, input_data[1:]))
total = n * (n + 1) // 2
cnt0 = count_no(arr, 0) # 不含0的子数组数
cnt1_all = count_no(arr, 1) # 不含1的子数组数
cnt2_all = count_no(arr, 2) # 不含2的子数组数
cnt01 = count_no_pair(arr, 0, 1) # 不含0和1的子数组数
cnt02 = count_no_pair(arr, 0, 2) # 不含0和2的子数组数
cnt12 = count_no_pair(arr, 1, 2) # 不含1和2的子数组数
# 计算各个类别的子数组个数
cnt_mex0 = cnt0 #mex=0的子数组
cnt_mex1 = cnt1_all - cnt01 #mex=1的子数组:包含0但不含1
cnt_mex2 = cnt2_all - cnt02 - cnt12 # mex=2的子数组:包含0和1但不含2
cnt_mex3 = total - (cnt0 + cnt1_all + cnt2_all) + (cnt01 + cnt02 + cnt12) # mex=3的子数组
# 最终答案
ans = 1 * cnt_mex1 + 2 * cnt_mex2 + 3 * cnt_mex3
print(ans)
if __name__ == '__main__':
main()
java
import java.io.*;
import java.util.*;
public class Main {
// 统计不包含数字 x 的子数组数
public static long countNo(int[] arr, int x) {
long res = 0;
long cnt = 0;
for (int v : arr) {
if (v == x) {
res += cnt * (cnt + 1) / 2;
cnt = 0;
} else {
cnt++;
}
}
res += cnt * (cnt + 1) / 2;
return res;
}
// 统计不包含数字 x 和 y 的子数组数(子数组中只能出现剩下的那个数)
public static long countNoPair(int[] arr, int x, int y) {
long res = 0;
long cnt = 0;
for (int v : arr) {
if (v == x || v == y) {
res += cnt * (cnt + 1) / 2;
cnt = 0;
} else {
cnt++;
}
}
res += cnt * (cnt + 1) / 2;
return res;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine().trim());
String[] tokens = br.readLine().split("\\s+");
int[] arr = new int[n];
for (int i = 0; i < n; i++){
arr[i] = Integer.parseInt(tokens[i]);
}
// 子数组总数 $$total = \frac{n(n+1)}{2}$$
long total = (long) n * (n + 1) / 2;
long cnt0 = countNo(arr, 0); // 不含0的子数组数
long cnt1_all = countNo(arr, 1); // 不含1的子数组数
long cnt2_all = countNo(arr, 2); // 不含2的子数组数
long cnt01 = countNoPair(arr, 0, 1); // 不含0和1的子数组数
long cnt02 = countNoPair(arr, 0, 2); // 不含0和2的子数组数
long cnt12 = countNoPair(arr, 1, 2); // 不含1和2的子数组数
// 计算各个类别的子数组个数
long cnt_mex0 = cnt0; //mex=0的子数组
long cnt_mex1 = cnt1_all - cnt01; //mex=1的子数组:包含0但不含1
long cnt_mex2 = cnt2_all - cnt02 - cnt12; // mex=2的子数组:包含0和1但不含2
long cnt_mex3 = total - (cnt0 + cnt1_all + cnt2_all) + (cnt01 + cnt02 + cnt12); //mex=3的子数组
// 最终答案
long ans = 1 * cnt_mex1 + 2 * cnt_mex2 + 3 * cnt_mex3;
System.out.println(ans);
}
}
题目内容
整数数组的mex定义为没有出现在数组中的最小非负整数。例如mex(1,2,3)=0,mex(0,2,5)=1。 现在,对于给定的由n个整数组成的数组{a1,a2,...,an},取出全部连续非空子数组,并计算每个子数组的mex之和。
连续非空子数组为从原数组中,连续的选择一段元素(可以全选、可以不选)得到的新数组,且新数组中至少有一个元素。
输入描述
第一行输入一个整数n(1<=n<=2×105)代表数组中的元素数量。
第二行输入n个整数a1,a2,...,an(0<=ai<=2)代表数组元素。
输出描述
输出一个整数,代表所有子数组的mex之和。
样例1
输入
3
1 1 0
输出
3
说明
在这个样例中,答案由以下三部分构成:
长度为1的连续子数组:(1)。(1)、(0),mex之和为0+0+1=1
长度为2的连续子数组:(1,1)、(1,0),mex之和为0+2=2
长度为的固挂子数据:(1,1,0),mex之和为2。
因此,答案为1+2+2=5。