快速幂运算

幂运算$$a^m$$是指m个n相乘,当a,m太大时,计算时间太长,使用快速幂算法可以降低时间复杂度。常见的算法有二进制与指数折半两种,时间复杂度相同

首先我们知道

$ a^{m+n} = a^m *a^n $

二进制的方法就是通过指数的二进制表示将幂运算拆解为一个个很小的任务

指数折半就是将指数分解为2的m次方形式,例如

$ 3^{12} = (3^2)^3 = 9^3 $

通过缩小指数的大小,减少循环的次数,以此来达到幂运算的结果不变但是加快幂运算速度的效果。

二进制取幂

当n的m次方中,m取9时,9的二进制为1001,转化关系如下:

9=123+022+021+1209 = 1*2^3+0*2^2+0*2^1+1*2^0

于是,通过幂运算的性质,原本的复杂的幂运算可以拆分为

n9(10)=n1001(2)=n123+022+021+120=n123n022n021n120n^{9_{(10)}} = n^{1001_{(2)}} = n^{1*2^3+0*2^2+0*2^1+1*2^0}=n^{1*2^3}*n^{0*2^2}*n^{0*2^1}*n^{1*2^0}

由于我们知道任何数乘0都为0,任何数的0次方都为1,所以,实际上拆分的幂运算为

n9=n123n120n^9 = n^{1*2^3}*n^{1*2^0}

由此发现,实际上只需要计算众多小任务中二进制位等于1部分的累乘即可

实现代码

1
2
3
4
5
6
7
8
def binpow(a, b):
res = 1
while b > 0:
if b & 1:
res = res * a
a = a * a
b >>= 1
return res

矩阵快速幂运算

根据线性代数中我们学到的知识,矩阵乘法的原理就是左行乘右列,这里不多描述,代码实现如下

1
2
3
4
5
6
7
8
9
10
for(i=1;i<=n;i++)
{
for(j=1;j<=n;j++)
{
for(k=1;k<=n;k++)
{
c[i][j] = c[i][j] + a[i][k] * b[k][j];
}
}
}

将矩阵乘法套在快速幂运算中,也就构成了矩阵快速幂运算

1
2
3
4
5
6
7
8
9
10
11
12
13
def multiply_matrices(a, b):
res = [[0, 0], [0, 0]]
for i in range(0, 2):
for j in range(0, 2):
for p in range(0, 2):
res[i][j] = (res[i][j] + a[i][p] * b[p][j]) % M
return res
b = 次方数
while b > 0:
if b & 1:
ans = multiply_matrices(ans, base)
base = multiply_matrices(base, base)
b >>= 1

例题

2024.11.SICTF#Round4.Math Cocktail

题目

1
2
3
4
5
6
7
8
9
from secret import key
x = key
M = 94665789456132156456789461321289656332321
n = 123456789123456789
k = x + pow(x,-1,M)
result = pow(x,n,M) + pow(x,-n,M)
print("k = " + str(k))
flag = "SICTF{"+str(result)+"}"
#k = 15396893775857205606087136852231851457937

根据题意,我们需要求

result=xn+xn1result = x^n + x^{n-1}

我们已知

k=x+x1k = x + x^{-1}

于是,设

an=xn+xn1a_n = x^n +x^{n-1}

我们有

a0=2a1=ka_0 = 2\\ a_1 = k

我们又知道

ana1=(xn+xn1)(x+x1)=xn+1+x(n+1)+xn1+x(n1)=an+1+an1an+1=kanan1an=kan1an2a_n*a_1 = (x^n+x^{n-1})*(x+x^{-1})= x^{n+1}+x^{-(n+1)}+x^{n-1}+x^{-(n-1)}=a_{n+1}+a_{n-1}\\ a_{n+1} = k*a_n - a_{n-1}\\ a_{n} = k*a_{n-1} - a_{n-2}

之后将递推关系转化为矩阵乘法

(anan1)=(k110)(an1an2)\begin{pmatrix} a_n \\ a_{n-1} \end{pmatrix} = \begin{pmatrix} k & -1 \\ 1 & 0 \end{pmatrix} \begin{pmatrix} a_{n-1} \\ a_{n-2} \end{pmatrix}

exp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def multiply_matrices(a, b):
res = [[0, 0], [0, 0]]
for i in range(0, 2):
for j in range(0, 2):
for p in range(0, 2):
res[i][j] = (res[i][j] + a[i][p] * b[p][j]) % M
return res
M = 94665789456132156456789461321289656332321
n = 123456789123456789
k = 15396893775857205606087136852231851457937
base = [[k, 1], [-1, 0]]
ans = [[(k*k-2) % M, k % M], [0, 0]]
b = n - 2
while b > 0:
if b & 1:
ans = multiply_matrices(ans, base)
base = multiply_matrices(base, base)
b >>= 1
print((ans[0][0] + M) % M)