这里给出无需容斥的解法。
|api−api+1| 只与相邻两个位置的值有关,这是连续段 DP 的重要标志。
众所周知连续段 DP 的思想是逐个加入元素来生成整个序列,并决定这个元素是新建连续段、延续连续段、抑或是合并连续段。
能否延续和合并的标准,即为这一个数和与之相邻的那个数的差的绝对值是否为 k。
这启示我们寻找约束关系以确定某特定的元素能否延续和合并某特定的连续段。
把所有差为 k 的数之间连边,约束关系应恰好是若干条链的形式(这里我们把单点也算是链的一种)。
假如只有一条链,如何解决这个问题?比如 k=1,a={1,2,3,⋯,n} 的情况。
像其它的连续段 DP 一样,我们设计的状态大致形如:fi,j 表示:填入了这条链上的前 i 个数,现在已经形成了 j 段,合法局面的方案数。
但是你发现没法转移,因为填入这条链的第 i+1 个数的时候,你并不确定第 i 个数填在了哪里,但第 i 个数填在哪里会决定这一次转移的系数。
举个例子:
- 假如第 i 个数合并了两个连续段,现在它在一整段的内部。那么填第 i+1 个数的时候,这个数可以随便乱放,因为其无论如何也不会和第 i 个数相邻。
- 但如果第 i 个数新建了一个连续段,那填第 i+1 个数的时候就会有两个位置不能填。
- 同理,假如第 i 个数延续连续段,其转移系数与第 i 个数新建或是合并连续段亦不相同。
那我们在状态里面再记上第 i 个数是新建连续段、延续连续段、抑或是合并连续段。可以转移吗?
似乎也不行,因为:
有可能第 i 个数延续连续段并被放在了当前整个序列的一端(此时你记录下来:第 i 个数的状态是延续连续段),接下来放第 i+1 个数,假如你要让其去合并连续段,你会发现即使第 i 个数是延续连续段,但是实际上不对第 i+1 个数造成干扰——随便合并任意两段都可以。
但是如果第 i 个数延续连续段并被放在了当前整个序列的非一端(此时你记录下来:第 i 个数的状态是延续连续段),接下来放第 i+1 个数,假如你要让其去合并连续段,你会发现第 i 个数是延续连续段,但是对第 i+1 个数造成干扰——并不是随便合并任意两段都可以了。
那么我们再记录第 i 个数是否放在了当前生成的序列的一端,即现在的状态表示如下:
- fi,j,0——填入了这条链上的前 i 个数,现在已经形成了 j 段,第 i 个数合并连续段,合法局面的方案数;
- fi,j,1——填入了这条链上的前 i 个数,现在已经形成了 j 段,第 i 个数延续连续段且其不在当前生成序列的任何一端,合法局面的方案数;
- fi,j,2——填入了这条链上的前 i 个数,现在已经形成了 j 段,第 i 个数延续连续段且其在当前生成序列的其中一端,合法局面的方案数;
- fi,j,3——填入了这条链上的前 i 个数,现在已经形成了 j 段,第 i 个数新建连续段且其不在当前生成序列的任何一端,合法局面的方案数;
- fi,j,4——填入了这条链上的前 i 个数,现在已经形成了 j 段,第 i 个数新建连续段且其在当前生成序列的其中一端,合法局面的方案数;
转移方程比较繁琐但并不困难,代码实现中清晰地呈现了转移方程供您参考。
那么对于多条链的情况(即原问题),如何改进算法?
多条链的情况实际上相当于一条链的情况删去若干约束。
区别于上面讨论的情况,假如第 i 个数与上一个填入的数之间不存在约束,那如何来描述这种状态?
这里不要想得太复杂,实际上从本质上来讲,这种情况是非常简单的——我们之前想了那么多方法,就是为了解决两个数之间存在约束的情况所产生的系数不一致的问题,现在好了,根本不存在约束了,你开心了吧!
你完全可以新建一个 fi,j,5——填入了这条链上的前 i 个数,现在已经形成了 j 段,第 i+1 个数与这一个数(即第 i 个数)之间不存在约束,合法局面的方案数;
转移的话也是简单的——假如第 i+1 个数与这一个数(即第 i 个数)之间不存在约束,那就直接转移到 fi,j,5 就行了嘛!
同样地,fi,j,5 向 i+1 的转移也并不复杂——你可以参照 fi,j,0 向 i+1 的转移。
(哦草,不存在约束的情况,即 fi,j,5,不就是 fi,j,0 嘛。)
所以有一个等效的写法,就是你可以直接转移到 fi,j,0,这样就不需要再新设一种状态了。
求出所有链想必对大家来说都是简单的,那么这个问题就以 O(n2) 的复杂度告终。
代码中有极其疯狂的转移!!!
#include<bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<string>
#include<cmath>
#include<unordered_map>
#include<vector>
#include<bitset>
#include<queue>
#include<set>
#include<map>
#include<ctime>
#include<random>
#include<numeric>
using namespace std;
#define int long long
#define ll long long
#define ull unsigned long long
#define lc (x<<1)
#define rc (x<<1|1)
#define pii pair<int,int>
#define mkp make_pair
#define fi first
#define se second
const int Mx=5005,p=998244353;
int read(){
char ch=getchar();
int Alice=0,Aya=1;
while(ch<'0'||ch>'9'){
if(ch=='-') Aya=-Aya;
ch=getchar();
}
while(ch>='0'&&ch<='9')
Alice=(Alice<<3)+(Alice<<1)+(ch^48),ch=getchar();
return (Aya==1?Alice:-Alice);
}
int n,k;
int a[Mx];
int f[Mx][Mx][5];
bool vis[Mx];
int len[Mx],ed[Mx];
signed main(){
n=read(),k=read();
for(int i=1;i<=n;i++){
a[i]=read();
len[i]=1,ed[i]=1;
}
for(int i=1;i<=n;i++){
for(int j=i+1;j<=n;j++){
if(a[i]+k==a[j]) len[j]=len[i]+1,ed[j]=1,ed[i]=0;
}
}
int s=0,c=0;
vector<int>vec;
for(int i=1;i<=n;i++) if(ed[i]){
vec.push_back(len[i]);
}
sort(vec.begin(),vec.end(),greater<int>());
for(int v:vec){
s+=v,c++;
vis[s]=1;
}
if(c==n){
int ans=1;
for(int i=1;i<=n;i++) (ans*=i)%=p;
cout<<ans<<endl;
return 0;
}
f[2][2][4]=2;
for(int i=3;i<=n;i++){
if(vis[i-1]){
for(int j=1;j<=i;j++){
f[i-1][j][0]=f[i-1][j][0]+f[i-1][j][1]+f[i-1][j][2]+f[i-1][j][3]+f[i-1][j][4];
f[i-1][j][1]=0;
f[i-1][j][2]=0;
f[i-1][j][3]=0;
f[i-1][j][4]=0;
}
}
for(int j=1;j<=i;j++){
f[i][j][0]=f[i-1][j+1][0]*j
+f[i-1][j+1][1]*(j-1)
+f[i-1][j+1][2]*j
+f[i-1][j+1][3]*(j-2)
+f[i-1][j+1][4]*(j-1);
f[i][j][1]=f[i-1][j][0]*(2*j-2)
+f[i-1][j][1]*(2*j-3)
+f[i-1][j][2]*(2*j-2)
+f[i-1][j][3]*(2*j-4)
+f[i-1][j][4]*(2*j-3);
f[i][j][2]=f[i-1][j][0]*2
+f[i-1][j][1]*2
+f[i-1][j][2]*1
+f[i-1][j][3]*2
+f[i-1][j][4]*1;
f[i][j][3]=f[i-1][j-1][0]*(j-2)
+f[i-1][j-1][1]*(j-2)
+f[i-1][j-1][2]*(j-2)
+f[i-1][j-1][3]*(j-2)
+f[i-1][j-1][4]*(j-2);
f[i][j][4]=f[i-1][j-1][0]*2
+f[i-1][j-1][1]*2
+f[i-1][j-1][2]*2
+f[i-1][j-1][3]*2
+f[i-1][j-1][4]*2;
}
for(int j=1;j<=i;j++) f[i][j][0]%=p,f[i][j][1]%=p,f[i][j][2]%=p,f[i][j][3]%=p,f[i][j][4]%=p;
}
int ans=f[n][1][0]+f[n][1][1]+f[n][1][2]+f[n][1][3]+f[n][1][4];
ans%=p;
cout<<ans;
return 0;
}