#P3678. 第2题-01串计数
-
1000ms
Tried: 101
Accepted: 20
Difficulty: 6
所属公司 :
京东
时间 :2025年9月13日
-
算法标签>数位dp
第2题-01串计数
解题思路
给定一个上界 n(以二进制字符串给出)和 m 条限制。每条限制为一对 (p,o),表示:把 x 写成二进制(最低位下标从 0 计),则从低位起的第 p 位必须等于 o∈{0,1}。问区间 [0,n] 内满足所有限制的整数个数,答案对 998244353 取模。
关键观察:
-
若某个位置被要求既为 0 又为 1,或某条限制要求在 p≥L(L=len(n))的高位为 1,则不可能,答案为 0。因为对所有 x≤n,在第 p(≥L) 位一定为 0。
-
将限制先映射到长度为 L 的数组
need[i](下标从高位到低位)。 输入的 p 是从低位计数,换算为从高位的下标:i = L-1-p。若i<0或i>=L则是上面的越界情况。 -
在长度为 L 的二进制位上做“数位 DP”(Digit DP,二进制版本):从最高位向最低位扫描,维护是否“已严格小于 n”的状态。
-
设
f0表示已经比前缀小(loose)的方案数,f1表示到当前位为止与 n 完全相等(tight)的方案数。 -
对当前位允许的取值集合
allowed(由限制决定,可能是{0,1}、{0}、{1}):- 从
f0转移:已经小于,任何允许位都能选,仍然保持小于:new0 += f0 * |allowed|。 - 从
f1转移:只能选不超过n当前位nb的允许位。 若取b < nb则转入new0;若b == nb则转入new1;b > nb不可选。
- 从
-
-
初值
f1=1, f0=0。最终答案为(f0 + f1) mod 998244353。
复杂度分析
- 预处理限制:O(m)
- 数位 DP:每一位最多尝试两种取值,O(L)
- 总时间复杂度:O(L+m)
- 额外空间:O(L)(存限制;若用数组只存 −1/0/1 可视为 O(L))
代码实现
Python
import sys
MOD = 998244353
def main():
data = sys.stdin.read().strip().split()
if not data:
return
it = iter(data)
s = next(it).strip() # n 的二进制串,无前导零
L = len(s)
m = int(next(it)) # 限制条数
# need[i]: 从高位到低位的第 i 位的要求,-1 表示不限,0/1 表示固定
need = [-1] * L
# 读取限制
for _ in range(m):
p = int(next(it))
o = int(next(it))
if p >= L:
# 对所有 x<=n,高于最高位的比特均为 0;若要求为 1 则不可能
if o == 1:
print(0)
return
# 要求为 0 则自动满足,忽略
continue
i = L - 1 - p # 转成从高位到低位的下标
if need[i] == -1:
need[i] = o
elif need[i] != o:
# 同一位出现矛盾
print(0)
return
f0, f1 = 0, 1 # f0: 已经小于;f1: 前缀相等
for i in range(L):
nb = ord(s[i]) - ord('0') # n 的当前位
# 当前位允许的集合
if need[i] == -1:
allowed = (0, 1)
else:
allowed = (need[i],)
new0, new1 = 0, 0
# 从已经小于的状态转移:任意允许位都可以,仍然保持小于
cnt_allowed = len(allowed)
new0 = (new0 + f0 * cnt_allowed) % MOD
# 从相等状态转移:不能超过 n 的当前位
if f1:
for b in allowed:
if b < nb:
new0 = (new0 + f1) % MOD
elif b == nb:
new1 = (new1 + f1) % MOD
# b > nb 直接跳过
f0, f1 = new0, new1
print((f0 + f1) % MOD)
if __name__ == "__main__":
main()
C++
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
string s; // n 的二进制
if (!(cin >> s)) return 0;
int L = (int)s.size();
int m; cin >> m;
vector<int> need(L, -1); // -1 不限,0/1 固定
// 读取限制并检测矛盾
for (int k = 0; k < m; ++k) {
long long p; int o;
cin >> p >> o;
if (p >= L) {
if (o == 1) { // 超过最高位要求为 1,不可能
cout << 0 << '\n';
return 0;
}
continue; // 要求为 0 自动满足
}
int i = L - 1 - (int)p; // 转到从高位起的下标
if (need[i] == -1) need[i] = o;
else if (need[i] != o) { // 同位矛盾
cout << 0 << '\n';
return 0;
}
}
long long f0 = 0, f1 = 1; // f0: 已经小于;f1: 前缀相等
for (int i = 0; i < L; ++i) {
int nb = s[i] - '0';
long long new0 = 0, new1 = 0;
if (need[i] == -1) {
// allowed = {0,1}
new0 = (new0 + f0 * 2) % MOD; // 已小于可任意选两种
if (f1) {
if (0 < nb) new0 = (new0 + f1) % MOD;
if (0 == nb) new1 = (new1 + f1) % MOD;
if (1 < nb) new0 = (new0 + f1) % MOD;
if (1 == nb) new1 = (new1 + f1) % MOD;
}
} else {
int b = need[i]; // 只能取固定位
new0 = (new0 + f0) % MOD; // 已小于,选它仍小于
if (f1) {
if (b < nb) new0 = (new0 + f1) % MOD;
else if (b == nb) new1 = (new1 + f1) % MOD;
}
}
f0 = new0 % MOD;
f1 = new1 % MOD;
}
cout << (f0 + f1) % MOD << '\n';
return 0;
}
Java
import java.io.*;
import java.util.*;
public class Main {
static final int MOD = 998244353;
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
// 读取全部文本后按空白切分,避免逐行判断的繁琐
StringBuilder all = new StringBuilder();
String line;
while ((line = br.readLine()) != null) {
all.append(line).append(' ');
}
String[] tokens = all.toString().trim().split("\\s+");
if (tokens.length == 0) return;
int idx = 0;
String s = tokens[idx++]; // n 的二进制
int L = s.length();
int m = Integer.parseInt(tokens[idx++]);
int[] need = new int[L]; // -1 不限,0/1 固定
Arrays.fill(need, -1);
// 读取限制
for (int k = 0; k < m; k++) {
long p = Long.parseLong(tokens[idx++]);
int o = Integer.parseInt(tokens[idx++]);
if (p >= L) {
if (o == 1) { // 超过最高位要求为 1,不可能
System.out.println(0);
return;
}
continue; // 要求为 0 自动满足
}
int i = (int)(L - 1 - p); // 转成从高位起的下标
if (need[i] == -1) need[i] = o;
else if (need[i] != o) { // 同位矛盾
System.out.println(0);
return;
}
}
long f0 = 0, f1 = 1; // f0: 已小于;f1: 前缀相等
for (int i = 0; i < L; i++) {
int nb = s.charAt(i) - '0';
long new0 = 0, new1 = 0;
if (need[i] == -1) {
// 该位可为 0 或 1
new0 = (new0 + f0 * 2) % MOD; // 已小于,任选两种
if (f1 != 0) {
// 选择 0
if (0 < nb) new0 = (new0 + f1) % MOD;
else if (0 == nb) new1 = (new1 + f1) % MOD;
// 选择 1
if (1 < nb) new0 = (new0 + f1) % MOD;
else if (1 == nb) new1 = (new1 + f1) % MOD;
}
} else {
int b = need[i]; // 固定位
new0 = (new0 + f0) % MOD; // 已小于仍小于
if (f1 != 0) {
if (b < nb) new0 = (new0 + f1) % MOD;
else if (b == nb) new1 = (new1 + f1) % MOD;
}
}
f0 = new0 % MOD;
f1 = new1 % MOD;
}
System.out.println((f0 + f1) % MOD);
}
}
题目内容
给定一个正整数 n ,你需要计算有多少个非负整数 x∈[0,n] 且满足给定的所有 m 条限制。
每一条限制形如:将 x 表示成二进制,某一位必须是 0/ 必须是 1 。
输入描述
第一行描述 n 。本题中,n 可能很大(详见数据规模),所以 n 会以二进制表示法给出。例如,若 n=11 ,则输入为 1011 。保证输入的二进制电不含前导 0 。
第二行一个正整数 m ,表示限制条数。
接下来 m 行,每一行两个整数 p,o(0≤p≤[log2n]+1,o∈{0,1}),描述一条限制:若将 x 表示成二进制,则从低位起的第 p 位必须是 o 。
注意,我们约定最低位下标从 0 开始,例如,对于二进制数 1011 ,从低位起的第 0 位是 1 ,第 1 位是 1 ,第 2 位是 0 ,第 3 位是 1 。1≤n≤22∗105,1≤m≤2∗105
输出描述
输出一行一个整数,表示满足所有限制的 x 的个数。
因为答案可能很大,所以你只需要输出答案对 998244353 取模的结果。
样例1
输入
1011
2
1 1
3 0
输出
4
说明
满足条件的 x 有:2,3,6 和 7 。它们写成二进制分别为 (0010)2,(0011)2,(0110)2 和 (0111)2 。
样例2
输入
110011
3
0 0
2 0
4 1
输出
6