Skip to content

Fast Exponentiation - The Crypto Engine

"A^B trong O(log B) thay vì O(B). Đây là lý do RSA hoạt động được." - HPN

Problem Statement

Tính A^B mod M một cách hiệu quả.

python
# Naive: O(B) - quá chậm cho B lớn
result = 1
for _ in range(B):
    result = (result * A) % M

# Fast: O(log B) 
# Với B = 10^18, chỉ cần ~60 phép nhân!

Why It Matters

ApplicationA^B mod M
RSA Encryptionmessage^e mod n
RSA Decryptionciphertext^d mod n
Diffie-Hellmang^a mod p
Digital Signatureshash^d mod n
Competitive ProgrammingFibonacci thứ 10^18, Matrix exponentiation

📘 RSA Numbers

RSA thường dùng:

  • e = 65537 (public exponent)
  • d ≈ 10^600 (private exponent)
  • n ≈ 10^616 (modulus)

Naive approach: 10^600 phép nhânHàng triệu năm Fast exponentiation: ~2000 phép nhânMilliseconds

The Algorithm: Binary Exponentiation

Key Insight

Biểu diễn B dưới dạng binary và dùng square-and-multiply.

A^13 = A^(1101₂)
     = A^8 × A^4 × A^1
     = (A²)² )² × (A²)² × A

13 = 1×8 + 1×4 + 0×2 + 1×1

Step-by-Step

Calculate 3^13 mod 1000:

Step | B (binary) | Action          | Result | Base
-----|------------|-----------------|--------|------
Init | 1101       |                 | 1      | 3
1    | 1101       | B&1=1 → Multiply| 3      | 9 (3²)
2    | 110        | B&1=0 → Skip    | 3      | 81 (9²)
3    | 11         | B&1=1 → Multiply| 243    | 561 (81² mod 1000)
4    | 1          | B&1=1 → Multiply| 443    | done

3^13 mod 1000 = 594643 mod 1000 = 443 ✓

Implementation

python
def power_mod(base: int, exp: int, mod: int) -> int:
    """
    Calculate base^exp mod mod in O(log exp).
    
    Also known as: Binary Exponentiation, Fast Exponentiation,
    Exponentiation by Squaring.
    
    Time: O(log exp)
    Space: O(1)
    """
    result = 1
    base = base % mod
    
    while exp > 0:
        # If exp is odd, multiply result by base
        if exp & 1:
            result = (result * base) % mod
        
        # exp = exp // 2
        exp >>= 1
        
        # Square the base
        base = (base * base) % mod
    
    return result


# Python built-in (uses same algorithm)
# pow(base, exp, mod) → Fast!

Recursive (Educational)

python
def power_mod_recursive(base: int, exp: int, mod: int) -> int:
    """
    Recursive version for understanding.
    
    Recurrence:
    - A^0 = 1
    - A^n = (A^(n/2))² if n even
    - A^n = A × (A^(n/2))² if n odd
    """
    if exp == 0:
        return 1
    
    if exp % 2 == 0:
        half = power_mod_recursive(base, exp // 2, mod)
        return (half * half) % mod
    else:
        half = power_mod_recursive(base, (exp - 1) // 2, mod)
        return (base * half * half) % mod
))

Matrix Exponentiation

Fibonacci in O(log N)

python
import numpy as np


def matrix_mult(A: list, B: list, mod: int) -> list:
    """Multiply 2x2 matrices mod m."""
    return [
        [(A[0][0]*B[0][0] + A[0][1]*B[1][0]) % mod,
         (A[0][0]*B[0][1] + A[0][1]*B[1][1]) % mod],
        [(A[1][0]*B[0][0] + A[1][1]*B[1][0]) % mod,
         (A[1][0]*B[0][1] + A[1][1]*B[1][1]) % mod]
    ]


def matrix_power(M: list, n: int, mod: int) -> list:
    """Matrix exponentiation: M^n mod m."""
    # Identity matrix
    result = [[1, 0], [0, 1]]
    
    while n > 0:
        if n & 1:
            result = matrix_mult(result, M, mod)
        M = matrix_mult(M, M, mod)
        n >>= 1
    
    return result


def fibonacci_fast(n: int, mod: int = 10**9 + 7) -> int:
    """
    Calculate n-th Fibonacci number in O(log n).
    
    Uses: [F(n+1), F(n)] = [[1,1],[1,0]]^n × [1, 0]
    
    Can compute F(10^18) instantly!
    """
    if n <= 1:
        return n
    
    M = [[1, 1], [1, 0]]
    result = matrix_power(M, n, mod)
    
    return result[0][1]


# Demo
print(f"F(10) = {fibonacci_fast(10)}")      # 55
print(f"F(100) = {fibonacci_fast(100)}")    # Large number mod 10^9+7
print(f"F(10^18) mod 10^9+7 = {fibonacci_fast(10**18)}")

Production Code

python
# HPN Engineering Standard
# Implementation: Fast Exponentiation - Full Featured

from typing import List, Tuple
from functools import lru_cache


class FastPow:
    """Production-ready exponentiation utilities."""
    
    @staticmethod
    def power_mod(base: int, exp: int, mod: int) -> int:
        """Calculate base^exp mod mod in O(log exp)."""
        result = 1
        base = base % mod
        
        while exp > 0:
            if exp & 1:
                result = (result * base) % mod
            exp >>= 1
            base = (base * base) % mod
        
        return result
    
    @staticmethod
    def mod_inverse(a: int, mod: int) -> int:
        """
        Calculate modular inverse: a^(-1) mod m.
        
        Uses Fermat's Little Theorem: a^(-1) = a^(m-2) mod m
        Requires: mod is prime!
        """
        return FastPow.power_mod(a, mod - 2, mod)
    
    @staticmethod
    def mod_divide(a: int, b: int, mod: int) -> int:
        """Calculate (a / b) mod m using modular inverse."""
        return (a * FastPow.mod_inverse(b, mod)) % mod


class MatrixPow:
    """Matrix exponentiation for recurrence relations."""
    
    @staticmethod
    def multiply(A: List[List[int]], B: List[List[int]], mod: int) -> List[List[int]]:
        """Multiply two matrices mod m."""
        n = len(A)
        C = [[0] * n for _ in range(n)]
        for i in range(n):
            for j in range(n):
                for k in range(n):
                    C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod
        return C
    
    @staticmethod
    def power(M: List[List[int]], exp: int, mod: int) -> List[List[int]]:
        """Calculate M^exp mod m."""
        n = len(M)
        # Identity matrix
        result = [[1 if i == j else 0 for j in range(n)] for i in range(n)]
        
        while exp > 0:
            if exp & 1:
                result = MatrixPow.multiply(result, M, mod)
            M = MatrixPow.multiply(M, M, mod)
            exp >>= 1
        
        return result


class Combinatorics:
    """Combinatorics with modular arithmetic."""
    
    def __init__(self, max_n: int, mod: int = 10**9 + 7):
        """
        Precompute factorials and inverse factorials.
        
        Time: O(max_n)
        Space: O(max_n)
        """
        self.mod = mod
        self.fact = [1] * (max_n + 1)
        self.inv_fact = [1] * (max_n + 1)
        
        for i in range(1, max_n + 1):
            self.fact[i] = (self.fact[i-1] * i) % mod
        
        self.inv_fact[max_n] = FastPow.mod_inverse(self.fact[max_n], mod)
        for i in range(max_n - 1, -1, -1):
            self.inv_fact[i] = (self.inv_fact[i+1] * (i+1)) % mod
    
    def nCr(self, n: int, r: int) -> int:
        """Calculate C(n, r) mod m in O(1)."""
        if r < 0 or r > n:
            return 0
        return (self.fact[n] * self.inv_fact[r] % self.mod) * self.inv_fact[n-r] % self.mod
    
    def nPr(self, n: int, r: int) -> int:
        """Calculate P(n, r) mod m in O(1)."""
        if r < 0 or r > n:
            return 0
        return self.fact[n] * self.inv_fact[n-r] % self.mod


# ============================================
# RSA DEMO (EDUCATIONAL)
# ============================================

def generate_rsa_keys(p: int, q: int) -> Tuple[Tuple[int, int], Tuple[int, int]]:
    """
    Generate RSA key pair from two primes.
    
    ⚠️ EDUCATIONAL ONLY! Real RSA uses 2048+ bit primes.
    
    Returns: (public_key, private_key)
             public_key = (e, n)
             private_key = (d, n)
    """
    n = p * q
    phi = (p - 1) * (q - 1)
    
    # Common public exponent
    e = 65537
    
    # Private exponent: d = e^(-1) mod phi
    # Using extended Euclidean algorithm
    d = pow(e, -1, phi)
    
    return ((e, n), (d, n))


def rsa_encrypt(message: int, public_key: Tuple[int, int]) -> int:
    """Encrypt: ciphertext = message^e mod n"""
    e, n = public_key
    return FastPow.power_mod(message, e, n)


def rsa_decrypt(ciphertext: int, private_key: Tuple[int, int]) -> int:
    """Decrypt: message = ciphertext^d mod n"""
    d, n = private_key
    return FastPow.power_mod(ciphertext, d, n)


# ============================================
# USAGE EXAMPLE
# ============================================

if __name__ == "__main__":
    print("=== Fast Exponentiation ===")
    
    # Basic power mod
    result = FastPow.power_mod(2, 100, 10**9 + 7)
    print(f"2^100 mod 10^9+7 = {result}")
    
    # Compare with Python built-in
    assert result == pow(2, 100, 10**9 + 7)
    print("Matches Python pow() ✓")
    
    # Fibonacci
    print(f"\nF(50) = {MatrixPow.power([[1,1],[1,0]], 50, 10**9+7)[0][1]}")
    
    # Combinatorics
    print("\n=== Combinatorics ===")
    comb = Combinatorics(1000)
    print(f"C(10, 3) = {comb.nCr(10, 3)}")  # 120
    print(f"C(100, 50) mod 10^9+7 = {comb.nCr(100, 50)}")
    
    # RSA Demo
    print("\n=== RSA Demo (Educational) ===")
    p, q = 61, 53  # Small primes for demo
    public_key, private_key = generate_rsa_keys(p, q)
    print(f"Public key (e, n): {public_key}")
    print(f"Private key (d, n): {private_key}")
    
    message = 42
    encrypted = rsa_encrypt(message, public_key)
    decrypted = rsa_decrypt(encrypted, private_key)
    
    print(f"Original: {message}")
    print(f"Encrypted: {encrypted}")
    print(f"Decrypted: {decrypted}")
    print(f"Match: {message == decrypted} ✓")

Complexity Analysis

AlgorithmTimeSpace
NaiveO(B)O(1)
Binary ExponentiationO(log B)O(1)
Matrix Exp (NxN)O(N³ log B)O(N²)

Common Applications

ProblemSolution
A^B mod MBinary exponentiation
Fibonacci(10^18)Matrix exponentiation
C(n,r) mod pPrecompute factorials + modular inverse
A^(-1) mod pA^(p-2) mod p (Fermat)
RSA encryptionM^e mod n

💡 HPN's Rule

"Exponent lớn (>1000) → Binary exponentiation. Recurrence relation → Matrix exponentiation. Modular division → Fermat's inverse."