The secret of being a bore is to tell everything

0%

[CF1097D] Makoto and a Blackboard

题意

一开始有一个整数$n(n\leq 10^{15})$,你可以执行以下操作$k(k\leq 10^4)$次:把$n$替换成$n$的任意一个约数(包括$1$和$n$),假设每个约数都有相同概率被选中。现在问你$k$次操作后剩下这个数的期望值是多少。

解题思路

我们知道期望公式为$\sum{x*p(x)}$,在这道题中,由于$x$的范围是离散且确定的,我们只需要知道每个$x$出现的概率就行了。

我们可以画一张图来看看转移的过程,以$6$为例:

我们可以得出几个有用的信息:

  • 图中所有节点都为原数$n$的约数
  • 每个节点只会转移到小于等于它的数
  • 因为$k$有限,这个流程可以直接模拟

于是我们就可以用动态规划,令$f(i, x)$为第$i$轮时,得到数字为$x$的概率,可得:
$$
f(i + 1, y) = f(i + 1, y) + f(i, x) * \frac{1}{|d(x)|} (y \in d(x))
$$
于是这题就解决了……

但是等等,这个玩意复杂度是多少呢?我们花$O(\sqrt{n})$预处理出来约数,约数个数大约是$O(\log{n})$,都在时间范围内。状态数量是$O(k\log{n})$,转移是$O(\log{n})$,合在一起$O(k\log^2{n})$但是交上去会T掉,为什么呢?

因为因数个数虽然是$O(\log{n})$,但是常数巨大,一般来说一个数的质因数分解可以表示为$\prod_k{(p_k^{\alpha_k})}$那么它会有$\prod_k{(\alpha_k+1})$个因数。

那么假如我给出$n=2^{20}\times3^{20}$,足足会有$21\times21=441$个因数,那么$10^4\times441^2=1.6*10^9$这不T到天上去?
那么接下来就要引出我写这篇博客的意义了,首先把我们刚刚得到的递推式优化一下。
$$
f(i+1, y) = \sum_{ky|n} \frac{f(i, ky)}{|d(ky)|}
$$
而形如$g(n) = \sum_{d|n} f(d)$这类函数给我们的提示就是它很有可能是积性的,由于地方太小就不写证明了。当$f(i, x)$是积性函数的时候,$f(i,x)=\prod_{k} f(i, p_k^{\alpha_k})$。于是我们可以分别计算对于$n$的每个质因数的$f(i, p^{\alpha_k})$,然后把他们乘起来。此时时间复杂度为$O(k\sum{\alpha_i^2})$,此前为$O(k\prod{(\alpha_i+1)})$,对于$2^{20}\times3^{20}$来说,只需要$10^4\times800=8\times10^6$,是一个巨大的提升。

时间复杂度

$O(k\sum{\alpha_i^2})$

参考代码

1
int main() {
2
#ifdef LOCALLL
3
    freopen("in", "r", stdin);
4
    freopen("out", "w", stdout);
5
#endif
6
    scanf("%lld%lld", &n, &k);
7
    ll x = n;
8
    // 获取质因数以及其指数
9
    for (ll i = 2; i * i <= n; i++) {
10
        if (i > x) break;
11
        int cnt = 0;
12
        while (x % i == 0) {
13
            x /= i;
14
            cnt++;
15
        }
16
        if (cnt) {
17
            divs.push_back({i, cnt});
18
        }
19
    }
20
    if (x > 1) divs.push_back({x, 1});
21
    ll ans = 1;
22
    // 最外层是每个质因数p^i形式
23
    for (auto pair : divs) {
24
        int cnt = pair.second;
25
        // 滚动数组优化空间
26
        memset(dp[0], 0, sizeof(dp[0]));
27
        dp[0][cnt] = 1LL;
28
        // 从这里开始是每个状态的转移
29
        for (int i = 0; i < k; i++) {
30
            memset(dp[(i & 1) ^ 1], 0, sizeof(dp[0]));
31
            for (int j = 0; j <= cnt; j++) {
32
                for (int s = 0; s <= j; s++) {
33
                    dp[(i & 1) ^ 1][s] +=
34
                        getInv(j + 1, MOD) * dp[i & 1][j] % MOD;
35
                    if (dp[(i & 1) ^ 1][s] > MOD) {
36
                        dp[(i & 1) ^ 1][s] -= MOD;
37
                    }
38
                }
39
            }
40
        }
41
        ll tmp = 0;
42
        ll pw = 1;
43
        for (int i = 0; i <= cnt; i++) {
44
            tmp += (pw * dp[k & 1][i] % MOD);
45
            if (tmp > MOD) {
46
                tmp -= MOD;
47
            }
48
            pw = pw * pair.first % MOD;
49
        }
50
        ans = (ans * tmp) % MOD;
51
    }
52
    printf("%lld", ans);
53
    return 0;
54
}