#P3348. 第2题-01串划分
          
                        
                                    
                      
        
              - 
          
          
                      3000ms
            
          
                      Tried: 112
            Accepted: 27
            Difficulty: 6
            
          
          
          
                       所属公司 : 
                              京东
                                
            
                        
              时间 :2025年8月9日
                              
                      
          
 
- 
                        算法标签>动态规划          
 
第2题-01串划分
思路与算法
1)预处理所有候选区间
- 将每个 ai 转为二进制串 bi(如 2→"10")。
 - 在 s 中扫描所有长度为 ∣bi∣ 的窗口,记录与 bi 完全相同的出现位置 (l,r)(下标从 1 开始,r=l+∣bi∣−1)。
 - 若某个 i 完全没有匹配位置,则本组直接输出 NO。
 
2)位压 DP(按右端点递增选区间)
任何一组不相交区间都可以按右端点从小到大排序。考虑状态:
- 
dp[mask][r] 为布尔值,表示已为 mask 中的若干个 ai 各选了一个不相交区间,且最后一个区间的右端点是 r(r=0 表示尚未选任何区间)。
 - 
初始 dp[0][0]=true。
 - 
转移:从 dp[mask][r]=true 出发,对每个尚未使用的 i,枚举其所有出现 (l,rr)。若 l>r(与已选最后区间不相交),则令
dp[mask∪{i}][rr]=true. - 
若存在任意 r 使 dp[(1≪m)−1][r]=true,则可行,输出 YES;否则 NO。
 
3)正确性说明(要点)
- 枚举候选区间保证了只考虑有效匹配。
 - 由于不相交区间可以按右端点排序,DP 始终以“上一个右端点 r”为边界扩展,不会错过解。
 - m≤6,总状态至多 2m⋅(n+1)≤64⋅101。转移枚举的候选区间总量上界约为 m⋅n(长度为 1 时最密),因此整体足够快。
 
复杂度分析
- 预处理匹配:对每个 i 在 s 上滑窗,复杂度 O(∑i(n−∣bi∣+1)⋅∣bi∣)≤ O(m⋅n⋅10)。
 - DP 转移:每个状态尝试把一个未用的 i 扩展到其所有出现处,整体 O(2m⋅∑i#occi),最坏近似 O(2m⋅m⋅n)。
 - 在给定约束下,量级 ≤4×104 级别,绰绰有余。
 
参考实现
Python
import sys
def to_bin(x: int) -> str:
    # 0 的二进制表示为 "0"
    return "0" if x == 0 else bin(x)[2:]
def solve():
    data = sys.stdin.read().strip().split()
    t = int(data[0]); p = 1
    out = []
    for _ in range(t):
        n = int(data[p]); m = int(data[p+1]); p += 2
        s = data[p].strip(); p += 1
        a = list(map(int, data[p:p+m])); p += m
        bins = [to_bin(x) for x in a]
        occ = []  # 每个 i 的出现区间列表 [(l,r), ...],1-based
        impossible = False
        for bi in bins:
            L = len(bi)
            cur = []
            for st in range(n - L + 1):
                if s[st:st+L] == bi:
                    l = st + 1
                    r = st + L
                    cur.append((l, r))
            if not cur:
                impossible = True
            occ.append(cur)
        if impossible:
            out.append("NO")
            continue
        # dp[mask] = 能达到的所有最后右端点 r 的集合(0..n)
        full = (1 << m) - 1
        dp = [set() for _ in range(1 << m)]
        dp[0].add(0)
        ok = False
        for mask in range(1 << m):
            if ok:
                break
            for last_r in list(dp[mask]):
                if ok:
                    break
                # 尝试放入尚未使用的第 i 个数
                for i in range(m):
                    if (mask >> i) & 1:
                        continue
                    for (l, r) in occ[i]:
                        if l > last_r:
                            nm = mask | (1 << i)
                            dp[nm].add(r)
                            if nm == full:
                                ok = True
                                break
                    if ok:
                        break
        out.append("YES" if ok else "NO")
    print("\n".join(out))
if __name__ == "__main__":
    solve()
C++
#include <bits/stdc++.h>
using namespace std;
string to_bin(int x){
    if(x==0) return "0";
    string t;
    while(x){
        t.push_back(char('0'+(x&1)));
        x >>= 1;
    }
    reverse(t.begin(), t.end());
    return t;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int T; 
    if(!(cin>>T)) return 0;
    while(T--){
        int n,m; 
        cin>>n>>m;
        string s; cin>>s;
        vector<int> a(m);
        for(int i=0;i<m;i++) cin>>a[i];
        vector<string> bins(m);
        for(int i=0;i<m;i++) bins[i]=to_bin(a[i]);
        // 预处理每个 a_i 的出现区间
        vector<vector<pair<int,int>>> occ(m);
        bool bad=false;
        for(int i=0;i<m;i++){
            const string &b = bins[i];
            int L = (int)b.size();
            for(int st=0; st+L<=n; ++st){
                if(s.compare(st, L, b)==0){
                    occ[i].push_back({st+1, st+L}); // 1-based
                }
            }
            if(occ[i].empty()) bad = true;
        }
        if(bad){
            cout<<"NO\n";
            continue;
        }
        int full = (1<<m)-1;
        vector<bitset<105>> dp(1<<m); // dp[mask][r] 是否可达
        dp[0].set(0);
        bool ok=false;
        for(int mask=0; mask<(1<<m) && !ok; ++mask){
            for(int r=0; r<=n && !ok; ++r){
                if(!dp[mask].test(r)) continue;
                for(int i=0;i<m && !ok;i++){
                    if(mask>>i & 1) continue;
                    for(auto [l,rr]: occ[i]){
                        if(l>r){
                            int nm = mask | (1<<i);
                            dp[nm].set(rr);
                            if(nm==full){ ok=true; break; }
                        }
                    }
                }
            }
        }
        cout<<(ok?"YES":"NO")<<"\n";
    }
    return 0;
}
Java
import java.io.*;
import java.util.*;
/* 读取输入,位压 DP 判定是否存在不相交匹配 */
public class Main {
    static String toBin(int x){
        // 0 -> "0",其它用标准二进制
        return Integer.toBinaryString(x);
    }
    public static void main(String[] args) throws Exception{
        FastScanner fs = new FastScanner(System.in);
        StringBuilder ans = new StringBuilder();
        int T = fs.nextInt();
        while(T-- > 0){
            int n = fs.nextInt(), m = fs.nextInt();
            String s = fs.next();
            int[] a = new int[m];
            for(int i=0;i<m;i++) a[i]=fs.nextInt();
            String[] bins = new String[m];
            for(int i=0;i<m;i++){
                // Java 中 0 的 toBinaryString 就是 "0"
                bins[i] = toBin(a[i]);
            }
            // 预处理每个 a_i 的出现区间
            @SuppressWarnings("unchecked")
            ArrayList<int[]>[] occ = new ArrayList[m];
            boolean bad = false;
            for(int i=0;i<m;i++){
                occ[i] = new ArrayList<>();
                String b = bins[i];
                int L = b.length();
                for(int st=0; st+L<=n; st++){
                    if(s.regionMatches(st, b, 0, L)){
                        occ[i].add(new int[]{st+1, st+L}); // 1-based
                    }
                }
                if(occ[i].isEmpty()) bad = true;
            }
            if(bad){
                ans.append("NO\n");
                continue;
            }
            int full = (1<<m)-1;
            boolean[][] dp = new boolean[1<<m][n+1];
            dp[0][0] = true;
            boolean ok = false;
            for(int mask=0; mask<(1<<m) && !ok; mask++){
                for(int r=0; r<=n && !ok; r++){
                    if(!dp[mask][r]) continue;
                    for(int i=0;i<m && !ok;i++){
                        if(((mask>>i)&1)==1) continue;
                        for(int[] seg: occ[i]){
                            int l = seg[0], rr = seg[1];
                            if(l>r){
                                int nm = mask | (1<<i);
                                dp[nm][rr] = true;
                                if(nm==full){ ok = true; break; }
                            }
                        }
                    }
                }
            }
            ans.append(ok?"YES":"NO").append('\n');
        }
        System.out.print(ans.toString());
    }
    /* 快速读入 */
    static class FastScanner {
        private final InputStream in;
        private final byte[] buffer = new byte[1<<16];
        private int ptr=0, len=0;
        FastScanner(InputStream is){ in=is; }
        private int read() throws IOException{
            if(ptr>=len){
                len = in.read(buffer);
                ptr = 0;
                if(len<=0) return -1;
            }
            return buffer[ptr++];
        }
        String next() throws IOException{
            StringBuilder sb = new StringBuilder();
            int c;
            while((c=read())!=-1 && c<=32);
            if(c==-1) return null;
            do{
                sb.append((char)c);
                c=read();
            }while(c>32);
            return sb.toString();
        }
        int nextInt() throws IOException{
            String s = next();
            return Integer.parseInt(s);
        }
    }
}
        题目内容
小钟有一个长度为n的01串s,即仅由字符0和1组成的字符串,如0101101。除此之外他还有m个数字,分别用a1,a2,..am表示。
小钟很好奇,他能否选择m个不相交的区间[l1,r1][l2,r2],...,[lm,rm],使得对于任意的ai,其二进制表示(没有前导0,0的二进制表示就是0),都能用s的某个连续子串slj,lj+1,...rj来表示。
输入描述
输入包括多组测试数据。
输入第一行包括一个正整数T(1≤T≤20),表示测试数据的组数。
每组测试数据的第一行有两个整数n(1≤n≤100),m(1≤m≤6),分别表示01串s的长度,数字个数。
第二行有一行长度为n的01串s。
第三行有m个整数a1,a2,...am(0≤ai<210),表示小钟的m个数字。
输出描述
对于每组测试数据,如果存在答案,输出一行“YES”;否则,输出一行“NO”。
样例1
输入
2
5 2
10110
2 1
5 1
00000
1
输出
YES
NO
说明
对于第一组测试数据,2的二进制表示为10,1的二进制表示为1,其中一种可以选择的区间 为 [1,2]、[3,3]。
对于第二组测试数据,1的二进制表示为1,由于01串中不存在字符1,故答案一定不存在。