初始化函数又忘了在主函数写,debug到永远。。。。
题意:
问A到B之间的所有整数,转换成BCD Code后,
有多少个不包含属于给定病毒串集合的子串,A,B <=10^200,病毒串总长度<= 2000.
分析:
1、先建立一颗trie树,注意:build函数end[]数组不仅仅要涉及本结点是不是结束符,还要检查其fail节点是不是结束符。具体代码:
while(!q.empty())
{
int now = q.front(); q.pop();
if(end[fail[now]] == 1) end[now] = 1;
for(int i = 0; i < 2; i++)
if(nex[now][i] == -1) nex[now][i] = nex[fail[now]][i];
else fail[nex[now][i]] = nex[fail[now]][i], q.push(nex[now][i]);
}
2、预处理出trie上每个节点从0-9的BCD码转移的过程中会不会出现病毒串,其实对于0-9每个数转移四次就好。
3、数位dp
-注意前导零的判断
-注意字符数组减一的处理
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
//typedef __int128 lll;
#define close() ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define mem(a, b) memset(a, b, sizeof(a))
const ll mod = 1e9 + 9;
const int maxn = 2e3 + 10;
const int inf = 0x3f3f3f3f;
struct trie
{
int tot, root;
int nex[maxn][2], fail[maxn], end[maxn];
int newnode()
{
for(int i = 0; i < 2; i++)
nex[tot][i] = -1;
end[tot++] = -1;
return tot - 1;
}
void init()
{
tot = 0;
root = newnode();
}
void insert(char s[])
{
int now = root;
for(int i = 0; s[i]; i++)
{
if(nex[now][s[i] - '0'] == -1) nex[now][s[i] - '0'] = newnode();
now = nex[now][s[i] - '0'];
}
end[now] = 1;
}
void build()
{
queue<int> q;
fail[root] = root;
for(int i = 0; i < 2; i++)
{
if(nex[root][i] == -1) nex[root][i] = root;
else fail[nex[root][i]] = root, q.push(nex[root][i]);
}
while(!q.empty())
{
int now = q.front(); q.pop();
if(end[fail[now]] == 1) end[now] = 1;
for(int i = 0; i < 2; i++)
if(nex[now][i] == -1) nex[now][i] = nex[fail[now]][i];
else fail[nex[now][i]] = nex[fail[now]][i], q.push(nex[now][i]);
}
}
};
trie ti;
int v[maxn];
int dp[maxn][maxn];
int bcd[maxn][10];
int find(int pos, int num)
{
if(ti.end[pos] != -1) return -1;
for(int i = 3; i >= 0; i--)
{
int now = ti.nex[pos][(num >> i) & 1];
if(ti.end[now] != -1) return -1;
pos = now;
}
return pos;
}
void makebcd()
{
for(int i = 0; i < ti.tot; i++)
for(int j = 0; j < 10; j++)
bcd[i][j] = find(i, j);
}
ll dfs(int pos, int state, int zero, int limit)
{
if(!pos) return 1;
if(!limit && ~dp[pos][state]) return dp[pos][state];
int up = limit ? v[pos] : 9;
ll res = 0;
for(int i = 0; i <= up; i++)
{
if(!i && zero) res += dfs(pos - 1, state, i == 0 && zero, limit && i == up);
else if(bcd[state][i] != -1) res += dfs(pos - 1, bcd[state][i], i == 0 && zero, limit && i == up);
res %= mod;
}
if(!limit && !zero) dp[pos][state] = res;
return res;
}
ll solve(char s[])
{
int len = strlen(s);
int pos = 0;
for(int i = len - 1; i >= 0; i--) v[++pos] = s[i] - '0';
return dfs(pos, 0, 1, 1);
}
int main()
{
int t; cin >> t;
while(t--)
{
int n; cin >> n;
char s[300];
ti.init();
mem(dp, -1);
for(int i = 1; i <= n; i++)
scanf("%s", s), ti.insert(s);
ti.build();
makebcd();
char l[maxn], r[maxn];
scanf("%s%s", l, r);
int len = strlen(l);
for(int i = len - 1; i >= 0; i--)
if(l[i] == '0') l[i] = '9';
else
{
l[i]--;
break;
}
cout << (solve(r) - solve(l) + mod) % mod << endl;
}
}