RSA 加解密的 Java 实现

4 min read

书接上文;我们已经用 Python 实现了 RSA 加解密算法。

今天再来撸一篇 RSA 算法的 Java 实现。与 Python 与 RSA 算法采用同样的算法和思路 (推荐对比着一起看)。因此,就不讲解具体实现过程了,想看公式可以去:Python 与 RSA 算法

主要是帮朋友交作业,所以细节肯定不可能那么精致 (甚至可能还有点 bug),并且 (暂时) 没有实现中国剩余算法。

顺便吐槽一下,Java 的 BigInteger 是真的难用 (ー`´ー)。

具体代码如下:

import java.math.BigInteger;
import java.util.Random;
import java.lang.StringBuffer;
import java.util.Scanner;

/**
 * rsa
 */
public class rsa {

    public static void main(String[] args) {
        Calculate c = new Calculate();
        Scanner sc = new Scanner(System.in);
        BigInteger p = (new Prime()).getPrime(1024);
        BigInteger q = (new Prime()).getPrime(1024);
        BigInteger e = new BigInteger("65537");
        BigInteger n = p.multiply(q); // n = p * q
        BigInteger r = c.euler(p, q); // (p-1)(q-1)
        BigInteger d = c.modInverse(e, r);

        while (true) {
            if (c.gcd(e, r).equals(BigInteger.ONE)) {
                break;
            } else {
                e = e.subtract(BigInteger.ONE);
            }
        }

        System.out.println("p: " + p);
        System.out.println("q: " + q);
        System.out.println("d: " + d);
        System.out.println("n: " + n);
        System.out.println("e: " + e);

        System.out.print("Input message: ");
        BigInteger M = new BigInteger(sc.next());
        BigInteger C = c.powmod(M, e, n);
        System.out.println("Ciphertext: " + C);
        BigInteger M2 = c.powmod(C, d, n);
        System.out.println("Decrypted message: " + M2);
        sc.close();
    }
}

class Calculate {
    public BigInteger powmod(BigInteger p, BigInteger q, BigInteger n) {
        BigInteger result = new BigInteger("1");
        while (q.compareTo(BigInteger.ZERO) == 1) {
            if (q.testBit(0)) {
                result = result.multiply(p).mod(n);
            }
            q = q.shiftRight(1);
            p = p.multiply(p).mod(n);
        }
        return result;
    }

    public BigInteger euler(BigInteger p, BigInteger q) {
        return p.subtract(BigInteger.ONE).multiply(q.subtract(BigInteger.ONE));
    }

    public BigInteger gcd(BigInteger a, BigInteger b) {
        while (a.compareTo(BigInteger.ZERO) != 0) {
            BigInteger t = a;
            a = b.mod(a);
            b = t;
        }
        return b;
    }

    public BigInteger modInverse(BigInteger a, BigInteger b) {
        // y1 = e^-1 mod r
        BigInteger x0 = a, x1 = BigInteger.ZERO,
                   y0 = b, y1 = BigInteger.ONE;
        while (y0.compareTo(BigInteger.ZERO) != 0) {
            BigInteger t1 = x1;
            x1 = y1.subtract(x0.divide(y0).multiply(x1));
            y1 = t1;
            BigInteger t0 = x0;
            x0 = y0;
            y0 = t0.mod(y0);
        }
        if (y1.compareTo(BigInteger.ZERO) == -1) {
            y1 = y1.add(b);
        }
        return y1;
    }
}

class Prime {

    Calculate c = new Calculate();

    public BigInteger fakePrime(int bits) {
        Random r = new Random();
        StringBuffer sb=new StringBuffer();
        sb.append("1");
        for (int i = 0; i < bits - 2; i++) {
            sb.append(r.nextInt(2));
        }
        sb.append("1");
        return new BigInteger(sb.toString(), 2);
    }

    public boolean prime_miller_rabin(BigInteger n, BigInteger k) {
        BigInteger one = BigInteger.ONE;
        BigInteger two = BigInteger.TWO;

        String[] primes = {"2", "3", "5", "7", "11", "13", "17", "19", "23", "29", "31",
                           "37", "41", "43", "47", "53", "59", "61", "67", "71", "73",
                           "79","83","89","97"};
        for (String prime : primes) {
            if (n.divideAndRemainder(new BigInteger(prime))[1].equals(BigInteger.ZERO)) {
                return false;
            }
        }

        if (c.powmod(k, n.subtract(one), n).equals(one)) {
            BigInteger d = n.subtract(one);
            int q = 0;
            while (!d.testBit(0)) {
                q++;
                d = d.shiftRight(1);
            }
            BigInteger m = d;

            for (int i = 0; i < q; i++) {
                BigInteger u = m.multiply(two.pow(i));
                BigInteger tmp = c.powmod(k, u, n);
                if (tmp.equals(one) || tmp.equals(n.subtract(one))) {
                    return true;
                }
            }
            return false;
        } else {
            return false;
        }
    }

    public BigInteger getPrime(int bit) {
        BigInteger prime_number = fakePrime(bit);
        BigInteger rand = fakePrime(50);
        while (true) {
            boolean flag = false;
            for (int i = 0; i < 5; i++) {
                if (!this.prime_miller_rabin(prime_number, rand)) {
                    flag = true;
                }
            }
            if (flag) {
                prime_number = fakePrime(bit);
            } else {
                return prime_number;
            }
        }
    }
}