//! Module-Lattice-Based Digital Signature Algorithm (ML-DSA) as specified in NIST FIPS 204. //! //! ML-DSA is a post-quantum secure digital signature scheme based on the hardness //! of the Module Learning With Errors (MLWE) and Module Short Integer Solution (MSIS) //! problems over module lattices. //! //! We provide three parameter sets: //! //! - ML-DSA-44: NIST security category 2 (128-bit security) //! - ML-DSA-65: NIST security category 3 (192-bit security) //! - ML-DSA-87: NIST security category 5 (256-bit security) const std = @import("std"); const builtin = @import("builtin"); const testing = std.testing; const assert = std.debug.assert; const crypto = std.crypto; const errors = std.crypto.errors; const math = std.math; const mem = std.mem; const sha3 = crypto.hash.sha3; const ContextTooLongError = errors.ContextTooLongError; const EncodingError = errors.EncodingError; const SignatureVerificationError = errors.SignatureVerificationError; /// ML-DSA-44 (Module-Lattice-Based Digital Signature Algorithm, 44 parameter set) /// as specified in NIST FIPS 204. /// /// This is a post-quantum signature scheme providing NIST security category 2, /// which is roughly equivalent to the security of SHA-256 or AES-128. /// /// Key sizes: /// /// - Public key: 1312 bytes /// - Secret key: 2560 bytes /// - Signature: 2420 bytes /// /// Example usage: /// /// ```zig /// const kp = MLDSA44.KeyPair.generate(); /// const msg = "Hello, post-quantum world!"; /// const sig = try kp.sign(msg, null); /// try sig.verify(msg, kp.public_key); /// ``` pub const MLDSA44 = MLDSAImpl(.{ .name = "ML-DSA-44", .k = 4, .l = 4, .eta = 2, .omega = 80, .tau = 39, .gamma1_bits = 17, .gamma2 = 95232, // (Q-1)/88 .tr_size = 64, .ctilde_size = 32, }); /// ML-DSA-65 (Module-Lattice-Based Digital Signature Algorithm, 65 parameter set) /// as specified in NIST FIPS 204. /// /// This is a post-quantum signature scheme providing NIST security category 3, /// which is roughly equivalent to the security of SHA-384 or AES-192. /// /// Key sizes: /// /// - Public key: 1952 bytes /// - Secret key: 4032 bytes /// - Signature: 3309 bytes /// /// This parameter set offers higher security than ML-DSA-44 at the cost of /// larger keys and signatures. pub const MLDSA65 = MLDSAImpl(.{ .name = "ML-DSA-65", .k = 6, .l = 5, .eta = 4, .omega = 55, .tau = 49, .gamma1_bits = 19, .gamma2 = 261888, // (Q-1)/32 .tr_size = 64, .ctilde_size = 48, }); /// ML-DSA-87 (Module-Lattice-Based Digital Signature Algorithm, 87 parameter set) /// as specified in NIST FIPS 204. /// /// This is a post-quantum signature scheme providing NIST security category 5, /// which is roughly equivalent to the security of SHA-512 or AES-256. /// /// Key sizes: /// /// - Public key: 2592 bytes /// - Secret key: 4896 bytes /// - Signature: 4627 bytes /// /// This parameter set offers the highest security level among the three ML-DSA /// variants, suitable for applications requiring maximum security assurance. pub const MLDSA87 = MLDSAImpl(.{ .name = "ML-DSA-87", .k = 8, .l = 7, .eta = 2, .omega = 75, .tau = 60, .gamma1_bits = 19, .gamma2 = 261888, // (Q-1)/32 .tr_size = 64, .ctilde_size = 64, }); const N: usize = 256; // Degree of polynomials const Q: u32 = 8380417; // Modulus: 2^23 - 2^13 + 1 const Q_BITS: u32 = 23; const D: u32 = 13; // Dropped bits in power2Round // Montgomery constant R = 2^32 mod q const R: u64 = 1 << 32; // Q^(-1) mod 2^32 = -(q^-1) mod 2^32 const Q_INV: u32 = 4236238847; // (256)^(-1) * R^2 mod q, used in inverse NTT const R_OVER_256: u32 = 41978; // Primitive 512th root of unity const ZETA: u32 = 1753; const Params = struct { name: []const u8, // Matrix dimensions k: u8, // Height of matrix A l: u8, // Width of matrix A // Sampling parameter eta: u8, // Bound for secret coefficients // Hint parameters omega: u16, // Maximum number of hint bits // Challenge parameter tau: u16, // Weight of challenge polynomial // Rounding parameters gamma1_bits: u8, // Bits for gamma1 gamma2: u32, // Parameter for decompose // Sizes tr_size: usize, // Size of tr hash ctilde_size: usize, // Size of challenge hash }; const Poly = struct { cs: [N]u32, const zero: Poly = .{ .cs = .{0} ** N }; // Add two polynomials (no normalization) fn add(a: Poly, b: Poly) Poly { var ret: Poly = undefined; for (0..N) |i| { ret.cs[i] = a.cs[i] + b.cs[i]; } return ret; } // Subtract two polynomials (assumes b coefficients < 2q) fn sub(a: Poly, b: Poly) Poly { var ret: Poly = undefined; for (0..N) |i| { ret.cs[i] = a.cs[i] +% (@as(u32, 2 * Q) -% b.cs[i]); } return ret; } // Reduce each coefficient to < 2q fn reduceLe2Q(p: Poly) Poly { var ret = p; for (0..N) |i| { ret.cs[i] = le2Q(ret.cs[i]); } return ret; } // Normalize coefficients to [0, q) fn normalize(p: Poly) Poly { var ret = p; for (0..N) |i| { ret.cs[i] = modQ(ret.cs[i]); } return ret; } // Normalize assuming coefficients already < 2q fn normalizeAssumingLe2Q(p: Poly) Poly { var ret = p; for (0..N) |i| { ret.cs[i] = le2qModQ(ret.cs[i]); } return ret; } // Pointwise multiplication in NTT domain (Montgomery form) fn mulHat(a: Poly, b: Poly) Poly { var ret: Poly = undefined; for (0..N) |i| { ret.cs[i] = montReduceLe2Q(@as(u64, a.cs[i]) * @as(u64, b.cs[i])); } return ret; } // Forward NTT fn ntt(p: Poly) Poly { var ret = p; ret.nttInPlace(); return ret; } // In-place forward NTT fn nttInPlace(p: *Poly) void { var k: usize = 0; var l: usize = N / 2; while (l > 0) : (l >>= 1) { var offset: usize = 0; while (offset < N - l) : (offset += 2 * l) { k += 1; const zeta: u64 = zetas[k]; for (offset..offset + l) |j| { const t = montReduceLe2Q(zeta * @as(u64, p.cs[j + l])); p.cs[j + l] = p.cs[j] +% (2 * Q -% t); p.cs[j] +%= t; } } } } // Inverse NTT fn invNTT(p: Poly) Poly { var ret = p; ret.invNTTInPlace(); return ret; } // In-place inverse NTT fn invNTTInPlace(p: *Poly) void { var k: usize = 0; var l: usize = 1; while (l < N) : (l <<= 1) { var offset: usize = 0; while (offset < N - l) : (offset += 2 * l) { const zeta: u64 = inv_zetas[k]; k += 1; for (offset..offset + l) |j| { const t = p.cs[j]; p.cs[j] = t +% p.cs[j + l]; p.cs[j + l] = montReduceLe2Q(zeta * @as(u64, t +% 256 * Q -% p.cs[j + l])); } } } for (0..N) |j| { p.cs[j] = montReduceLe2Q(@as(u64, R_OVER_256) * @as(u64, p.cs[j])); } } /// Apply Power2Round to all coefficients /// Returns both t0 and t1 polynomials fn power2RoundPoly(p: Poly) struct { t0: Poly, t1: Poly } { var t0 = Poly.zero; var t1 = Poly.zero; for (0..N) |i| { const result = power2Round(p.cs[i]); t0.cs[i] = result.a0_plus_q; t1.cs[i] = result.a1; } return .{ .t0 = t0, .t1 = t1 }; } // Check if infinity norm exceeds bound fn exceeds(p: Poly, bound: u32) bool { var result: u32 = 0; for (0..N) |i| { const x = @as(i32, @intCast((Q - 1) / 2)) - @as(i32, @intCast(p.cs[i])); const abs_x = x ^ (x >> 31); const norm = @as(i32, @intCast((Q - 1) / 2)) - abs_x; const exceeds_bit = @intFromBool(@as(u32, @intCast(norm)) >= bound); result |= exceeds_bit; } return result != 0; } }; fn PolyVec(comptime len: u8) type { return struct { ps: [len]Poly, const Self = @This(); const zero: Self = .{ .ps = .{Poly.zero} ** len }; /// Apply a unary operation to each polynomial in the vector fn map(v: Self, comptime op: fn (Poly) Poly) Self { var ret: Self = undefined; inline for (0..len) |i| { ret.ps[i] = op(v.ps[i]); } return ret; } /// Apply a binary operation pairwise to two vectors fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self { var ret: Self = undefined; inline for (0..len) |i| { ret.ps[i] = op(a.ps[i], b.ps[i]); } return ret; } /// Apply a binary operation between a vector and a scalar polynomial fn mapBinaryPoly(v: Self, scalar: Poly, comptime op: fn (Poly, Poly) Poly) Self { var ret: Self = undefined; inline for (0..len) |i| { ret.ps[i] = op(v.ps[i], scalar); } return ret; } fn add(a: Self, b: Self) Self { return mapBinary(a, b, Poly.add); } fn sub(a: Self, b: Self) Self { return mapBinary(a, b, Poly.sub); } fn ntt(v: Self) Self { return map(v, Poly.ntt); } fn invNTT(v: Self) Self { return map(v, Poly.invNTT); } fn normalize(v: Self) Self { return map(v, Poly.normalize); } fn reduceLe2Q(v: Self) Self { return map(v, Poly.reduceLe2Q); } fn normalizeAssumingLe2Q(v: Self) Self { return map(v, Poly.normalizeAssumingLe2Q); } // Check if any polynomial in the vector exceeds the bound fn exceeds(v: Self, bound: u32) bool { var result = false; for (0..len) |i| { result = result or v.ps[i].exceeds(bound); } return result; } /// Apply Power2Round to each polynomial in the vector /// Returns both t0 and t1 vectors fn power2Round(v: Self, t0_out: *Self) Self { var t1: Self = undefined; for (0..len) |i| { const result = v.ps[i].power2RoundPoly(); t0_out.ps[i] = result.t0; t1.ps[i] = result.t1; } return t1; } /// Generic packing function for vectors fn packWith( v: Self, buf: []u8, comptime poly_size: usize, comptime pack_fn: fn (Poly, []u8) void, ) void { inline for (0..len) |i| { const offset = i * poly_size; pack_fn(v.ps[i], buf[offset..][0..poly_size]); } } /// Generic unpacking function for vectors fn unpackWith( comptime poly_size: usize, comptime unpack_fn: fn ([]const u8) Poly, buf: []const u8, ) Self { var result: Self = undefined; inline for (0..len) |i| { const offset = i * poly_size; result.ps[i] = unpack_fn(buf[offset..][0..poly_size]); } return result; } /// Pack T1 vector to bytes fn packT1(v: Self, buf: []u8) void { const poly_size = (N * (Q_BITS - D)) / 8; packWith(v, buf, poly_size, polyPackT1); } /// Unpack T1 vector from bytes fn unpackT1(bytes: []const u8) Self { const poly_size = (N * (Q_BITS - D)) / 8; return unpackWith(poly_size, polyUnpackT1, bytes); } /// Pack T0 vector to bytes fn packT0(v: Self, buf: []u8) void { const poly_size = (N * D) / 8; packWith(v, buf, poly_size, polyPackT0); } /// Unpack T0 vector from bytes fn unpackT0(buf: []const u8) Self { const poly_size = (N * D) / 8; return unpackWith(poly_size, polyUnpackT0, buf); } /// Pack vector with coefficients in [-eta, eta] fn packLeqEta(v: Self, comptime eta: u8, buf: []u8) void { const poly_size = if (eta == 2) 96 else 128; const pack_fn = struct { fn pack(p: Poly, b: []u8) void { polyPackLeqEta(p, eta, b); } }.pack; packWith(v, buf, poly_size, pack_fn); } /// Unpack vector with coefficients in [-eta, eta] fn unpackLeqEta(comptime eta: u8, buf: []const u8) Self { const poly_size = if (eta == 2) 96 else 128; const unpack_fn = struct { fn unpack(b: []const u8) Poly { return polyUnpackLeqEta(eta, b); } }.unpack; return unpackWith(poly_size, unpack_fn, buf); } /// Pack vector of polynomials with coefficients < gamma1 fn packLeGamma1(v: Self, comptime gamma1_bits: u8, buf: []u8) void { const poly_size = ((gamma1_bits + 1) * N) / 8; const pack_fn = struct { fn pack(p: Poly, b: []u8) void { polyPackLeGamma1(p, gamma1_bits, b); } }.pack; packWith(v, buf, poly_size, pack_fn); } /// Unpack vector of polynomials with coefficients < gamma1 fn unpackLeGamma1(comptime gamma1_bits: u8, buf: []const u8) Self { const poly_size = ((gamma1_bits + 1) * N) / 8; const unpack_fn = struct { fn unpack(b: []const u8) Poly { return polyUnpackLeGamma1(gamma1_bits, b); } }.unpack; return unpackWith(poly_size, unpack_fn, buf); } /// Pack high bits w1 for signature verification fn packW1(v: Self, comptime gamma1_bits: u8, buf: []u8) void { const poly_size = (N * (Q_BITS - gamma1_bits)) / 8; const pack_fn = struct { fn pack(p: Poly, b: []u8) void { polyPackW1(p, gamma1_bits, b); } }.pack; packWith(v, buf, poly_size, pack_fn); } /// Decompose each polynomial in the vector into high and low bits fn decomposeVec(v: Self, comptime gamma2: u32, w0_out: *Self) Self { var w1: Self = undefined; for (0..len) |i| { for (0..N) |j| { const r = decompose(v.ps[i].cs[j], gamma2); w0_out.ps[i].cs[j] = r.a0_plus_q; w1.ps[i].cs[j] = r.a1; } } return w1; } /// Create hints for vector, returns hint population count fn makeHintVec(w0mcs2pct0: Self, w1: Self, comptime gamma2: u32) struct { hint: Self, pop: u32 } { var hint: Self = undefined; var pop: u32 = 0; for (0..len) |i| { const result = polyMakeHint(w0mcs2pct0.ps[i], w1.ps[i], gamma2); hint.ps[i] = result.hint; pop += result.count; } return .{ .hint = hint, .pop = pop }; } /// Apply hints to recover high bits fn useHint(v: Self, hint: Self, comptime gamma2: u32) Self { var result: Self = undefined; for (0..len) |i| { result.ps[i] = polyUseHint(v.ps[i], hint.ps[i], gamma2); } return result; } /// Multiply vector by 2^D (left shift) fn mulBy2toD(v: Self) Self { var result: Self = undefined; for (0..len) |i| { for (0..N) |j| { result.ps[i].cs[j] = v.ps[i].cs[j] << D; } } return result; } /// Sample vector with coefficients uniformly in (-gamma1, gamma1] /// Wraps expandMask (FIPS 204: ExpandMask) fn deriveUniformLeGamma1(comptime gamma1_bits: u8, seed: *const [64]u8, nonce: u16) Self { var result: Self = undefined; for (0..len) |i| { result.ps[i] = expandMask(gamma1_bits, seed, nonce + @as(u16, @intCast(i))); } return result; } /// Pack hints into bytes /// Format: for each polynomial, find positions where hint[i]=1, encode those positions fn packHint(v: Self, comptime omega: u16, buf: []u8) bool { var idx: usize = 0; var count: u32 = 0; for (0..len) |i| { for (0..N) |j| { if (v.ps[i].cs[j] != 0) { count += 1; } } } if (count > omega) { return false; } // Hint encoding format per FIPS 204: // First omega bytes: positions of set bits across all polynomials // Last len bytes: boundary indices showing where each polynomial's hints end for (0..len) |i| { for (0..N) |j| { if (v.ps[i].cs[j] != 0) { buf[idx] = @intCast(j); idx += 1; } } buf[omega + i] = @intCast(idx); } while (idx < omega) : (idx += 1) { buf[idx] = 0; } return true; } /// Unpack hints from bytes fn unpackHint(comptime omega: u16, buf: []const u8) ?Self { var result: Self = .{ .ps = .{Poly.zero} ** len }; var prev_sop: u8 = 0; // previous switch-over-point for (0..len) |i| { const sop = buf[omega + i]; // switch-over-point if (sop < prev_sop or sop > omega) { return null; // ensures switch-over-points are increasing } var j = prev_sop; while (j < sop) : (j += 1) { // Validation: indices must be strictly increasing within each polynomial if (j > prev_sop and buf[j] <= buf[j - 1]) { return null; } const pos = buf[j]; if (pos >= N) { return null; } result.ps[i].cs[pos] = 1; } prev_sop = sop; } var j = prev_sop; while (j < omega) : (j += 1) { if (buf[j] != 0) { return null; } } return result; } }; } // Matrix of k x l polynomials fn Mat(comptime k: u8, comptime l: u8) type { return struct { rows: [k]PolyVec(l), const Self = @This(); const VecL = PolyVec(l); const VecK = PolyVec(k); /// Expand matrix A from seed rho using SHAKE-128 /// This is the ExpandA function from FIPS 204 fn derive(rho: *const [32]u8) Self { var m: Self = undefined; for (0..k) |i| { if (i + 1 < k) { @prefetch(&m.rows[i + 1], .{ .rw = .write, .locality = 2 }); } for (0..l) |j| { // Nonce is i*256 + j const nonce: u16 = (@as(u16, @intCast(i)) << 8) | @as(u16, @intCast(j)); m.rows[i].ps[j] = polyDeriveUniform(rho, nonce); } } return m; } /// Multiply matrix by vector in NTT domain and return result in regular domain. /// Takes a vector in NTT form and returns the product in regular form. fn mulVec(self: Self, v_hat: VecL) VecK { var result = VecK.zero; for (0..k) |i| { result.ps[i] = dotHat(l, self.rows[i], v_hat); result.ps[i] = result.ps[i].reduceLe2Q(); result.ps[i] = result.ps[i].invNTT(); } return result; } /// Multiply matrix by vector in NTT domain and return result in NTT domain. /// Takes a vector in NTT form and returns the product in NTT form. fn mulVecHat(self: Self, v_hat: VecL) VecK { var result: VecK = undefined; for (0..k) |i| { result.ps[i] = dotHat(l, self.rows[i], v_hat); } return result; } }; } // Dot product in NTT domain fn dotHat(comptime len: u8, a: PolyVec(len), b: PolyVec(len)) Poly { var ret = Poly.zero; for (0..len) |i| { const prod = a.ps[i].mulHat(b.ps[i]); ret = ret.add(prod); } return ret; } // Modular arithmetic operations // Reduce x to [0, 2q) using the fact that 2^23 = 2^13 - 1 (mod q) fn le2Q(x: u32) u32 { // Write x = x1 * 2^23 + x2 with x2 < 2^23 and x1 < 2^9 // Then x = x2 + x1 * 2^13 - x1 (mod q) // and x2 + x1 * 2^13 - x1 <= 2^23 + 2^13 < 2q const x1 = x >> 23; const x2 = x & 0x7FFFFF; // 2^23 - 1 return x2 +% (x1 << 13) -% x1; } // Reduce x to [0, q) fn modQ(x: u32) u32 { return le2qModQ(le2Q(x)); } // Given x < 2q, reduce to [0, q) fn le2qModQ(x: u32) u32 { const r = x -% Q; const mask = signMask(u32, r); return r +% (mask & Q); } // Montgomery reduction: for x < q*2^32, return y < 2q where y ≡ x*R^(-1) (mod q) // where R = 2^32. This is used for efficient modular multiplication in NTT operations. fn montReduceLe2Q(x: u64) u32 { const m = (x *% Q_INV) & 0xffffffff; return @truncate((x +% m * @as(u64, Q)) >> 32); } // Precomputed zetas for NTT (Montgomery form) // zetas[i] = zeta^brv(i) * R mod q const zetas = computeZetas(); fn computeZetas() [N]u32 { @setEvalBranchQuota(100000); var ret: [N]u32 = undefined; for (0..N) |i| { const brv_i = @bitReverse(@as(u8, @intCast(i))); const power = modularPow(u32, ZETA, brv_i, Q); ret[i] = toMont(power); } return ret; } // Precomputed inverse zetas for inverse NTT const inv_zetas = computeInvZetas(); fn computeInvZetas() [N]u32 { @setEvalBranchQuota(100000); var ret: [N]u32 = undefined; const inv_zeta = modularInverse(u32, ZETA, Q); for (0..N) |i| { const idx = 255 - i; const brv_idx = @bitReverse(@as(u8, @intCast(idx))); // Exponent is -(brv_idx - 256) = 256 - brv_idx const exp: u32 = @as(u32, 256) - brv_idx; // Compute inv_zeta^exp const power = modularPow(u32, inv_zeta, exp, Q); // Convert to Montgomery form ret[i] = toMont(power); } return ret; } // Convert to Montgomery form: x -> x * R mod q fn toMont(x: u32) u32 { // R = 2^32, R mod q can be computed as: // 2^32 mod q = 2^32 mod (2^23 - 2^13 + 1) // Using the identity 2^23 = 2^13 - 1 (mod q), we can reduce 2^32 // But it's easier to just do: return montReduce(x * R^2 mod q) // where R^2 mod q is precomputed // Computing R^2 mod q: // R = 2^32, so R^2 = 2^64 // We can compute this by noting that R mod q first: // 2^32 = 2^32 mod q // But let's use a simpler approach: multiply x by R in the Montgomery domain // Actually, the simplest is: x * R mod q = montReduceLe2Q(x * R^2 mod q) // Precompute R^2 mod q at comptime const r_mod_q = comptime blk: { // 2^32 mod q - compute by successive squaring var r: u64 = 1; for (0..32) |_| { r = (r * 2) % Q; } break :blk @as(u32, @intCast(r)); }; const r2_mod_q = comptime blk: { const r = @as(u64, r_mod_q); break :blk @as(u32, @intCast((r * r) % Q)); }; return montReduceLe2Q(@as(u64, x) * @as(u64, r2_mod_q)); } /// Splits 0 ≤ a < Q into a0 and a1 with a = a1*2^D + a0 /// and -2^(D-1) < a0 ≤ 2^(D-1). Returns a0 + Q and a1. /// FIPS 204: Power2Round (Algorithm 19) fn power2Round(a: u32) struct { a0_plus_q: u32, a1: u32 } { // We effectively compute a0 = a mod± 2^D // and a1 = (a - a0) / 2^D var a0 = a & ((1 << D) - 1); // a mod 2^D // a0 is one of 0, 1, ..., 2^(D-1)-1, 2^(D-1), 2^(D-1)+1, ..., 2^D-1 a0 -%= (1 << (D - 1)) + 1; // now a0 is -2^(D-1)-1, -2^(D-1), ..., -2, -1, 0, ..., 2^(D-1)-2 // Next, add 2^D to those a0 that are negative (seen as i32) a0 +%= @as(u32, @bitCast(@as(i32, @bitCast(a0)) >> 31)) & (1 << D); // now a0 is 2^(D-1)-1, 2^(D-1), ..., 2^D-2, 2^D-1, 0, ..., 2^(D-1)-2 a0 -%= (1 << (D - 1)) - 1; // now a0 is 0, 1, 2, ..., 2^(D-1)-1, 2^(D-1), -2^(D-1)+1, ..., -1 const a0_plus_q = Q +% a0; const a1 = (a -% a0) >> D; return .{ .a0_plus_q = a0_plus_q, .a1 = a1 }; } /// Splits 0 ≤ a < q into a0 and a1 with a = a1*alpha + a0 with -alpha/2 < a0 ≤ alpha/2, /// except when we would have a1 = (q-1)/alpha in which case a1=0 is taken /// and -alpha/2 ≤ a0 < 0. Returns a0 + q. Note 0 ≤ a1 < (q-1)/alpha. /// Recall alpha = 2*gamma2. fn decompose(a: u32, comptime gamma2: u32) struct { a0_plus_q: u32, a1: u32 } { const alpha = 2 * gamma2; // a1 = ⌈a / 128⌉ var a1 = (a + 127) >> 7; if (alpha == 523776) { // For ML-DSA-87: gamma2 = 261888, alpha = 523776 // 1025/2^22 is close enough to 1/4092 so that a1 becomes a/alpha rounded down a1 = ((a1 * 1025 + (1 << 21)) >> 22); // For the corner-case a1 = (q-1)/alpha = 16, we have to set a1=0 a1 &= 15; } else if (alpha == 190464) { // For ML-DSA-65: gamma2 = 95232, alpha = 190464 // 11275/2^24 is close enough to 1/1488 so that a1 becomes a/alpha rounded down a1 = ((a1 * 11275) + (1 << 23)) >> 24; // For the corner-case a1 = (q-1)/alpha = 44, we have to set a1=0 a1 ^= @as(u32, @bitCast(@as(i32, @bitCast(43 -% a1)) >> 31)) & a1; } else { @compileError("unsupported gamma2/alpha value"); } var a0_plus_q = a -% a1 * alpha; // In the corner-case, when we set a1=0, we will incorrectly // have a0 > (q-1)/2 and we'll need to subtract q. As we // return a0 + q, that comes down to adding q if a0 < (q-1)/2. a0_plus_q +%= @as(u32, @bitCast(@as(i32, @bitCast(a0_plus_q -% (Q - 1) / 2)) >> 31)) & Q; return .{ .a0_plus_q = a0_plus_q, .a1 = a1 }; } /// Creates a hint bit to help recover high bits after a small perturbation. /// Given: /// - z0: the modified low bits (r0 - f mod Q) where f is small /// - r1: the original high bits /// Returns 1 if a hint is needed, 0 otherwise. /// /// This implements makeHint from FIPS 204. The hint helps recover r1 from /// r' = r - f without knowing f explicitly. fn makeHint(z0: u32, r1: u32, comptime gamma2: u32) u32 { // If -alpha/2 < r0 - f <= alpha/2, then r1*alpha + r0 - f is a valid // decomposition of r' with the restrictions of decompose() and so r'1 = r1. // So the hint should be 0. This is covered by the first two inequalities. // There is one other case: if r0 - f = -alpha/2, then r1*alpha + r0 - f is // also a valid decomposition if r1 = 0. In the other cases a one is carried // and the hint should be 1. const cond1 = @intFromBool(z0 <= gamma2); const cond2 = @intFromBool(z0 > Q - gamma2); const eq_gamma2 = @intFromBool(z0 == Q - gamma2); const r1_is_zero = @intFromBool(r1 == 0); const cond3 = eq_gamma2 & r1_is_zero; return 1 - (cond1 | cond2 | cond3); } /// Uses a hint to reconstruct high bits from a perturbed value. /// Given: /// - rp: the perturbed value (r' = r - f) /// - hint: the hint bit from makeHint /// Returns the reconstructed high bits r1. /// /// This implements useHint from FIPS 204. fn useHint(rp: u32, hint: u32, comptime gamma2: u32) u32 { const decomp = decompose(rp, gamma2); const rp0_plus_q = decomp.a0_plus_q; var rp1 = decomp.a1; if (hint == 0) { return rp1; } // Depending on gamma2, handle the adjustment differently if (gamma2 == 261888) { // ML-DSA-65 and ML-DSA-87: max r1 is 15 if (rp0_plus_q > Q) { rp1 = (rp1 + 1) & 15; } else { rp1 = (rp1 -% 1) & 15; } } else if (gamma2 == 95232) { // ML-DSA-44: max r1 is 43 if (rp0_plus_q > Q) { if (rp1 == 43) { rp1 = 0; } else { rp1 += 1; } } else { if (rp1 == 0) { rp1 = 43; } else { rp1 -= 1; } } } else { @compileError("unsupported gamma2 value"); } return rp1; } /// Creates a hint polynomial for the difference between perturbed and original high bits. /// Returns the number of hint bits set to 1 (the population count). /// /// This is used during signature generation to create hints that help verification /// recover the high bits without access to the secret. fn polyMakeHint(p0: Poly, p1: Poly, comptime gamma2: u32) struct { hint: Poly, count: u32 } { var hint = Poly.zero; var count: u32 = 0; for (0..N) |i| { const h = makeHint(p0.cs[i], p1.cs[i], gamma2); hint.cs[i] = h; count += h; } return .{ .hint = hint, .count = count }; } /// Applies hints to reconstruct high bits from a perturbed polynomial. /// /// This is used during signature verification to recover the high bits /// using the hints provided in the signature. fn polyUseHint(q: Poly, hint: Poly, comptime gamma2: u32) Poly { var result = Poly.zero; for (0..N) |i| { result.cs[i] = useHint(q.cs[i], hint.cs[i], gamma2); } return result; } /// Pack polynomial with coefficients in [Q-eta, Q+eta] into bytes. /// For eta=2: packs coefficients into 3 bits each (96 bytes total) /// For eta=4: packs coefficients into 4 bits each (128 bytes total) /// Assumes coefficients are not normalized, but in [q-η, q+η]. fn polyPackLeqEta(p: Poly, comptime eta: u8, buf: []u8) void { comptime { if (eta != 2 and eta != 4) { @compileError("eta must be 2 or 4"); } } if (eta == 2) { // 3 bits per coefficient: pack 8 coefficients into 3 bytes var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 3) { const c0 = Q + eta - p.cs[j]; const c1 = Q + eta - p.cs[j + 1]; const c2 = Q + eta - p.cs[j + 2]; const c3 = Q + eta - p.cs[j + 3]; const c4 = Q + eta - p.cs[j + 4]; const c5 = Q + eta - p.cs[j + 5]; const c6 = Q + eta - p.cs[j + 6]; const c7 = Q + eta - p.cs[j + 7]; buf[i] = @truncate(c0 | (c1 << 3) | (c2 << 6)); buf[i + 1] = @truncate((c2 >> 2) | (c3 << 1) | (c4 << 4) | (c5 << 7)); buf[i + 2] = @truncate((c5 >> 1) | (c6 << 2) | (c7 << 5)); j += 8; } } else { // eta == 4 // 4 bits per coefficient: pack 2 coefficients into 1 byte var j: usize = 0; for (0..buf.len) |i| { const c0 = Q + eta - p.cs[j]; const c1 = Q + eta - p.cs[j + 1]; buf[i] = @truncate(c0 | (c1 << 4)); j += 2; } } } /// Unpack polynomial with coefficients in [Q-eta, Q+eta] from bytes. /// Output coefficients will not be normalized, but in [q-η, q+η]. fn polyUnpackLeqEta(comptime eta: u8, buf: []const u8) Poly { comptime { if (eta != 2 and eta != 4) { @compileError("eta must be 2 or 4"); } } var p = Poly.zero; if (eta == 2) { // 3 bits per coefficient: unpack 8 coefficients from 3 bytes var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 3) { p.cs[j] = Q + eta - (buf[i] & 7); p.cs[j + 1] = Q + eta - ((buf[i] >> 3) & 7); p.cs[j + 2] = Q + eta - ((buf[i] >> 6) | ((buf[i + 1] << 2) & 7)); p.cs[j + 3] = Q + eta - ((buf[i + 1] >> 1) & 7); p.cs[j + 4] = Q + eta - ((buf[i + 1] >> 4) & 7); p.cs[j + 5] = Q + eta - ((buf[i + 1] >> 7) | ((buf[i + 2] << 1) & 7)); p.cs[j + 6] = Q + eta - ((buf[i + 2] >> 2) & 7); p.cs[j + 7] = Q + eta - ((buf[i + 2] >> 5) & 7); j += 8; } } else { // eta == 4 // 4 bits per coefficient: unpack 2 coefficients from 1 byte var j: usize = 0; for (0..buf.len) |i| { p.cs[j] = Q + eta - (buf[i] & 15); p.cs[j + 1] = Q + eta - (buf[i] >> 4); j += 2; } } return p; } /// Pack polynomial with coefficients < 1024 (T1) into bytes. /// Packs 10 bits per coefficient: 4 coefficients into 5 bytes. /// Assumes coefficients are normalized. fn polyPackT1(p: Poly, buf: []u8) void { var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 5) { buf[i] = @truncate(p.cs[j]); buf[i + 1] = @truncate((p.cs[j] >> 8) | (p.cs[j + 1] << 2)); buf[i + 2] = @truncate((p.cs[j + 1] >> 6) | (p.cs[j + 2] << 4)); buf[i + 3] = @truncate((p.cs[j + 2] >> 4) | (p.cs[j + 3] << 6)); buf[i + 4] = @truncate(p.cs[j + 3] >> 2); j += 4; } } /// Unpack polynomial with coefficients < 1024 (T1) from bytes. /// Output coefficients will be normalized. fn polyUnpackT1(buf: []const u8) Poly { var p = Poly.zero; var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 5) { p.cs[j] = (@as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8)) & 0x3ff; p.cs[j + 1] = ((@as(u32, buf[i + 1]) >> 2) | (@as(u32, buf[i + 2]) << 6)) & 0x3ff; p.cs[j + 2] = ((@as(u32, buf[i + 2]) >> 4) | (@as(u32, buf[i + 3]) << 4)) & 0x3ff; p.cs[j + 3] = ((@as(u32, buf[i + 3]) >> 6) | (@as(u32, buf[i + 4]) << 2)) & 0x3ff; j += 4; } return p; } /// Pack polynomial with coefficients in (-2^(D-1), 2^(D-1)] (T0) into bytes. /// Packs 13 bits per coefficient: 8 coefficients into 13 bytes. /// Assumes coefficients are not normalized, but in (q-2^(D-1), q+2^(D-1)]. fn polyPackT0(p: Poly, buf: []u8) void { const bound = 1 << (D - 1); var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 13) { const p0 = Q + bound - p.cs[j]; const p1 = Q + bound - p.cs[j + 1]; const p2 = Q + bound - p.cs[j + 2]; const p3 = Q + bound - p.cs[j + 3]; const p4 = Q + bound - p.cs[j + 4]; const p5 = Q + bound - p.cs[j + 5]; const p6 = Q + bound - p.cs[j + 6]; const p7 = Q + bound - p.cs[j + 7]; buf[i] = @truncate(p0 >> 0); buf[i + 1] = @truncate((p0 >> 8) | (p1 << 5)); buf[i + 2] = @truncate(p1 >> 3); buf[i + 3] = @truncate((p1 >> 11) | (p2 << 2)); buf[i + 4] = @truncate((p2 >> 6) | (p3 << 7)); buf[i + 5] = @truncate(p3 >> 1); buf[i + 6] = @truncate((p3 >> 9) | (p4 << 4)); buf[i + 7] = @truncate(p4 >> 4); buf[i + 8] = @truncate((p4 >> 12) | (p5 << 1)); buf[i + 9] = @truncate((p5 >> 7) | (p6 << 6)); buf[i + 10] = @truncate(p6 >> 2); buf[i + 11] = @truncate((p6 >> 10) | (p7 << 3)); buf[i + 12] = @truncate(p7 >> 5); j += 8; } } /// Unpack polynomial with coefficients in (-2^(D-1), 2^(D-1)] (T0) from bytes. /// Output coefficients will not be normalized, but in (-2^(D-1), 2^(D-1)]. fn polyUnpackT0(buf: []const u8) Poly { const bound = 1 << (D - 1); var p = Poly.zero; var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 13) { p.cs[j] = Q + bound - ((@as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8)) & 0x1fff); p.cs[j + 1] = Q + bound - (((@as(u32, buf[i + 1]) >> 5) | (@as(u32, buf[i + 2]) << 3) | (@as(u32, buf[i + 3]) << 11)) & 0x1fff); p.cs[j + 2] = Q + bound - (((@as(u32, buf[i + 3]) >> 2) | (@as(u32, buf[i + 4]) << 6)) & 0x1fff); p.cs[j + 3] = Q + bound - (((@as(u32, buf[i + 4]) >> 7) | (@as(u32, buf[i + 5]) << 1) | (@as(u32, buf[i + 6]) << 9)) & 0x1fff); p.cs[j + 4] = Q + bound - (((@as(u32, buf[i + 6]) >> 4) | (@as(u32, buf[i + 7]) << 4) | (@as(u32, buf[i + 8]) << 12)) & 0x1fff); p.cs[j + 5] = Q + bound - (((@as(u32, buf[i + 8]) >> 1) | (@as(u32, buf[i + 9]) << 7)) & 0x1fff); p.cs[j + 6] = Q + bound - (((@as(u32, buf[i + 9]) >> 6) | (@as(u32, buf[i + 10]) << 2) | (@as(u32, buf[i + 11]) << 10)) & 0x1fff); p.cs[j + 7] = Q + bound - ((@as(u32, buf[i + 11]) >> 3) | (@as(u32, buf[i + 12]) << 5)); j += 8; } return p; } /// Convert coefficient from centered representation to non-negative. /// Transforms value from [0,γ₁] ∪ (Q-γ₁, Q) to [0, 2γ₁). fn centeredToPositive(val: u32, comptime gamma1: u32) u32 { var result = gamma1 -% val; result +%= (signMask(u32, result) & Q); return result; } /// Pack polynomial with coefficients in (-gamma1, gamma1] into bytes. /// For gamma1_bits=17: packs 18 bits per coefficient (4 coefficients into 9 bytes) /// For gamma1_bits=19: packs 20 bits per coefficient (2 coefficients into 5 bytes) /// Assumes coefficients are normalized. fn polyPackLeGamma1(p: Poly, comptime gamma1_bits: u8, buf: []u8) void { const gamma1: u32 = @as(u32, 1) << gamma1_bits; if (gamma1_bits == 17) { // Pack 4 coefficients into 9 bytes (18 bits each) var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 9) { // Convert from [0,γ₁] ∪ (Q-γ₁, Q) to [0, 2γ₁) const p0 = centeredToPositive(p.cs[j], gamma1); const p1 = centeredToPositive(p.cs[j + 1], gamma1); const p2 = centeredToPositive(p.cs[j + 2], gamma1); const p3 = centeredToPositive(p.cs[j + 3], gamma1); buf[i] = @truncate(p0); buf[i + 1] = @truncate(p0 >> 8); buf[i + 2] = @truncate((p0 >> 16) | (p1 << 2)); buf[i + 3] = @truncate(p1 >> 6); buf[i + 4] = @truncate((p1 >> 14) | (p2 << 4)); buf[i + 5] = @truncate(p2 >> 4); buf[i + 6] = @truncate((p2 >> 12) | (p3 << 6)); buf[i + 7] = @truncate(p3 >> 2); buf[i + 8] = @truncate(p3 >> 10); j += 4; } } else if (gamma1_bits == 19) { // Pack 2 coefficients into 5 bytes (20 bits each) var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 5) { const p0 = centeredToPositive(p.cs[j], gamma1); const p1 = centeredToPositive(p.cs[j + 1], gamma1); buf[i] = @truncate(p0); buf[i + 1] = @truncate(p0 >> 8); buf[i + 2] = @truncate((p0 >> 16) | (p1 << 4)); buf[i + 3] = @truncate(p1 >> 4); buf[i + 4] = @truncate(p1 >> 12); j += 2; } } else { @compileError("gamma1_bits must be 17 or 19"); } } /// Unpack polynomial with coefficients in (-gamma1, gamma1] from bytes. /// Output coefficients will be normalized. fn polyUnpackLeGamma1(comptime gamma1_bits: u8, buf: []const u8) Poly { const gamma1: u32 = @as(u32, 1) << gamma1_bits; var p = Poly.zero; if (gamma1_bits == 17) { // Unpack 4 coefficients from 9 bytes (18 bits each) var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 9) { var p0 = @as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8) | ((@as(u32, buf[i + 2]) & 0x3) << 16); var p1 = (@as(u32, buf[i + 2]) >> 2) | (@as(u32, buf[i + 3]) << 6) | ((@as(u32, buf[i + 4]) & 0xf) << 14); var p2 = (@as(u32, buf[i + 4]) >> 4) | (@as(u32, buf[i + 5]) << 4) | ((@as(u32, buf[i + 6]) & 0x3f) << 12); var p3 = (@as(u32, buf[i + 6]) >> 6) | (@as(u32, buf[i + 7]) << 2) | (@as(u32, buf[i + 8]) << 10); // Convert from [0, 2γ₁) to (-γ₁, γ₁] p0 = centeredToPositive(p0, gamma1); p1 = centeredToPositive(p1, gamma1); p2 = centeredToPositive(p2, gamma1); p3 = centeredToPositive(p3, gamma1); p.cs[j] = p0; p.cs[j + 1] = p1; p.cs[j + 2] = p2; p.cs[j + 3] = p3; j += 4; } } else if (gamma1_bits == 19) { // Unpack 2 coefficients from 5 bytes (20 bits each) var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 5) { var p0 = @as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8) | ((@as(u32, buf[i + 2]) & 0xf) << 16); var p1 = (@as(u32, buf[i + 2]) >> 4) | (@as(u32, buf[i + 3]) << 4) | (@as(u32, buf[i + 4]) << 12); p0 = centeredToPositive(p0, gamma1); p1 = centeredToPositive(p1, gamma1); p.cs[j] = p0; p.cs[j + 1] = p1; j += 2; } } else { @compileError("gamma1_bits must be 17 or 19"); } return p; } /// Pack W1 polynomial for verification. /// For gamma1_bits=17: packs 6 bits per coefficient (4 coefficients into 3 bytes) /// For gamma1_bits=19: packs 4 bits per coefficient (2 coefficients into 1 byte) /// Assumes coefficients are normalized. fn polyPackW1(p: Poly, comptime gamma1_bits: u8, buf: []u8) void { if (gamma1_bits == 17) { // Pack 4 coefficients into 3 bytes (6 bits each) var j: usize = 0; var i: usize = 0; while (i < buf.len) : (i += 3) { buf[i] = @truncate(p.cs[j] | (p.cs[j + 1] << 6)); buf[i + 1] = @truncate((p.cs[j + 1] >> 2) | (p.cs[j + 2] << 4)); buf[i + 2] = @truncate((p.cs[j + 2] >> 4) | (p.cs[j + 3] << 2)); j += 4; } } else if (gamma1_bits == 19) { // Pack 2 coefficients into 1 byte (4 bits each) - equivalent to packLe16 var j: usize = 0; for (0..buf.len) |i| { buf[i] = @truncate(p.cs[j] | (p.cs[j + 1] << 4)); j += 2; } } else { @compileError("gamma1_bits must be 17 or 19"); } } fn polyDeriveUniform(seed: *const [32]u8, nonce: u16) Poly { var domain_sep: [2]u8 = undefined; domain_sep[0] = @truncate(nonce); domain_sep[1] = @truncate(nonce >> 8); return sampleUniformRejection( Poly, Q, 23, N, seed, &domain_sep, ); } /// Sample p uniformly with coefficients of norm less than or equal to η, /// using the given seed and nonce with SHAKE-256. /// The polynomial will not be normalized, but will have coefficients in [q-η, q+η]. /// FIPS 204: ExpandS (Algorithm 27) fn expandS(comptime eta: u8, seed: *const [64]u8, nonce: u16) Poly { comptime { if (eta != 2 and eta != 4) { @compileError("eta must be 2 or 4"); } } var p = Poly.zero; var i: usize = 0; var buf: [sha3.Shake256.block_length]u8 = undefined; // SHAKE-256 rate is 136 bytes // Prepare input: seed || nonce (little-endian u16) var input: [66]u8 = undefined; @memcpy(input[0..64], seed); input[64] = @truncate(nonce); input[65] = @truncate(nonce >> 8); var h = sha3.Shake256.init(.{}); h.update(&input); while (i < N) { h.squeeze(&buf); // Process buffer: extract two samples per byte (4-bit nibbles) var j: usize = 0; while (j < buf.len and i < N) : (j += 1) { var t1 = @as(u32, buf[j]) & 15; var t2 = @as(u32, buf[j]) >> 4; if (eta == 2) { // For eta=2: reject if t > 14, then reduce mod 5 if (t1 <= 14) { t1 -%= ((205 * t1) >> 10) * 5; // reduce mod 5 p.cs[i] = Q + eta - t1; i += 1; } if (t2 <= 14 and i < N) { t2 -%= ((205 * t2) >> 10) * 5; // reduce mod 5 p.cs[i] = Q + eta - t2; i += 1; } } else if (eta == 4) { // For eta=4: accept if t <= 2*eta = 8 if (t1 <= 2 * eta) { p.cs[i] = Q + eta - t1; i += 1; } if (t2 <= 2 * eta and i < N) { p.cs[i] = Q + eta - t2; i += 1; } } } } return p; } /// Sample p uniformly with τ non-zero coefficients in {Q-1, 1} using SHAKE-256. /// This creates a "ball" polynomial with exactly tau non-zero ±1 coefficients. /// The polynomial will be normalized with coefficients in {0, 1, Q-1}. /// FIPS 204: SampleInBall (Algorithm 18) fn sampleInBall(comptime tau: u16, seed: []const u8) Poly { var p = Poly.zero; var buf: [sha3.Shake256.block_length]u8 = undefined; // SHAKE-256 rate is 136 bytes var h = sha3.Shake256.init(.{}); h.update(seed); h.squeeze(&buf); // Extract signs from first 8 bytes var signs: u64 = 0; for (0..8) |j| { signs |= @as(u64, buf[j]) << @intCast(j * 8); } var buf_off: usize = 8; // Generate tau non-zero coefficients using Fisher-Yates shuffle // Start with N-tau zeros, then add tau ±1 values var i: u16 = N - tau; while (i < N) : (i += 1) { var b: u16 = undefined; // Find location using rejection sampling while (true) { if (buf_off >= buf.len) { h.squeeze(&buf); buf_off = 0; } b = buf[buf_off]; buf_off += 1; if (b <= i) { break; } } // Shuffle: move existing value to position i p.cs[i] = p.cs[b]; // Set position b to ±1 based on sign bit p.cs[b] = 1; const sign_bit: u1 = @truncate(signs); const mask = bitMask(u32, sign_bit); p.cs[b] ^= mask & (1 | (Q - 1)); signs >>= 1; } return p; } /// Sample a polynomial with coefficients uniformly distributed in (-gamma1, gamma1] /// Used for sampling the masking vector y during signing /// FIPS 204: ExpandMask (Algorithm 28) fn expandMask(comptime gamma1_bits: u8, seed: *const [64]u8, nonce: u16) Poly { const packed_size = ((gamma1_bits + 1) * N) / 8; var buf: [packed_size]u8 = undefined; // Construct IV: seed || nonce (little-endian) var iv: [66]u8 = undefined; @memcpy(iv[0..64], seed); iv[64] = @truncate(nonce & 0xFF); iv[65] = @truncate(nonce >> 8); var h = sha3.Shake256.init(.{}); h.update(&iv); h.squeeze(&buf); // Unpack the polynomial return polyUnpackLeGamma1(gamma1_bits, &buf); } fn MLDSAImpl(comptime p: Params) type { return struct { pub const params = p; pub const name = p.name; pub const gamma1: u32 = @as(u32, 1) << p.gamma1_bits; pub const beta: u32 = p.tau * p.eta; pub const alpha: u32 = 2 * p.gamma2; const Self = @This(); const PolyVecL = PolyVec(p.l); const PolyVecK = PolyVec(p.k); const MatKxL = Mat(p.k, p.l); /// Length of the seed used for deterministic key generation (32 bytes). pub const seed_length: usize = 32; /// Length (in bytes) of optional random bytes, for non-deterministic signatures. pub const noise_length = 32; /// Size of an encoded public key in bytes. pub const public_key_bytes: usize = 32 + polyT1PackedSize() * p.k; /// Size of an encoded secret key in bytes. pub const private_key_bytes: usize = 32 + 32 + p.tr_size + polyLeqEtaPackedSize() * (p.l + p.k) + polyT0PackedSize() * p.k; /// Size of an encoded signature in bytes. pub const signature_bytes: usize = p.ctilde_size + polyLeGamma1PackedSize() * p.l + p.omega + p.k; // Packed sizes for different polynomial representations fn polyLeqEtaPackedSize() usize { // For eta=2: 3 bits per coefficient (values in [0,4]) // For eta=4: 4 bits per coefficient (values in [0,8]) const double_eta_bits = if (p.eta == 2) 3 else 4; return (N * double_eta_bits) / 8; } fn polyLeGamma1PackedSize() usize { return ((p.gamma1_bits + 1) * N) / 8; } fn polyT1PackedSize() usize { return (N * (Q_BITS - D)) / 8; } fn polyT0PackedSize() usize { return (N * D) / 8; } fn polyW1PackedSize() usize { return (N * (Q_BITS - p.gamma1_bits)) / 8; } /// Helper function to compute CRH (Collision Resistant Hash) using SHAKE-256. /// This consolidates the repeated pattern of init-update-squeeze for hash operations. fn crh(comptime outsize: usize, inputs: anytype) [outsize]u8 { var h = sha3.Shake256.init(.{}); inline for (inputs) |input| { h.update(input); } var out: [outsize]u8 = undefined; h.squeeze(&out); return out; } /// Helper function to compute t = As1 + s2. /// This is used during key generation and public key reconstruction. fn computeT(A: MatKxL, s1_hat: PolyVecL, s2: PolyVecK) PolyVecK { const t = A.mulVec(s1_hat).add(s2); return t.normalize(); } /// ML-DSA public key pub const PublicKey = struct { /// Size of the encoded public key in bytes pub const encoded_length: usize = 32 + polyT1PackedSize() * p.k; rho: [32]u8, // Seed for matrix A t1: PolyVecK, // High bits of t = As1 + s2 // Cached values t1_packed: [polyT1PackedSize() * p.k]u8, A: MatKxL, tr: [p.tr_size]u8, // CRH(rho || t1) /// Encode public key to bytes pub fn toBytes(self: PublicKey) [encoded_length]u8 { var out: [encoded_length]u8 = undefined; @memcpy(out[0..32], &self.rho); @memcpy(out[32..], &self.t1_packed); return out; } /// Decode public key from bytes pub fn fromBytes(bytes: [encoded_length]u8) !PublicKey { var pk: PublicKey = undefined; @memcpy(&pk.rho, bytes[0..32]); @memcpy(&pk.t1_packed, bytes[32..]); pk.t1 = PolyVecK.unpackT1(pk.t1_packed[0..]); pk.A = MatKxL.derive(&pk.rho); pk.tr = crh(p.tr_size, .{&bytes}); return pk; } }; /// ML-DSA secret key pub const SecretKey = struct { /// Size of the encoded secret key in bytes pub const encoded_length: usize = 32 + 32 + p.tr_size + polyLeqEtaPackedSize() * (p.l + p.k) + polyT0PackedSize() * p.k; rho: [32]u8, // Seed for matrix A key: [32]u8, // Seed for signature generation randomness tr: [p.tr_size]u8, // CRH(rho || t1) s1: PolyVecL, // Secret vector 1 s2: PolyVecK, // Secret vector 2 t0: PolyVecK, // Low bits of t = As1 + s2 // Cached values (in NTT domain) A: MatKxL, s1_hat: PolyVecL, s2_hat: PolyVecK, t0_hat: PolyVecK, /// Encode secret key to bytes pub fn toBytes(self: SecretKey) [encoded_length]u8 { var out: [encoded_length]u8 = undefined; var offset: usize = 0; @memcpy(out[offset .. offset + 32], &self.rho); offset += 32; @memcpy(out[offset .. offset + 32], &self.key); offset += 32; @memcpy(out[offset .. offset + p.tr_size], &self.tr); offset += p.tr_size; if (p.eta == 2) { self.s1.packLeqEta(2, out[offset..][0 .. p.l * polyLeqEtaPackedSize()]); } else { self.s1.packLeqEta(4, out[offset..][0 .. p.l * polyLeqEtaPackedSize()]); } offset += p.l * polyLeqEtaPackedSize(); if (p.eta == 2) { self.s2.packLeqEta(2, out[offset..][0 .. p.k * polyLeqEtaPackedSize()]); } else { self.s2.packLeqEta(4, out[offset..][0 .. p.k * polyLeqEtaPackedSize()]); } offset += p.k * polyLeqEtaPackedSize(); self.t0.packT0(out[offset..][0 .. p.k * polyT0PackedSize()]); offset += p.k * polyT0PackedSize(); return out; } /// Decode secret key from bytes pub fn fromBytes(bytes: [encoded_length]u8) !SecretKey { var sk: SecretKey = undefined; var offset: usize = 0; @memcpy(&sk.rho, bytes[offset .. offset + 32]); offset += 32; @memcpy(&sk.key, bytes[offset .. offset + 32]); offset += 32; @memcpy(&sk.tr, bytes[offset .. offset + p.tr_size]); offset += p.tr_size; sk.s1 = if (p.eta == 2) PolyVecL.unpackLeqEta(2, bytes[offset..][0 .. p.l * polyLeqEtaPackedSize()]) else PolyVecL.unpackLeqEta(4, bytes[offset..][0 .. p.l * polyLeqEtaPackedSize()]); offset += p.l * polyLeqEtaPackedSize(); sk.s2 = if (p.eta == 2) PolyVecK.unpackLeqEta(2, bytes[offset..][0 .. p.k * polyLeqEtaPackedSize()]) else PolyVecK.unpackLeqEta(4, bytes[offset..][0 .. p.k * polyLeqEtaPackedSize()]); offset += p.k * polyLeqEtaPackedSize(); sk.t0 = PolyVecK.unpackT0(bytes[offset..][0 .. p.k * polyT0PackedSize()]); offset += p.k * polyT0PackedSize(); // Compute cached NTT values for efficient signing sk.A = MatKxL.derive(&sk.rho); sk.s1_hat = sk.s1.ntt(); sk.s2_hat = sk.s2.ntt(); sk.t0_hat = sk.t0.ntt(); return sk; } /// Compute the public key from this private key pub fn public(self: *const SecretKey) PublicKey { var pk: PublicKey = undefined; pk.rho = self.rho; pk.A = self.A; pk.tr = self.tr; // Reconstruct t = As1 + s2, then extract high bits t1 // Using power2Round: t = t1 * 2^D + t0 const t = computeT(self.A, self.s1_hat, self.s2); var t0_unused: PolyVecK = undefined; pk.t1 = t.power2Round(&t0_unused); pk.t1.packT1(&pk.t1_packed); return pk; } /// Create a Signer for incrementally signing a message. /// The noise parameter can be null for deterministic signatures, /// or provide randomness for hedged signatures (recommended for fault attack resistance). pub fn signer(self: *const SecretKey, noise: ?[noise_length]u8) !Signer { return self.signerWithContext(noise, ""); } /// Create a Signer for incrementally signing a message with context. /// The noise parameter can be null for deterministic signatures, /// or provide randomness for hedged signatures (recommended for fault attack resistance). /// The context parameter is an optional context string (max 255 bytes). pub fn signerWithContext(self: *const SecretKey, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer { return Signer.init(self, noise, context); } }; /// Generate a new key pair from a seed (deterministic) pub fn newKeyFromSeed(seed: *const [seed_length]u8) struct { pk: PublicKey, sk: SecretKey } { var sk: SecretKey = undefined; var pk: PublicKey = undefined; // NIST mode: expand seed || k || l using SHAKE-256 to get 128-byte expanded seed const e_seed = crh(128, .{ seed, &[_]u8{ p.k, p.l } }); @memcpy(&pk.rho, e_seed[0..32]); const s_seed = e_seed[32..96]; @memcpy(&sk.key, e_seed[96..128]); @memcpy(&sk.rho, &pk.rho); sk.A = MatKxL.derive(&pk.rho); pk.A = sk.A; const s_seed_array: *const [64]u8 = s_seed[0..64]; for (0..p.l) |i| { sk.s1.ps[i] = expandS(p.eta, s_seed_array, @intCast(i)); } for (0..p.k) |i| { sk.s2.ps[i] = expandS(p.eta, s_seed_array, @intCast(p.l + i)); } sk.s1_hat = sk.s1.ntt(); sk.s2_hat = sk.s2.ntt(); const t = computeT(sk.A, sk.s1_hat, sk.s2); pk.t1 = t.power2Round(&sk.t0); sk.t0_hat = sk.t0.ntt(); pk.t1.packT1(&pk.t1_packed); // tr = H(pk) = H(rho || t1) const pk_bytes = pk.toBytes(); const tr = crh(p.tr_size, .{&pk_bytes}); sk.tr = tr; pk.tr = tr; return .{ .pk = pk, .sk = sk }; } /// ML-DSA signature pub const Signature = struct { /// Size of the encoded signature in bytes pub const encoded_length: usize = p.ctilde_size + polyLeGamma1PackedSize() * p.l + p.omega + p.k; c_tilde: [p.ctilde_size]u8, // Challenge hash z: PolyVecL, // Response vector hint: PolyVecK, // Hint vector /// Encode signature to bytes pub fn toBytes(self: Signature) [encoded_length]u8 { var out: [encoded_length]u8 = undefined; var offset: usize = 0; @memcpy(out[offset .. offset + p.ctilde_size], &self.c_tilde); offset += p.ctilde_size; self.z.packLeGamma1(p.gamma1_bits, out[offset .. offset + polyLeGamma1PackedSize() * p.l]); offset += polyLeGamma1PackedSize() * p.l; _ = self.hint.packHint(p.omega, out[offset..]); return out; } /// Decode signature from bytes pub fn fromBytes(bytes: [encoded_length]u8) EncodingError!Signature { var sig: Signature = undefined; var offset: usize = 0; @memcpy(&sig.c_tilde, bytes[offset .. offset + p.ctilde_size]); offset += p.ctilde_size; sig.z = PolyVecL.unpackLeGamma1(p.gamma1_bits, bytes[offset .. offset + polyLeGamma1PackedSize() * p.l]); offset += polyLeGamma1PackedSize() * p.l; // Validate ||z||_inf < gamma1 - beta per FIPS 204 if (sig.z.exceeds(gamma1 - beta)) { return error.InvalidEncoding; } sig.hint = PolyVecK.unpackHint(p.omega, bytes[offset..]) orelse return error.InvalidEncoding; return sig; } pub const VerifyError = Verifier.InitError || Verifier.VerifyError; /// Verify this signature against a message and public key. /// Returns an error if the signature is invalid. pub fn verify( sig: Signature, msg: []const u8, public_key: PublicKey, ) VerifyError!void { return sig.verifyWithContext(msg, public_key, ""); } /// Verify this signature against a message and public key with context. /// Returns an error if the signature is invalid. /// The context parameter is an optional context string (max 255 bytes). pub fn verifyWithContext( sig: Signature, msg: []const u8, public_key: PublicKey, context: []const u8, ) VerifyError!void { if (context.len > 255) { return error.SignatureVerificationFailed; } var h = sha3.Shake256.init(.{}); h.update(&public_key.tr); h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA h.update(&[_]u8{@intCast(context.len)}); if (context.len > 0) { h.update(context); } h.update(msg); var mu: [64]u8 = undefined; h.squeeze(&mu); const z_hat = sig.z.ntt(); const Az = public_key.A.mulVecHat(z_hat); // Compute w' ≈ Az - 2^d·c·t1 (approximate w used in signing) var Az2dct1 = public_key.t1.mulBy2toD(); Az2dct1 = Az2dct1.ntt(); const c_poly = sampleInBall(p.tau, &sig.c_tilde); const c_hat = c_poly.ntt(); for (0..p.k) |i| { Az2dct1.ps[i] = Az2dct1.ps[i].mulHat(c_hat); } Az2dct1 = Az.sub(Az2dct1); Az2dct1 = Az2dct1.reduceLe2Q(); Az2dct1 = Az2dct1.invNTT(); Az2dct1 = Az2dct1.normalizeAssumingLe2Q(); // Apply hints to recover high bits w1' var w1_prime = Az2dct1.useHint(sig.hint, p.gamma2); var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined; w1_prime.packW1(p.gamma1_bits, &w1_packed); const c_prime = crh(p.ctilde_size, .{ &mu, &w1_packed }); if (!mem.eql(u8, &c_prime, &sig.c_tilde)) { return error.SignatureVerificationFailed; } } /// Create a Verifier for incrementally verifying a signature. pub fn verifier(self: Signature, public_key: PublicKey) !Verifier { return self.verifierWithContext(public_key, ""); } /// Create a Verifier for incrementally verifying a signature with context. /// The context parameter is an optional context string (max 255 bytes). pub fn verifierWithContext(self: Signature, public_key: PublicKey, context: []const u8) ContextTooLongError!Verifier { return Verifier.init(self, public_key, context); } }; /// A Signer is used to incrementally compute a signature over a streamed message. /// It can be obtained from a `SecretKey` or `KeyPair`, using the `signer()` function. pub const Signer = struct { h: sha3.Shake256, // For computing μ = CRH(tr || msg) secret_key: *const SecretKey, rnd: [32]u8, /// Initialize a new Signer. /// The noise parameter can be null for deterministic signatures, /// or provide randomness for hedged signatures (recommended for fault attack resistance). /// The context parameter is an optional context string (max 255 bytes). pub fn init(secret_key: *const SecretKey, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer { if (context.len > 255) { return error.ContextTooLong; } var h = sha3.Shake256.init(.{}); h.update(&secret_key.tr); h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA h.update(&[_]u8{@intCast(context.len)}); if (context.len > 0) { h.update(context); } return Signer{ .h = h, .secret_key = secret_key, .rnd = noise orelse .{0} ** 32, }; } /// Add new data to the message being signed. pub fn update(self: *Signer, data: []const u8) void { self.h.update(data); } /// Compute a signature over the entire message. pub fn finalize(self: *Signer) Signature { var mu: [64]u8 = undefined; self.h.squeeze(&mu); const rho_prime = crh(64, .{ &self.secret_key.key, &self.rnd, &mu }); var sig: Signature = undefined; var y_nonce: u16 = 0; // Rejection sampling loop (FIPS 204 Algorithm 2, steps 5-16) var attempt: u32 = 0; while (true) { attempt += 1; if (attempt >= 576) { // (6/7)⁵⁷⁶ < 2⁻¹²⁸ @branchHint(.unlikely); unreachable; } const y = PolyVecL.deriveUniformLeGamma1(p.gamma1_bits, &rho_prime, y_nonce); y_nonce += @intCast(p.l); const y_hat = y.ntt(); var w = self.secret_key.A.mulVec(y_hat); w = w.normalize(); var w0: PolyVecK = undefined; const w1 = w.decomposeVec(p.gamma2, &w0); var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined; w1.packW1(p.gamma1_bits, &w1_packed); sig.c_tilde = crh(p.ctilde_size, .{ &mu, &w1_packed }); const c_poly = sampleInBall(p.tau, &sig.c_tilde); const c_hat = c_poly.ntt(); // Rejection check: ensure masking is effective var w0mcs2: PolyVecK = undefined; for (0..p.k) |i| { w0mcs2.ps[i] = c_hat.mulHat(self.secret_key.s2_hat.ps[i]); w0mcs2.ps[i] = w0mcs2.ps[i].invNTT(); } w0mcs2 = w0.sub(w0mcs2); w0mcs2 = w0mcs2.normalize(); if (w0mcs2.exceeds(p.gamma2 - beta)) { continue; } // Compute response z = y + c·s1 for (0..p.l) |i| { sig.z.ps[i] = c_hat.mulHat(self.secret_key.s1_hat.ps[i]); sig.z.ps[i] = sig.z.ps[i].invNTT(); } sig.z = sig.z.add(y); sig.z = sig.z.normalize(); if (sig.z.exceeds(gamma1 - beta)) { continue; } var ct0: PolyVecK = undefined; for (0..p.k) |i| { ct0.ps[i] = c_hat.mulHat(self.secret_key.t0_hat.ps[i]); ct0.ps[i] = ct0.ps[i].invNTT(); } ct0 = ct0.reduceLe2Q(); ct0 = ct0.normalize(); if (ct0.exceeds(p.gamma2)) { continue; } // Generate hints for verification var w0mcs2pct0 = w0mcs2.add(ct0); w0mcs2pct0 = w0mcs2pct0.reduceLe2Q(); w0mcs2pct0 = w0mcs2pct0.normalizeAssumingLe2Q(); const hint_result = PolyVecK.makeHintVec(w0mcs2pct0, w1, p.gamma2); if (hint_result.pop > p.omega) { continue; } sig.hint = hint_result.hint; return sig; } } }; /// A Verifier is used to incrementally verify a signature over a streamed message. /// It can be obtained from a `Signature`, using the `verifier()` function. pub const Verifier = struct { h: sha3.Shake256, // For computing μ = CRH(tr || msg) signature: Signature, public_key: PublicKey, pub const InitError = EncodingError; pub const VerifyError = SignatureVerificationError; /// Initialize a new Verifier. /// The context parameter is an optional context string (max 255 bytes). pub fn init(signature: Signature, public_key: PublicKey, context: []const u8) ContextTooLongError!Verifier { if (context.len > 255) { return error.ContextTooLong; } var h = sha3.Shake256.init(.{}); h.update(&public_key.tr); h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA h.update(&[_]u8{@intCast(context.len)}); // Context length if (context.len > 0) { h.update(context); } return Verifier{ .h = h, .signature = signature, .public_key = public_key, }; } /// Add new content to the message to be verified. pub fn update(self: *Verifier, data: []const u8) void { self.h.update(data); } /// Verify that the signature is valid for the entire message. pub fn verify(self: *Verifier) SignatureVerificationError!void { var mu: [64]u8 = undefined; self.h.squeeze(&mu); const z_hat = self.signature.z.ntt(); const Az = self.public_key.A.mulVecHat(z_hat); // Compute w' ≈ Az - 2^d·c·t1 (approximate w used in signing) var Az2dct1 = self.public_key.t1.mulBy2toD(); Az2dct1 = Az2dct1.ntt(); const c_poly = sampleInBall(p.tau, &self.signature.c_tilde); const c_hat = c_poly.ntt(); for (0..p.k) |i| { Az2dct1.ps[i] = Az2dct1.ps[i].mulHat(c_hat); } Az2dct1 = Az.sub(Az2dct1); Az2dct1 = Az2dct1.reduceLe2Q(); Az2dct1 = Az2dct1.invNTT(); Az2dct1 = Az2dct1.normalizeAssumingLe2Q(); // Apply hints to recover high bits w1' var w1_prime = Az2dct1.useHint(self.signature.hint, p.gamma2); var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined; w1_prime.packW1(p.gamma1_bits, &w1_packed); const c_prime = crh(p.ctilde_size, .{ &mu, &w1_packed }); if (!mem.eql(u8, &c_prime, &self.signature.c_tilde)) { return error.SignatureVerificationFailed; } } }; /// A key pair consisting of a secret key and its corresponding public key. pub const KeyPair = struct { /// Length (in bytes) of a seed required to create a key pair. pub const seed_length = Self.seed_length; /// The public key component. public_key: PublicKey, /// The secret key component. secret_key: SecretKey, /// Generate a new random key pair. /// This uses the system's cryptographically secure random number generator. /// /// `crypto.random.bytes` must be supported by the target. pub fn generate() KeyPair { var seed: [Self.seed_length]u8 = undefined; crypto.random.bytes(&seed); return generateDeterministic(seed) catch unreachable; } /// Generate a key pair deterministically from a seed. /// Use for testing or when reproducibility is required. /// The seed should be generated using a cryptographically secure random source. pub fn generateDeterministic(seed: [32]u8) !KeyPair { const keys = newKeyFromSeed(&seed); return .{ .public_key = keys.pk, .secret_key = keys.sk, }; } /// Derive the public key from an existing secret key. /// This recomputes the public key components from the secret key. pub fn fromSecretKey(sk: SecretKey) !KeyPair { var pk: PublicKey = undefined; pk.rho = sk.rho; pk.tr = sk.tr; pk.A = sk.A; const t = computeT(sk.A, sk.s1_hat, sk.s2); var t0: PolyVecK = undefined; pk.t1 = t.power2Round(&t0); pk.t1.packT1(&pk.t1_packed); return .{ .public_key = pk, .secret_key = sk, }; } /// Create a Signer for incrementally signing a message. /// The noise parameter can be null for deterministic signatures, /// or provide randomness for hedged signatures (recommended for fault attack resistance). pub fn signer(self: *const KeyPair, noise: ?[noise_length]u8) !Signer { return self.secret_key.signer(noise); } /// Create a Signer for incrementally signing a message with context. /// The noise parameter can be null for deterministic signatures, /// or provide randomness for hedged signatures (recommended for fault attack resistance). /// The context parameter is an optional context string (max 255 bytes). pub fn signerWithContext(self: *const KeyPair, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer { return self.secret_key.signerWithContext(noise, context); } /// Sign a message using this key pair. /// The noise parameter can be null for deterministic signatures, /// or provide randomness for hedged signatures (recommended for fault attack resistance). pub fn sign( kp: KeyPair, msg: []const u8, noise: ?[noise_length]u8, ) !Signature { return kp.signWithContext(msg, noise, ""); } /// Sign a message using this key pair with context. /// The noise parameter can be null for deterministic signatures, /// or provide randomness for hedged signatures (recommended for fault attack resistance). /// The context parameter is an optional context string (max 255 bytes). pub fn signWithContext( kp: KeyPair, msg: []const u8, noise: ?[noise_length]u8, context: []const u8, ) ContextTooLongError!Signature { var st = try kp.signerWithContext(noise, context); st.update(msg); return st.finalize(); } }; }; } test "modular arithmetic" { // Test Montgomery reduction const x: u64 = 12345678; const y = montReduceLe2Q(x); try testing.expect(y < 2 * Q); // Test modQ try testing.expectEqual(@as(u32, 0), modQ(Q)); try testing.expectEqual(@as(u32, 1), modQ(Q + 1)); } test "polynomial operations" { var p1 = Poly.zero; p1.cs[0] = 1; p1.cs[1] = 2; var p2 = Poly.zero; p2.cs[0] = 3; p2.cs[1] = 4; const p3 = p1.add(p2); try testing.expectEqual(@as(u32, 4), p3.cs[0]); try testing.expectEqual(@as(u32, 6), p3.cs[1]); } test "NTT and inverse NTT" { // Create a test polynomial in REGULAR FORM (not Montgomery) var p = Poly.zero; for (0..N) |i| { p.cs[i] = @intCast(i % Q); } // Apply NTT then inverse NTT // According to Dilithium spec: NTT followed by invNTT multiplies by R // So result will be p * R (i.e., p in Montgomery form) var p_ntt = p.ntt(); // Reduce before invNTT (as Go test does) p_ntt = p_ntt.reduceLe2Q(); const p_restored = p_ntt.invNTT(); // Reduce and normalize const p_reduced = p_restored.reduceLe2Q(); const p_norm = p_reduced.normalize(); // Check if we get p * R (which equals toMont(p)) for (0..N) |i| { const original: u32 = @intCast(i % Q); const expected = toMont(original); const expected_norm = modQ(expected); try testing.expectEqual(expected_norm, p_norm.cs[i]); } } test "parameter set instantiation" { // Just verify we can instantiate all three parameter sets const ml44 = MLDSA44; const ml65 = MLDSA65; const ml87 = MLDSA87; try testing.expectEqualStrings("ML-DSA-44", ml44.name); try testing.expectEqualStrings("ML-DSA-65", ml65.name); try testing.expectEqualStrings("ML-DSA-87", ml87.name); } test "compare zetas with Go implementation" { // First 16 zetas from Go implementation (in Montgomery form) const go_zetas = [16]u32{ 4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468, 1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103, }; // Compare our computed zetas with Go's for (0..16) |i| { try testing.expectEqual(go_zetas[i], zetas[i]); } } test "NTT with simple polynomial" { // Test with a very simple polynomial: just one coefficient set to 1 in regular form var p = Poly.zero; p.cs[0] = 1; var p_ntt = p.ntt(); // Reduce before invNTT (as Go test does) p_ntt = p_ntt.reduceLe2Q(); const p_restored = p_ntt.invNTT(); // Result should be 1 * R = toMont(1) in Montgomery form const p_reduced = p_restored.reduceLe2Q(); const p_norm = p_reduced.normalize(); const expected = modQ(toMont(1)); try testing.expectEqual(expected, p_norm.cs[0]); // All other coefficients should be 0 * R = 0 for (1..N) |i| { try testing.expectEqual(@as(u32, 0), p_norm.cs[i]); } } test "Montgomery reduction correctness" { // Test that Montgomery reduction works correctly // montReduceLe2Q(a * b * R) = a * b mod q (where a, b are in Montgomery form) const x: u32 = 12345; const y: u32 = 67890; // Convert to Montgomery form const x_mont = toMont(x); const y_mont = toMont(y); // Multiply in Montgomery form const product_mont = montReduceLe2Q(@as(u64, x_mont) * @as(u64, y_mont)); // Convert back from Montgomery form const product = montReduceLe2Q(@as(u64, product_mont)); // Direct multiplication mod q const expected = modQ(@as(u32, @intCast((@as(u64, x) * @as(u64, y)) % Q))); try testing.expectEqual(expected, modQ(product)); } // Removed debug test - was causing noise in output test "compare inv_zetas with Go implementation" { // First 16 inv_zetas from Go implementation const go_inv_zetas = [16]u32{ 6403635, 846154, 6979993, 4442679, 1362209, 48306, 4460757, 554416, 3545687, 6767575, 976891, 8196974, 2286327, 420899, 2235985, 2939036, }; // Compare our computed inv_zetas with Go's for (0..16) |i| { if (inv_zetas[i] != go_inv_zetas[i]) { std.debug.print("Mismatch at inv_zetas[{d}]: got {d}, expected {d}\n", .{ i, inv_zetas[i], go_inv_zetas[i] }); } try testing.expectEqual(go_inv_zetas[i], inv_zetas[i]); } } test "power2Round correctness" { // Test that power2Round correctly splits values // For all a in [0, Q), we should have a = a1*2^D + a0 // where -2^(D-1) < a0 <= 2^(D-1) // Test a few specific values const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345, 8380416 }; for (test_values) |a| { if (a >= Q) continue; const result = power2Round(a); const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q)); const a1 = result.a1; // Check reconstruction: a = a1*2^D + a0 const reconstructed = @as(i32, @bitCast(a1 << D)) + a0; try testing.expectEqual(@as(i32, @bitCast(a)), reconstructed); // Check a0 bounds: -2^(D-1) < a0 <= 2^(D-1) const bound: i32 = 1 << (D - 1); try testing.expect(a0 > -bound and a0 <= bound); } } test "decompose correctness for ML-DSA-65" { // Test decompose with gamma2 = 95232 (ML-DSA-44) const gamma2 = 95232; const alpha = 2 * gamma2; const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345 }; for (test_values) |a| { if (a >= Q) continue; const result = decompose(a, gamma2); const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q)); const a1 = result.a1; // Check reconstruction: a = a1*alpha + a0 (mod Q) var reconstructed: i64 = @as(i64, @intCast(a1)) * @as(i64, @intCast(alpha)) + @as(i64, a0); reconstructed = @mod(reconstructed, @as(i64, Q)); try testing.expectEqual(@as(i64, @intCast(a)), reconstructed); // Check a0 bounds (approximately) const bound: i32 = @intCast(alpha / 2); try testing.expect(@abs(a0) <= bound); } } test "decompose correctness for ML-DSA-87" { // Test decompose with gamma2 = 261888 (ML-DSA-65 and ML-DSA-87) const gamma2 = 261888; const alpha = 2 * gamma2; const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345 }; for (test_values) |a| { if (a >= Q) continue; const result = decompose(a, gamma2); const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q)); const a1 = result.a1; // Check reconstruction: a = a1*alpha + a0 (mod Q) var reconstructed: i64 = @as(i64, @intCast(a1)) * @as(i64, @intCast(alpha)) + @as(i64, a0); reconstructed = @mod(reconstructed, @as(i64, Q)); try testing.expectEqual(@as(i64, @intCast(a)), reconstructed); // Check a0 bounds (approximately) const bound: i32 = @intCast(alpha / 2); try testing.expect(@abs(a0) <= bound); } } test "polyDeriveUniform deterministic" { // Test that polyDeriveUniform produces deterministic results const seed: [32]u8 = .{0x01} ++ .{0x00} ** 31; const nonce: u16 = 0; const p1 = polyDeriveUniform(&seed, nonce); const p2 = polyDeriveUniform(&seed, nonce); // Should be identical for (0..N) |i| { try testing.expectEqual(p1.cs[i], p2.cs[i]); } // All coefficients should be in [0, Q) for (0..N) |i| { try testing.expect(p1.cs[i] < Q); } } test "polyDeriveUniform different nonces" { // Test that different nonces produce different polynomials const seed: [32]u8 = .{0x01} ++ .{0x00} ** 31; const p1 = polyDeriveUniform(&seed, 0); const p2 = polyDeriveUniform(&seed, 1); // Should be different var different = false; for (0..N) |i| { if (p1.cs[i] != p2.cs[i]) { different = true; break; } } try testing.expect(different); } test "expandS with eta=2" { // Test eta=2 sampling const seed: [64]u8 = .{0x02} ++ .{0x00} ** 63; const nonce: u16 = 0; const p = expandS(2, &seed, nonce); // All coefficients should be in [Q-eta, Q+eta] // The function returns coefficients as Q + eta - t, where t is in [0, 2*eta] // So coefficients are in [Q-eta, Q+eta] for (0..N) |i| { const c = p.cs[i]; // Check that c is in [Q-2, Q+2] try testing.expect(c >= Q - 2 and c <= Q + 2); } } test "expandS with eta=4" { // Test eta=4 sampling const seed: [64]u8 = .{0x03} ++ .{0x00} ** 63; const nonce: u16 = 0; const p = expandS(4, &seed, nonce); // All coefficients should be in [Q-eta, Q+eta] for (0..N) |i| { const c = p.cs[i]; // Check bounds (coefficients are around Q ± eta) const diff = if (c >= Q) c - Q else Q - c; try testing.expect(diff <= 4); } } test "sampleInBall has correct weight" { // Test that ball polynomial has exactly tau non-zero coefficients const tau = 39; // From ML-DSA-44 const seed: [32]u8 = .{0x04} ++ .{0x00} ** 31; const p = sampleInBall(tau, &seed); // Count non-zero coefficients var count: u32 = 0; for (0..N) |i| { if (p.cs[i] != 0) { count += 1; // Non-zero coefficients should be 1 or Q-1 try testing.expect(p.cs[i] == 1 or p.cs[i] == Q - 1); } } try testing.expectEqual(tau, count); } test "sampleInBall deterministic" { // Test that ball sampling is deterministic const tau = 49; // From ML-DSA-65 const seed: [32]u8 = .{0x05} ++ .{0x00} ** 31; const p1 = sampleInBall(tau, &seed); const p2 = sampleInBall(tau, &seed); // Should be identical for (0..N) |i| { try testing.expectEqual(p1.cs[i], p2.cs[i]); } } test "polyPackLeqEta / polyUnpackLeqEta roundtrip for eta=2" { // Test packing and unpacking for eta=2 const eta = 2; // Create a test polynomial with coefficients in [Q-eta, Q+eta] var p = Poly.zero; for (0..N) |i| { // Use various values in range const val = @as(u32, @intCast(i % 5)); // 0, 1, 2, 3, 4 p.cs[i] = Q + eta - val; } // Pack it var buf: [96]u8 = undefined; // eta=2: 3 bits per coeff = 96 bytes polyPackLeqEta(p, eta, &buf); // Unpack it const p2 = polyUnpackLeqEta(eta, &buf); // Should be identical for (0..N) |i| { try testing.expectEqual(p.cs[i], p2.cs[i]); } } test "polyPackLeqEta / polyUnpackLeqEta roundtrip for eta=4" { // Test packing and unpacking for eta=4 const eta = 4; // Create a test polynomial with coefficients in [Q-eta, Q+eta] var p = Poly.zero; for (0..N) |i| { // Use various values in range const val = @as(u32, @intCast(i % 9)); // 0, 1, 2, ..., 8 p.cs[i] = Q + eta - val; } // Pack it var buf: [128]u8 = undefined; // eta=4: 4 bits per coeff = 128 bytes polyPackLeqEta(p, eta, &buf); // Unpack it const p2 = polyUnpackLeqEta(eta, &buf); // Should be identical for (0..N) |i| { try testing.expectEqual(p.cs[i], p2.cs[i]); } } test "polyPackT1 / polyUnpackT1 roundtrip" { // Create a test polynomial with coefficients < 1024 var p = Poly.zero; for (0..N) |i| { p.cs[i] = @intCast(i % 1024); } // Pack it var buf: [320]u8 = undefined; // (256 * 10) / 8 = 320 bytes polyPackT1(p, &buf); // Unpack it const p2 = polyUnpackT1(&buf); // Should be identical for (0..N) |i| { try testing.expectEqual(p.cs[i], p2.cs[i]); } } test "polyPackT0 / polyUnpackT0 roundtrip" { // Create a test polynomial with coefficients in (Q-2^12, Q+2^12] // This is the range (-2^12, 2^12] represented as unsigned around Q const bound = 1 << 12; // 2^(D-1) where D=13 var p = Poly.zero; for (0..N) |i| { // Cycle through valid range for T0 // Values should be Q + offset where offset is in (-bound, bound] const cycle_val = @as(i32, @intCast(i % (2 * bound))); // 0 to 2*bound-1 const offset = cycle_val - bound + 1; // (-bound+1) to bound p.cs[i] = @as(u32, @intCast(@as(i32, Q) + offset)); } // Pack it var buf: [416]u8 = undefined; // (256 * 13) / 8 = 416 bytes polyPackT0(p, &buf); // Unpack it const p2 = polyUnpackT0(&buf); // Should be identical for (0..N) |i| { try testing.expectEqual(p.cs[i], p2.cs[i]); } } test "polyPackLeGamma1 / polyUnpackLeGamma1 roundtrip gamma1_bits=17" { const gamma1_bits = 17; const gamma1: u32 = @as(u32, 1) << gamma1_bits; // Create a test polynomial with coefficients in (-gamma1, gamma1] // Normalized: [0, gamma1] ∪ (Q-gamma1, Q) var p = Poly.zero; for (0..N) |i| { if (i % 2 == 0) { // Positive values: [0, gamma1] p.cs[i] = @intCast((i / 2) % (gamma1 + 1)); } else { // Negative values: (Q-gamma1, Q) const neg_val: u32 = @intCast(((i / 2) % gamma1) + 1); p.cs[i] = Q - neg_val; } } // Pack it var buf: [576]u8 = undefined; // (256 * 18) / 8 = 576 bytes polyPackLeGamma1(p, gamma1_bits, &buf); // Unpack it const p2 = polyUnpackLeGamma1(gamma1_bits, &buf); // Should be identical for (0..N) |i| { try testing.expectEqual(p.cs[i], p2.cs[i]); } } test "polyPackLeGamma1 / polyUnpackLeGamma1 roundtrip gamma1_bits=19" { const gamma1_bits = 19; const gamma1: u32 = @as(u32, 1) << gamma1_bits; // Create a test polynomial with coefficients in (-gamma1, gamma1] var p = Poly.zero; for (0..N) |i| { if (i % 2 == 0) { // Positive values: [0, gamma1] p.cs[i] = @intCast((i / 2) % (gamma1 + 1)); } else { // Negative values: (Q-gamma1, Q) const neg_val: u32 = @intCast(((i / 2) % gamma1) + 1); p.cs[i] = Q - neg_val; } } // Pack it var buf: [640]u8 = undefined; // (256 * 20) / 8 = 640 bytes polyPackLeGamma1(p, gamma1_bits, &buf); // Unpack it const p2 = polyUnpackLeGamma1(gamma1_bits, &buf); // Should be identical for (0..N) |i| { try testing.expectEqual(p.cs[i], p2.cs[i]); } } test "polyPackW1 for gamma1_bits=17" { const gamma1_bits = 17; // Create a test polynomial with small coefficients (w1 values < 64) var p = Poly.zero; for (0..N) |i| { p.cs[i] = @intCast(i % 64); // 6-bit values } // Pack it var buf: [192]u8 = undefined; // (256 * 6) / 8 = 192 bytes polyPackW1(p, gamma1_bits, &buf); // Verify basic properties // All bytes should be used var non_zero = false; for (buf) |b| { if (b != 0) { non_zero = true; break; } } try testing.expect(non_zero); } test "polyPackW1 for gamma1_bits=19" { const gamma1_bits = 19; // Create a test polynomial with small coefficients (w1 values < 16) var p = Poly.zero; for (0..N) |i| { p.cs[i] = @intCast(i % 16); // 4-bit values } // Pack it var buf: [128]u8 = undefined; // (256 * 4) / 8 = 128 bytes polyPackW1(p, gamma1_bits, &buf); // Verify basic properties var non_zero = false; for (buf) |b| { if (b != 0) { non_zero = true; break; } } try testing.expect(non_zero); } test "makeHint and useHint correctness for gamma2=261888" { // Test for ML-DSA-65 and ML-DSA-87 const gamma2: u32 = 261888; // Test a selection of values to verify the hint mechanism works const test_values = [_]u32{ 0, 100, 1000, 10000, 100000, 1000000, Q / 2, Q - 1 }; for (test_values) |w| { // Decompose w to get w0 and w1 const decomp = decompose(w, gamma2); const w0_plus_q = decomp.a0_plus_q; const w1 = decomp.a1; // Test with various small perturbations f in [0, gamma2] const perturbations = [_]u32{ 0, 1, 10, 100, 1000, gamma2 / 2, gamma2 }; for (perturbations) |f| { // Test f (positive perturbation) const z0_pos = (w0_plus_q +% Q -% f) % Q; const hint_pos = makeHint(z0_pos, w1, gamma2); const w_perturbed_pos = (w +% Q -% f) % Q; const w1_recovered_pos = useHint(w_perturbed_pos, hint_pos, gamma2); try testing.expectEqual(w1, w1_recovered_pos); // Test -f (negative perturbation) if (f > 0) { const z0_neg = (w0_plus_q +% f) % Q; const hint_neg = makeHint(z0_neg, w1, gamma2); const w_perturbed_neg = (w +% f) % Q; const w1_recovered_neg = useHint(w_perturbed_neg, hint_neg, gamma2); try testing.expectEqual(w1, w1_recovered_neg); } } } } test "makeHint and useHint correctness for gamma2=95232" { // Test for ML-DSA-44 const gamma2: u32 = 95232; // Test a selection of values to verify the hint mechanism works const test_values = [_]u32{ 0, 100, 1000, 10000, 100000, 1000000, Q / 2, Q - 1 }; for (test_values) |w| { // Decompose w to get w0 and w1 const decomp = decompose(w, gamma2); const w0_plus_q = decomp.a0_plus_q; const w1 = decomp.a1; // Test with various small perturbations f in [0, gamma2] const perturbations = [_]u32{ 0, 1, 10, 100, 1000, gamma2 / 2, gamma2 }; for (perturbations) |f| { // Test f (positive perturbation) const z0_pos = (w0_plus_q +% Q -% f) % Q; const hint_pos = makeHint(z0_pos, w1, gamma2); const w_perturbed_pos = (w +% Q -% f) % Q; const w1_recovered_pos = useHint(w_perturbed_pos, hint_pos, gamma2); try testing.expectEqual(w1, w1_recovered_pos); // Test -f (negative perturbation) if (f > 0) { const z0_neg = (w0_plus_q +% f) % Q; const hint_neg = makeHint(z0_neg, w1, gamma2); const w_perturbed_neg = (w +% f) % Q; const w1_recovered_neg = useHint(w_perturbed_neg, hint_neg, gamma2); try testing.expectEqual(w1, w1_recovered_neg); } } } } test "polyMakeHint basic functionality" { const gamma2: u32 = 261888; // Create test polynomials var p0 = Poly.zero; var p1 = Poly.zero; // Fill with test values for (0..N) |i| { p0.cs[i] = @intCast((i * 17) % Q); p1.cs[i] = @intCast((i * 3) % 16); // High bits are at most 15 for gamma2=261888 } // Make hints const result = polyMakeHint(p0, p1, gamma2); const hint = result.hint; const count = result.count; // Verify that hints are binary for (0..N) |i| { try testing.expect(hint.cs[i] == 0 or hint.cs[i] == 1); } // Verify that count matches the number of 1s in hint var actual_count: u32 = 0; for (0..N) |i| { actual_count += hint.cs[i]; } try testing.expectEqual(count, actual_count); } test "polyUseHint reconstruction" { const gamma2: u32 = 261888; // Create a test polynomial q var q = Poly.zero; for (0..N) |i| { q.cs[i] = @intCast((i * 123) % Q); } // Decompose q to get high and low bits var q0_plus_q_array: [N]u32 = undefined; var q1_array: [N]u32 = undefined; for (0..N) |i| { const decomp = decompose(q.cs[i], gamma2); q0_plus_q_array[i] = decomp.a0_plus_q; q1_array[i] = decomp.a1; } const q0_plus_q = Poly{ .cs = q0_plus_q_array }; const q1 = Poly{ .cs = q1_array }; // Create hints (in this case, they'll mostly be 0 since q and q are the same) const hint_result = polyMakeHint(q0_plus_q, q1, gamma2); const hint = hint_result.hint; // Use hints to recover high bits const recovered = polyUseHint(q, hint, gamma2); // Recovered should match original high bits q1 for (0..N) |i| { try testing.expectEqual(q1.cs[i], recovered.cs[i]); } } test "hint roundtrip with perturbation" { const gamma2: u32 = 261888; // Create a test polynomial w var w = Poly.zero; for (0..N) |i| { w.cs[i] = @intCast((i * 7919) % Q); } // Decompose w to get w0 and w1 var w0_plus_q = Poly.zero; var w1 = Poly.zero; for (0..N) |i| { const decomp = decompose(w.cs[i], gamma2); w0_plus_q.cs[i] = decomp.a0_plus_q; w1.cs[i] = decomp.a1; } // Apply a small perturbation var f = Poly.zero; for (0..N) |i| { // Small perturbation in [-gamma2, gamma2] const f_val = @as(u32, @intCast(i % 1000)); f.cs[i] = if (i % 2 == 0) f_val else Q -% f_val; } // Compute w' = w - f and z0 = w0 - f var w_prime = Poly.zero; var z0 = Poly.zero; for (0..N) |i| { w_prime.cs[i] = (w.cs[i] +% Q -% f.cs[i]) % Q; z0.cs[i] = (w0_plus_q.cs[i] +% Q -% f.cs[i]) % Q; } // Make hints const hint_result = polyMakeHint(z0, w1, gamma2); const hint = hint_result.hint; // Use hints to recover w1 from w_prime const w1_recovered = polyUseHint(w_prime, hint, gamma2); // Verify that we recovered the original high bits for (0..N) |i| { try testing.expectEqual(w1.cs[i], w1_recovered.cs[i]); } } // Parameterized test helper for key generation fn testKeyGenerationBasic(comptime MlDsa: type, seed: [32]u8) !void { const result = MlDsa.newKeyFromSeed(&seed); const pk = result.pk; const sk = result.sk; // Basic sanity checks try testing.expect(pk.rho.len == 32); try testing.expect(sk.rho.len == 32); try testing.expectEqualSlices(u8, &pk.rho, &sk.rho); // Verify tr matches between pk and sk try testing.expectEqualSlices(u8, &pk.tr, &sk.tr); // Test toBytes/fromBytes round-trip for public key const pk_bytes = pk.toBytes(); const pk2 = try MlDsa.PublicKey.fromBytes(pk_bytes); try testing.expectEqualSlices(u8, &pk.rho, &pk2.rho); try testing.expectEqualSlices(u8, &pk.tr, &pk2.tr); // Test toBytes/fromBytes round-trip for secret key const sk_bytes = sk.toBytes(); const sk2 = try MlDsa.SecretKey.fromBytes(sk_bytes); try testing.expectEqualSlices(u8, &sk.rho, &sk2.rho); try testing.expectEqualSlices(u8, &sk.key, &sk2.key); try testing.expectEqualSlices(u8, &sk.tr, &sk2.tr); } test "Key generation basic - all variants" { inline for (.{ .{ .variant = MLDSA44, .seed_byte = 0x44 }, .{ .variant = MLDSA65, .seed_byte = 0x65 }, .{ .variant = MLDSA87, .seed_byte = 0x87 }, }) |config| { const seed = [_]u8{config.seed_byte} ** 32; try testKeyGenerationBasic(config.variant, seed); } } test "Key generation determinism" { const seed = [_]u8{ 0x12, 0x34, 0x56, 0x78 } ++ [_]u8{0xAB} ** 28; // Generate two key pairs from the same seed const result1 = MLDSA44.newKeyFromSeed(&seed); const result2 = MLDSA44.newKeyFromSeed(&seed); // They should be identical const pk_bytes1 = result1.pk.toBytes(); const pk_bytes2 = result2.pk.toBytes(); try testing.expectEqualSlices(u8, &pk_bytes1, &pk_bytes2); const sk_bytes1 = result1.sk.toBytes(); const sk_bytes2 = result2.sk.toBytes(); try testing.expectEqualSlices(u8, &sk_bytes1, &sk_bytes2); } test "Private key can compute public key" { const seed = [_]u8{0xFF} ** 32; const result = MLDSA44.newKeyFromSeed(&seed); const pk = result.pk; const sk = result.sk; // Compute public key from private key const pk_from_sk = sk.public(); // Pack both public keys and compare const pk_bytes1 = pk.toBytes(); const pk_bytes2 = pk_from_sk.toBytes(); try testing.expectEqualSlices(u8, &pk_bytes1, &pk_bytes2); } // Parameterized test helper for sign and verify fn testSignAndVerify(comptime MlDsa: type, seed: [32]u8, message: []const u8) !void { const result = MlDsa.newKeyFromSeed(&seed); const kp = try MlDsa.KeyPair.fromSecretKey(result.sk); // Sign the message const sig = try kp.sign(message, null); // Verify the signature try sig.verify(message, kp.public_key); } test "Sign and verify - all variants" { inline for (.{ .{ .variant = MLDSA44, .seed_byte = 0x44, .message = "Hello, ML-DSA-44!" }, .{ .variant = MLDSA65, .seed_byte = 0x65, .message = "Hello, ML-DSA-65!" }, .{ .variant = MLDSA87, .seed_byte = 0x87, .message = "Hello, ML-DSA-87!" }, }) |config| { const seed = [_]u8{config.seed_byte} ** 32; try testSignAndVerify(config.variant, seed, config.message); } } test "Invalid signature rejection" { const seed = [_]u8{0x99} ** 32; const result = MLDSA44.newKeyFromSeed(&seed); const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); const message = "Original message"; // Sign the message const sig = try kp.sign(message, null); // Verify with wrong message should fail const wrong_message = "Modified message"; try testing.expectError(error.SignatureVerificationFailed, sig.verify(wrong_message, kp.public_key)); // Modify signature and verify should fail var corrupted_sig_bytes = sig.toBytes(); corrupted_sig_bytes[0] ^= 0xFF; const corrupted_sig = try MLDSA44.Signature.fromBytes(corrupted_sig_bytes); try testing.expectError(error.SignatureVerificationFailed, corrupted_sig.verify(message, kp.public_key)); } test "Context string support" { const seed = [_]u8{0xAA} ** 32; const result = MLDSA44.newKeyFromSeed(&seed); const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); const message = "Test message"; const context1 = "context1"; const context2 = "context2"; // Sign with context1 const sig1 = try kp.signWithContext(message, null, context1); // Verify with correct context should succeed try sig1.verifyWithContext(message, kp.public_key, context1); // Verify with wrong context should fail try testing.expectError(error.SignatureVerificationFailed, sig1.verifyWithContext(message, kp.public_key, context2)); // Verify with empty context should fail try testing.expectError(error.SignatureVerificationFailed, sig1.verify(message, kp.public_key)); // Sign with empty context const sig2 = try kp.sign(message, null); // Verify with empty context should succeed try sig2.verify(message, kp.public_key); // Verify with non-empty context should fail try testing.expectError(error.SignatureVerificationFailed, sig2.verifyWithContext(message, kp.public_key, context1)); // Test maximum context length (255 bytes) const max_context = [_]u8{0xBB} ** 255; const sig3 = try kp.signWithContext(message, null, &max_context); try sig3.verifyWithContext(message, kp.public_key, &max_context); // Test context too long (256 bytes should fail) const too_long_context = [_]u8{0xCC} ** 256; try testing.expectError(error.ContextTooLong, kp.signWithContext(message, null, &too_long_context)); } test "Context string with streaming API" { const seed = [_]u8{0xDD} ** 32; const result = MLDSA44.newKeyFromSeed(&seed); const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); const context = "streaming-context"; const message_part1 = "Hello, "; const message_part2 = "World!"; // Sign using streaming API with context var signer = try kp.signerWithContext(null, context); signer.update(message_part1); signer.update(message_part2); const sig = signer.finalize(); // Verify using streaming API with context var verifier = try sig.verifierWithContext(kp.public_key, context); verifier.update(message_part1); verifier.update(message_part2); try verifier.verify(); // Verify with wrong context should fail var verifier_wrong = try sig.verifierWithContext(kp.public_key, "wrong"); verifier_wrong.update(message_part1); verifier_wrong.update(message_part2); try testing.expectError(error.SignatureVerificationFailed, verifier_wrong.verify()); } test "Signature determinism (same rnd)" { const seed = [_]u8{0x11} ** 32; const result = MLDSA44.newKeyFromSeed(&seed); const sk = result.sk; const message = "Deterministic test"; const rnd = [_]u8{0x22} ** 32; // Sign twice with same randomness using streaming API var st1 = try sk.signer(rnd); st1.update(message); const sig1 = st1.finalize(); var st2 = try sk.signer(rnd); st2.update(message); const sig2 = st2.finalize(); // Signatures should be identical try testing.expectEqualSlices(u8, &sig1.toBytes(), &sig2.toBytes()); } test "Signature toBytes/fromBytes roundtrip" { const seed = [_]u8{0x33} ** 32; const result = MLDSA44.newKeyFromSeed(&seed); const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); const message = "toBytes/fromBytes test"; // Sign the message const sig = try kp.sign(message, null); const sig_bytes = sig.toBytes(); // Unpack and repack const sig_reparsed = try MLDSA44.Signature.fromBytes(sig_bytes); const repacked = sig_reparsed.toBytes(); // Should match original try testing.expectEqualSlices(u8, &sig_bytes, &repacked); } test "Empty message signing" { const seed = [_]u8{0x44} ** 32; const result = MLDSA44.newKeyFromSeed(&seed); const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); const message = ""; // Sign empty message const sig = try kp.sign(message, null); // Verify should work try sig.verify(message, kp.public_key); } test "Long message signing" { const seed = [_]u8{0x55} ** 32; const result = MLDSA44.newKeyFromSeed(&seed); const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk); // Create a long message (1KB) const long_message = [_]u8{0xAB} ** 1024; // Sign long message const sig = try kp.sign(&long_message, null); // Verify should work try sig.verify(&long_message, kp.public_key); } // Helper function to decode hex string into bytes fn hexToBytes(comptime hex: []const u8, out: []u8) !void { if (hex.len != out.len * 2) return error.InvalidLength; var i: usize = 0; while (i < out.len) : (i += 1) { const hi = try std.fmt.charToDigit(hex[i * 2], 16); const lo = try std.fmt.charToDigit(hex[i * 2 + 1], 16); out[i] = (hi << 4) | lo; } } test "ML-DSA-44 KAT test vector 0" { // Test vector from NIST ML-DSA KAT (count = 0) // xi is the seed for key generation (Algorithm 1, line 1) const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68"; const pk_hex_start = "bd4e96f9a038ab5e36214fe69c0b1cb835ef9d7c8417e76aecd152f5cddebec8"; const msg_hex = "6dbbc4375136df3b07f7c70e639e223e"; // Parse xi (32-byte seed for key generation) var xi: [32]u8 = undefined; try hexToBytes(xi_hex, &xi); // Generate keys from xi const result = MLDSA44.newKeyFromSeed(&xi); const pk = result.pk; const sk = result.sk; // Verify public key starts with expected bytes const pk_bytes = pk.toBytes(); var expected_pk_start: [32]u8 = undefined; try hexToBytes(pk_hex_start, &expected_pk_start); // Check first 32 bytes of public key match try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]); // Parse message var msg: [16]u8 = undefined; try hexToBytes(msg_hex, &msg); // Sign the message (deterministic mode with fixed randomness) const kp = try MLDSA44.KeyPair.fromSecretKey(sk); const sig = try kp.sign(&msg, null); // Verify the signature try sig.verify(&msg, kp.public_key); } test "ML-DSA-65 KAT test vector 0" { // Test vector from NIST ML-DSA KAT (count = 0) // xi is the seed for key generation (Algorithm 1, line 1) const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68"; const pk_hex_start = "e50d03fff3b3a70961abbb92a390008dec1283f603f50cdbaaa3d00bd659bc76"; const msg_hex = "6dbbc4375136df3b07f7c70e639e223e"; // Parse xi (32-byte seed for key generation) var xi: [32]u8 = undefined; try hexToBytes(xi_hex, &xi); // Generate keys from xi const result = MLDSA65.newKeyFromSeed(&xi); const pk = result.pk; const sk = result.sk; // Verify public key starts with expected bytes const pk_bytes = pk.toBytes(); var expected_pk_start: [32]u8 = undefined; try hexToBytes(pk_hex_start, &expected_pk_start); // Check first 32 bytes of public key match try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]); // Parse message var msg: [16]u8 = undefined; try hexToBytes(msg_hex, &msg); // Sign the message const kp = try MLDSA65.KeyPair.fromSecretKey(sk); const sig = try kp.sign(&msg, null); // Verify the signature try sig.verify(&msg, kp.public_key); } test "ML-DSA-87 KAT test vector 0" { // Test vector from NIST ML-DSA KAT (count = 0) // xi is the seed for key generation (Algorithm 1, line 1) const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68"; const pk_hex_start = "bc89b367d4288f47c71a74679d0fcffbe041de41b5da2f5fc66d8e28c5899494"; const msg_hex = "6dbbc4375136df3b07f7c70e639e223e"; // Parse xi (32-byte seed for key generation) var xi: [32]u8 = undefined; try hexToBytes(xi_hex, &xi); // Generate keys from xi const result = MLDSA87.newKeyFromSeed(&xi); const pk = result.pk; const sk = result.sk; // Verify public key starts with expected bytes const pk_bytes = pk.toBytes(); var expected_pk_start: [32]u8 = undefined; try hexToBytes(pk_hex_start, &expected_pk_start); // Check first 32 bytes of public key match try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]); // Parse message var msg: [16]u8 = undefined; try hexToBytes(msg_hex, &msg); // Sign the message const kp = try MLDSA87.KeyPair.fromSecretKey(sk); const sig = try kp.sign(&msg, null); // Verify the signature try sig.verify(&msg, kp.public_key); } test "KeyPair API - generate and sign" { // Test the new KeyPair API with random generation const kp = MLDSA44.KeyPair.generate(); const msg = "Test message for KeyPair API"; // Sign with deterministic mode (no noise) const sig = try kp.sign(msg, null); // Verify using Signature.verify API try sig.verify(msg, kp.public_key); } test "KeyPair API - generateDeterministic" { // Test deterministic key generation const seed = [_]u8{42} ** 32; const kp1 = try MLDSA44.KeyPair.generateDeterministic(seed); const kp2 = try MLDSA44.KeyPair.generateDeterministic(seed); // Same seed should produce same keys const pk1_bytes = kp1.public_key.toBytes(); const pk2_bytes = kp2.public_key.toBytes(); try testing.expectEqualSlices(u8, &pk1_bytes, &pk2_bytes); } test "KeyPair API - fromSecretKey" { // Generate a key pair const kp1 = MLDSA44.KeyPair.generate(); // Derive public key from secret key const kp2 = try MLDSA44.KeyPair.fromSecretKey(kp1.secret_key); // Public keys should match const pk1_bytes = kp1.public_key.toBytes(); const pk2_bytes = kp2.public_key.toBytes(); try testing.expectEqualSlices(u8, &pk1_bytes, &pk2_bytes); } test "Signature verification with noise" { // Test signing with randomness (hedged signatures) const kp = MLDSA65.KeyPair.generate(); const msg = "Message to be signed with randomness"; // Create some noise const noise = [_]u8{ 1, 2, 3, 4, 5 } ++ [_]u8{0} ** 27; // Sign with noise const sig = try kp.sign(msg, noise); // Verify should still work try sig.verify(msg, kp.public_key); } test "Signature verification failure" { // Test that invalid signatures are rejected const kp = MLDSA44.KeyPair.generate(); const msg = "Original message"; const sig = try kp.sign(msg, null); // Verify with wrong message should fail const wrong_msg = "Different message"; try testing.expectError(error.SignatureVerificationFailed, sig.verify(wrong_msg, kp.public_key)); } test "Streaming API - sign and verify" { const seed = [_]u8{0x55} ** 32; const kp = try MLDSA44.KeyPair.generateDeterministic(seed); const msg = "Test message for streaming API"; // Sign using streaming API var signer = try kp.signer(null); signer.update(msg); const sig = signer.finalize(); // Verify using streaming API var verifier = try sig.verifier(kp.public_key); verifier.update(msg); try verifier.verify(); } test "Streaming API - chunked message" { const seed = [_]u8{0x66} ** 32; const kp = try MLDSA44.KeyPair.generateDeterministic(seed); // Create a message in chunks const chunk1 = "Hello, "; const chunk2 = "streaming "; const chunk3 = "world!"; const full_msg = chunk1 ++ chunk2 ++ chunk3; // Sign with chunks var signer = try kp.signer(null); signer.update(chunk1); signer.update(chunk2); signer.update(chunk3); const sig_chunked = signer.finalize(); // Sign with full message for comparison var signer2 = try kp.signer(null); signer2.update(full_msg); const sig_full = signer2.finalize(); // Signatures should be identical try testing.expectEqualSlices(u8, &sig_chunked.toBytes(), &sig_full.toBytes()); // Verify with chunks const sig = sig_chunked; var verifier = try sig.verifier(kp.public_key); verifier.update(chunk1); verifier.update(chunk2); verifier.update(chunk3); try verifier.verify(); } test "Streaming API - large message" { const seed = [_]u8{0x77} ** 32; const kp = try MLDSA44.KeyPair.generateDeterministic(seed); // Create a large message (1MB) const chunk_size = 4096; const num_chunks = 256; var chunk: [chunk_size]u8 = undefined; for (0..chunk_size) |i| { chunk[i] = @intCast(i % 256); } // Sign streaming var signer = try kp.signer(null); for (0..num_chunks) |_| { signer.update(&chunk); } const sig = signer.finalize(); // Verify streaming var verifier = try sig.verifier(kp.public_key); for (0..num_chunks) |_| { verifier.update(&chunk); } try verifier.verify(); } test "Streaming API - all parameter sets" { const test_msg = "Streaming test for all ML-DSA parameter sets"; // ML-DSA-44 { const seed = [_]u8{0x44} ** 32; const kp = try MLDSA44.KeyPair.generateDeterministic(seed); var signer = try kp.signer(null); signer.update(test_msg); const sig = signer.finalize(); var verifier = try sig.verifier(kp.public_key); verifier.update(test_msg); try verifier.verify(); } // ML-DSA-65 { const seed = [_]u8{0x65} ** 32; const kp = try MLDSA65.KeyPair.generateDeterministic(seed); var signer = try kp.signer(null); signer.update(test_msg); const sig = signer.finalize(); var verifier = try sig.verifier(kp.public_key); verifier.update(test_msg); try verifier.verify(); } // ML-DSA-87 { const seed = [_]u8{0x87} ** 32; const kp = try MLDSA87.KeyPair.generateDeterministic(seed); var signer = try kp.signer(null); signer.update(test_msg); const sig = signer.finalize(); var verifier = try sig.verifier(kp.public_key); verifier.update(test_msg); try verifier.verify(); } } /// Extended Euclidian Algorithm /// Only meant to be used on comptime values; correctness matters, performance doesn't. fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } { var a = a_; var b = b_; var x0: T = 1; var x1: T = 0; var y0: T = 0; var y1: T = 1; while (b != 0) { const q = @divTrunc(a, b); const temp_a = a; a = b; b = temp_a - q * b; const temp_x = x0; x0 = x1; x1 = temp_x - q * x1; const temp_y = y0; y0 = y1; y1 = temp_y - q * y1; } return .{ .gcd = a, .x = x0, .y = y0 }; } /// Modular inversion: computes a^(-1) mod p /// Requires gcd(a,p) = 1. The result is normalized to the range [0, p). fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T { // Use a signed type for EEA computation const type_info = @typeInfo(T); const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned) std.meta.Int(.signed, type_info.int.bits) else T; const a_signed = @as(SignedT, @intCast(a)); const p_signed = @as(SignedT, @intCast(p)); const r = extendedEuclidean(SignedT, a_signed, p_signed); assert(r.gcd == 1); // Normalize result to [0, p) var result = r.x; while (result < 0) { result += p_signed; } return @intCast(result); } /// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm. fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T { const type_info = @typeInfo(T); const bits = type_info.int.bits; const WideT = std.meta.Int(.unsigned, bits * 2); var ret: T = 1; var base: T = a; var exp = s; while (exp > 0) { if (exp & 1 == 1) { ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p); } base = @intCast((@as(WideT, base) * @as(WideT, base)) % p); exp >>= 1; } return ret; } /// Creates an all-ones or all-zeros mask from a single bit value. /// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0. fn bitMask(comptime T: type, bit: T) T { const type_info = @typeInfo(T); if (type_info != .int or type_info.int.signedness != .unsigned) { @compileError("bitMask requires an unsigned integer type"); } return -%bit; } /// Creates a mask from the sign bit of a signed integer. /// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0. fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) { const type_info = @typeInfo(T); if (type_info != .int) { @compileError("signMask requires an integer type"); } const bits = type_info.int.bits; const SignedT = std.meta.Int(.signed, bits); // Convert to signed if needed, arithmetic right shift to propagate sign bit const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x); const shifted = x_signed >> (bits - 1); return @bitCast(shifted); } /// Montgomery reduction: for input x, returns y where y ≡ x*R^(-1) (mod q). /// This is a generic implementation parameterized by the modulus q, its inverse qInv, /// the Montgomery constant R, and the result bound. /// /// For ML-DSA: R = 2^32, returns y < 2q /// For ML-KEM: R = 2^16, returns y in range (-q, q) fn montgomeryReduce( comptime InT: type, comptime OutT: type, comptime q: comptime_int, comptime qInv: comptime_int, comptime r_bits: comptime_int, x: InT, ) OutT { const mask = (@as(InT, 1) << r_bits) - 1; const m_full = (x *% qInv) & mask; const m: OutT = @truncate(m_full); const yR = x -% @as(InT, m) * @as(InT, q); const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits; return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted))); } /// Uniform sampling using SHAKE-128 with rejection sampling. /// Samples polynomial coefficients uniformly from [0, q) using rejection sampling. /// /// Parameters: /// - PolyType: The polynomial type to return /// - q: Modulus /// - bits_per_coef: Number of bits per coefficient (12 or 23) /// - n: Number of coefficients /// - seed: Random seed /// - domain_sep: Domain separation bytes (appended to seed) fn sampleUniformRejection( comptime PolyType: type, comptime q: comptime_int, comptime bits_per_coef: comptime_int, comptime n: comptime_int, seed: []const u8, domain_sep: []const u8, ) PolyType { var h = sha3.Shake128.init(.{}); h.update(seed); h.update(domain_sep); const buf_len = sha3.Shake128.block_length; // 168 bytes var buf: [buf_len]u8 = undefined; var ret: PolyType = undefined; var coef_idx: usize = 0; if (bits_per_coef == 12) { // ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each) outer: while (true) { h.squeeze(&buf); var j: usize = 0; while (j < buf_len) : (j += 3) { const b0 = @as(u16, buf[j]); const b1 = @as(u16, buf[j + 1]); const b2 = @as(u16, buf[j + 2]); const ts: [2]u16 = .{ b0 | ((b1 & 0xf) << 8), (b1 >> 4) | (b2 << 4), }; inline for (ts) |t| { if (t < q) { ret.cs[coef_idx] = @intCast(t); coef_idx += 1; if (coef_idx == n) break :outer; } } } } } else if (bits_per_coef == 23) { // ML-DSA path: 1 coefficient per 3 bytes (23 bits) while (coef_idx < n) { h.squeeze(&buf); var j: usize = 0; while (j < buf_len and coef_idx < n) : (j += 3) { const t = (@as(u32, buf[j]) | (@as(u32, buf[j + 1]) << 8) | (@as(u32, buf[j + 2]) << 16)) & 0x7fffff; if (t < q) { ret.cs[coef_idx] = @intCast(t); coef_idx += 1; } } } } else { @compileError("bits_per_coef must be 12 or 23"); } return ret; } test "bitMask and signMask helpers" { try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0)); try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1)); try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0)); try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1)); try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0)); try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1)); try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1)); try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100)); try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0)); try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1)); try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100)); try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear }