From 4ea472808437c932f6240c0df5e4d5912d282052 Mon Sep 17 00:00:00 2001 From: Frank Denis <124872+jedisct1@users.noreply.github.com> Date: Tue, 18 Nov 2025 16:39:58 +0100 Subject: [PATCH] Align ML-KEM code with ML-DSA (#25964) This will facilitate maintainance and code sharing between primitives. --- lib/std/crypto/ml_kem.zig | 534 ++++++++++++++++++++++++-------------- 1 file changed, 338 insertions(+), 196 deletions(-) diff --git a/lib/std/crypto/ml_kem.zig b/lib/std/crypto/ml_kem.zig index 0a8e73f785..db9f0c2736 100644 --- a/lib/std/crypto/ml_kem.zig +++ b/lib/std/crypto/ml_kem.zig @@ -105,19 +105,20 @@ const crypto = std.crypto; const errors = std.crypto.errors; const math = std.math; const mem = std.mem; -const RndGen = std.Random.DefaultPrng; const sha3 = crypto.hash.sha3; -// Q is the parameter q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1. +const RndGen = std.Random.DefaultPrng; + +// Q is the modulus q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1 const Q: i16 = 3329; -// Montgomery R +// Montgomery R = 2^16 mod Q (for Montgomery multiplication) const R: i32 = 1 << 16; -// Parameter n, degree of polynomials. +// N is the degree of polynomials (polynomial ring dimension) const N: usize = 256; -// Size of "small" vectors used in encryption blinds. +// eta2 is the size of "small" vectors used in encryption blinds const eta2: u8 = 2; const Params = struct { @@ -215,7 +216,7 @@ fn Kyber(comptime p: Params) type { pub const ciphertext_length = Poly.compressedSize(p.du) * p.k + Poly.compressedSize(p.dv); const Self = @This(); - const V = Vec(p.k); + const V = PolyVec(p.k); const M = Mat(p.k); /// Length (in bytes) of a shared secret. @@ -241,7 +242,7 @@ fn Kyber(comptime p: Params) type { hpk: [h_length]u8, // H(pk) /// Size of a serialized representation of the key, in bytes. - pub const bytes_length = InnerPk.bytes_length; + pub const encoded_length = InnerPk.encoded_length; /// Generates a shared secret, and encapsulates it for the public key. /// If `seed` is `null`, a random seed is used. This is recommended. @@ -289,14 +290,14 @@ fn Kyber(comptime p: Params) type { } /// Serializes the key into a byte array. - pub fn toBytes(pk: PublicKey) [bytes_length]u8 { + pub fn toBytes(pk: PublicKey) [encoded_length]u8 { return pk.pk.toBytes(); } /// Deserializes the key from a byte array. - pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!PublicKey { + pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!PublicKey { var ret: PublicKey = undefined; - ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.bytes_length]); + ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.encoded_length]); sha3.Sha3_256.hash(buf, &ret.hpk, .{}); return ret; } @@ -310,8 +311,8 @@ fn Kyber(comptime p: Params) type { z: [shared_length]u8, /// Size of a serialized representation of the key, in bytes. - pub const bytes_length: usize = - InnerSk.bytes_length + InnerPk.bytes_length + h_length + shared_length; + pub const encoded_length: usize = + InnerSk.encoded_length + InnerPk.encoded_length + h_length + shared_length; /// Decapsulates the shared secret within ct using the private key. pub fn decaps(sk: SecretKey, ct: *const [ciphertext_length]u8) ![shared_length]u8 { @@ -346,18 +347,18 @@ fn Kyber(comptime p: Params) type { } /// Serializes the key into a byte array. - pub fn toBytes(sk: SecretKey) [bytes_length]u8 { + pub fn toBytes(sk: SecretKey) [encoded_length]u8 { return sk.sk.toBytes() ++ sk.pk.toBytes() ++ sk.hpk ++ sk.z; } /// Deserializes the key from a byte array. - pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!SecretKey { + pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!SecretKey { var ret: SecretKey = undefined; comptime var s: usize = 0; - ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.bytes_length]); - s += InnerSk.bytes_length; - ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]); - s += InnerPk.bytes_length; + ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.encoded_length]); + s += InnerSk.encoded_length; + ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.encoded_length]); + s += InnerPk.encoded_length; ret.hpk = buf[s..][0..h_length].*; s += h_length; ret.z = buf[s..][0..shared_length].*; @@ -418,7 +419,7 @@ fn Kyber(comptime p: Params) type { // Cached values aT: M, - const bytes_length = V.bytes_length + 32; + const encoded_length = V.encoded_length + 32; fn encrypt( pk: InnerPk, @@ -436,7 +437,7 @@ fn Kyber(comptime p: Params) type { // Note that coefficients of r are bounded by q and those of Aᵀ // are bounded by 4.5q and so their product is bounded by 2¹⁵q // as required for multiplication. - u.ps[i] = pk.aT.vs[i].dotHat(rh); + u.ps[i] = pk.aT.rows[i].dotHat(rh); } // Aᵀ and r were not in Montgomery form, so the Montgomery @@ -451,14 +452,14 @@ fn Kyber(comptime p: Params) type { return u.compress(p.du) ++ v.compress(p.dv); } - fn toBytes(pk: InnerPk) [bytes_length]u8 { + fn toBytes(pk: InnerPk) [encoded_length]u8 { return pk.th.toBytes() ++ pk.rho; } - fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!InnerPk { + fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!InnerPk { var ret: InnerPk = undefined; - const th_bytes = buf[0..V.bytes_length]; + const th_bytes = buf[0..V.encoded_length]; ret.th = V.fromBytes(th_bytes).normalize(); if (p.ml_kem) { @@ -468,7 +469,7 @@ fn Kyber(comptime p: Params) type { } } - ret.rho = buf[V.bytes_length..bytes_length].*; + ret.rho = buf[V.encoded_length..encoded_length].*; ret.aT = M.uniform(ret.rho, true); return ret; } @@ -477,7 +478,7 @@ fn Kyber(comptime p: Params) type { // Private key of the inner PKE const InnerSk = struct { sh: V, // NTT(s), normalized - const bytes_length = V.bytes_length; + const encoded_length = V.encoded_length; fn decrypt(sk: InnerSk, ct: *const [ciphertext_length]u8) [inner_plaintext_length]u8 { const u = V.decompress(p.du, ct[0..comptime V.compressedSize(p.du)]); @@ -491,11 +492,11 @@ fn Kyber(comptime p: Params) type { .normalize().compress(1); } - fn toBytes(sk: InnerSk) [bytes_length]u8 { + fn toBytes(sk: InnerSk) [encoded_length]u8 { return sk.sh.toBytes(); } - fn fromBytes(buf: *const [bytes_length]u8) InnerSk { + fn fromBytes(buf: *const [encoded_length]u8) InnerSk { var ret: InnerSk = undefined; ret.sh = V.fromBytes(buf).normalize(); return ret; @@ -516,7 +517,7 @@ fn Kyber(comptime p: Params) type { // Sample secret vector s. sk.sh = V.noise(p.eta1, 0, sigma).ntt().normalize(); - const eh = Vec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e. + const eh = PolyVec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e. var th: V = undefined; // Next, we compute t = A s + e. @@ -528,7 +529,7 @@ fn Kyber(comptime p: Params) type { // multiplications in the inner product added a factor R⁻¹ which // we'll cancel out with toMont(). This will also ensure the // coefficients of th are bounded in absolute value by q. - th.ps[i] = pk.aT.vs[i].dotHat(sk.sh).toMont(); + th.ps[i] = pk.aT.rows[i].dotHat(sk.sh).toMont(); } pk.th = th.add(eh).normalize(); // bounded by 8q @@ -565,7 +566,6 @@ const zetas = computeZetas(); // not enough, the other coefficient is reduced as well. // // This is actually optimal, as proven in https://eprint.iacr.org/2020/1377.pdf -// TODO generate comptime? const inv_ntt_reductions = [_]i16{ -1, // after layer 1 -1, // after layer 2 @@ -634,31 +634,8 @@ test "invNTTReductions bounds" { } } -// Extended euclidean algorithm. -// -// For a, b finds x, y such that x a + y b = gcd(a, b). Used to compute -// modular inverse. -fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) { - if (a == 0) { - return .{ .gcd = b, .x = 0, .y = 1 }; - } - const r = eea(@rem(b, a), a); - return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x }; -} - -fn EeaResult(comptime T: type) type { - return struct { gcd: T, x: T, y: T }; -} - -// Returns least common multiple of a and b. -fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) { - const r = eea(a, b); - return a * b / r.gcd; -} - -// Invert modulo p. fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) { - const r = eea(a, p); + const r = extendedEuclidean(@TypeOf(a), a, p); assert(r.gcd == 1); return r.x; } @@ -788,31 +765,12 @@ test "Test csubq" { } } -// Compute a^s mod p. -fn mpow(a: anytype, s: @TypeOf(a), p: @TypeOf(a)) @TypeOf(a) { - var ret: @TypeOf(a) = 1; - var s2 = s; - var a2 = a; - - while (true) { - if (s2 & 1 == 1) { - ret = @mod(ret * a2, p); - } - s2 >>= 1; - if (s2 == 0) { - break; - } - a2 = @mod(a2 * a2, p); - } - return ret; -} - // Computes zetas table used by ntt and invNTT. fn computeZetas() [128]i16 { @setEvalBranchQuota(10000); var ret: [128]i16 = undefined; for (&ret, 0..) |*r, i| { - const t = @as(i16, @intCast(mpow(@as(i32, zeta), @bitReverse(@as(u7, @intCast(i))), Q))); + const t = @as(i16, @intCast(modularPow(i32, zeta, @bitReverse(@as(u7, @intCast(i))), Q))); r.* = csubq(feBarrettReduce(feToMont(t))); } return ret; @@ -828,9 +786,10 @@ fn computeZetas() [128]i16 { const Poly = struct { cs: [N]i16, - const bytes_length = N / 2 * 3; + const encoded_length = N / 2 * 3; const zero: Poly = .{ .cs = .{0} ** N }; + // Add two polynomials (coefficients not normalized) fn add(a: Poly, b: Poly) Poly { var ret: Poly = undefined; for (0..N) |i| { @@ -839,6 +798,7 @@ const Poly = struct { return ret; } + // Subtract two polynomials (coefficients not normalized) fn sub(a: Poly, b: Poly) Poly { var ret: Poly = undefined; for (0..N) |i| { @@ -847,25 +807,6 @@ const Poly = struct { return ret; } - // For testing, generates a random polynomial with for each - // coefficient |x| ≤ q. - fn randAbsLeqQ(rnd: anytype) Poly { - var ret: Poly = undefined; - for (0..N) |i| { - ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q); - } - return ret; - } - - // For testing, generates a random normalized polynomial. - fn randNormalized(rnd: anytype) Poly { - var ret: Poly = undefined; - for (0..N) |i| { - ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q); - } - return ret; - } - // Executes a forward "NTT" on p. // // Assumes the coefficients are in absolute value ≤q. The resulting @@ -1054,7 +995,7 @@ const Poly = struct { var in_off: usize = 0; var out_off: usize = 0; - const batch_size: usize = comptime lcm(@as(i16, d), 8); + const batch_size: usize = comptime math.lcm(d, 8); const in_batch_size: usize = comptime batch_size / d; const out_batch_size: usize = comptime batch_size / 8; @@ -1118,7 +1059,7 @@ const Poly = struct { var in_off: usize = 0; var out_off: usize = 0; - const batch_size: usize = comptime lcm(@as(i16, d), 8); + const batch_size: usize = comptime math.lcm(d, 8); const in_batch_size: usize = comptime batch_size / 8; const out_batch_size: usize = comptime batch_size / d; @@ -1275,53 +1216,23 @@ const Poly = struct { return ret; } - // Sample p uniformly from the given seed and x and y coordinates. fn uniform(seed: [32]u8, x: u8, y: u8) Poly { - var h = sha3.Shake128.init(.{}); - const suffix: [2]u8 = .{ x, y }; - h.update(&seed); - h.update(&suffix); - - const buf_len = sha3.Shake128.block_length; // rate SHAKE-128 - var buf: [buf_len]u8 = undefined; - - var ret: Poly = undefined; - var i: usize = 0; // index into ret.cs - outer: while (true) { - h.squeeze(&buf); - - var j: usize = 0; // index into buf - while (j < buf_len) : (j += 3) { - const b0 = @as(u16, buf[j]); - const b1 = @as(u16, buf[j + 1]); - const b2 = @as(u16, buf[j + 2]); - - const ts: [2]u16 = .{ - b0 | ((b1 & 0xf) << 8), - (b1 >> 4) | (b2 << 4), - }; - - inline for (ts) |t| { - if (t < Q) { - ret.cs[i] = @as(i16, @intCast(t)); - i += 1; - - if (i == N) { - break :outer; - } - } - } - } - } - - return ret; + const domain_sep: [2]u8 = .{ x, y }; + return sampleUniformRejection( + Poly, + Q, + 12, + N, + &seed, + &domain_sep, + ); } // Packs p. // // Assumes p is normalized (and not just Barrett reduced). - fn toBytes(p: Poly) [bytes_length]u8 { - var ret: [bytes_length]u8 = undefined; + fn toBytes(p: Poly) [encoded_length]u8 { + var ret: [encoded_length]u8 = undefined; for (0..comptime N / 2) |i| { const t0 = @as(u16, @intCast(p.cs[2 * i])); const t1 = @as(u16, @intCast(p.cs[2 * i + 1])); @@ -1335,7 +1246,7 @@ const Poly = struct { // Unpacks a Poly from buf. // // p will not be normalized; instead 0 ≤ p[i] < 4096. - fn fromBytes(buf: *const [bytes_length]u8) Poly { + fn fromBytes(buf: *const [encoded_length]u8) Poly { var ret: Poly = undefined; for (0..comptime N / 2) |i| { const b0 = @as(i16, buf[3 * i]); @@ -1348,71 +1259,65 @@ const Poly = struct { } }; -// A vector of K polynomials. -fn Vec(comptime K: u8) type { +// A vector of k polynomials. +fn PolyVec(comptime k: u8) type { return struct { - ps: [K]Poly, + ps: [k]Poly, const Self = @This(); - const bytes_length = K * Poly.bytes_length; + const encoded_length = k * Poly.encoded_length; fn compressedSize(comptime d: u8) usize { - return Poly.compressedSize(d) * K; + return Poly.compressedSize(d) * k; } - fn ntt(a: Self) Self { + /// Apply unary operation to each polynomial + fn map(v: Self, comptime op: fn (Poly) Poly) Self { var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].ntt(); + inline for (0..k) |i| { + ret.ps[i] = op(v.ps[i]); } return ret; } - fn invNTT(a: Self) Self { + /// Apply binary operation pairwise + fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self { var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].invNTT(); + inline for (0..k) |i| { + ret.ps[i] = op(a.ps[i], b.ps[i]); } return ret; } - fn normalize(a: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].normalize(); - } - return ret; + fn ntt(v: Self) Self { + return map(v, Poly.ntt); } - fn barrettReduce(a: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].barrettReduce(); - } - return ret; + fn invNTT(v: Self) Self { + return map(v, Poly.invNTT); + } + + fn normalize(v: Self) Self { + return map(v, Poly.normalize); + } + + fn barrettReduce(v: Self) Self { + return map(v, Poly.barrettReduce); } fn add(a: Self, b: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].add(b.ps[i]); - } - return ret; + return mapBinary(a, b, Poly.add); } fn sub(a: Self, b: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].sub(b.ps[i]); - } - return ret; + return mapBinary(a, b, Poly.sub); } // Samples v[i] from centered binomial distribution with the given η, // seed and nonce+i. fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Self { var ret: Self = undefined; - for (0..K) |i| { + for (0..k) |i| { ret.ps[i] = Poly.noise(eta, nonce + @as(u8, @intCast(i)), seed); } return ret; @@ -1428,7 +1333,7 @@ fn Vec(comptime K: u8) type { // of the Montgomery factor. fn dotHat(a: Self, b: Self) Poly { var ret: Poly = Poly.zero; - for (0..K) |i| { + for (0..k) |i| { ret = ret.add(a.ps[i].mulHat(b.ps[i])); } return ret; @@ -1437,7 +1342,7 @@ fn Vec(comptime K: u8) type { fn compress(v: Self, comptime d: u8) [compressedSize(d)]u8 { const cs = comptime Poly.compressedSize(d); var ret: [compressedSize(d)]u8 = undefined; - inline for (0..K) |i| { + inline for (0..k) |i| { ret[i * cs .. (i + 1) * cs].* = v.ps[i].compress(d); } return ret; @@ -1446,27 +1351,27 @@ fn Vec(comptime K: u8) type { fn decompress(comptime d: u8, buf: *const [compressedSize(d)]u8) Self { const cs = comptime Poly.compressedSize(d); var ret: Self = undefined; - inline for (0..K) |i| { + inline for (0..k) |i| { ret.ps[i] = Poly.decompress(d, buf[i * cs .. (i + 1) * cs]); } return ret; } /// Serializes the key into a byte array. - fn toBytes(v: Self) [bytes_length]u8 { - var ret: [bytes_length]u8 = undefined; - inline for (0..K) |i| { - ret[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length].* = v.ps[i].toBytes(); + fn toBytes(v: Self) [encoded_length]u8 { + var ret: [encoded_length]u8 = undefined; + inline for (0..k) |i| { + ret[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length].* = v.ps[i].toBytes(); } return ret; } /// Deserializes the key from a byte array. - fn fromBytes(buf: *const [bytes_length]u8) Self { + fn fromBytes(buf: *const [encoded_length]u8) Self { var ret: Self = undefined; - inline for (0..K) |i| { + inline for (0..k) |i| { ret.ps[i] = Poly.fromBytes( - buf[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length], + buf[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length], ); } return ret; @@ -1474,19 +1379,19 @@ fn Vec(comptime K: u8) type { }; } -// A matrix of K vectors -fn Mat(comptime K: u8) type { +// A matrix of k vectors +fn Mat(comptime k: u8) type { return struct { const Self = @This(); - vs: [K]Vec(K), + rows: [k]PolyVec(k), fn uniform(seed: [32]u8, comptime transposed: bool) Self { var ret: Self = undefined; var i: u8 = 0; - while (i < K) : (i += 1) { + while (i < k) : (i += 1) { var j: u8 = 0; - while (j < K) : (j += 1) { - ret.vs[i].ps[j] = Poly.uniform( + while (j < k) : (j += 1) { + ret.rows[i].ps[j] = Poly.uniform( seed, if (transposed) i else j, if (transposed) j else i, @@ -1499,9 +1404,9 @@ fn Mat(comptime K: u8) type { // Returns transpose of A fn transpose(m: Self) Self { var ret: Self = undefined; - for (0..K) |i| { - for (0..K) |j| { - ret.vs[i].ps[j] = m.vs[j].ps[i]; + for (0..k) |i| { + for (0..k) |j| { + ret.rows[i].ps[j] = m.rows[j].ps[i]; } } return ret; @@ -1522,12 +1427,30 @@ fn cmov(comptime len: usize, dst: *[len]u8, src: [len]u8, b: u1) void { } } +// Test helper: generates a random polynomial with each coefficient |x| ≤ q +fn randPolyAbsLeqQ(rnd: anytype) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q); + } + return ret; +} + +// Test helper: generates a random normalized polynomial +fn randPolyNormalized(rnd: anytype) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q); + } + return ret; +} + test "MulHat" { var rnd = RndGen.init(0); for (0..100) |_| { - const a = Poly.randAbsLeqQ(&rnd); - const b = Poly.randAbsLeqQ(&rnd); + const a = randPolyAbsLeqQ(&rnd); + const b = randPolyAbsLeqQ(&rnd); const p2 = a.ntt().mulHat(b.ntt()).barrettReduce().invNTT().normalize(); var p: Poly = undefined; @@ -1557,7 +1480,7 @@ test "NTT" { var rnd = RndGen.init(0); for (0..1000) |_| { - var p = Poly.randAbsLeqQ(&rnd); + var p = randPolyAbsLeqQ(&rnd); const q = p.toMont().normalize(); p = p.ntt(); @@ -1580,7 +1503,7 @@ test "Compression" { var rnd = RndGen.init(0); inline for (.{ 1, 4, 5, 10, 11 }) |d| { for (0..1000) |_| { - const p = Poly.randNormalized(&rnd); + const p = randPolyNormalized(&rnd); const pp = p.compress(d); const pq = Poly.decompress(d, &pp).compress(d); try testing.expectEqual(pp, pq); @@ -1671,7 +1594,7 @@ test "Polynomial packing" { var rnd = RndGen.init(0); for (0..1000) |_| { - const p = Poly.randNormalized(&rnd); + const p = randPolyNormalized(&rnd); try testing.expectEqual(Poly.fromBytes(&p.toBytes()), p); } } @@ -1839,3 +1762,222 @@ const NistDRBG = struct { return ret; } }; + +/// Extended Euclidian Algorithm +/// Only meant to be used on comptime values; correctness matters, performance doesn't. +fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } { + var a = a_; + var b = b_; + var x0: T = 1; + var x1: T = 0; + var y0: T = 0; + var y1: T = 1; + + while (b != 0) { + const q = @divTrunc(a, b); + const temp_a = a; + a = b; + b = temp_a - q * b; + + const temp_x = x0; + x0 = x1; + x1 = temp_x - q * x1; + + const temp_y = y0; + y0 = y1; + y1 = temp_y - q * y1; + } + + return .{ .gcd = a, .x = x0, .y = y0 }; +} + +/// Modular inversion: computes a^(-1) mod p +/// Requires gcd(a,p) = 1. The result is normalized to the range [0, p). +fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T { + // Use a signed type for EEA computation + const type_info = @typeInfo(T); + const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned) + std.meta.Int(.signed, type_info.int.bits) + else + T; + + const a_signed = @as(SignedT, @intCast(a)); + const p_signed = @as(SignedT, @intCast(p)); + + const r = extendedEuclidean(SignedT, a_signed, p_signed); + assert(r.gcd == 1); + + // Normalize result to [0, p) + var result = r.x; + while (result < 0) { + result += p_signed; + } + + return @intCast(result); +} + +/// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm. +fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T { + const type_info = @typeInfo(T); + const bits = type_info.int.bits; + const WideT = std.meta.Int(.unsigned, bits * 2); + + var ret: T = 1; + var base: T = a; + var exp = s; + + while (exp > 0) { + if (exp & 1 == 1) { + ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p); + } + base = @intCast((@as(WideT, base) * @as(WideT, base)) % p); + exp >>= 1; + } + + return ret; +} + +/// Creates an all-ones or all-zeros mask from a single bit value. +/// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0. +fn bitMask(comptime T: type, bit: T) T { + const type_info = @typeInfo(T); + if (type_info != .int or type_info.int.signedness != .unsigned) { + @compileError("bitMask requires an unsigned integer type"); + } + return -%bit; +} + +/// Creates a mask from the sign bit of a signed integer. +/// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0. +fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) { + const type_info = @typeInfo(T); + if (type_info != .int) { + @compileError("signMask requires an integer type"); + } + + const bits = type_info.int.bits; + const SignedT = std.meta.Int(.signed, bits); + + // Convert to signed if needed, arithmetic right shift to propagate sign bit + const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x); + const shifted = x_signed >> (bits - 1); + return @bitCast(shifted); +} + +test "bitMask and signMask helpers" { + try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0)); + try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1)); + try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0)); + try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1)); + try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0)); + try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1)); + + try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1)); + try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100)); + try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0)); + try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1)); + try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100)); + + try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set + try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear +} + +/// Montgomery reduction: for input x, returns y where y ≡ x*R^(-1) (mod q). +/// This is a generic implementation parameterized by the modulus q, its inverse qInv, +/// the Montgomery constant R, and the result bound. +/// +/// For ML-DSA: R = 2^32, returns y < 2q +/// For ML-KEM: R = 2^16, returns y in range (-q, q) +fn montgomeryReduce( + comptime InT: type, + comptime OutT: type, + comptime q: comptime_int, + comptime qInv: comptime_int, + comptime r_bits: comptime_int, + x: InT, +) OutT { + const mask = (@as(InT, 1) << r_bits) - 1; + const m_full = (x *% qInv) & mask; + const m: OutT = @truncate(m_full); + + const yR = x -% @as(InT, m) * @as(InT, q); + const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits; + return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted))); +} + +/// Uniform sampling using SHAKE-128 with rejection sampling. +/// Samples polynomial coefficients uniformly from [0, q) using rejection sampling. +/// +/// Parameters: +/// - PolyType: The polynomial type to return +/// - q: Modulus +/// - bits_per_coef: Number of bits per coefficient (12 or 23) +/// - n: Number of coefficients +/// - seed: Random seed +/// - domain_sep: Domain separation bytes (appended to seed) +fn sampleUniformRejection( + comptime PolyType: type, + comptime q: comptime_int, + comptime bits_per_coef: comptime_int, + comptime n: comptime_int, + seed: []const u8, + domain_sep: []const u8, +) PolyType { + var h = sha3.Shake128.init(.{}); + h.update(seed); + h.update(domain_sep); + + const buf_len = sha3.Shake128.block_length; // 168 bytes + var buf: [buf_len]u8 = undefined; + + var ret: PolyType = undefined; + var coef_idx: usize = 0; + + if (bits_per_coef == 12) { + // ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each) + outer: while (true) { + h.squeeze(&buf); + + var j: usize = 0; + while (j < buf_len) : (j += 3) { + const b0 = @as(u16, buf[j]); + const b1 = @as(u16, buf[j + 1]); + const b2 = @as(u16, buf[j + 2]); + + const ts: [2]u16 = .{ + b0 | ((b1 & 0xf) << 8), + (b1 >> 4) | (b2 << 4), + }; + + inline for (ts) |t| { + if (t < q) { + ret.cs[coef_idx] = @intCast(t); + coef_idx += 1; + if (coef_idx == n) break :outer; + } + } + } + } + } else if (bits_per_coef == 23) { + // ML-DSA path: 1 coefficient per 3 bytes (23 bits) + while (coef_idx < n) { + h.squeeze(&buf); + + var j: usize = 0; + while (j < buf_len and coef_idx < n) : (j += 3) { + const t = (@as(u32, buf[j]) | + (@as(u32, buf[j + 1]) << 8) | + (@as(u32, buf[j + 2]) << 16)) & 0x7fffff; + + if (t < q) { + ret.cs[coef_idx] = @intCast(t); + coef_idx += 1; + } + } + } + } else { + @compileError("bits_per_coef must be 12 or 23"); + } + + return ret; +}