#P1604. 第3题-封狼居胥
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 60
            Accepted: 22
            Difficulty: 10
            
          
          
          
                       所属公司 : 
                              阿里
                                
            
                        
              时间 :2023年9月25日-阿里淘天
                              
                      
          
 
- 
                        算法标签>组合数学          
 
第3题-封狼居胥
思路:组合数学 乘法逆元 快速幂
首先我们考虑,对于一个0,如果它只有左侧或者右侧某一侧有1,那么它只能被这一侧的1消灭,因此对应的方案数为1
如果对于两侧都有1的0,我们其实是需要考虑连续的一段0
比如下面这个例子10010001
有两段连续的0的子串,一个含有2个0,一个含有3个0,我们单独考虑每一段0
对于第一段0,我们设2个0为p0,p1,那么其实只有两种消灭顺序:[p0,p1],[p1,p0]
对于第二段0,我们设3个0为p0,p1,p2,有四种消灭顺序:$[p_0,p_1,p_2],[p_0,p_2,p_1],[p_2,p_1,p_0],[p_2,p_0,p_1]$
因此,对于一段含有连续0的个数为k的子串,则有2k−1种消灭顺序
但是,除此之外,对于有多个连续0的子串,比如我们设其长度分别为k1,k2,k3,我们还需要考虑他们之间的“合并”情况
我们设k1+k2+k3=m
则我们先从m个位置选k1位置放第一个子串,对应的方案数为C(m,k1)
然后再从m−k1个位置选k2个位置放第二个子串,对应的方案数为C(m−k1,k2)
最后再从m−k1−k2个位置选k3个位置放第三个子串,对应的方案数为C(m−k1−k2,k3)
将上述方案数累乘即可。
由于本题数据范围较大,求组合数需要使用乘法逆元求解,具体可以参考下面代码。
代码
C++
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 1e5+10;
const int mod = 1e9+7;
int n;
string s;
vector<int> a;
ll jc[N],fjc[N];
ll ksm(ll a, ll b)
{
    ll res = 1;
    while (b)
    {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}
ll C(int a, int b)
{
 if (a < 0 || b < 0 || a < b) return 0;
    if (b == 0 || a == b) return 1;
    return jc[a] * fjc[a - b] % mod * fjc[b] % mod;
}
int main() {
    jc[0] = 1;
    for (int i = 1; i < N; i++)
     jc[i] = jc[i - 1] * i % mod;
    fjc[N - 1] = ksm(jc[N - 1], mod - 2);
    for (int i = N - 2; i; i--)
     fjc[i] = fjc[i + 1] * (i + 1) % mod;
    cin >> n >> s;
    int all = 0;
    for(int i = 0 ; i < s.size() ; i ++) {
     if(s[i] == '1') continue;
     int j = i;
     while(j + 1 < s.size() && s[j + 1] == '0') {
      j ++;
     }
     a.push_back(j-i+1);
     i = j;
     all += a.back();
    }
    ll ans = 1;
    for(int t: a) {
     ans = ans * C(all, t) % mod * ksm(2, t-1) % mod;
     all -= t;
    }
    if(s[0] == '0') {
     ans = ans * ksm(ksm(2, a[0]-1), mod-2) % mod;
    }
    if(s.back() == '0') {
     ans = ans * ksm(ksm(2, a.back()-1), mod-2) % mod;
    }
    cout << ans << endl;
}
Java
import java.io.*;
import java.util.*;
public class Main{
    static final int N = (int)1e5+5;
    static final int MOD = (int)1e9+7;
    public static int qmi(int base, int k) {
        long a = base; long res = 1;
        while (k > 0) {
            if ((k & 1) == 1) res = res * a % MOD;
            a =  a * a %  MOD;
            k >>= 1;
        }
        return (int)res;
    }
    public static void main(String[] args) {
        MyScanner sc = new MyScanner();
        out = new PrintWriter(new BufferedOutputStream(System.out));
        int n = sc.nextInt();
        String s = sc.nextLine();
        // 连通块的数量
        ArrayList<Integer> list1 = new ArrayList<>();
        // 两个方向进攻的敌军数量
        ArrayList<Integer> list2 = new ArrayList<>();
        int firstOne = -1; int lastOne = -1;
        for (int i = 0; i < s.length(); i++) {
            if (s.charAt(i) == '1') {
                if (firstOne == -1) firstOne = i;
                lastOne = i;
            }
        }
        int l = 0;
        while (l < s.length()) {
            while (l < s.length() && s.charAt(l) == '1') l++;
            if (l >= s.length()) break;
            int r = l;
            while (r + 1 < s.length() && s.charAt(l) == s.charAt(r+1)) r++;
            list1.add(r-l+1);
            if (firstOne <= l && l <= lastOne) {
                list2.add(r - l + 1);
            }
            l = r + 1;
        }
        int[] a = new int[N];
        int[] b = new int[N];
        a[1] = 1; b[1] = 1;
        for (int i = 2; i < N; i++) {
            a[i] = (int)((long)a[i-1] * i % MOD);
            b[i] = qmi(a[i], MOD-2);
        }
        long res = 1; int m = list1.get(0);
        for (int i = 1; i < list1.size(); i++) {
            int p = list1.get(i);
            long tmp = ((long)a[m + p] * b[p] % MOD) * b[m] % MOD;
            res = res * tmp % MOD;
            m += p;
        }
        for (int i = 0; i < list2.size(); i++) {
            long tmp = qmi(2, list2.get(i)-1);
            res = res * tmp % MOD;
        }
        System.out.println(res);
        out.close();
    }
    //-----------PrintWriter for faster output---------------------------------
    public static PrintWriter out;
    //-----------MyScanner class for faster input----------
    public static class MyScanner {
        BufferedReader br;
        StringTokenizer st;
        public MyScanner() {
            br = new BufferedReader(new InputStreamReader(System.in));
        }
        String next() {
            while (st == null || !st.hasMoreElements()) {
                try {
                    st = new StringTokenizer(br.readLine());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return st.nextToken();
        }
        int nextInt() {
            return Integer.parseInt(next());
        }
        long nextLong() {
            return Long.parseLong(next());
        }
        double nextDouble() {
            return Double.parseDouble(next());
        }
        String nextLine(){
            String str = "";
            try {
                str = br.readLine();
            } catch (IOException e) {
                e.printStackTrace();
            }
            return str;
        }
    }
    //--------------------------------------------------------
}
Python
n = int(input())
mod = int(1e9 + 7)
f = [1]
for i in range(n):
    f.append((f[-1] * (i + 1)) % mod)
def c(n, m):
    r = f[n] * pow(f[m], -1, mod) * pow(f[n - m], -1, mod)
    
    r %= mod
    #print(n, m, r, f)
    return r
s = [c for c in input()]
cur = 0
while s and s[-1] == '0':
    s.pop()
    cur += 1
cnt = cur
s = s[::-1]
cur = 0
while s and s[-1] == '0':
    s.pop()
    cur += 1
cnt += cur
ans = c(cnt, cur)
#print(ans)
s = ''.join(s)
s = [len(x) for x in s.split('1') if x]
for i in s:
    cnt += i
    ans *= pow(2, i - 1, mod)
    ans *= c(cnt, i)
    ans %= mod
    
print(ans)
        题目描述
霍去病改革下的汉军骑兵,身披甲胄,手持长矛,以冲锋形态向匈奴军队进行无畏冲刺,这是汉荡平匈奴的一大成因。现在你置身广阔的草原,手下的骑兵和匈奴骑兵共n个人已然混作一团,但在霍将军看来,实际上只是一字排开。现在塔子模拟了一个游戏,你需要指挥骑兵作战。我们假设这一排人中有你的大汉的骑兵,也有匈奴的骑兵,每一次指挥,你有两种选择:
1.如果第i个位置是汉军骑兵的,并且i+1<n,第i+1个位置是匈奴骑兵,那么你可以指挥汉军骑兵将其消灭,在第i+1个位置将自动生成一个你的骑兵。(游戏嘛,never mind)
2.如果第i个位置是汉军骑兵的,并且i−1>0,第i−1个位置是匈奴骑兵,那么你可以指挥汉军骑兵将其消灭,在第i−1个位置将自动生成一个你的骑兵。(游戏嘛,never mind)
你必然可以调度过程中生成的你的骑兵!
如此看来,胜利是必然的,但是塔子出品,必属精品,你需要回答有多少种击杀顺序,可以把所有匈奴骑兵都消灭。
输入描述
第一行一个整数n,表示骑兵的数量。 第二行一个字符串s,表示初始时哪些格子是汉军骑兵。如果第i个位置是汉军骑兵,那么si=1,否则si=0。1<n<105
输出描述
输出一个整数,表示答案对 109+7 取模的结果
样例
输入
5
01100
输出
3
说明
可能的击杀所有匈奴骑兵的顺序是,[0,3,4],[3,0,4],[3,4,0]
Limitation
1s, 1024KiB for each test case.