題目網址(需登入 Exercism)

這個題目在 Rust 語言方面,沒有什麼太困難的地方。所以重點都是在 modular arithmetic 的運算上。

use rand::Rng;

pub fn private_key(p: u64) -> u64 {
    let mut rng = rand::thread_rng();
    rng.gen_range(2..p)
}

pub fn public_key(p: u64, g: u64, a: u64) -> u64 {
    modular_exp(g, a, p)
}

pub fn secret(p: u64, b_pub: u64, a: u64) -> u64 {
    modular_exp(b_pub, a, p)
}

首先,是要計算 r = be mod m 的結果,也就是 modular exponentiation 運算。

  • be 在 e 很大的情況下,幾乎是保證會 overflow,因此不可能硬幹。
  • 另一個直覺的方法,是在 b 連乘 e 次的過程中,每乘一次就取一次 mod m。這個方法(比較)不會爆,但時間複雜度是 O(e),e 很大的話,很慢。
  • 我們可以從 bit[0] (LSB) 至 bit[n-1] (MSB) 逐一檢視 e 的每個 bit。若第 k 個 bit 為 1 的話,就表示 r 之中乘入了 2k 個 b。因此我們可以用 loop 去掃 e 的每一個 bit,如果是 1,就乘入 b(然後 mod m)。
    又,b(2 ^ k) * b(2 ^ k) = b(2 ^ (k + 1)),所以在每 loop 一次時,b 只要再乘一次自己(然後 mod m),就可以取得下一次所需要的 b。
    這個作法可以把時間複雜度降至 O(log(e)),當 e 很大的時候,效果驚人。
    (可以參考 Wikipedia 中的 模冪 條目)
fn modular_exp(mut b: u64, mut e: u64, m: u64) -> u64 {
    if m == 1 {
        return 0;
    }
    let mut result = 1;
    b %= m;
    while e > 0 {
        if e % 2 == 1 {
            result = result * b % m;
        }
        e >>= 1;
        b = b * b % m;
    }
    result
}

解決了 modular exponentiation 的問題後,基本的 test cases 應該都能在很短的時間內完成了。但這個題目還有 bonus:萬一 m 很大,b 沒有辦法靠一開始的 mod m 變小,導致 b * b 就 overflow 了,怎麼辦?對於內建 big integer 的程式語言(例如 Python, Haskell)來說,這不是什麼問題(頂多計算速度變慢);但 Rust 並不是。所以必須要靠別的方法來解決。

要解決這個問題,基本上是要靠 Montgomery modular multiplication(中譯:蒙哥馬利乘模)來幫忙。這個演算法的詳細解說可以參考 Wikipedia 上 英文版的條目(中文版跟沒寫沒兩樣)或是這篇 CSDN 上的 簡體中文解說。後者其實寫得很不錯。老實說我都沒看完。 XDD 但看了 CSDN 上那篇的「預備知識」後,其實就有能力推導出這個程式所需要的演算法了。

首先解釋什麼是 Zn。Zn 其實就是一組正整數的集合,代表了 mod N 的相關計算時所有可能的結果,也就是 (0, 1, 2,..., N-1)。所有的計算,參數必是 Zn,結果也一樣會是 Zn。

依照這樣的概念,我們先思考 Zn 的加法該怎麼做?最直覺的想法當然就是 (Zn(a) + Zn(b)) % N。但是,假設變數的型別都是 u64,只要 N > 0x8000_0000_0000_0000,Zn(a) + Zn(b) 就有可能會 overflow。所以我們必須小心處理。在 fn zn_add() 中,利用了移項的技巧,確保每一步驟都不會 overflow。

fn zn_add(a: u64, b: u64, m: u64) -> u64 {
    let zn_a = a % m;
    let zn_b = b % m;
    if zn_a >= m - zn_b {    // === if zn_a + zn_b >= m {
        zn_a - (m - zn_b)   // ===     zn_a + zn_b - m 
    } else {
        zn_a + zn_b
    }
}

有了 Zn 的加法後,就可以進一步構築 Zn 的乘法了。我們從十進位的乘法開始思考。假設我們要計算 x * y,而 y 是 4 位數,可以表示為 (y3 * 1000 + y2 * 100 + y1 * 10 + y0 * 1)。那麼 x * y 就可以變成:

x * (y3 * 1000 + y2 * 100 + y1 * 10 + y0 * 1)
= xy3 * 1000 + xy2 * 100 + xy1 * 10 + xy0
= (((xy3) * 10 + xy2) * 10 + xy1) * 10 + xy0

看起來很複雜,但若是以程式的 loop 角度來看,其實就是「r 的初始值為 0;y 從左到右,每遇到一位數字,r 先乘以 10,再加上該位數乘以 x 的值,一直到 y 的最右邊一位」。

把 10 進位的概念轉到 2 進位:

  • 「r 先乘以 10」變成「r 先乘以 2」;然後 r * 2 就是 r + r,因此就可以套用 Zn 加法。
  • 「加上該位數乘以 x 的值」.... 2 進位只有 0 或 1,所以就是「該位數為 1 的時候,加上 x」。這也可以套用 Zn 加法。

這就是 Zn 乘法的演算法。

fn zn_mul(a: u64, b: u64, m: u64) -> u64 {
    let zn_a = a % m;
    let zn_b = b % m;
    let digits = format!("{:b}", zn_a);
    let mut result = 0;
    for ch in digits.chars() {
        result = zn_add(result, result, m);
        if ch == '1' {
            result = zn_add(result, zn_b, m);
        }
    }
    result
}

如果用 iterator 的 fn fold() 來實作的話(最近在養成儘可能少用 mut 的習慣):

fn zn_mul(a: u64, b: u64, m: u64) -> u64 {
    let zn_a = a % m;
    let zn_b = b % m;
    format!("{:b}", zn_a)
        .chars()
        .fold(0, |r, ch| {
            if ch == '1' {
                zn_add(zn_add(r, r, m), zn_b, m)
            } else {
                zn_add(r, r, m)
            }
        })
}

最後修改 fn modular_exp(),用 zn_mul() 取代原本的乘法。順便(?)再改用 iterator 消 mut

fn modular_exp(b: u64, e: u64, m: u64) -> u64 {
    if m == 1 {
        0
    } else {
        format!("{:b}", e)
            .chars()
            .rev()
            .fold((b % m, 1), |(b, r), c| {
                if c == '1' {
                    (zn_mul(b, b, m), zn_mul(r, b, m))
                } else {
                    (zn_mul(b, b, m), r)
                }
            })
            .1
    }
}

Case closed.

By closer

發表迴響