#P2880. 第3题-可爱数
-
3000ms
Tried: 16
Accepted: 7
Difficulty: 9
所属公司 :
美团
时间 :2025年4月19日-技术岗
-
算法标签>动态规划
第3题-可爱数
解题思路
要在超长上界 n 下计数,不能枚举。我们结合Aho–Corasick 自动机和数位 DP(Digit DP)来做:
1. 构建 Aho–Corasick 自动机
- 插入模式串:将所有 m 个可爱数字串插入字典树,每个结尾节点记录“权重”——该串出现的次数(如果有相同串插入多次,权重大于1)。
- 构建 fail 指针:BFS 建立 fail 链,并在每个节点汇总其 fail 链上所有模式串的权重,这样在匹配过程中,只要落到某节点,就知道以该位置为结尾的新匹配数。
构建复杂度:
- 插入:O(∑∣si∣)
- BFS 构建 fail:O(states×Σ),Σ=10
2. 数位 DP
我们按 n 的十进制高位到低位枚举,并在 DP 中维护:
- 位置 pos:已经处理到第 pos 位。
- AC 状态 st:当前匹配自动机的结点。
- 已匹配次数 cnt∈{0,1,2}:累积的可爱度,超过 1 均记为 2(“≥2”)。
- 限位标记
tight:前缀是否已与 n 完全相等(0/1)。 - 前导零
lz:是否仍在前导零区(0/1),前导零时不进自动机、也不计匹配。
转移:
枚举当前位数字 d,更新
new_tight = tight && (d == n[pos])new_lz = lz && (d == 0)- 若
new_lz,则new_st = 0,不计匹配;否则按自动机从st走到new_st并加上out[new_st](节点权重)到cnt,阈值截断到 2。
最终答案为处理完所有位后,所有 tight∈{0,1}、lz=0、cnt=1 的 DP 值之和。
复杂度分析
- 自动机构建:O(∑∣si∣×Σ)≈104
- 数位 DP:
- 状态数:O(L×states×3×2×2),其中 L= len(n)≤103,states≤103。
- 转移:每状态枚举 ≤10 个数字,整体约 O(L×states×10)≈107,加上常数因子。
- 空间:仅需保存两层位置的 DP,O(states×3×2×2)。
代码实现
Python 代码
MOD = 10**9 + 7
class AC:
def __init__(self):
self.next = [{}] # 字典:digit -> 下一个节点
self.fail = [0] # fail 指针
self.out = [0] # 权重(模式串结束时的计数)
def add(self, s):
p = 0
for ch in s:
d = ord(ch) - 48
if d not in self.next[p]:
self.next[p][d] = len(self.next)
self.next.append({})
self.fail.append(0)
self.out.append(0)
p = self.next[p][d]
self.out[p] += 1
def build(self):
from collections import deque
q = deque()
# 初始化 0 的 fail 为 0
for d, v in self.next[0].items():
q.append(v)
self.fail[v] = 0
while q:
u = q.popleft()
# 累加 fail 链上的权重
self.out[u] += self.out[self.fail[u]]
for d, v in self.next[u].items():
f = self.fail[u]
while f and d not in self.next[f]:
f = self.fail[f]
self.fail[v] = self.next[f].get(d, 0)
q.append(v)
def count_beauty_one(n_str, patterns):
ac = AC()
for s in patterns:
ac.add(s)
ac.build()
L = len(n_str)
S = len(ac.next)
# dp[tight][lz][st][cnt]: 4D -> 当前位
dp = [ [ [ [0]*3 for _ in range(S) ] for __ in range(2) ] for ___ in range(2) ]
dp[1][1][0][0] = 1 # 从最高位开始,tight=1, lz=1, st=0, cnt=0
for i, ch in enumerate(n_str):
ndp = [ [ [ [0]*3 for _ in range(S) ] for __ in range(2) ] for ___ in range(2) ]
limit = ord(ch) - 48
for tight in (0,1):
for lz in (0,1):
for st in range(S):
for cnt in range(3):
v = dp[tight][lz][st][cnt]
if not v: continue
up = limit if tight else 9
for d in range(up+1):
nt = tight and (d==limit)
nlz = lz and (d==0)
if nlz:
nst = 0
add = 0
else:
# AC 转移
p = st
while p and d not in ac.next[p]:
p = ac.fail[p]
nst = ac.next[p].get(d, 0)
add = ac.out[nst]
nc = cnt + add
if nc > 2: nc = 2
ndp[nt][nlz][nst][nc] = (ndp[nt][nlz][nst][nc] + v) % MOD
dp = ndp
ans = 0
for tight in (0,1):
for st in range(S):
ans = (ans + dp[tight][0][st][1]) % MOD
return ans
if __name__ == "__main__":
import sys
data = sys.stdin.read().split()
n_str, m = data[0], int(data[1])
pats = data[2:]
print(count_beauty_one(n_str, pats))
Java 代码
import java.io.*;
import java.util.*;
public class Main {
static final int MOD = 1_000_000_007;
static class AC {
List<Map<Integer, Integer>> next = new ArrayList<>();
List<Integer> fail = new ArrayList<>();
List<Integer> out = new ArrayList<>();
AC() {
next.add(new HashMap<>());
fail.add(0);
out.add(0);
}
void add(String s) {
int p = 0;
for (char c : s.toCharArray()) {
int d = c - '0';
next.get(p).putIfAbsent(d, next.size());
if (next.get(p).get(d) == next.size()) {
next.add(new HashMap<>());
fail.add(0);
out.add(0);
}
p = next.get(p).get(d);
}
out.set(p, out.get(p) + 1);
}
void build() {
Queue<Integer> q = new ArrayDeque<>();
for (var e : next.get(0).entrySet()) {
q.add(e.getValue());
fail.set(e.getValue(), 0);
}
while (!q.isEmpty()) {
int u = q.poll();
out.set(u, out.get(u) + out.get(fail.get(u)));
for (var e : next.get(u).entrySet()) {
int d = e.getKey(), v = e.getValue();
int f = fail.get(u);
while (f != 0 && !next.get(f).containsKey(d)) {
f = fail.get(f);
}
fail.set(v, next.get(f).getOrDefault(d, 0));
q.add(v);
}
}
}
}
public static void main(String[] args) throws IOException {
BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
String[] first = in.readLine().split(" ");
String n = first[0];
int m = Integer.parseInt(first[1]);
AC ac = new AC();
for (int i = 0; i < m; i++) {
ac.add(in.readLine());
}
ac.build();
int L = n.length(), S = ac.next.size();
int[][][][] dp = new int[2][2][S][3];
dp[1][1][0][0] = 1;
for (int i = 0; i < L; i++) {
int[][][][] ndp = new int[2][2][S][3];
int limit = n.charAt(i) - '0';
for (int t = 0; t < 2; t++) {
for (int lz = 0; lz < 2; lz++) {
for (int st = 0; st < S; st++) {
for (int cnt = 0; cnt < 3; cnt++) {
int v = dp[t][lz][st][cnt];
if (v == 0) continue;
for (int d = 0; d <= (t == 1 ? limit : 9); d++) {
int nt = (t==1 && d==limit) ? 1 : 0;
int nlz = (lz==1 && d==0) ? 1 : 0;
int nst, add;
if (nlz == 1) {
nst = 0; add = 0;
} else {
int p = st;
while (p!=0 && !ac.next.get(p).containsKey(d)) p = ac.fail.get(p);
nst = ac.next.get(p).getOrDefault(d, 0);
add = ac.out.get(nst);
}
int nc = cnt + add;
if (nc > 2) nc = 2;
ndp[nt][nlz][nst][nc] = (ndp[nt][nlz][nst][nc] + v) % MOD;
}
}
}
}
}
dp = ndp;
}
long ans = 0;
for (int t = 0; t < 2; t++) {
for (int st = 0; st < ac.next.size(); st++) {
ans = (ans + dp[t][0][st][1]) % MOD;
}
}
System.out.println(ans);
}
}
C++ 代码
#include <bits/stdc++.h>
using namespace std;
static const int MOD = 1e9 + 7;
struct AC {
vector<array<int,10>> nxt;
vector<int> fail, out;
AC() { nxt.push_back({}); fail.push_back(0); out.push_back(0); }
void add(const string &s) {
int p = 0;
for (char c : s) {
int d = c - '0';
if (!nxt[p][d]) {
nxt[p][d] = nxt.size();
nxt.push_back({});
fail.push_back(0);
out.push_back(0);
}
p = nxt[p][d];
}
out[p]++;
}
void build() {
queue<int> q;
for (int d = 0; d < 10; d++) {
int v = nxt[0][d];
if (v) { q.push(v); fail[v] = 0; }
}
while (!q.empty()) {
int u = q.front(); q.pop();
out[u] += out[fail[u]];
for (int d = 0; d < 10; d++) {
int v = nxt[u][d];
if (!v) continue;
int f = fail[u];
while (f && !nxt[f][d]) f = fail[f];
fail[v] = nxt[f][d];
q.push(v);
}
}
}
};
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
string n; int m;
cin >> n >> m;
AC ac;
for (int i = 0; i < m; i++) {
string s; cin >> s;
ac.add(s);
}
ac.build();
int L = n.size(), S = ac.nxt.size();
// dp[tight][lz][state][cnt]
static int dp[2][2][2005][3], ndp[2][2][2005][3];
dp[1][1][0][0] = 1;
for (int i = 0; i < L; i++){
memset(ndp, 0, sizeof ndp);
int lim = n[i] - '0';
for (int t = 0; t < 2; t++){
for (int lz = 0; lz < 2; lz++){
for (int st = 0; st < S; st++){
for (int cnt = 0; cnt < 3; cnt++){
int v = dp[t][lz][st][cnt];
if (!v) continue;
for (int d = 0; d <= (t?lim:9); d++){
int nt = (t && d==lim);
int nlz = (lz && d==0);
int nst, add;
if (nlz) {
nst = 0; add = 0;
} else {
int p = st;
while (p && !ac.nxt[p][d]) p = ac.fail[p];
nst = ac.nxt[p][d];
add = ac.out[nst];
}
int nc = cnt + add;
if (nc > 2) nc = 2;
ndp[nt][nlz][nst][nc] = (ndp[nt][nlz][nst][nc] + v) % MOD;
}
}
}
}
}
memcpy(dp, ndp, sizeof dp);
}
long long ans = 0;
for (int t = 0; t < 2; t++){
for (int st = 0; st < S; st++){
ans = (ans + dp[t][0][st][1]) % MOD;
}
}
cout << ans << "\n";
return 0;
}
题目内容
给定 m 个可爱数字串,它们仅由 0 ~ 9 这九个数字字符构成,且可能包含前导 0。
你需要求解,在区间 [1,n] 中,有多少个整数满足,可爱度恰好为 1。由于答案可能很大,请将答案对 (109+7) 取模后输出。在这里,一个整数的可爱度定义为:
-
取出一段连续的数位,如果这段数位恰好是给定的 m 个可爱数字串中的一个或多个(完全匹配),则可爱度加上这个匹配的次数;
-
对于同一段连续的数位,仅计算一次可爱度。
举例说明,如果有两个可爱串,分别是 1110 和 111,那么 21110 可爱度不为 1,因为它同时包含了两个可爱串;1111 的可爱度为 2,因为包含了两次 111。
输入描述
第一行输入两个整数 n,m(1≤n≤101000;1≤m≤1000),表示询问区间的范围、给定的可爱数字串的数量。
接下来的 m 行,每行输入一个长度不超过 103、仅由 0 ~ 9 这十个数字字符构成的数字串 s,代表一个可爱数字串。可能包含前导 0。
除此之外,保证单个测试文件给出的 s 的字符数量之和不超过 103。
输出描述
输出一个整数,表示在区间 [1,n] 中可爱度恰好为 1 的数字个数。由于答案可能很大,请输出答案对 109+7 取模后的结果。
示例1
输入
2000 2
1110
111
输出
9
说明
在区间 [1,2000] 中,可爱度恰好为 1 的数字有 111,1112,1113,1114,1115,1116,1117,1118,1119,一共 9 个。
示例2
输入
1120 2
111
111
输出
0