emm……NTT 左转去看各位神犇的博客吧 QwQ

这里只贴代码及部分操作的推导过程。

# 首先是喜闻乐见的 NTT 多项式乘法板子

(这个就不解释了)

mark
namespace NTT{
    ll lim, len;
    inline ll qpow(ll a, ll b){
        ll res = 1;
        while(b){
            if(b & 1) res = res * a % mod;
            a = a * a % mod, b >>= 1;
        }
        return res;
    }
    inline void get_rev(cl n){
        lim = 1, len = 0;
        while(lim < n) lim <<= 1, ++len;
        for(int i = 0; i < lim; ++i) p[i] = (p[i >> 1] >> 1) | ((i & 1) << (len - 1));
    }
    inline void ntt(ll A[], cl lim, cl type){
        for(int i = 0; i < lim; ++i)
            if(i < p[i]) swap(A[i], A[p[i]]);
        for(int mid = 1; mid < lim; mid <<= 1){
            ll Wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
            for(int i = 0; i < lim; i += (mid << 1)){
                ll w = 1;
                for(int j = 0; j < mid; ++j, w = w * Wn % mod){
                    ll x = A[i + j], y = w * A[i + j + mid] % mod;
                    A[i + j] = (x + y) % mod;
                    A[i + j + mid] = (x - y + mod) % mod;
                }
            }
        }
        if(type == 1) return;
        ll inv = qpow(lim, mod - 2);
        for(int i = 0; i < lim; ++i) A[i] = A[i] * inv % mod;
    }
    inline void Mul(cl n, cl m, ll a[], ll b[]){
        get_rev(n + m);
        ntt(a, lim, 1), ntt(b, lim, 1);
        for(int i = 0; i < lim; ++i) a[i] = a[i] * b[i] % mod;
        ntt(a, lim, -1);
    }
}
using namespace NTT;

# 紧跟着的是求逆

已知多项式 F(x)F(x),要求 G(x)F(x)1(modxn)G(x)F(x) \equiv 1\ \ (mod\ \ x ^ n)

H(x)F(x)1(modxn2)H(x)F(x) \equiv 1 \ \ (mod\ x^{\lceil\frac{n}{2}\rceil})

那么

(G(x)H(x))0(modxn2)(G(x) - H(x)) \equiv 0\ \ (mod\ \ x ^ {\lceil\frac{n}{2}\rceil})

两边同时平方:

(G(x)H(x))20(modxn2)G(x)22G(x)H(x)+H(x)20(modxn)(G(x) - H(x))^2 \equiv 0\ \ (mod\ \ x ^ {\lceil\frac{n}{2}\rceil})\\ G(x)^2 - 2G(x)H(x) + H(x)^2 \equiv 0\ \ (mod\ \ x ^ n)

两边再同时乘上 F(x)F(x),由于 G(x)F(x)1(modxn)G(x)F(x) \equiv 1\ \ (mod\ \ x ^ n)

G(x)2H(x)+H(x)2F(x)0(modxn)G(x) - 2H(x) + H(x)^2F(x) \equiv 0\ \ (mod\ \ x ^ n)

移一下项并提取公因多项式(:

G(x)=H(x)(2H(x)F(x))G(x) = H(x)(2 - H(x)F(x))

mark
inline void Inv(ll n, ll a[], ll b[]){
    if(n == 1) return b[0] = qpow(a[0], mod - 2), void();
    Inv((n + 1) >> 1, a, b);
    get_rev(n << 1);
    for(int i = 0; i < n; ++i) c[i] = a[i];
    for(int i = n; i < lim; ++i) c[i] = 0;
    ntt(c, lim, 1), ntt(b, lim, 1);
    for(int i = 0; i < lim; ++i)
        b[i] = (2ll - c[i] * b[i] % mod + mod) * b[i] % mod;
    ntt(b, lim, -1);
    for(int i = n; i < lim; ++i) b[i] = 0;
}

# 然后是多项式开根

我们已知 F(x)F(x),要求 G(x)G(x) 使得 G(x)2F(x)(modxn)G(x)^2 \equiv F(x)\ \ (mod\ \ x^n)

H(x)2F(x)(modxn2)H(x)^2 \equiv F(x)\ \ (mod\ x^{\lceil\frac{n}{2}\rceil})

那么

G(x)H(x)(modxn2)(G(x)H(x))0(modxn2)G(x) \equiv H(x)\ \ (mod\ x^{\lceil\frac{n}{2}\rceil})\\ (G(x) - H(x)) \equiv 0\ \ (mod\ x^{\lceil\frac{n}{2}\rceil})

还是两边同时平方:

(G(x)H(x))20(modxn)G(x)22G(x)H(x)+H(x)20(modxn)(G(x) - H(x))^2 \equiv 0\ \ (mod\ x^n)\\ G(x)^2 - 2G(x)H(x) + H(x)^2 \equiv 0\ \ (mod\ \ x ^ n)

由于 G(x)2F(x)(modxn)G(x)^2 \equiv F(x)\ \ (mod\ \ x^n) 得:

F(x)2G(x)H(x)+H(x)20(modxn)G(x)=F(x)+H(x)22H(x)F(x) - 2G(x)H(x) + H(x)^2 \equiv 0 \ \ (mod\ \ x ^ n)\\ G(x) = \frac{F(x) + H(x)^2}{2H(x)}

那么事实上我的代码里的式子是这样的:

G(x)=F(x)H(x)1+H(x)2G(x) = \frac{F(x)H(x)^{-1} + H(x)}{2}

b(x)H(x)b(x) \rightarrow H(x)

d(x)F(x)d(x) \rightarrow F(x)

e(x)H(x)e(x) \rightarrow H(x)-1

mark
inline void Sqrt(ll n, ll a[], ll b[]){
    if(n == 1) return b[0] = 1, void();
    Sqrt((n + 1) >> 1, a, b);
    get_rev(n << 1);
    memset(e, 0, sizeof(e));
    Inv(n, b, e);
    for(int i = 0; i < n; ++i) d[i] = a[i];
    for(int i = n; i < lim; ++i) d[i] = 0;
    ntt(d, lim, 1), ntt(e, lim, 1), ntt(b, lim, 1);
    for(int i = 0; i < lim; ++i) b[i] = (b[i] + d[i] * e[i] % mod) * inv2 % mod;
    ntt(b, lim, -1);
    for(int i = n; i < lim; ++i) b[i] = 0;
}

# 接着自然是是求 ln

这个就需要通道一点微积分的知识了,虽然我也不太会 QwQ,但是我会贺代码 (。・ω・。)

还是先来推推式子吧。

已知 F(x)F(x),求 G(x)=lnF(x)G(x) = lnF(x)

众所周知g(x)=lnf(x)g(x) = ln\,f(x) 的导数为 g(x)=f(x)f(x)g'(x) = \frac{f'(x)}{f(x)},所以:

G(x)=F(x)F(x)G'(x) = \frac{F'(x)}{F(x)}

也就是说,G(x)G'(x) 我们可以直接求出来了,下面给出不定积分的一个运算公式:

xadx=1a+1xa+1\int x^a\,dx = \frac{1}{a + 1} x ^ {a + 1}

由于我们的 G(x)=a0+a1x+a2x2++anxnG(x) = a_0 + a_1x + a_2x^2 + ···+ a_nx^n

所以对每一项求个积分就完了,具体看代码吧。

mark
namespace Ln{
    inline void Diff(cl n, ll a[], ll b[]){// 微分求导
        for(int i = 1; i < n; ++i) b[i - 1] = i * a[i];
        b[n - 1] = 0;
    }
    inline void Integral(cl n, ll a[], ll b[]){// 积分
        for(int i = 1; i < n; ++i) b[i] = a[i - 1] * qpow(i, mod - 2) % mod;
        b[0] = 0;
    }
    inline void ln(cl n, ll a[], ll b[]){
        memset(f1, 0, sizeof(f1));
        memset(f2, 0, sizeof(f2));
        Diff(n, a, f1), Inv(n, a, f2);//f1(x) = a'(x),f2(x) = a(x)^(-1)
        Mul(n, n, f1, f2);//f1(x) = f1(x) * f2(x)
        Integral(n, f1, b);//g (x) = f (x) 的积分
    }
}
using namespace Ln;

# exp 当然也不能少

已知 F(x)F(x),求 G(x)=eG(x) = eF(x)
emm…… 要用到泰勒展开,牛顿迭代什么的,然鹅我太蒻了,还不会 QwQ。
所以这里就只有结论了,像我这种蒟蒻还是全文背诵吧。

F(x)=F0(x)G(F0(x))G(F0(x)F(x) = F_0(x) - \frac{G(F_0(x))}{G'(F_0(x)}

再经过一番清新简单的推导过程之后……

G(x)G0(x)(1lnG0(x)+F(x))(modxn)G(x) \equiv G_0(x)(1 - lnG_0(x) + F(x))\ \ (mod\ \ x ^ n)

所以就可以算了……

mark
inline void Exp(cl n, ll a[], ll b[]){
    if(n == 1) return b[0] = 1, void();
    Exp((n + 1) >> 1, a, b);
    get_rev(n << 1);
    ln(n, b, d);// d(x) = ln(b(x))
    for(int i = 0; i < n; ++i) e[i] = a[i];
    for(int i = n; i < lim; ++i) e[i] = 0;
    ntt(d, lim, 1), ntt(e, lim, 1), ntt(b, lim, 1);
    for(int i = 0; i < lim; ++i) b[i] = (1ll - d[i] + e[i] + mod) * b[i] % mod;
    ntt(b, lim, -1);
    for(int i = n; i < lim; ++i) b[i] = 0;
}

# 最后是多项式除法

除法由于其优秀的边界故不和上面放到一起,并且这里给出完整代码。

以及我这里是计算出 Q(x)Q(x)R(x)R(x) 后一块输出,由于计算 R(x)R(x) 时需要多项式卷积一下,因此对 Q(x)Q(x) 做了一遍 NTTNTT,所以卷积完之后一定记得要 NTTNTT 回来啊。

来看看推导过程。

已知 F(x)F(x)G(x)G(x),求 F(x)=Q(x)G(x)+R(x)F(x) = Q(x)G(x) + R(x)

首先就要用到一个神奇的操作,设 FR(x)F_R(x)F(x)F(x) 把系数反过来之后的多项式,即

FR(x)=an+an1x+an2x2++a0xnF_R(x) = a_n + a_{n - 1}x + a_{n - 2}x^2 +···+a_0x^n

容易发现,FR(x)=xnF(1x)F_R(x) = x^nF(\frac{1}{x})

然后就可以愉快的推式子啦 :)

F(x)=Q(x)G(x)+R(x)F(1x)=Q(1x)G(1x)+R(1x)F(x) = Q(x)G(x) + R(x) \\ F(\frac{1}{x}) = Q(\frac{1}{x})G(\frac{1}{x}) + R(\frac{1}{x})

F(x)F(x)nn 次的,G(x)G(x)mm 次的,Q(x)Q(x)nmn - m 次的,R(x)R(x)nm1n - m - 1 次的,所以两边同乘 xnx^n

xnF(x)=xnmQ(x)xmG(x)+xnm1R(x)xm+1xnF(x)xnmQ(x)xmG(x)(modxnm1)FR(x)QR(x)GR(x)(modxnm1)QR(x)FR(x)GR(x)1(modxnm1)x^nF(x) = x^{n - m}Q(x) * x^{m}G(x) + x^{n - m - 1}R(x) * x^{m + 1} \\ x^nF(x) \equiv x^{n - m}Q(x) * x^{m}G(x)\ \ (mod\ \ x^{n - m - 1}) \\ F_R(x) \equiv Q_R(x) * G_R(x)\ \ (mod\ \ x^{n - m - 1}) \\ Q_R(x) \equiv F_R(x) * G_R(x)^{-1}\ \ (mod\ \ x^{n - m - 1})

FR(x)F_R(x)GR(x)G_R(x) 都是已知的(预处理一下即可),Q(x)Q(x) 就是 QR(x)Q_R(x) 的系数倒过来。

计算出 Q(x)Q_(x) 之后 R(x)R(x) 也就简单了:

R(x)=F(x)Q(x)G(x)R(x) = F(x) - Q(x)G(x)

一定要注意边界啊啊啊啊!

mark
#include <bits/stdc++.h>
#define ll long long
using namespace std;
namespace IO{
    inline ll read(){
        ll x = 0;
        char ch = getchar();
        while(!isdigit(ch)) ch = getchar();
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x;
    }
    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
    inline void print(ll a[], ll n){
        for(int i = 0; i <= n; ++i) write(a[i]), putchar(' ');
        puts("");
    }
}
using namespace IO;
const ll N = 3e5 + 10;
const ll mod = 998244353;
const ll G = 3, Gi = 332748118;
ll n, m;
ll c[N];
ll f[N], g[N], fr[N], gr[N], ig[N];
ll q[N], r[N], p[N];
namespace NTT{
    ll lim, len;
    inline ll qpow(ll a, ll b){
        ll res = 1;
        while(b){
            if(b & 1) res = res * a % mod;
            a = a * a % mod, b >>= 1;
        }
        return res;
    }
    inline void get_rev(int n){
        lim = 1, len = 0;
        while(lim <= n) lim <<= 1, ++len;
        for(int i = 0; i <= lim; ++i) p[i] = (p[i >> 1] >> 1) | ((i & 1) << (len - 1));
    }
    inline void ntt(ll A[], ll lim, int type){
        for(int i = 0; i <= lim; ++i)
            if(i < p[i]) swap(A[i], A[p[i]]);
        for(int mid = 1; mid < lim; mid <<= 1){
            ll Wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
            for(int i = 0; i < lim; i += (mid << 1)){
                ll w = 1;
                for(int j = 0; j < mid; ++j, w = w * Wn % mod){
                    ll x = A[i + j], y = w * A[i + j + mid] % mod;
                    A[i + j] = (x + y) % mod;
                    A[i + j + mid] = (x - y + mod) % mod;
                }
            }
        }
        if(type == 1) return;
        ll inv = qpow(lim, mod - 2);
        for(int i = 0; i <= lim; ++i) A[i] = A[i] * inv % mod;
    }
    inline void Mul(ll n, ll m, ll a[], ll b[]){
        get_rev(n + m);
        ntt(a, lim, 1), ntt(b, lim, 1);
        for(int i = 0; i <= lim; ++i) a[i] = a[i] * b[i] % mod;
        ntt(a, lim, -1), ntt(b, lim, -1);//------------NTT 回来,NTT 回来,NTT 回来!!!
    }
    inline void Inv(ll n, ll a[], ll b[]){
        if(n == 0) return b[0] = qpow(a[0], mod - 2), void();
        Inv(n >> 1, a, b);
        get_rev(n << 1);
        for(int i = 0; i <= n; ++i) c[i] = a[i];
        for(int i = n + 1; i <= lim; ++i) c[i] = 0;
        ntt(c, lim, 1), ntt(b, lim, 1);
        for(int i = 0; i <= lim; ++i) b[i] = (2ll - c[i] * b[i] % mod + mod) * b[i] % mod;
        ntt(b, lim, -1);
        for(int i = n + 1; i <= lim; ++i) b[i] = 0;
    }
}
using namespace NTT;
inline void Division(){
    for(int i = n - m + 1; i <= m; ++i) gr[i] = 0;
    Inv(n - m, gr, ig);
    Mul(n, n - m, fr, ig);
    for(int i = 0; i <= n - m; ++i) q[i] = fr[n - m - i];
    Mul(n - m, m, g, q);
    for(int i = 0; i < m; ++i) r[i] = (f[i] - g[i] + mod) % mod;
    print(q, n - m);
    print(r, m - 1);
}
signed main(){
    n = read(), m = read();
    for(int i = 0; i <= n; ++i) f[i] = read(), fr[n - i] = f[i];
    for(int i = 0; i <= m; ++i) g[i] = read(), gr[m - i] = g[i];
    Division();
    return 0;
}

# 附:P5273 【模板】多项式幂函数 (加强版)

怒调 3h(

mark
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int mod = 998244353;
const int G = 3, Gi = 332748118;
const int N = 3e5 + 10;
ll n, k, flag, phik;
namespace IO{
    inline ll read(){
        ll x = 0;
        char ch = getchar();
        while(!isdigit(ch)) ch = getchar();
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x;
    }
    inline ll readk(){
        ll k = 0;
        char ch = getchar();
        while(!isdigit(ch)) ch = getchar();
        while(isdigit(ch)){
            if((k << 3) + (k << 1) + ch - '0' > n) flag = 1;
            k = ((k << 3) + (k << 1) + ch - '0') % mod;
            phik = ((phik << 3) + (phik << 1) + ch - '0') % (mod - 1);
            ch = getchar();
        }
        return k;
    }
    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;
ll A[N], B[N], c[N], lnb[N];
ll lowa[N], lna[N];
int p[N];
ll f1[N], f2[N], f[N], g[N];
inline ll qpow(ll a, int b){
    ll res = 1;
    while(b){
        if(b & 1) res = res * a % mod;
        a = a * a % mod, b >>= 1;
    }
    return res;
}
    
inline void clear(ll a[], int l, int r) {for(int i = l; i < r; ++i) a[i] = 0;}
inline void clone(ll a[], ll b[], int n) {for(int i = 0; i < n; ++i) a[i] = b[i];}
inline void print(ll a[], int n) {for(int i = 0; i < n; ++i) write(a[i]), putchar(' '); puts("");}
namespace NTT{
    int lim, len;
    inline void get_rev(int n){
        lim = 1, len = 0;
        while(lim < n) lim <<= 1, ++len;
        for(int i = 0; i < lim; ++i) p[i] = (p[i >> 1] >> 1) | ((i & 1) << (len - 1));
    }
    inline void ntt(ll A[], int lim, int type){
        for(int i = 0; i < lim; ++i)
            if(i < p[i]) swap(A[i], A[p[i]]);
        for(int mid = 1; mid < lim; mid <<= 1){
            ll Wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
            for(int i = 0; i < lim; i += (mid << 1)){
                ll w = 1;
                for(int j = 0; j < mid; ++j, w = w * Wn % mod){
                    ll x = A[i + j], y = w * A[i + j + mid] % mod;
                    A[i + j] = (x + y) % mod;
                    A[i + j + mid] = (x - y + mod) % mod;
                }
            }
        }
        if(type == 1) return;
        ll inv = qpow(lim, mod - 2);
        for(int i = 0; i < lim; ++i) A[i] = A[i] * inv % mod;
    }
    inline void Mul(int n, int m, ll a[], ll b[]){
        get_rev(n + m);
        clear(A, 0, lim), clear(B, 0, lim);
        clone(A, a, n), clone(B, b, m);
        ntt(A, lim, 1), ntt(B, lim, 1);
        for(int i = 0; i < lim; ++i) A[i] = A[i] * B[i] % mod;
        ntt(A, lim, -1);
        clone(a, A, lim);
    }
    inline void Inv(int n, ll a[], ll b[]){
        if(n == 1) return b[0] = qpow(a[0], mod - 2), void();
        Inv((n + 1) >> 1, a, b);
        get_rev(n << 1);
        clone(c, a, n), clear(c, n, lim);
        ntt(b, lim, 1), ntt(c, lim, 1);
        for(int i = 0; i < lim; ++i) b[i] = (2ll - c[i] * b[i] % mod + mod) * b[i] % mod;
        ntt(b, lim, -1);
        clear(b, n, lim);
    }
}
using namespace NTT;
namespace Poly{
    inline void Diff(int n, ll a[], ll b[]){
        for(int i = 1; i < n; ++i) b[i - 1] = i * a[i] % mod;
        b[n - 1] = 0;
    }
    inline void Integral(int n, ll a[], ll b[]){
        for(int i = 1; i < n; ++i) b[i] = a[i - 1] * qpow(i, mod - 2) % mod;
        b[0] = 0;
    }
    inline void Ln(int n, ll a[], ll b[]){
        get_rev(n << 1);
        clear(f1, 0, lim), clear(f2, 0, lim);
        Diff(n, a, f1), Inv(n, a, f2);
        Mul(n, n, f1, f2);
        Integral(n, f1, b);
    }
    inline void Exp(int n, ll a[], ll b[]){
        if(n == 1) return b[0] = 1, void();
        Exp((n + 1) >> 1, a, b);
        get_rev(n << 1);
        Ln(n, b, lnb);
        for(int i = 0; i < n; ++i) lnb[i] = (a[i] - lnb[i] + mod) % mod;
        lnb[0]++;
        Mul(n, n, b, lnb), clear(b, n, lim);
    }
    inline void Qpow(int n, int k, ll a[], ll b[], ll phik){
        clear(b, 0, n);
        int shift = 0;
        while(!a[shift]) shift++;
        if((ll)shift * k >= n) return;
        n -= shift;
        for(int i = 0; i < n; ++i) lowa[i] = a[i + shift];
        int low0 = lowa[0], inv0 = qpow(low0, mod - 2);
        for(int i = 0; i < n; ++i) lowa[i] = lowa[i] * inv0 % mod;
        Ln(n, lowa, lna);
        for(int i = 0; i < n; ++i) lna[i] = lna[i] * k % mod;
        Exp(n, lna, b);
        shift *= k;
        for(int i = n + shift - 1; i >= shift; --i) b[i] = b[i - shift] * qpow(low0, phik) % mod;
        clear(b, 0, shift);
    }
}
using namespace Poly;
int main(){
    n = read(), k = readk();
    for(int i = 0; i < n; ++i) f[i] = read();
    if(!flag || f[0]) Qpow(n, k, f, g, phik);
    print(g, n);
    return 0;
}

_EOF_\_EOF\_

阅读次数