2022 International Collegiate Programming Contest, Jinan Site C. DFS Order 2 (tree dp + rollback backpack)

Title

A tree with n (n<=500) points is a rooted tree with 1 as the root.

You need to output an n*n matrix,

(i,j) represents the number of solutions for which point i is the j-th visited node in dfs order.

The answer is modulo 998244353

Source of ideas

Official solution + strict pigeon blog

2022 ICPC Jinan Station C (Return Backpack) – Zhihu

Solution

The official solution is as written above

I feel that I refer to part of this method, and also refer to part of the strict pigeon method.

Combining the two, I also used my own statistical method.

1. Find the total number of solutions

h[i] represents the number of dfs ordering options when only considering points in the subtree of i

Naturally, it is the product of the number of directly connected children to each point in the subtree,

That is, assuming h[u]=1, for each direct son of u, h[u]*=h[v],

Then h[u]*=m!, where m is the number of directly connected sons of u

2. When only considering the subtree u, consider the several sons v of u as (1,sz[v]),

The first dimension represents a directly connected son, and the second dimension represents the size of the subtree of this son.

Because accessing a directly connected son requires accessing all points in the subtree of this son.

Consider how to calculate the number of options in which the distance from u to v is exactly k in dfs order.

g[k + 1] represents the number of solutions in which the distance from u to v in dfs order is exactly k + 1

At this time, only the points in the subtree of u are considered. Since each time u is found, it needs to be redone, so the first dimension can be rolled out

Then k points need to be placed from u to v. These points are contributed by i direct sons of u (excluding v).

The total subtree size of i sons is k

So, counting f[i][j] means considering the subtree of u, selecting i son, and the number of options for which the size of the subtree of this i son is j,

First make a backpack and find the f array. And notice that the essence of the backpack is the multiplication of several polynomials

A plan in the backpack is unordered for the selection method

In dfs order, it is a different solution, so you need to multiply the k directly connected sons in order and multiply by k!

Count m as the total number of directly connected sons of u. The remaining m-1-k sons also need to be ordered, multiplied by (m-1-k)!

The points in the subtree that are directly connected to the son v also need to be multiplied by the corresponding order.

Assume that the directly connected sons of u are v, v1, v2, v3, v4, then you need to multiply by h[v1]*h[v2]*h[v3]*h[v4]

And this is equal to h[u]/m!/h[v], with the following pseudocode:

for i: // Enumerate the direct sons of u and select i

for k: // The total subtree size of enumerating i’s directly connected sons is k

g[k + 1] + =f[i][k]*(h[u]/m!/h[v])

After finding the g array, we only determined the distance between (u, v),

For v, it is necessary to determine the distance to the directly connected father, the distance of the father’s father,…

That is to say, you need to find the distance between v and each point on the ancestor chain to determine the final position in the dfs ordinal array.

Therefore, count dp[i][j] as the number of options for point i at position j in the dfs order (when the options inside the i subtree are not considered)

Merge the g array from top to bottom according to the chain, that is, merge the backpacks once. There is the following pseudo code:

for i: // Enumerate the position i of u in dfs order

for j:// enumerates the distance j between u and v in dfs order

dp[v][i + j] + =dp[u][i]*g[j] // The position of v in the dfs order is i + j

This and the final request only check the plan inside the i subtree, so dp[i][j]*h[i] is what we want

The complexity of doing this violently is O(n^4), because the transfer of the f array mentioned above does not include v

For u, when enumerating each v, the knapsack of f is redone. The complexity is O(n^2)

① Maintain prefix backpacks and suffix backpacks for all directly connected sons, and then merge the prefixes and suffixes,

It is not feasible because the complexity of merging two backpacks is still O(n^2)

②Line segment tree divides and conquers to maintain the occurrence time of each point

It should be feasible, because point v will only disappear in the continuous interval of the subtree corresponding to v.

That is, there are two intervals. Up to 4 changes can be made to add or subtract items in the backpack, but it is too difficult to write.

Backpack minus items = rollback backpack, so it is better to just write rollback backpack

The essence of a backpack is the multiplication of several polynomials, adding an item multiplied by a polynomial,

Then, when subtracting an item, just divide it by this polynomial,

Specifically, when adding, traverse from large to small in reverse order, and when subtracting, traverse in forward order from small to large.

for(auto & amp;v:e[u]){
for(int i=m;i>=1;–i){
for(int j=sz[u];j>=sz[v];–j){
f[i][j] + =f[i-1][j-sz[v]];
}
}
}

Before calculating the g array and dp array, undo v, and undo it after the calculation, that is, add it back, and then dfs subtree

ComplexityO(n^3)

Code

#include<bits/stdc + + .h>
//#include<iostream>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b); + + i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define first
#define se second
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d", & amp;(a))
#define pb push_back
#define pt(a) printf("%d",a);
#define pte(a) printf("%d\\
",a)
#define ptlle(a) printf("%lld\\
",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
const int N=505,M=500,mod=998244353;
int n,u,v,h[N],sz[N],fac[N],son[N],f[N][N],dp[N][N];
vector<int>e[N];
void add(int &x,int y){
    x=(x + y)%mod;
}
int modpow(int x,int n,int mod){
    int res=1;
    for(;n;n>>=1,x=1ll*x*x%mod){
        if(n & amp;1)res=1ll*res*x%mod;
    }
    return res;
}
void dfs(int u,int fa){
    sz[u]=1;
    h[u]=1;
    for(auto & amp;v:e[u]){
        if(v==fa)continue;
        dfs(v,u);
        son[u] + + ;
        sz[u] + =sz[v];
        h[u]=1ll*h[u]*h[v]%mod;
    }
    h[u]=1ll*h[u]*fac[son[u]]%mod;
    //printf("u:%d son:%d h:%d\\
",u,son[u],h[u]);
}
void dfs2(int u,int fa){
    int m=son[u];
    vector<vector<int>>f(m + 1,vector<int>(sz[u] + 1,0));
    f[0][0]=1;//f[i][j]: The number of options with size j selected for i node
    for(auto & amp;v:e[u]){
        per(i,m,1){
            per(j,sz[u],sz[v]){
                add(f[i][j],f[i-1][j-sz[v]]);
            }
        }
    }
    h[u]=1ll*h[u]*modpow(fac[m],mod-2,mod)%mod;
    for(auto & amp;v:e[u]){
        if(v==fa)continue;
        h[u]=1ll*h[u]*modpow(h[v],mod-2,mod)%mod;
        rep(i,1,m){
            rep(j,sz[v],sz[u]){
                add(f[i][j],mod-f[i-1][j-sz[v]]);
            }
        }
        vector<int>g(n + 1,0);//g[k]: the number of solutions where the distance between u and v is k
        rep(i,0,m-1){
            rep(j,0,sz[u]-1){
                add(g[j + 1],1ll*f[i][j]*fac[i]%mod*fac[m-1-i]%mod*h[u]%mod);
            }
        }
        rep(i,0,n){
            if(!dp[u][i])continue;
            rep(j,1,sz[u]){
                if(!g[j])continue;
                add(dp[v][i + j],1ll*dp[u][i]*g[j]%mod);
            }
        }
        per(i,m,1){
            per(j,sz[u],sz[v]){
                add(f[i][j],f[i-1][j-sz[v]]);
            }
        }
        h[u]=1ll*h[u]*h[v]%mod;
        dfs2(v,u);
    }
    h[u]=1ll*h[u]*fac[m]%mod;
}
int main(){
    //freopen("jinan.in","r",stdin);
    //freopen("jinan.out","w",stdout);
    fac[0]=1;
    rep(i,1,M)fac[i]=1ll*fac[i-1]*i%mod;
    sci(n);
    rep(i,1,n-1){
        sci(u),sci(v);
        e[u].pb(v);
        e[v].pb(u);
    }
    dfs(1,0);
    dp[1][0]=1;
    dfs2(1,0);
    rep(i,1,n){
        rep(j,0,n-1){
            int ans=1ll*dp[i][j]*h[i]%mod;
            printf("%d%c",ans," \\
"[j==n-1]);
        }
    }
    return 0;
}