BCD Code ZOJ - 3494 | AC自动机 + 数位dp

初始化函数又忘了在主函数写,debug到永远。。。。

题意:
问A到B之间的所有整数,转换成BCD Code后,
有多少个不包含属于给定病毒串集合的子串,A,B <=10^200,病毒串总长度<= 2000.
分析:
1、先建立一颗trie树,注意:build函数end[]数组不仅仅要涉及本结点是不是结束符,还要检查其fail节点是不是结束符。具体代码:

while(!q.empty())
{
    int now = q.front(); q.pop();
    if(end[fail[now]] == 1) end[now] = 1;
    for(int i = 0; i < 2; i++)
        if(nex[now][i] == -1) nex[now][i] = nex[fail[now]][i];
        else fail[nex[now][i]] = nex[fail[now]][i], q.push(nex[now][i]);
}

2、预处理出trie上每个节点从0-9的BCD码转移的过程中会不会出现病毒串,其实对于0-9每个数转移四次就好。
3、数位dp
-注意前导零的判断
-注意字符数组减一的处理
代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
//typedef __int128 lll;
#define close() ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define mem(a, b) memset(a, b, sizeof(a))
const ll mod = 1e9 + 9;
const int maxn = 2e3 + 10;
const int inf = 0x3f3f3f3f;
struct trie
{
    int tot, root;
    int nex[maxn][2], fail[maxn], end[maxn];
    int newnode()
    {
        for(int i = 0; i < 2; i++)
            nex[tot][i] = -1;
        end[tot++] = -1;
        return tot - 1;
    }
    void init()
    {
        tot = 0;
        root = newnode();
    }
    void insert(char s[])
    {
        int now = root;
        for(int i = 0; s[i]; i++)
        {
            if(nex[now][s[i] - '0'] == -1) nex[now][s[i] - '0'] = newnode();
            now = nex[now][s[i] - '0'];
        }
        end[now] = 1;
    }
    void build()
    {
        queue<int> q;
        fail[root] = root;
        for(int i = 0; i < 2; i++)
        {
            if(nex[root][i] == -1) nex[root][i] = root;
            else fail[nex[root][i]] = root, q.push(nex[root][i]);
        }
        while(!q.empty())
        {
            int now = q.front(); q.pop();
            if(end[fail[now]] == 1) end[now] = 1;
            for(int i = 0; i < 2; i++)
                if(nex[now][i] == -1) nex[now][i] = nex[fail[now]][i];
                else fail[nex[now][i]] = nex[fail[now]][i], q.push(nex[now][i]);
        }
    }
};
trie ti;
int v[maxn];
int dp[maxn][maxn];
int bcd[maxn][10];

int find(int pos, int num)
{
    if(ti.end[pos] != -1) return -1;
    for(int i = 3; i >= 0; i--)
    {
        int now = ti.nex[pos][(num >> i) & 1];
        if(ti.end[now] != -1) return -1;
        pos = now;
    }
    return pos;
}

void makebcd()
{
    for(int i = 0; i < ti.tot; i++)
        for(int j = 0; j < 10; j++)
            bcd[i][j] = find(i, j);
}

ll dfs(int pos, int state, int zero, int limit)
{
    if(!pos) return 1;
    if(!limit && ~dp[pos][state]) return dp[pos][state];
    int up = limit ? v[pos] : 9;
    ll res = 0;
    for(int i = 0; i <= up; i++)
    {
        if(!i && zero) res += dfs(pos - 1, state, i == 0 && zero, limit && i == up);
        else if(bcd[state][i] != -1) res += dfs(pos - 1, bcd[state][i], i == 0 && zero, limit && i == up);
        res %= mod;
    }
    if(!limit && !zero) dp[pos][state] = res;
    return res;
}

ll solve(char s[])
{
    int len = strlen(s);
    int pos = 0;
    for(int i = len - 1; i >= 0; i--) v[++pos] = s[i] - '0';
    return dfs(pos, 0, 1, 1);
}

int main()
{
    int t; cin >> t;
    while(t--)
    {
        int n; cin >> n;
        char s[300];
        ti.init();
        mem(dp, -1);
        for(int i = 1; i <= n; i++)
            scanf("%s", s), ti.insert(s);
        ti.build();
        makebcd();
        char l[maxn], r[maxn];
        scanf("%s%s", l, r);
        int len = strlen(l);
        for(int i = len - 1; i >= 0; i--)
            if(l[i] == '0') l[i] = '9';
            else 
            {
                l[i]--;
                break;
            }
        cout << (solve(r) - solve(l) + mod) % mod << endl;
    }
}