#P1704. 第4题-众数和
-
1000ms
Tried: 200
Accepted: 42
Difficulty: 8
所属公司 :
美团
时间 :2024年3月16日
-
算法标签>树状数组
第4题-众数和
题目思路
对于这题,我们可以将1视为 -1,2 视为 1,那么区间 [l,r] 的和就相当于区间中 2 的数量减去 1 的数量,如果该区间的和 > 0,则说明对于区间 [l,r] 而言,区间的众数为 2 ,否则为 1 。
为了区间和的计算方便,这里采用前缀和来进行处理:
- 记 s[i] 为前 i 个数的和,那么区间 [l,r] 的和可以表示为: s[r]−s[l−1]
将本题转换为,对于每个位置 r,找到其左侧所有满足 s[r]−s[l]>0,l∈[1,r−1] 的数量
其中查询左侧满足的数量个数可以用树状数组来维护查询。
值得注意的是:由于这里将所有的 1 都变成 -1 了, 那么所有的前缀和的范围为 [−n,n] ,因此在使数状数组的时候要加上偏移量 n+1 。
python
import sys
input = lambda:sys.stdin.readline().strip()
res = 0
n = int(input())
a = list(map(int, input().split()))
ans = 0
def add(x):
while x <= n + n + 1:
f[x] += 1
x += x & -x
def get(x):
res = 0
while x > 0:
res += f[x]
x -= x & -x
return res
f = [0] * (n + n + 2)
ans = 0
add(n + 1)
for v in a:
if v == 1:
ans -= 1
else:
ans += 1
res += get(ans + n)
add(ans + n + 1)
r = n * (n + 1) // 2 - res
print(res * 2 + r)
java
import java.util.*;
import java.io.*;
// 注意类名必须为Main
class Node{
int left;
int right;
int val;
public Node(int l,int r){
left = l;
right = r;
}
public String toString(){
return left+" "+right+" "+val;
}
}
class Main {
static int n;
static int[] a;
static int[] diff;
static long m = 0;
static Node[] tree;
public static long aaa(){
//初始化diff数组
//diff范围 -n --> n
bbb();
//建树
tree = new Node[8*n];
ccc();
//插入第一条数据diff[0]
// add(diff[0]);
//遍历diff数组
for(int i=0;i<n;i++){
m += select(-n,diff[i]-1);
add(diff[i]);
if(diff[i]>0){
m++;
}
// System.out.print(diff[i]+" ");
}
// for(int i=0;i<n;i++){
// for(int i=0;i<8*n;i++){
// System.out.println(tree[i]);
// }
// System.out.println(m);
return m + 1L*(n+1)*n/2;
}
public static long select(int leftNum,int rightNum){
if(rightNum < leftNum){
return 0;
}
return diguiSelect(1,leftNum,rightNum);
}
public static long diguiSelect(int treeIndex,int leftNum,int rightNum){
Node node = tree[treeIndex];
if(node.left > rightNum || node.right < leftNum){
return 0;
}
if(node.left >= leftNum && node.right <= rightNum){
return node.val;
}
return diguiSelect(treeIndex*2,leftNum,rightNum) + diguiSelect(treeIndex*2+1,leftNum,rightNum);
}
public static void add(int num){
Node node = tree[1];
int treeIndex = 1;
while(node.left != node.right){
node.val++;
if(num >= tree[treeIndex*2].left && num <= tree[treeIndex*2].right){
treeIndex = treeIndex*2;
// node = tree[treeIndex];
}else{
treeIndex = treeIndex*2+1;
// node = tree[treeIndex*2+1];
}
node = tree[treeIndex];
}
node.val++;
}
public static void ccc(){
tree[1] = new Node(-n,n);
digui(1);
}
public static void digui(int nodeIndex){
if(tree[nodeIndex].left == tree[nodeIndex].right){
return;
}
int mid = (tree[nodeIndex].left+tree[nodeIndex].right)/2;
if(tree[nodeIndex].right <= 0){
tree[2*nodeIndex] = new Node(tree[nodeIndex].left,mid-1);
tree[2*nodeIndex+1] = new Node(mid,tree[nodeIndex].right);
}else{
tree[2*nodeIndex] = new Node(tree[nodeIndex].left,mid);
tree[2*nodeIndex+1] = new Node(mid+1,tree[nodeIndex].right);
}
digui(2*nodeIndex);
digui(2*nodeIndex+1);
}
public static void bbb(){
int n1 = 0;
int n2 = 0;
diff = new int[n];
for(int i=0;i<n;i++){
if(a[i] == 1){
n1++;
}else{
n2++;
}
diff[i] = n2-n1;
}
}
public static void main(String[] args) throws Exception {
// Scanner sc = new Scanner(System.in);
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
// n = sc.nextInt();
n = Integer.parseInt(reader.readLine());
a = new int[n];
String[] ss = reader.readLine().split(" ");
for(int i=0;i<n;i++){
// int t = ;
a[i] = (ss[i].charAt(0) == '1') ? 1 : 2;
}
System.out.println(aaa());
}
}
C++
#include <iostream>
#include <vector>
#include <cassert>
using namespace std;
using ll = long long;
int lowbit(int const& x) {
return (x & (-x));
}
void update_tree(vector<int>& tree, int const& tree_size, int const& wh) {
auto p = wh;
while (p >= 1 && p <= tree_size) {
tree[p]++;
p += lowbit(p);
}
}
int query_tree(vector<int> const& tree, int const& tree_size, int const& wh) {
auto p = wh;
auto ans = 0;
while (p >= 1 && p <= tree_size) {
ans += tree[p];
p -= lowbit(p);
}
return ans;
}
int main(int argc, char const *argv[]) {
std::ios::sync_with_stdio(false);
auto m = 0;
cin >> m;
auto s1 = 0, s2 = 0;
vector<int> deltas(m + 1, 0);
for (auto i = 1; i <= m; i++) {
auto v = 0;
cin >> v;
if (v == 1) {
s1++;
} else {
s2++;
}
auto const d = s2 - s1;
deltas[i] = d;
}
auto const tree_size = 2 * m + 1;
vector<int> tree(tree_size + 5, 0);
ll ans = 0;
for (auto i = 1; i <= m; i++) {
// 统计有多少个是 2 的
ll twos = 0;
if (deltas[i] > 0) {
twos++;
}
if (i > 1) {
// 有多少个比 x 小的
auto x = deltas[i];
x += m + 1;
// 从 1 到 x - 1 有多少个
if (x > 0) {
twos += query_tree(tree, tree_size, x - 1);
}
}
ll ones = i - twos;
assert(twos >= 0 && ones >= 0 && ones + twos == i);
ll tempans = ones + twos * 2;
ans += tempans;
update_tree(tree, tree_size, deltas[i] + m + 1);
}
cout << ans << endl;
return 0;
}
题目描述
小美拿到了一个数组,她希里你求出所有区间众数之和,你能帮帮她吗?
定义区间的众数为出现次数最多的那个数,如果有多个数出现次数最多,那么众数是其中最小的那个数。
输入描述
第一行输入一个正整数n,代表数组的大小
第二行输入n个正整数ai,代表数组的元素
1≤n≤2×105
1≤ai≤2
输出描述
一个正整数,代表所有区间的众数之和。
样例
输入
3
2 1 2
输出
9
说明
[2],[2,1,2],[2]的众数是 2.
[2,1],[1],[1,2]的众数是 1.
因此答案是 9.