Python 与 RSA 算法

16 min read

cover

本周,Computer Security 这门课的作业是用代码实现 RSA 算法,要求加密 + 解密都要实现,语言不限。

这种东西网上一搜一大把,但因为 RSA 太重要了,老师额外要求要 “Defend your code!” (相当于当着他面,一对一做一次 Code Review)。没办法,老老实实的把相关课件重新学了一遍,然后动笔实现。害怕到时候讲不清楚,所以特地写一篇博客理一遍思路。

虽说语言不限,但能方便快捷的操作这么大的数 (老师要求使用至少 1024 bit 的质数!) 的语言还真不多,对比了一下,最终选择了 Python。

快速模幂运算

快速模幂运算,是与 RSA 相关的算法中最基础也是最核心的一个算法了。毕竟 RSA 加密和解密的本质就是做一个模幂运算;而求我们生成一个大质数的过程中也需要用到这个函数 (费马定理,验证这个数是否为质数)。

这个算法主要是为了减少计算机的计算量,毕竟要硬算 $ 1.047247e152^{5.80901e149} \pmod{65537} $ 这样的大数的模幂可不容易。

我们都知道,关于取模,有这么一个公式:

$$ c \bmod m = (a \cdot b) \bmod m = [(a \bmod m) \cdot (b \bmod m)] \bmod m $$

再结合 $ a^{b + c} = a^{b} \cdot a^{c} $,我们可以很容易将我们的模幂运算分解:

$$ a^e \bmod m = a^{b + c} \bmod m = [(a^{b} \bmod m) \cdot (a^{c} \bmod m)] \bmod m $$

但是,如何分解 $ e = b + c $ 最高效又成了问题。这里我们使用平方求幂原理。

首先将 $ e $ 表示成二进制,即:

$$ e = \sum*{i = 0}^{n - 1} a_i 2^i, (a*{n - 1} = 1) $$

$ b^{e} $ 可以表现为:

$$ b^{e} = b^{(\sum*{i = 0}^{n - 1} a_i 2^i)} = \prod*{i = 0}^{n - 1} (b^{2^{i}})^{a_i} $$

因此 $ e $ 为:

$$ e \equiv \prod_{i = 0}^{n - 1} (b^{2^{i}})^{a_i} \pmod m $$

具体实现的代码如下 (递归):

def powmod(p, q, m):
    if not q: return 1
    if q & 1: return (p * pow_mod(p, q - 1, m)) % m
    return pow_mod((p ** 2) % n, q >> 1, m)

虽说递归的实现最为直观,但我在后面的实际使用过程中遇到了性能问题,不得已重新用 while 重构了一遍:

def powmod(p, q, m):
    r = 1
    while q:
        if q & 1:
            r = (r * p) % m
        q >>= 1
        p = (p * p) % m
    return r

生成一个大大的 (伪) 质数

接着,我们需要一个函数用来生成一个固定位数的大数,这个数可能是质数也可能是合数。

我们直接用 Python 的 random 模块来生成一个 w 位的二进制字符串 (仅有 01 组成的字符串),然后将字符串转化为数字 (十进制):

import random

def fake_prime(w):
    list = []
    list.append('1')
    for i in range(w - 2):
        list.append(random.choice(['0', '1']))
    list.append('1')
    return int(''.join(list), 2)

需要注意的是,二进制字符串的首位必须位 '1',因为如果首位为 0,就不能保证转换后的数字是 w 位了。同时,末尾也必须为 '1',因为我们的这个函数的目的是生成 (伪) 质数,而所有的偶数都是合数。

米勒-拉宾素性检验

接下来我们需要一个函数来对生成的 (伪) 质数进行验证,这里我们需要用的米勒-拉宾素性检验 (Miller-Rabin prime test)。

这种算法基于大名鼎鼎的黎曼猜想,其利用了费马小定理的结论:

对于素数 $ n $ 和任意整数 $ a $,有

$$ a^{n − 1} \equiv 1 \pmod n $$

米勒-拉宾素性检验在费马定理的基础上增加了推论:

对于任何大于 $ 2 $ 的素数 $ p $,不存在 $ 1 \bmod p $ 的“非平凡平方根”1

让我们假设 $ a $ 是 $ 1 \bmod p $ 的平方根之一 ($ 0 < a < p $),于是有:

$$ a^{2} \equiv 1 \pmod p \newline (a - 1)(a + 1) \equiv 0 \pmod p $$

因为 $ p $ 是质数,$ (a - 1) \bmod p $ 和 $ (a + 1) \bmod p $ 均不可能为 $ 0 $。

那么唯一的可能性只能是:

$$ (a - 1)(a + 1) = 0 \space or \space (a - 1)(a + 1) = p $$

所以:

$$ a = 1 \space or \space a = p - 1 $$

我们获得推论:如果 $ p $ 是素数,则 $ a^{2} \equiv 1 \pmod p $ 的解为 $ a = 1 $ 或 $ a = p−1 $。

但是对于所有数都使用米勒-拉宾素性检验有点不划算,因为这种算法的开销其实还是挺大的。

我们可以先将 $ 1 $ 至 $ 100 $ 以内的所有素数列出来,用这些素数对需要测试的数取模,其结果如果为 $ 0 $ ,则必并是合数 (素数的倍数必定是合数);然后再对非零的情况使用米勒-拉宾素性检验:

def prime_miller_rabin(a, p):
    prime_numbers = (2,3,5,7,11,13,17,19,23,29,31,37,41,
                     43,47,53,59,61,67,71,73,79,83,89,97)
    for y in prime_numbers:
        if p % y == 0:
            return False

    def fermat():
        return powmod(a, p - 1, p) == 1

    def miller_rabin():
        d = p - 1
        q = 0
        while not(d & 1):
            q = q+1
            d >>= 1
        m = d

        for i in range(q):
            u = m * (2**i)
            tmp = powmod(a, u, p)
            if tmp == 1 or tmp == p - 1:
                return True
        return False

    return fermat() and miller_rabin()

需要注意的是,米勒-拉宾素性检验是个证明素数的只是一个必要但不充分条件,有极小的概率误判;不过当我们取多个 a 都能通过测试的话,我们可以将误判的概念忽略不计。

因此用一个函数 prime_testprime_miller_rabin 包装一下,生成多个随机数 a,测试其是否都符合条件。

import random

def prime_test(n, k):
    if k <= 0: return True
    a = random.randint(2, n - 1)
    return prime_miller_rabin(a, n) and prime_test(n, k - 1)

接着用 prime_test 来测试 fake_prime 生成的数是否是质数,如果不是,则重新生成,直到获得一个质数为止。

def get_prime(bit):
    prime_number = fake_prime(bit)
    if prime_test(prime_number, 5):
        return prime_number
    return get_prime(bit)

实现 RSA 加密

终于可以正式开始实现 RSA 算法了!

我们先来看一下公式:

$$ C = M^{e} \bmod n $$

$ M $ 是需要加密的信息。$ (n, e) $ 是我们的公钥。其中 $ n $ 就是 $ p $ 与 $ q $ 的乘积,即:

$$ n = p \cdot q $$

而 $ e $ 是一个 小于 $ r $ 并且与 $ r $ 互质的正整数。

$ r $ 就是所谓的欧拉函数,公式如下:

$$ r = \phi(n) = \phi(p) \cdot \phi(q) = (p - 1) \cdot (q - 1) $$

$ e $ 的大小直接影响了加密的速度,所以一般不会很大,也不会很小,并且里面 (二进制) 的 1 要尽可能的少。这里我们选择 0b1000000000001 (即 4097) 作为默认的 $ e $。

这个 $ r $ 虽然不直接参与 RSA 加密,但其必须 $ e $ 互质;我们通过欧几里得算法求 $ e $ 与 $ r $ 的最大公约数,进而判断它们是否互质;如果不互质,就换一个 $ e $ (e -= 1):

def get_r_e(p, q, e = 4097):
    def gcd(a, b):
        if not b: return a
        return gcd(b, a % b)

    r = (p - 1) * (q - 1)

    if gcd(e, r) == 1:
        return (r, e)
    return gen_orla(p, q, e - 1)

套用公式,实现对 $ M $ 的加密:

if __name__ == "__main__":

    p = get_prime(1024)
    q = get_prime(1024)
    n = p * q

    r, e = get_r_e(p, q)

    M = int("7355608")

    C = powmod(M, e, n)

实现 RSA 解密

除了加密,我们还需要实现解密。

同样的,先来看一看公式:

$$ M = C^{d} \bmod n $$

这里的 $ C $ 是上一步得到的密文。$ (n, d) $ 是我们的私钥。

$ n $ 已经解释过了,是 $ p $ 与 $ q $ 的乘积 ($ n = p \cdot q $)。

这个 $ d $ 是 $ e $ 关于 $ r $ 的逆模元,即:

$$ e \cdot d \equiv 1 \pmod r $$

所以:

$$ d = e^{-1} \bmod r $$

拓展欧几里得算法

这里我们使用拓展欧几里得算法:

令 $ s_0 = e $,$ s_1 = r $,$ x_0 = 1 $,$ x_1 = 1 $,$ y_0 = 0 $,$ y_1 = 1 $,

当 $ s_{i} > 0 $ 时:

$$ x*{i} = \lfloor \dfrac{s*{i-2}}{s*{i-1}} \rfloor \cdot x*{i-1} + x*{i-2} \newline y*{i} = \lfloor \dfrac{s*{i-2}}{s*{i-1}} \rfloor \cdot y*{i-1} + y*{i-2} \newline s*{i} = s*{i-2} \bmod s_{i-1} $$

当 $ s_{n-1} = 0 $ 时:

$$ (e, r) = s*{n} \newline x = (-1)^{n} \cdot x*{n} \newline y = (-1)^{n+1} \cdot y_{n} $$

写成代码就是:

def mod_inverse(e, r):
    x0, x1, x2 = r, 0, 1
    while r != 0:
        x1, x2 = x2 - ((e // r) * x1), x1
        e, r = r, e % r
    if x2 < 0: x2 += x0
    return x2

然后套用公式,就可以对 $ C $ 进行解密了:

if __name__ == "__main__":

    # p, q, C, r, n 来自于上一步
    d = mod_inverse(e, r)
    D = powmod(C, d, n)

到现在为止,虽然 RSA 算法已经可以用了,但仍然很慢;特别是解密,当质数的位数非常大时就非常容易卡死。

因此,我们需要使用中国剩余算法对解密进行优化。

中国剩余算法

中国剩余算法来自于《孙子算经》的第 26 题:

今有物不知其数,三三数之剩二,五五数之剩三,七七数之剩二,问物几何? 答曰:二十三。 术曰:三三数之剩二,置一百四十;五五数之剩三,置六十三,七七数之剩二,置三十,并之。得二百三十三,以二百一十减之,即得。凡三三数之剩一,则置七十;五> 五数之剩一,则置二十一;七七数之剩一,则置十五;一百六以上以一百五减之即得。

其核心思想就是辗转相除法。

对于给定的数 $ k $,令 $ n*{i} (1 \leq i \leq k) $ 与其互质;令 $ a*{i} (1 \leq i \leq k) $ 满足 $ 0 \leq a*{i} \leq n*{i} $。此时,有且只有一个数 $ x $ 同时满足:

$$ 0 \leq x < N (N = n*{1} \cdot n*{2} \cdot \mathellipsis \cdot n_{k}) $$

$$ x \equiv a*{1} \bmod{n*{1}} \newline x \equiv a*{2} \bmod{n*{2}} \newline \vdots \newline x \equiv a*{k} \bmod{n*{k}} $$

其在 RSA 解密中的应用如下:

  1. 计算 $ d*{p} $ 和 $ d*{q} $:

    $$ d*{p} \equiv d \bmod{p - 1} \newline d*{q} \equiv d \bmod{q - 1} $$

  2. 计算 $ m*{p} $ 和 $ m*{q} $:

    $$ m*{p} \equiv c^{d*{p}} \bmod{p} \newline m*{q} \equiv c^{d*{q}} \bmod{p} $$

  3. 使用 $ m $ 将其改写为同余的形式:

    $$ m \equiv m*{p} \bmod{p} \newline m \equiv m*{q} \bmod{q} $$

  4. 使用中国剩余定理计算 $ m $:

    $$ 1 = y*{p}p + y*{q}q $$

    因为 $ n = p \cdot q $,所以:

    $$ m \equiv (m*{p} y*{q} q + m*{q} y*{p} p) \bmod{n} $$

写成代码就是:

def chinese_remainder(c, d, p, q):
    dp = powmod(d, 1, p - 1)
    dq = powmod(d, 1, q - 1)

    m1 = powmod(c, dp, p)
    m2 = powmod(c, dq, q)

    qinv = mod_inverse(q, p)
    h = (qinv * (m1 - m2)) % p
    m = m2 + h * q
    return m

然后使用 chinese_remainder 代替 powmod 进行解密:

if __name__ == "__main__":

    # p, q, C, r, n 来自于上一步
    d = mod_inverse(e, r)
    D = chinese_remainder(C, d, p, q)

用 Java 实现

传送门 -> RSA 加解密的 Java 实现

References

附录

完整的代码如下:

#!/usr/bin/env python

import random

def powmod(p, q, m):
    """
    p^q mod n
    """
    r = 1
    while q:
        if q & 1:
            r = (r * p) % m
        q >>= 1
        p = (p * p) % m
    return r

def fake_prime(w):
    list = []
    list.append('1')
    for i in range(w - 2):
        list.append(random.choice(['0', '1']))
    list.append('1')
    return int(''.join(list), 2)

def prime_miller_rabin(a, p):
    prime_numbers = (2,3,5,7,11,13,17,19,23,29,31,37,41,
                     43,47,53,59,61,67,71,73,79,83,89,97)
    for y in prime_numbers:
        if p % y == 0:
            return False

    def fermat():
        """
        a^(p-1) = 1 (mod p)
        """
        return powmod(a, p - 1, p) == 1

    def miller_rabin():
        """
        a^2 = 1 (mod p)
        a = 1 or a = p - 1
        """
        d = p - 1
        q = 0
        while not (d & 1):
            q = q+1
            d >>= 1
        m = d

        for i in range(q):
            u = m * (2**i)
            tmp = powmod(a, u, p)
            if tmp == 1 or tmp == p - 1:
                return True
        return False

    return fermat() and miller_rabin()


def prime_test(n, k):
    if k <= 0: return True
    a = random.randint(2, n - 1)
    return prime_miller_rabin(a, n) and prime_test(n, k - 1)

def get_prime(bit):
    prime_number = fake_prime(bit)
    if prime_test(prime_number, 5):
        return prime_number
    return get_prime(bit)

def get_r_e(p, q, e = 4097):
    def gcd(a, b):
        """
        Euclidean algorithm
        """
        if not b: return a
        return gcd(b, a % b)

    def euler(p, q):
        """
        Euler's totient function
        """
        return (p - 1) * (q - 1)

    r = euler(p, q)

    if gcd(e, r) == 1:
        return (r, e)
    return get_r_e(p, q, e - 1)

def mod_inverse(e, r):
    """
    Modular inverse operation\n
    y1 = e^-1 mod r
    """

    x0, x1, x2 = r, 0, 1
    while r != 0:
        x1, x2 = x2 - ((e // r) * x1), x1
        e, r = r, e % r
    if x2 < 0: x2 += x0
    return x2

def chinese_remainder(c, d, p, q):
    """
    Chinese remainder theorem
    """
    dp = powmod(d, 1, p - 1)
    dq = powmod(d, 1, q - 1)

    m1 = powmod(c, dp, p)
    m2 = powmod(c, dq, q)

    qinv = mod_inverse(q, p)
    h = (qinv * (m1 - m2)) % p
    m = m2 + h * q
    return m

if __name__ == "__main__":

    p = get_prime(1024)
    q = get_prime(1024)
    n = p * q

    r, e = get_r_e(p, q)
    d = mod_inverse(e, r)

    print('Private key:\n')
    print('p: %d\n' % p)
    print('q: %d\n' % q)
    print('d: %d\n\n' % d)

    print('Public key:\n')
    print('n: %d\n' % n)
    print('e: %d\n\n' % e)

    M = int(input("Input message: "))

    C = powmod(M, e, n) # encryption
    print('\nAfter encryption is completed, the ciphertext obtained:\n%d\n' % C)

    D = chinese_remainder(C, d, p, q) # decryption
    print('Decryption is complete, and the plaintext obtained is:\n%d\n' % D)

Footnotes

  1. 我们发现 $ 1^{2} \bmod p $ 和 $(-1)^{2} \bmod p $ 总是得到 $ 1 $,我们称 $ −1 $ 和 $ 1 $ 是 $ 1 \bmod p $ 的“平凡平方根”