跳转至

快速数论变换 (NTT)

前置知识:快速傅里叶变换(FFT)

P3803 多项式乘法

给定一个 \(n\) 次多项式 \(F(x)\),和一个 \(m\) 次多项式 \(G(x)\)

请求出 \(F(x)\)\(G(x)\) 的卷积。

对于 \(100\%\) 数据: \(1 \le n, m \leq {10}^6\)

视频讲解(来源:董晓算法)

FFT大量使用了复数运算,尤其是浮点数的运算,其精度不好控制

能否找到另一组数,其也具有单位根的优美性质呢?这就是快速数论变换(NTT)所做的。

从单位根到原根

原根

介绍原根之前,先介绍一下阶

根据欧拉定理,如果 \(a \bot n\),那么有

\[a^{\varphi(n)} \equiv 1 \pmod{n}\]

我们知道,若 \(g \bot p\),对于 \(\forall g \in \mathbf{N}\),均 \(\exists n \in \mathbf{N}\),使得 \(g^n \equiv 1 \pmod{p}\),我们将具有上述性质的最小的 \(n\) 定义为 \(g\)\(p\) 的阶,记作 \(\delta_p(g)\)

原根的定义

\(\delta_p(g)=\varphi(p)\),则称 \(g\) 为模 \(p\) 的一个原根

\(\delta_7(3)=6=\varphi(7)\)\(3\) 是模 \(7\) 的一个原根

\(\delta_7(2)=3\neq\varphi(7)\)\(2\) 不是模 \(7\) 的一个原根

原根的性质

\(g\) 是模 \(p\) 的一个原根,则 \(g^0,g^1,g^2,\cdots,g^{\delta - 1}\) 在模 \(p\) 意义下两两不同,之后进入周期

\(3^0,3^1,3^2,3^3,3^4,3^5\)\(7\) 两两不同,之后进入周期

因此我们貌似找到了可行的替代方案

为了多次二分,模数 \(p\) 应选形如 \(q\times2^k + 1\)质数,其中 \(q\) 为奇素数,\(k\) 为整数

原根 \(g\) 模数 \(p\) 分解 \(p\) 最大长度
\(3\) \(469762049\) \(7\times2^{26}+1\) \(2^{26}\)
\(3\) \(998244353\) \(119\times2^{23}+1\) \(2^{23}\)
\(3\) \(1004535809\) \(479\times2^{21}+1\) \(2^{21}\)

\(p\) 是质数,故 \(\varphi(p)=p - 1\),则 \(g^0,g^1,g^2,\cdots,g^{p - 1}\) 在模 \(p\) 下两两不同,从中选取对称的 \(n(n = 2^b, b \in \mathbf{N})\) 个值:

\(g_n^0 = 1,g_n^1 = g^{\frac{p - 1}{n}},\cdots,g_n^k = g^{\frac{p - 1}{n}k},\cdots,g_n^{n - 1}\)

对比 \(\omega_n^k(n = 2^b, b \in \mathbf{N})\) 的性质,我们发现 \(g_n^k\) 的如下性质(在模 \(p\) 意义下理解):

  1. (周期性) \(g_n^k = g_n^{k+n}\)
  2. (对称性) \(g_n^k = -g_n^{k+\frac{n}{2}}\)
  3. (折半性) \(g_{2n}^k = -g_n^{\frac{k}{2}}\)
  4. (指数性) \(g_n^{a + b} = g_n^a g_n^b\)

我们发现,\(g_n^k\) 真的可以代替 \(\omega_n^k\),并且FFT和NTT的流程几乎不变!

那IFFT和INTT也一样吗?

证明

注:以下计算均在模 \(p\) 意义下进行,其中的除法与倒数应当用乘法逆元理解

设多项式 \(A(x)=a_0 + a_1x + a_2x^2 + \cdots + a_{n - 1}x^{n - 1}\)

代入 \(g_n^0, g_n^1, g_n^2, \cdots, g_n^{n - 1}\) 得到的点值为 \(y_0, y_1, y_2, \cdots, y_{n - 1}\)

其中 \(y_i = \sum_{j = 0}^{n - 1}a_j(g_n^i)^j\)

构造多项式 \(B(x)=y_0 + y_1x + y_2x^2 + \cdots + y_{n - 1}x^{n - 1}\)

\(n\) 个原根的倒数 \(g_n^0, g_n^{-1}, g_n^{-2}, \cdots, g_n^{-(n - 1)}\) 代入 \(B(x)\)

得到的 \(n\) 个新点值,设为 \(z_0, z_1, z_2, \cdots, z_{n - 1}\)

\[ \begin{aligned} z_k &= \sum_{i = 0}^{n - 1}y_i(g_n^{-k})^i \\ &= \sum_{i = 0}^{n - 1}\sum_{j = 0}^{n - 1}a_j(g_n^i)^j(g_n^{-k})^i \\ &= \sum_{j = 0}^{n - 1}a_j\sum_{i = 0}^{n - 1}(g_n^{j - k})^i \end{aligned} \]

\(j = k\) 时,内层的和式等于 \(n\)

\(j \neq k\) 时,\(\frac{(g_n^{j - k})^n - 1}{g_n^{j - k} - 1}=\frac{(g_n^n)^{j - k} - 1}{g_n^{j - k} - 1}=0\)

所以 \(z_k = na_k\),即 \(a_k = \frac{z_k}{n}\)

除了四则运算都要取模,除法操作需要换成乘法逆元,其余操作真的一样!

这样我们得到NTT的代码:

实现(迭代版)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;
const int MAXN = 4e6 + 10;
const ll p = 998244353, g = 3; // 模数和原根
ll A[MAXN], B[MAXN];
int rev[MAXN], n, m;
ll qpow(ll a, ll b)
{
    static ll res;
    for (res = 1; b; b >>= 1, a = a * a % p)
        if (b & 1) res = res * a % p;
    return res;
}
ll inv(ll x) { return qpow(x, p - 2); }
const ll gi = inv(g);
void NTT(ll *A, int n, int op)
{
    for (int i = 0; i < n; i++) // 位逆序变换
        if (i < rev[i]) swap(A[i], A[rev[i]]);
    for (int i = 2; i <= n; i <<= 1)
    {
        ll g1 = qpow(op == 1 ? g : gi, (p - 1) / i); // 对应FFT中的单位根omega1
        for (int j = 0; j < n; j += i)
        {
            ll gk = 1;
            for (int k = j; k < j + i / 2; k++) // 这里的合并与FFT都是一样的
            {
                ll x = A[k], y = A[k + i / 2] * gk % p;
                A[k] = (x + y) % p;
                A[k + i / 2] = (x - y + p) % p;
                gk = gk * g1 % p; // 求解gk
            }
        }
    }
    if (op == 1) return;
    const ll ni = inv(n);
    for (int i = 0; i < n; i++) A[i] = A[i] * ni % p; // INTT
}
int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    for (int i = 0; i <= n; i++) cin >> A[i];
    for (int i = 0; i <= m; i++) cin >> B[i];
    for (m = n + m, n = 1; n <= m; n <<= 1);
    for (int i = 0; i < n; i++) rev[i] = (rev[i >> 1] >> 1) | ((n >> 1) * (i & 1));
    NTT(A, n, 1), NTT(B, n, 1);
    for (int i = 0; i < n; i++) A[i] = A[i] * B[i] % p;
    NTT(A, n, -1);
    for (int i = 0; i <= m; i++) cout << A[i] << ' ';
    return 0;
}

从中也能看出NTT的限制:

  • 所求的多项式要求是整系数。
  • 如果题目要求结果对质数 \(p\) 取模,这个质数往往只能是 \(998244353\),否则会有很多麻烦,这个会在后面谈到。
  • 所求多项式的项数应在 \(2^{23}\) 之内,因为 \(998244353 = 7 \times 17 \times 2^{23}+1\)
  • 结果的系数不应超过质数 \(P\)。( \(P\) 是自己选择的质数,一般定为 \(P = 998244353\)

扩展:任意模数NTT

P4245 【模板】任意模数多项式乘法

给定 \(2\) 个多项式 \(F(x), G(x)\) ,请求出 \(F(x) * G(x)\)

系数对 \(p\) 取模,且不保证 \(p\) 可以分解成 \(p = a \cdot 2^k + 1\) 之形式。

对于 \(100 \%\) 的数据,\(1 \leq n, m \leq 10^5\)\(0 \leq a_i, b_i \leq 10^9\)\(2 \leq p \leq 10^9 + 9\)

其实这是一个比较套路的想法,我们取用三个符合NTT要求的大模数,分别做三次卷积,就会得到每一位实际结果对三个模数取模后的结果,然后利用 excrt 合并答案。

这里并不需要真的写 excrt,手动合并就行了。

一共做 \(9\) 次NTT,常数巨大。

实现
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;
ll qpow(ll a, ll b, ll p)
{
    static ll res;
    for (res = 1; b; b >>= 1ll, a = a * a % p)
        if (b & 1ll) res = res * a % p;
    return res;
}
ll inv(ll a, ll p) { return qpow(a, p - 2, p); }
const ll mod[] = {0, 469762049, 998244353, 1004535809}, g = 3;
const int MAXN = 3e5 + 10;
int n, m, rev[MAXN];
ll A[4][MAXN], B[4][MAXN], p;
void NTT(ll *A, int n, int op, ll p)
{
    const ll gi = inv(g, p), ni = inv(n, p);
    for (int i = 0; i < n; i++)
        if (i < rev[i]) swap(A[i], A[rev[i]]);
    for (int i = 2; i <= n; i <<= 1)
    {
        ll g1 = qpow(op == 1 ? g : gi, (p - 1) / i, p);
        for (int j = 0; j < n; j += i)
        {
            ll gk = 1;
            for (int k = j; k < j + i / 2; k++)
            {
                ll x = A[k], y = A[k + i / 2] * gk % p;
                A[k] = (x + y) % p;
                A[k + i / 2] = (x - y + p) % p;
                gk = gk * g1 % p;
            }
        }
    }
    if (op == -1)
        for (int i = 0; i < n; i++) A[i] = A[i] * ni % p;
}
ll merge(ll a, ll b, ll c, ll p)
{
    const static ll md12 = mod[1] * mod[2], inv12 = inv(mod[1], mod[2]),
                    inv123 = inv(mod[1] * mod[2] % mod[3], mod[3]);
    ll d = (b - a + mod[2]) % mod[2] * inv12 % mod[2] * mod[1] + a;
    ll x =
        (c - d % mod[3] + mod[3]) % mod[3] * inv123 % mod[3] * (md12 % p) % p +
        d;
    return x % p;
}
int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m >> p;
    for (int x, i = 0; i <= n; i++)
        cin >> x, A[1][i] = A[2][i] = A[3][i] = x % p;
    for (int x, i = 0; i <= m; i++)
        cin >> x, B[1][i] = B[2][i] = B[3][i] = x % p;
    for (m = n + m, n = 1; n <= m; n <<= 1);
    for (int i = 0; i < n; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((n >> 1) * (i & 1));
    for (int i = 1; i <= 3; i++)
    {
        NTT(A[i], n, 1, mod[i]);
        NTT(B[i], n, 1, mod[i]);
        for (int j = 0; j < n; j++) A[i][j] = A[i][j] * B[i][j] % mod[i];
        NTT(A[i], n, -1, mod[i]);
    }
    for (int i = 0; i <= m; i++)
        cout << merge(A[1][i], A[2][i], A[3][i], p) << " ";
    return 0;
}