#P1704. 第4题-众数和
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 200
            Accepted: 42
            Difficulty: 8
            
          
          
          
                       所属公司 : 
                              美团
                                
            
                        
              时间 :2024年3月16日
                              
                      
          
 
- 
                        算法标签>树状数组          
 
第4题-众数和
题目思路
 对于这题,我们可以将1视为 -1,2 视为 1,那么区间 [l,r] 的和就相当于区间中 2 的数量减去 1 的数量,如果该区间的和 > 0,则说明对于区间 [l,r] 而言,区间的众数为 2 ,否则为 1 。
 为了区间和的计算方便,这里采用前缀和来进行处理:
- 记 s[i] 为前 i 个数的和,那么区间 [l,r] 的和可以表示为: s[r]−s[l−1]
 
 将本题转换为,对于每个位置 r,找到其左侧所有满足 s[r]−s[l]>0,l∈[1,r−1] 的数量
其中查询左侧满足的数量个数可以用树状数组来维护查询。
值得注意的是:由于这里将所有的 1 都变成 -1 了, 那么所有的前缀和的范围为 [−n,n] ,因此在使数状数组的时候要加上偏移量 n+1 。
python
import sys
input = lambda:sys.stdin.readline().strip()
res = 0
n = int(input())
a = list(map(int, input().split()))
ans = 0
def add(x):
    while x <= n + n + 1:
        f[x] += 1
        x += x & -x
def get(x):
    res = 0
    while x > 0:
        res += f[x]
        x -= x & -x
    return res
f = [0] * (n + n + 2)
ans = 0
add(n + 1)
for v in a:
    if v == 1:
        ans -= 1
    else:
        ans += 1
    res += get(ans + n)
    add(ans + n + 1)
r = n * (n + 1) // 2 - res
print(res * 2 + r)
java
import java.util.*;
import java.io.*;
// 注意类名必须为Main
class Node{
    int left;
    int right;
    int val;
    public Node(int l,int r){
        left = l;
        right = r;
    }
    public String toString(){
        return left+" "+right+" "+val;
    }
}
class Main {
    static int n;
    static int[] a;
    static int[] diff;
    static long m = 0;
    static Node[] tree;
    public static long aaa(){
        //初始化diff数组
        //diff范围 -n --> n
        bbb();
        //建树
        tree = new Node[8*n];
        ccc();
        //插入第一条数据diff[0]
        // add(diff[0]);
        //遍历diff数组
        for(int i=0;i<n;i++){
            m += select(-n,diff[i]-1);
            add(diff[i]);
            if(diff[i]>0){
                m++;
            }
            // System.out.print(diff[i]+" ");
        }
        // for(int i=0;i<n;i++){
        // for(int i=0;i<8*n;i++){
        //     System.out.println(tree[i]);
        // }
        // System.out.println(m);
        return m + 1L*(n+1)*n/2;
    }
    public static long select(int leftNum,int rightNum){
        if(rightNum < leftNum){
            return 0;
        }
        return diguiSelect(1,leftNum,rightNum);
    }
    public static long diguiSelect(int treeIndex,int leftNum,int rightNum){
        Node node = tree[treeIndex];
        if(node.left > rightNum || node.right < leftNum){
            return 0;
        }
        if(node.left >= leftNum && node.right <= rightNum){
            return node.val;
        }
        return diguiSelect(treeIndex*2,leftNum,rightNum) + diguiSelect(treeIndex*2+1,leftNum,rightNum);
        
    }
    public static void add(int num){
        Node node = tree[1];
        int treeIndex = 1;
        while(node.left != node.right){
            node.val++;
            if(num >= tree[treeIndex*2].left && num <= tree[treeIndex*2].right){
                treeIndex = treeIndex*2;
                // node = tree[treeIndex];
            }else{
                treeIndex = treeIndex*2+1;
                // node = tree[treeIndex*2+1];
            }
            node = tree[treeIndex];
        }
        node.val++;
    }
    public static void ccc(){
        tree[1] = new Node(-n,n);
        digui(1);
    }
    public static void digui(int nodeIndex){
        if(tree[nodeIndex].left == tree[nodeIndex].right){
            return;
        }
        int mid = (tree[nodeIndex].left+tree[nodeIndex].right)/2;
        if(tree[nodeIndex].right <= 0){
            tree[2*nodeIndex] = new Node(tree[nodeIndex].left,mid-1);
            tree[2*nodeIndex+1] = new Node(mid,tree[nodeIndex].right);
        }else{
            tree[2*nodeIndex] = new Node(tree[nodeIndex].left,mid);
            tree[2*nodeIndex+1] = new Node(mid+1,tree[nodeIndex].right);
        }
        
        digui(2*nodeIndex);
        digui(2*nodeIndex+1);
    }
    public static void bbb(){
        int n1 = 0;
        int n2 = 0;
        diff = new int[n];
        for(int i=0;i<n;i++){
            if(a[i] == 1){
                n1++;
            }else{
                n2++;
            }
            diff[i] = n2-n1;
        }
    }
    public static void main(String[] args) throws Exception {
        // Scanner sc = new Scanner(System.in);
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
        // n = sc.nextInt();
        n = Integer.parseInt(reader.readLine());
        a = new int[n];
        String[] ss = reader.readLine().split(" ");
        for(int i=0;i<n;i++){
            // int t = ;
            a[i] = (ss[i].charAt(0) == '1') ? 1 : 2;
        }
        System.out.println(aaa());
        
    }
}
C++
#include <iostream>
#include <vector>
#include <cassert>
using namespace std;
using ll = long long;
int lowbit(int const& x) {
    return (x & (-x));
}
void update_tree(vector<int>& tree, int const& tree_size, int const& wh) {
    auto p = wh;
    while (p >= 1 && p <= tree_size) {
        tree[p]++;
        p += lowbit(p);
    }
}
int query_tree(vector<int> const& tree, int const& tree_size, int const& wh) {
    auto p = wh;
    auto ans = 0;
    while (p >= 1 && p <= tree_size) {
        ans += tree[p];
        p -= lowbit(p);
    }
    return ans;
}
int main(int argc, char const *argv[]) {
    std::ios::sync_with_stdio(false);
    auto m = 0;
    cin >> m;
    auto s1 = 0, s2 = 0;
    vector<int> deltas(m + 1, 0);
    for (auto i = 1; i <= m; i++) {
        auto v = 0;
        cin >> v;
        if (v == 1) {
            s1++;
        } else {
            s2++;
        }
        auto const d = s2 - s1;
        deltas[i] = d;
    }
    auto const tree_size = 2 * m + 1;
    vector<int> tree(tree_size + 5, 0);
    ll ans = 0;
    for (auto i = 1; i <= m; i++) {
        // 统计有多少个是 2 的
        ll twos = 0;
        if (deltas[i] > 0) {
            twos++;
        }
        if (i > 1) {
            // 有多少个比 x 小的
            auto x = deltas[i];
            x += m + 1;
            // 从 1 到 x - 1 有多少个
            if (x > 0) {
                twos += query_tree(tree, tree_size, x - 1);
            }
        }
        ll ones = i - twos;
        assert(twos >= 0 && ones >= 0 && ones + twos == i);
        ll tempans = ones + twos * 2;
        ans += tempans;
        update_tree(tree, tree_size, deltas[i] + m + 1);
    }
    cout << ans << endl;
    return 0;
}
        题目描述
小美拿到了一个数组,她希里你求出所有区间众数之和,你能帮帮她吗?
定义区间的众数为出现次数最多的那个数,如果有多个数出现次数最多,那么众数是其中最小的那个数。
输入描述
第一行输入一个正整数n,代表数组的大小
第二行输入n个正整数ai,代表数组的元素
1≤n≤2×105
1≤ai≤2
输出描述
一个正整数,代表所有区间的众数之和。
样例
输入
3
2 1 2
输出
9
说明
[2],[2,1,2],[2]的众数是 2.
[2,1],[1],[1,2]的众数是 1.
因此答案是 9.