Align ML-KEM code with ML-DSA (#25964)

This will facilitate maintainance and code sharing between primitives.
This commit is contained in:
Frank Denis 2025-11-18 16:39:58 +01:00 committed by GitHub
parent 73f863a6fb
commit 4ea4728084
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -105,19 +105,20 @@ const crypto = std.crypto;
const errors = std.crypto.errors;
const math = std.math;
const mem = std.mem;
const RndGen = std.Random.DefaultPrng;
const sha3 = crypto.hash.sha3;
// Q is the parameter q 3329 = 2¹¹ + 2¹ + 2 + 1.
const RndGen = std.Random.DefaultPrng;
// Q is the modulus q 3329 = 2¹¹ + 2¹ + 2 + 1
const Q: i16 = 3329;
// Montgomery R
// Montgomery R = 2^16 mod Q (for Montgomery multiplication)
const R: i32 = 1 << 16;
// Parameter n, degree of polynomials.
// N is the degree of polynomials (polynomial ring dimension)
const N: usize = 256;
// Size of "small" vectors used in encryption blinds.
// eta2 is the size of "small" vectors used in encryption blinds
const eta2: u8 = 2;
const Params = struct {
@ -215,7 +216,7 @@ fn Kyber(comptime p: Params) type {
pub const ciphertext_length = Poly.compressedSize(p.du) * p.k + Poly.compressedSize(p.dv);
const Self = @This();
const V = Vec(p.k);
const V = PolyVec(p.k);
const M = Mat(p.k);
/// Length (in bytes) of a shared secret.
@ -241,7 +242,7 @@ fn Kyber(comptime p: Params) type {
hpk: [h_length]u8, // H(pk)
/// Size of a serialized representation of the key, in bytes.
pub const bytes_length = InnerPk.bytes_length;
pub const encoded_length = InnerPk.encoded_length;
/// Generates a shared secret, and encapsulates it for the public key.
/// If `seed` is `null`, a random seed is used. This is recommended.
@ -289,14 +290,14 @@ fn Kyber(comptime p: Params) type {
}
/// Serializes the key into a byte array.
pub fn toBytes(pk: PublicKey) [bytes_length]u8 {
pub fn toBytes(pk: PublicKey) [encoded_length]u8 {
return pk.pk.toBytes();
}
/// Deserializes the key from a byte array.
pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!PublicKey {
pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!PublicKey {
var ret: PublicKey = undefined;
ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.bytes_length]);
ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.encoded_length]);
sha3.Sha3_256.hash(buf, &ret.hpk, .{});
return ret;
}
@ -310,8 +311,8 @@ fn Kyber(comptime p: Params) type {
z: [shared_length]u8,
/// Size of a serialized representation of the key, in bytes.
pub const bytes_length: usize =
InnerSk.bytes_length + InnerPk.bytes_length + h_length + shared_length;
pub const encoded_length: usize =
InnerSk.encoded_length + InnerPk.encoded_length + h_length + shared_length;
/// Decapsulates the shared secret within ct using the private key.
pub fn decaps(sk: SecretKey, ct: *const [ciphertext_length]u8) ![shared_length]u8 {
@ -346,18 +347,18 @@ fn Kyber(comptime p: Params) type {
}
/// Serializes the key into a byte array.
pub fn toBytes(sk: SecretKey) [bytes_length]u8 {
pub fn toBytes(sk: SecretKey) [encoded_length]u8 {
return sk.sk.toBytes() ++ sk.pk.toBytes() ++ sk.hpk ++ sk.z;
}
/// Deserializes the key from a byte array.
pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!SecretKey {
pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!SecretKey {
var ret: SecretKey = undefined;
comptime var s: usize = 0;
ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.bytes_length]);
s += InnerSk.bytes_length;
ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]);
s += InnerPk.bytes_length;
ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.encoded_length]);
s += InnerSk.encoded_length;
ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.encoded_length]);
s += InnerPk.encoded_length;
ret.hpk = buf[s..][0..h_length].*;
s += h_length;
ret.z = buf[s..][0..shared_length].*;
@ -418,7 +419,7 @@ fn Kyber(comptime p: Params) type {
// Cached values
aT: M,
const bytes_length = V.bytes_length + 32;
const encoded_length = V.encoded_length + 32;
fn encrypt(
pk: InnerPk,
@ -436,7 +437,7 @@ fn Kyber(comptime p: Params) type {
// Note that coefficients of r are bounded by q and those of Aᵀ
// are bounded by 4.5q and so their product is bounded by 2¹q
// as required for multiplication.
u.ps[i] = pk.aT.vs[i].dotHat(rh);
u.ps[i] = pk.aT.rows[i].dotHat(rh);
}
// Aᵀ and r were not in Montgomery form, so the Montgomery
@ -451,14 +452,14 @@ fn Kyber(comptime p: Params) type {
return u.compress(p.du) ++ v.compress(p.dv);
}
fn toBytes(pk: InnerPk) [bytes_length]u8 {
fn toBytes(pk: InnerPk) [encoded_length]u8 {
return pk.th.toBytes() ++ pk.rho;
}
fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!InnerPk {
fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!InnerPk {
var ret: InnerPk = undefined;
const th_bytes = buf[0..V.bytes_length];
const th_bytes = buf[0..V.encoded_length];
ret.th = V.fromBytes(th_bytes).normalize();
if (p.ml_kem) {
@ -468,7 +469,7 @@ fn Kyber(comptime p: Params) type {
}
}
ret.rho = buf[V.bytes_length..bytes_length].*;
ret.rho = buf[V.encoded_length..encoded_length].*;
ret.aT = M.uniform(ret.rho, true);
return ret;
}
@ -477,7 +478,7 @@ fn Kyber(comptime p: Params) type {
// Private key of the inner PKE
const InnerSk = struct {
sh: V, // NTT(s), normalized
const bytes_length = V.bytes_length;
const encoded_length = V.encoded_length;
fn decrypt(sk: InnerSk, ct: *const [ciphertext_length]u8) [inner_plaintext_length]u8 {
const u = V.decompress(p.du, ct[0..comptime V.compressedSize(p.du)]);
@ -491,11 +492,11 @@ fn Kyber(comptime p: Params) type {
.normalize().compress(1);
}
fn toBytes(sk: InnerSk) [bytes_length]u8 {
fn toBytes(sk: InnerSk) [encoded_length]u8 {
return sk.sh.toBytes();
}
fn fromBytes(buf: *const [bytes_length]u8) InnerSk {
fn fromBytes(buf: *const [encoded_length]u8) InnerSk {
var ret: InnerSk = undefined;
ret.sh = V.fromBytes(buf).normalize();
return ret;
@ -516,7 +517,7 @@ fn Kyber(comptime p: Params) type {
// Sample secret vector s.
sk.sh = V.noise(p.eta1, 0, sigma).ntt().normalize();
const eh = Vec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e.
const eh = PolyVec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e.
var th: V = undefined;
// Next, we compute t = A s + e.
@ -528,7 +529,7 @@ fn Kyber(comptime p: Params) type {
// multiplications in the inner product added a factor R¹ which
// we'll cancel out with toMont(). This will also ensure the
// coefficients of th are bounded in absolute value by q.
th.ps[i] = pk.aT.vs[i].dotHat(sk.sh).toMont();
th.ps[i] = pk.aT.rows[i].dotHat(sk.sh).toMont();
}
pk.th = th.add(eh).normalize(); // bounded by 8q
@ -565,7 +566,6 @@ const zetas = computeZetas();
// not enough, the other coefficient is reduced as well.
//
// This is actually optimal, as proven in https://eprint.iacr.org/2020/1377.pdf
// TODO generate comptime?
const inv_ntt_reductions = [_]i16{
-1, // after layer 1
-1, // after layer 2
@ -634,31 +634,8 @@ test "invNTTReductions bounds" {
}
}
// Extended euclidean algorithm.
//
// For a, b finds x, y such that x a + y b = gcd(a, b). Used to compute
// modular inverse.
fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) {
if (a == 0) {
return .{ .gcd = b, .x = 0, .y = 1 };
}
const r = eea(@rem(b, a), a);
return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x };
}
fn EeaResult(comptime T: type) type {
return struct { gcd: T, x: T, y: T };
}
// Returns least common multiple of a and b.
fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) {
const r = eea(a, b);
return a * b / r.gcd;
}
// Invert modulo p.
fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) {
const r = eea(a, p);
const r = extendedEuclidean(@TypeOf(a), a, p);
assert(r.gcd == 1);
return r.x;
}
@ -788,31 +765,12 @@ test "Test csubq" {
}
}
// Compute a^s mod p.
fn mpow(a: anytype, s: @TypeOf(a), p: @TypeOf(a)) @TypeOf(a) {
var ret: @TypeOf(a) = 1;
var s2 = s;
var a2 = a;
while (true) {
if (s2 & 1 == 1) {
ret = @mod(ret * a2, p);
}
s2 >>= 1;
if (s2 == 0) {
break;
}
a2 = @mod(a2 * a2, p);
}
return ret;
}
// Computes zetas table used by ntt and invNTT.
fn computeZetas() [128]i16 {
@setEvalBranchQuota(10000);
var ret: [128]i16 = undefined;
for (&ret, 0..) |*r, i| {
const t = @as(i16, @intCast(mpow(@as(i32, zeta), @bitReverse(@as(u7, @intCast(i))), Q)));
const t = @as(i16, @intCast(modularPow(i32, zeta, @bitReverse(@as(u7, @intCast(i))), Q)));
r.* = csubq(feBarrettReduce(feToMont(t)));
}
return ret;
@ -828,9 +786,10 @@ fn computeZetas() [128]i16 {
const Poly = struct {
cs: [N]i16,
const bytes_length = N / 2 * 3;
const encoded_length = N / 2 * 3;
const zero: Poly = .{ .cs = .{0} ** N };
// Add two polynomials (coefficients not normalized)
fn add(a: Poly, b: Poly) Poly {
var ret: Poly = undefined;
for (0..N) |i| {
@ -839,6 +798,7 @@ const Poly = struct {
return ret;
}
// Subtract two polynomials (coefficients not normalized)
fn sub(a: Poly, b: Poly) Poly {
var ret: Poly = undefined;
for (0..N) |i| {
@ -847,25 +807,6 @@ const Poly = struct {
return ret;
}
// For testing, generates a random polynomial with for each
// coefficient |x| q.
fn randAbsLeqQ(rnd: anytype) Poly {
var ret: Poly = undefined;
for (0..N) |i| {
ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q);
}
return ret;
}
// For testing, generates a random normalized polynomial.
fn randNormalized(rnd: anytype) Poly {
var ret: Poly = undefined;
for (0..N) |i| {
ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q);
}
return ret;
}
// Executes a forward "NTT" on p.
//
// Assumes the coefficients are in absolute value q. The resulting
@ -1054,7 +995,7 @@ const Poly = struct {
var in_off: usize = 0;
var out_off: usize = 0;
const batch_size: usize = comptime lcm(@as(i16, d), 8);
const batch_size: usize = comptime math.lcm(d, 8);
const in_batch_size: usize = comptime batch_size / d;
const out_batch_size: usize = comptime batch_size / 8;
@ -1118,7 +1059,7 @@ const Poly = struct {
var in_off: usize = 0;
var out_off: usize = 0;
const batch_size: usize = comptime lcm(@as(i16, d), 8);
const batch_size: usize = comptime math.lcm(d, 8);
const in_batch_size: usize = comptime batch_size / 8;
const out_batch_size: usize = comptime batch_size / d;
@ -1275,53 +1216,23 @@ const Poly = struct {
return ret;
}
// Sample p uniformly from the given seed and x and y coordinates.
fn uniform(seed: [32]u8, x: u8, y: u8) Poly {
var h = sha3.Shake128.init(.{});
const suffix: [2]u8 = .{ x, y };
h.update(&seed);
h.update(&suffix);
const buf_len = sha3.Shake128.block_length; // rate SHAKE-128
var buf: [buf_len]u8 = undefined;
var ret: Poly = undefined;
var i: usize = 0; // index into ret.cs
outer: while (true) {
h.squeeze(&buf);
var j: usize = 0; // index into buf
while (j < buf_len) : (j += 3) {
const b0 = @as(u16, buf[j]);
const b1 = @as(u16, buf[j + 1]);
const b2 = @as(u16, buf[j + 2]);
const ts: [2]u16 = .{
b0 | ((b1 & 0xf) << 8),
(b1 >> 4) | (b2 << 4),
};
inline for (ts) |t| {
if (t < Q) {
ret.cs[i] = @as(i16, @intCast(t));
i += 1;
if (i == N) {
break :outer;
}
}
}
}
}
return ret;
const domain_sep: [2]u8 = .{ x, y };
return sampleUniformRejection(
Poly,
Q,
12,
N,
&seed,
&domain_sep,
);
}
// Packs p.
//
// Assumes p is normalized (and not just Barrett reduced).
fn toBytes(p: Poly) [bytes_length]u8 {
var ret: [bytes_length]u8 = undefined;
fn toBytes(p: Poly) [encoded_length]u8 {
var ret: [encoded_length]u8 = undefined;
for (0..comptime N / 2) |i| {
const t0 = @as(u16, @intCast(p.cs[2 * i]));
const t1 = @as(u16, @intCast(p.cs[2 * i + 1]));
@ -1335,7 +1246,7 @@ const Poly = struct {
// Unpacks a Poly from buf.
//
// p will not be normalized; instead 0 p[i] < 4096.
fn fromBytes(buf: *const [bytes_length]u8) Poly {
fn fromBytes(buf: *const [encoded_length]u8) Poly {
var ret: Poly = undefined;
for (0..comptime N / 2) |i| {
const b0 = @as(i16, buf[3 * i]);
@ -1348,71 +1259,65 @@ const Poly = struct {
}
};
// A vector of K polynomials.
fn Vec(comptime K: u8) type {
// A vector of k polynomials.
fn PolyVec(comptime k: u8) type {
return struct {
ps: [K]Poly,
ps: [k]Poly,
const Self = @This();
const bytes_length = K * Poly.bytes_length;
const encoded_length = k * Poly.encoded_length;
fn compressedSize(comptime d: u8) usize {
return Poly.compressedSize(d) * K;
return Poly.compressedSize(d) * k;
}
fn ntt(a: Self) Self {
/// Apply unary operation to each polynomial
fn map(v: Self, comptime op: fn (Poly) Poly) Self {
var ret: Self = undefined;
for (0..K) |i| {
ret.ps[i] = a.ps[i].ntt();
inline for (0..k) |i| {
ret.ps[i] = op(v.ps[i]);
}
return ret;
}
fn invNTT(a: Self) Self {
/// Apply binary operation pairwise
fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self {
var ret: Self = undefined;
for (0..K) |i| {
ret.ps[i] = a.ps[i].invNTT();
inline for (0..k) |i| {
ret.ps[i] = op(a.ps[i], b.ps[i]);
}
return ret;
}
fn normalize(a: Self) Self {
var ret: Self = undefined;
for (0..K) |i| {
ret.ps[i] = a.ps[i].normalize();
}
return ret;
fn ntt(v: Self) Self {
return map(v, Poly.ntt);
}
fn barrettReduce(a: Self) Self {
var ret: Self = undefined;
for (0..K) |i| {
ret.ps[i] = a.ps[i].barrettReduce();
fn invNTT(v: Self) Self {
return map(v, Poly.invNTT);
}
return ret;
fn normalize(v: Self) Self {
return map(v, Poly.normalize);
}
fn barrettReduce(v: Self) Self {
return map(v, Poly.barrettReduce);
}
fn add(a: Self, b: Self) Self {
var ret: Self = undefined;
for (0..K) |i| {
ret.ps[i] = a.ps[i].add(b.ps[i]);
}
return ret;
return mapBinary(a, b, Poly.add);
}
fn sub(a: Self, b: Self) Self {
var ret: Self = undefined;
for (0..K) |i| {
ret.ps[i] = a.ps[i].sub(b.ps[i]);
}
return ret;
return mapBinary(a, b, Poly.sub);
}
// Samples v[i] from centered binomial distribution with the given η,
// seed and nonce+i.
fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Self {
var ret: Self = undefined;
for (0..K) |i| {
for (0..k) |i| {
ret.ps[i] = Poly.noise(eta, nonce + @as(u8, @intCast(i)), seed);
}
return ret;
@ -1428,7 +1333,7 @@ fn Vec(comptime K: u8) type {
// of the Montgomery factor.
fn dotHat(a: Self, b: Self) Poly {
var ret: Poly = Poly.zero;
for (0..K) |i| {
for (0..k) |i| {
ret = ret.add(a.ps[i].mulHat(b.ps[i]));
}
return ret;
@ -1437,7 +1342,7 @@ fn Vec(comptime K: u8) type {
fn compress(v: Self, comptime d: u8) [compressedSize(d)]u8 {
const cs = comptime Poly.compressedSize(d);
var ret: [compressedSize(d)]u8 = undefined;
inline for (0..K) |i| {
inline for (0..k) |i| {
ret[i * cs .. (i + 1) * cs].* = v.ps[i].compress(d);
}
return ret;
@ -1446,27 +1351,27 @@ fn Vec(comptime K: u8) type {
fn decompress(comptime d: u8, buf: *const [compressedSize(d)]u8) Self {
const cs = comptime Poly.compressedSize(d);
var ret: Self = undefined;
inline for (0..K) |i| {
inline for (0..k) |i| {
ret.ps[i] = Poly.decompress(d, buf[i * cs .. (i + 1) * cs]);
}
return ret;
}
/// Serializes the key into a byte array.
fn toBytes(v: Self) [bytes_length]u8 {
var ret: [bytes_length]u8 = undefined;
inline for (0..K) |i| {
ret[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length].* = v.ps[i].toBytes();
fn toBytes(v: Self) [encoded_length]u8 {
var ret: [encoded_length]u8 = undefined;
inline for (0..k) |i| {
ret[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length].* = v.ps[i].toBytes();
}
return ret;
}
/// Deserializes the key from a byte array.
fn fromBytes(buf: *const [bytes_length]u8) Self {
fn fromBytes(buf: *const [encoded_length]u8) Self {
var ret: Self = undefined;
inline for (0..K) |i| {
inline for (0..k) |i| {
ret.ps[i] = Poly.fromBytes(
buf[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length],
buf[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length],
);
}
return ret;
@ -1474,19 +1379,19 @@ fn Vec(comptime K: u8) type {
};
}
// A matrix of K vectors
fn Mat(comptime K: u8) type {
// A matrix of k vectors
fn Mat(comptime k: u8) type {
return struct {
const Self = @This();
vs: [K]Vec(K),
rows: [k]PolyVec(k),
fn uniform(seed: [32]u8, comptime transposed: bool) Self {
var ret: Self = undefined;
var i: u8 = 0;
while (i < K) : (i += 1) {
while (i < k) : (i += 1) {
var j: u8 = 0;
while (j < K) : (j += 1) {
ret.vs[i].ps[j] = Poly.uniform(
while (j < k) : (j += 1) {
ret.rows[i].ps[j] = Poly.uniform(
seed,
if (transposed) i else j,
if (transposed) j else i,
@ -1499,9 +1404,9 @@ fn Mat(comptime K: u8) type {
// Returns transpose of A
fn transpose(m: Self) Self {
var ret: Self = undefined;
for (0..K) |i| {
for (0..K) |j| {
ret.vs[i].ps[j] = m.vs[j].ps[i];
for (0..k) |i| {
for (0..k) |j| {
ret.rows[i].ps[j] = m.rows[j].ps[i];
}
}
return ret;
@ -1522,12 +1427,30 @@ fn cmov(comptime len: usize, dst: *[len]u8, src: [len]u8, b: u1) void {
}
}
// Test helper: generates a random polynomial with each coefficient |x| q
fn randPolyAbsLeqQ(rnd: anytype) Poly {
var ret: Poly = undefined;
for (0..N) |i| {
ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q);
}
return ret;
}
// Test helper: generates a random normalized polynomial
fn randPolyNormalized(rnd: anytype) Poly {
var ret: Poly = undefined;
for (0..N) |i| {
ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q);
}
return ret;
}
test "MulHat" {
var rnd = RndGen.init(0);
for (0..100) |_| {
const a = Poly.randAbsLeqQ(&rnd);
const b = Poly.randAbsLeqQ(&rnd);
const a = randPolyAbsLeqQ(&rnd);
const b = randPolyAbsLeqQ(&rnd);
const p2 = a.ntt().mulHat(b.ntt()).barrettReduce().invNTT().normalize();
var p: Poly = undefined;
@ -1557,7 +1480,7 @@ test "NTT" {
var rnd = RndGen.init(0);
for (0..1000) |_| {
var p = Poly.randAbsLeqQ(&rnd);
var p = randPolyAbsLeqQ(&rnd);
const q = p.toMont().normalize();
p = p.ntt();
@ -1580,7 +1503,7 @@ test "Compression" {
var rnd = RndGen.init(0);
inline for (.{ 1, 4, 5, 10, 11 }) |d| {
for (0..1000) |_| {
const p = Poly.randNormalized(&rnd);
const p = randPolyNormalized(&rnd);
const pp = p.compress(d);
const pq = Poly.decompress(d, &pp).compress(d);
try testing.expectEqual(pp, pq);
@ -1671,7 +1594,7 @@ test "Polynomial packing" {
var rnd = RndGen.init(0);
for (0..1000) |_| {
const p = Poly.randNormalized(&rnd);
const p = randPolyNormalized(&rnd);
try testing.expectEqual(Poly.fromBytes(&p.toBytes()), p);
}
}
@ -1839,3 +1762,222 @@ const NistDRBG = struct {
return ret;
}
};
/// Extended Euclidian Algorithm
/// Only meant to be used on comptime values; correctness matters, performance doesn't.
fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } {
var a = a_;
var b = b_;
var x0: T = 1;
var x1: T = 0;
var y0: T = 0;
var y1: T = 1;
while (b != 0) {
const q = @divTrunc(a, b);
const temp_a = a;
a = b;
b = temp_a - q * b;
const temp_x = x0;
x0 = x1;
x1 = temp_x - q * x1;
const temp_y = y0;
y0 = y1;
y1 = temp_y - q * y1;
}
return .{ .gcd = a, .x = x0, .y = y0 };
}
/// Modular inversion: computes a^(-1) mod p
/// Requires gcd(a,p) = 1. The result is normalized to the range [0, p).
fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T {
// Use a signed type for EEA computation
const type_info = @typeInfo(T);
const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned)
std.meta.Int(.signed, type_info.int.bits)
else
T;
const a_signed = @as(SignedT, @intCast(a));
const p_signed = @as(SignedT, @intCast(p));
const r = extendedEuclidean(SignedT, a_signed, p_signed);
assert(r.gcd == 1);
// Normalize result to [0, p)
var result = r.x;
while (result < 0) {
result += p_signed;
}
return @intCast(result);
}
/// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm.
fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T {
const type_info = @typeInfo(T);
const bits = type_info.int.bits;
const WideT = std.meta.Int(.unsigned, bits * 2);
var ret: T = 1;
var base: T = a;
var exp = s;
while (exp > 0) {
if (exp & 1 == 1) {
ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p);
}
base = @intCast((@as(WideT, base) * @as(WideT, base)) % p);
exp >>= 1;
}
return ret;
}
/// Creates an all-ones or all-zeros mask from a single bit value.
/// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0.
fn bitMask(comptime T: type, bit: T) T {
const type_info = @typeInfo(T);
if (type_info != .int or type_info.int.signedness != .unsigned) {
@compileError("bitMask requires an unsigned integer type");
}
return -%bit;
}
/// Creates a mask from the sign bit of a signed integer.
/// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0.
fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) {
const type_info = @typeInfo(T);
if (type_info != .int) {
@compileError("signMask requires an integer type");
}
const bits = type_info.int.bits;
const SignedT = std.meta.Int(.signed, bits);
// Convert to signed if needed, arithmetic right shift to propagate sign bit
const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x);
const shifted = x_signed >> (bits - 1);
return @bitCast(shifted);
}
test "bitMask and signMask helpers" {
try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0));
try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1));
try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0));
try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1));
try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0));
try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1));
try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1));
try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100));
try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0));
try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1));
try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100));
try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set
try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear
}
/// Montgomery reduction: for input x, returns y where y x*R^(-1) (mod q).
/// This is a generic implementation parameterized by the modulus q, its inverse qInv,
/// the Montgomery constant R, and the result bound.
///
/// For ML-DSA: R = 2^32, returns y < 2q
/// For ML-KEM: R = 2^16, returns y in range (-q, q)
fn montgomeryReduce(
comptime InT: type,
comptime OutT: type,
comptime q: comptime_int,
comptime qInv: comptime_int,
comptime r_bits: comptime_int,
x: InT,
) OutT {
const mask = (@as(InT, 1) << r_bits) - 1;
const m_full = (x *% qInv) & mask;
const m: OutT = @truncate(m_full);
const yR = x -% @as(InT, m) * @as(InT, q);
const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits;
return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted)));
}
/// Uniform sampling using SHAKE-128 with rejection sampling.
/// Samples polynomial coefficients uniformly from [0, q) using rejection sampling.
///
/// Parameters:
/// - PolyType: The polynomial type to return
/// - q: Modulus
/// - bits_per_coef: Number of bits per coefficient (12 or 23)
/// - n: Number of coefficients
/// - seed: Random seed
/// - domain_sep: Domain separation bytes (appended to seed)
fn sampleUniformRejection(
comptime PolyType: type,
comptime q: comptime_int,
comptime bits_per_coef: comptime_int,
comptime n: comptime_int,
seed: []const u8,
domain_sep: []const u8,
) PolyType {
var h = sha3.Shake128.init(.{});
h.update(seed);
h.update(domain_sep);
const buf_len = sha3.Shake128.block_length; // 168 bytes
var buf: [buf_len]u8 = undefined;
var ret: PolyType = undefined;
var coef_idx: usize = 0;
if (bits_per_coef == 12) {
// ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each)
outer: while (true) {
h.squeeze(&buf);
var j: usize = 0;
while (j < buf_len) : (j += 3) {
const b0 = @as(u16, buf[j]);
const b1 = @as(u16, buf[j + 1]);
const b2 = @as(u16, buf[j + 2]);
const ts: [2]u16 = .{
b0 | ((b1 & 0xf) << 8),
(b1 >> 4) | (b2 << 4),
};
inline for (ts) |t| {
if (t < q) {
ret.cs[coef_idx] = @intCast(t);
coef_idx += 1;
if (coef_idx == n) break :outer;
}
}
}
}
} else if (bits_per_coef == 23) {
// ML-DSA path: 1 coefficient per 3 bytes (23 bits)
while (coef_idx < n) {
h.squeeze(&buf);
var j: usize = 0;
while (j < buf_len and coef_idx < n) : (j += 3) {
const t = (@as(u32, buf[j]) |
(@as(u32, buf[j + 1]) << 8) |
(@as(u32, buf[j + 2]) << 16)) & 0x7fffff;
if (t < q) {
ret.cs[coef_idx] = @intCast(t);
coef_idx += 1;
}
}
}
} else {
@compileError("bits_per_coef must be 12 or 23");
}
return ret;
}