跳转至

常系数线性齐次递推

P4723 【模板】常系数齐次线性递推

求一个满足 \(k\) 阶齐次线性递推数列 \({a_i}\) 的第 \(n\)\(\text{mod} \ 998244353\) 的结果,即:

\[a_n=\sum\limits_{i=1}^{k}f_i \times a_{n-i} \ \text{mod} \ 998244353\]

\(n \leq 10^{9}, k \leq 32000\),保证读入的数字均为 \([-10^9,10^9]\) 内的整数。

引入

事实上,我们并不需要使用任何线性代数的知识来导出这个做法。

让我们采用小学层面的方法求斐波那契数列(一个经典的例子)的第 \(n\) 项。

我们换一个角度,倒过来求。假设我们现在在欲求 \(F_n\)。我们可以列出如下过程: \(F_5 = F_4 + F_3 = 2F_3 + F_2 = 3F_2 + 2F_1 = 5F_1 + 3F_0\),最后带入 \(F_0, F_1\) 计算即可。这当中的推导过程不是什么秘密,不过是每次取最前面那一项拿递推式展开罢了。

本质

考虑一下上述过程本质上究竟干了什么,上述过程,就是通过每次消去最高次项,求 \(x^n\) 对斐波那契数列的特征多项式——\(x^2 - x - 1\) 取模的结果的过程。如果还没有反应过来,展示如下:

\[ \begin{aligned} & 0x^0 + 0x^1 + 0x^2 + 0x^3 + 0x^4 + 1x^5 \\ & \xrightarrow{-1x^3(x^2 - x -1)} 0x^0 + 0x^1 + 0x^2 + 1x^3 + 1x^4 \\ & \xrightarrow{-1x^2(x^2 - x -1)} 0x^0 + 0x^1 + 1x^2 + 2x^3 \\ & \xrightarrow{-2x^1(x^2 - x -1)} 0x^0 + 2x^1 + 3x^2 \\ & \xrightarrow{-3x^0(x^2 - x -1)} 3x^0 + 5x^1 \\ \end{aligned} \]

两者之间的逻辑关联很简单,因为前者是每次取和式里下标最大那一项,替换成两个下标较小的项之后加回去;而后者是每次从多项式里减去若干倍的 \(x^2\),并替换成相应倍数的 \(x + 1\) 之后加回去。

推广到正解

理解了上述逻辑之后,能轻易将其推广到一般线性齐次递推式的情形:

\[f_i = p_1f_{i-1} + p_2f_{i-2} + \cdots + p_kf_{i-k}\]

要求 \(f_n\),只要求多项式 \(x^n\)

\[p(x) = x^k - p_1x^{k -1} - p_2x^{k -2} - \cdots - p_kx^0\]

取模的结果。因为「多项式取模」的过程恰好与「每次按递推式展开下标最大的项」的过程一致,都可以用「替换」来解读。

实现
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;
const ll MAXN = 5e5 + 10, p = 998244353;
int n, m, k;
ll rev[MAXN], res[MAXN], T[MAXN], F[MAXN], G[MAXN], Q[MAXN], R[MAXN], lim;
ll qpow(ll a, ll b)
{
    static ll res;
    for (res = 1; b; b >>= 1, a = a * a % p)
        if (b & 1) res = res * a % p;
    return res;
}
ll inv(ll x) { return qpow(x, p - 2); }
const ll g = 3, gi = inv(g);
void NTT(ll *A, int n, int op)
{
    for (int i = 0; i < n; i++)
        if (i < rev[i]) swap(A[i], A[rev[i]]);
    for (int i = 2; i <= n; i <<= 1)
    {
        ll g1 = qpow(op == 1 ? g : gi, (p - 1) / i);
        for (int j = 0; j < n; j += i)
        {
            ll gk = 1;
            for (int k = j; k < j + i / 2; k++)
            {
                ll x = A[k], y = A[k + i / 2] * gk % p;
                A[k] = (x + y) % p;
                A[k + i / 2] = (x - y + p) % p;
                gk = gk * g1 % p;
            }
        }
    }
    if (op == 1) return;
    const int ni = inv(n);
    for (int i = 0; i < n; i++) A[i] = A[i] * ni % p;
}
void Inv(int n, ll *F, ll *G)
{
    G[0] = inv(F[0]);
    int lim, len;
    static ll A[MAXN], B[MAXN];
    for (len = 1; len < (n << 1); len <<= 1)
    {
        lim = len << 1;
        copy(F, F + len, A), copy(G, G + len, B);
        for (int i = 0; i < lim; i++)
            rev[i] = (rev[i >> 1] >> 1) | ((lim >> 1) * (i & 1));
        NTT(A, lim, 1), NTT(B, lim, 1);
        for (int i = 0; i < lim; i++)
            G[i] = ((2 - A[i] * B[i] % p) * B[i] % p + p) % p;
        NTT(G, lim, -1);
        fill(G + len, G + lim, 0);
    }
    fill(A, A + len, 0), fill(B, B + len, 0), fill(G + n, G + len, 0);
}
void Mul(int n, int m, ll *F, ll *G, bool flag = 1)
{
    static ll A[MAXN], B[MAXN];
    fill(A, A + (n << 2), 0), fill(B, B + (n << 2), 0), copy(F, F + n, A),
        copy(G, G + n, B);
    for (int i = 0; i < n; i++) A[i] = F[i], B[i] = G[i];
    for (lim = 1; lim <= (n + m); lim <<= 1);
    for (int i = 0; i <= lim; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (lim >> 1));
    NTT(A, lim, 1), NTT(B, lim, 1);
    for (int i = 0; i < lim; i++) A[i] = A[i] * B[i] % p;
    NTT(A, lim, -1), copy(A, A + (n << 1), F),
        fill(F + (n << 1), F + lim + 1, 0);
    if (flag) fill(F + n, F + (n << 1), 0);
}
void Mod(int n, int m, ll *F, ll *G, ll *R)
{
    static ll A[MAXN], B[MAXN], D[MAXN];
    fill(A, A + (n << 2), 0), fill(B, B + (n << 2), 0),
        fill(D, D + (n << 2), 0);
    for (int i = 0; i < n - m + 1; i++)
        A[i] = F[n - i - 1], B[i] = G[m - i - 1];
    Inv(n - m + 1, B, D), Mul(n - m + 1, n - m + 1, A, D);
    for (int i = 0; i <= n - m; i++) Q[i] = A[n - m - i];
    fill(A, A + (n << 2), 0), fill(B, B + (n << 2), 0), copy(F, F + n, A),
        copy(G, G + m, B), Mul(n, n, B, Q);
    for (int i = 0; i < m - 1; i++) R[i] = (A[i] - B[i] + p) % p;
    fill(R + m - 1, R + lim, 0);
}
void Pow(ll k, int n)
{
    for (T[1] = res[0] = 1; k;
         Mul(n, n, T, T, 0), Mod(n << 1, n, T, G, T), k >>= 1)
        if (k & 1) Mul(n, n, res, T, 0), Mod(n << 1, n, res, G, res);
}
int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    G[0] = 1;
    for (int i = 1; i <= m; i++)
    {
        ll val;
        cin >> val;
        G[i] = (p - (val % p + p) % p) % p;
    }
    reverse(G, G + m + 1);
    for (int i = 0; i < m; i++) cin >> F[i];
    Pow(n, m + 1);
    ll ans = 0;
    for (int i = 0; i < m; i++) ans = (ans + res[i] * F[i] % p + p) % p;
    cout << ans << endl;
    return 0;
}