#P2880. 第3题-可爱数
          
                        
                                    
                      
        
              - 
          
          
                      3000ms
            
          
                      Tried: 16
            Accepted: 7
            Difficulty: 9
            
          
          
          
                       所属公司 : 
                              美团
                                
            
                        
              时间 :2025年4月19日-技术岗
                              
                      
          
 
- 
                        算法标签>动态规划          
 
第3题-可爱数
解题思路
要在超长上界 n 下计数,不能枚举。我们结合Aho–Corasick 自动机和数位 DP(Digit DP)来做:
1. 构建 Aho–Corasick 自动机
- 插入模式串:将所有 m 个可爱数字串插入字典树,每个结尾节点记录“权重”——该串出现的次数(如果有相同串插入多次,权重大于1)。
 - 构建 fail 指针:BFS 建立 fail 链,并在每个节点汇总其 fail 链上所有模式串的权重,这样在匹配过程中,只要落到某节点,就知道以该位置为结尾的新匹配数。
 
构建复杂度:
- 插入:O(∑∣si∣)
 - BFS 构建 fail:O(states×Σ),Σ=10
 
2. 数位 DP
我们按 n 的十进制高位到低位枚举,并在 DP 中维护:
- 位置 pos:已经处理到第 pos 位。
 - AC 状态 st:当前匹配自动机的结点。
 - 已匹配次数 cnt∈{0,1,2}:累积的可爱度,超过 1 均记为 2(“≥2”)。
 - 限位标记 
tight:前缀是否已与 n 完全相等(0/1)。 - 前导零 
lz:是否仍在前导零区(0/1),前导零时不进自动机、也不计匹配。 
转移:
枚举当前位数字 d,更新
new_tight = tight && (d == n[pos])new_lz = lz && (d == 0)- 若 
new_lz,则new_st = 0,不计匹配;否则按自动机从st走到new_st并加上out[new_st](节点权重)到cnt,阈值截断到 2。 
最终答案为处理完所有位后,所有 tight∈{0,1}、lz=0、cnt=1 的 DP 值之和。
复杂度分析
- 自动机构建:O(∑∣si∣×Σ)≈104
 - 数位 DP:
- 状态数:O(L×states×3×2×2),其中 L= len(n)≤103,states≤103。
 - 转移:每状态枚举 ≤10 个数字,整体约 O(L×states×10)≈107,加上常数因子。
 
 - 空间:仅需保存两层位置的 DP,O(states×3×2×2)。
 
代码实现
Python 代码
MOD = 10**9 + 7
class AC:
    def __init__(self):
        self.next = [{}]      # 字典:digit -> 下一个节点
        self.fail = [0]       # fail 指针
        self.out = [0]        # 权重(模式串结束时的计数)
    
    def add(self, s):
        p = 0
        for ch in s:
            d = ord(ch) - 48
            if d not in self.next[p]:
                self.next[p][d] = len(self.next)
                self.next.append({})
                self.fail.append(0)
                self.out.append(0)
            p = self.next[p][d]
        self.out[p] += 1
    
    def build(self):
        from collections import deque
        q = deque()
        # 初始化 0 的 fail 为 0
        for d, v in self.next[0].items():
            q.append(v)
            self.fail[v] = 0
        while q:
            u = q.popleft()
            # 累加 fail 链上的权重
            self.out[u] += self.out[self.fail[u]]
            for d, v in self.next[u].items():
                f = self.fail[u]
                while f and d not in self.next[f]:
                    f = self.fail[f]
                self.fail[v] = self.next[f].get(d, 0)
                q.append(v)
def count_beauty_one(n_str, patterns):
    ac = AC()
    for s in patterns:
        ac.add(s)
    ac.build()
    L = len(n_str)
    S = len(ac.next)
    # dp[tight][lz][st][cnt]: 4D -> 当前位
    dp = [ [ [ [0]*3 for _ in range(S) ] for __ in range(2) ] for ___ in range(2) ]
    dp[1][1][0][0] = 1  # 从最高位开始,tight=1, lz=1, st=0, cnt=0
    
    for i, ch in enumerate(n_str):
        ndp = [ [ [ [0]*3 for _ in range(S) ] for __ in range(2) ] for ___ in range(2) ]
        limit = ord(ch) - 48
        for tight in (0,1):
            for lz in (0,1):
                for st in range(S):
                    for cnt in range(3):
                        v = dp[tight][lz][st][cnt]
                        if not v: continue
                        up = limit if tight else 9
                        for d in range(up+1):
                            nt = tight and (d==limit)
                            nlz = lz and (d==0)
                            if nlz:
                                nst = 0
                                add = 0
                            else:
                                # AC 转移
                                p = st
                                while p and d not in ac.next[p]:
                                    p = ac.fail[p]
                                nst = ac.next[p].get(d, 0)
                                add = ac.out[nst]
                            nc = cnt + add
                            if nc > 2: nc = 2
                            ndp[nt][nlz][nst][nc] = (ndp[nt][nlz][nst][nc] + v) % MOD
        dp = ndp
    
    ans = 0
    for tight in (0,1):
        for st in range(S):
            ans = (ans + dp[tight][0][st][1]) % MOD
    return ans
if __name__ == "__main__":
    import sys
    data = sys.stdin.read().split()
    n_str, m = data[0], int(data[1])
    pats = data[2:]
    print(count_beauty_one(n_str, pats))
Java 代码
import java.io.*;
import java.util.*;
public class Main {
    static final int MOD = 1_000_000_007;
    static class AC {
        List<Map<Integer, Integer>> next = new ArrayList<>();
        List<Integer> fail = new ArrayList<>();
        List<Integer> out = new ArrayList<>();
        AC() {
            next.add(new HashMap<>());
            fail.add(0);
            out.add(0);
        }
        void add(String s) {
            int p = 0;
            for (char c : s.toCharArray()) {
                int d = c - '0';
                next.get(p).putIfAbsent(d, next.size());
                if (next.get(p).get(d) == next.size()) {
                    next.add(new HashMap<>());
                    fail.add(0);
                    out.add(0);
                }
                p = next.get(p).get(d);
            }
            out.set(p, out.get(p) + 1);
        }
        void build() {
            Queue<Integer> q = new ArrayDeque<>();
            for (var e : next.get(0).entrySet()) {
                q.add(e.getValue());
                fail.set(e.getValue(), 0);
            }
            while (!q.isEmpty()) {
                int u = q.poll();
                out.set(u, out.get(u) + out.get(fail.get(u)));
                for (var e : next.get(u).entrySet()) {
                    int d = e.getKey(), v = e.getValue();
                    int f = fail.get(u);
                    while (f != 0 && !next.get(f).containsKey(d)) {
                        f = fail.get(f);
                    }
                    fail.set(v, next.get(f).getOrDefault(d, 0));
                    q.add(v);
                }
            }
        }
    }
    public static void main(String[] args) throws IOException {
        BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
        String[] first = in.readLine().split(" ");
        String n = first[0];
        int m = Integer.parseInt(first[1]);
        AC ac = new AC();
        for (int i = 0; i < m; i++) {
            ac.add(in.readLine());
        }
        ac.build();
        int L = n.length(), S = ac.next.size();
        int[][][][] dp = new int[2][2][S][3];
        dp[1][1][0][0] = 1;
        for (int i = 0; i < L; i++) {
            int[][][][] ndp = new int[2][2][S][3];
            int limit = n.charAt(i) - '0';
            for (int t = 0; t < 2; t++) {
                for (int lz = 0; lz < 2; lz++) {
                    for (int st = 0; st < S; st++) {
                        for (int cnt = 0; cnt < 3; cnt++) {
                            int v = dp[t][lz][st][cnt];
                            if (v == 0) continue;
                            for (int d = 0; d <= (t == 1 ? limit : 9); d++) {
                                int nt = (t==1 && d==limit) ? 1 : 0;
                                int nlz = (lz==1 && d==0) ? 1 : 0;
                                int nst, add;
                                if (nlz == 1) {
                                    nst = 0; add = 0;
                                } else {
                                    int p = st;
                                    while (p!=0 && !ac.next.get(p).containsKey(d)) p = ac.fail.get(p);
                                    nst = ac.next.get(p).getOrDefault(d, 0);
                                    add = ac.out.get(nst);
                                }
                                int nc = cnt + add;
                                if (nc > 2) nc = 2;
                                ndp[nt][nlz][nst][nc] = (ndp[nt][nlz][nst][nc] + v) % MOD;
                            }
                        }
                    }
                }
            }
            dp = ndp;
        }
        long ans = 0;
        for (int t = 0; t < 2; t++) {
            for (int st = 0; st < ac.next.size(); st++) {
                ans = (ans + dp[t][0][st][1]) % MOD;
            }
        }
        System.out.println(ans);
    }
}
C++ 代码
#include <bits/stdc++.h>
using namespace std;
static const int MOD = 1e9 + 7;
struct AC {
    vector<array<int,10>> nxt;
    vector<int> fail, out;
    AC() { nxt.push_back({}); fail.push_back(0); out.push_back(0); }
    void add(const string &s) {
        int p = 0;
        for (char c : s) {
            int d = c - '0';
            if (!nxt[p][d]) {
                nxt[p][d] = nxt.size();
                nxt.push_back({});
                fail.push_back(0);
                out.push_back(0);
            }
            p = nxt[p][d];
        }
        out[p]++;
    }
    void build() {
        queue<int> q;
        for (int d = 0; d < 10; d++) {
            int v = nxt[0][d];
            if (v) { q.push(v); fail[v] = 0; }
        }
        while (!q.empty()) {
            int u = q.front(); q.pop();
            out[u] += out[fail[u]];
            for (int d = 0; d < 10; d++) {
                int v = nxt[u][d];
                if (!v) continue;
                int f = fail[u];
                while (f && !nxt[f][d]) f = fail[f];
                fail[v] = nxt[f][d];
                q.push(v);
            }
        }
    }
};
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    string n; int m;
    cin >> n >> m;
    AC ac;
    for (int i = 0; i < m; i++) {
        string s; cin >> s;
        ac.add(s);
    }
    ac.build();
    int L = n.size(), S = ac.nxt.size();
    // dp[tight][lz][state][cnt]
    static int dp[2][2][2005][3], ndp[2][2][2005][3];
    dp[1][1][0][0] = 1;
    for (int i = 0; i < L; i++){
        memset(ndp, 0, sizeof ndp);
        int lim = n[i] - '0';
        for (int t = 0; t < 2; t++){
            for (int lz = 0; lz < 2; lz++){
                for (int st = 0; st < S; st++){
                    for (int cnt = 0; cnt < 3; cnt++){
                        int v = dp[t][lz][st][cnt];
                        if (!v) continue;
                        for (int d = 0; d <= (t?lim:9); d++){
                            int nt = (t && d==lim);
                            int nlz = (lz && d==0);
                            int nst, add;
                            if (nlz) {
                                nst = 0; add = 0;
                            } else {
                                int p = st;
                                while (p && !ac.nxt[p][d]) p = ac.fail[p];
                                nst = ac.nxt[p][d];
                                add = ac.out[nst];
                            }
                            int nc = cnt + add;
                            if (nc > 2) nc = 2;
                            ndp[nt][nlz][nst][nc] = (ndp[nt][nlz][nst][nc] + v) % MOD;
                        }
                    }
                }
            }
        }
        memcpy(dp, ndp, sizeof dp);
    }
    long long ans = 0;
    for (int t = 0; t < 2; t++){
        for (int st = 0; st < S; st++){
            ans = (ans + dp[t][0][st][1]) % MOD;
        }
    }
    cout << ans << "\n";
    return 0;
}
        题目内容
给定 m 个可爱数字串,它们仅由 0 ~ 9 这九个数字字符构成,且可能包含前导 0。
你需要求解,在区间 [1,n] 中,有多少个整数满足,可爱度恰好为 1。由于答案可能很大,请将答案对 (109+7) 取模后输出。在这里,一个整数的可爱度定义为:
- 
取出一段连续的数位,如果这段数位恰好是给定的 m 个可爱数字串中的一个或多个(完全匹配),则可爱度加上这个匹配的次数;
 - 
对于同一段连续的数位,仅计算一次可爱度。
 
举例说明,如果有两个可爱串,分别是 1110 和 111,那么 21110 可爱度不为 1,因为它同时包含了两个可爱串;1111 的可爱度为 2,因为包含了两次 111。
输入描述
第一行输入两个整数 n,m(1≤n≤101000;1≤m≤1000),表示询问区间的范围、给定的可爱数字串的数量。
接下来的 m 行,每行输入一个长度不超过 103、仅由 0 ~ 9 这十个数字字符构成的数字串 s,代表一个可爱数字串。可能包含前导 0。
除此之外,保证单个测试文件给出的 s 的字符数量之和不超过 103。
输出描述
输出一个整数,表示在区间 [1,n] 中可爱度恰好为 1 的数字个数。由于答案可能很大,请输出答案对 109+7 取模后的结果。
示例1
输入
2000 2
1110
111
输出
9
说明
在区间 [1,2000] 中,可爱度恰好为 1 的数字有 111,1112,1113,1114,1115,1116,1117,1118,1119,一共 9 个。
示例2
输入
1120 2
111
111
输出
0