mirror of
https://codeberg.org/ziglang/zig.git
synced 2025-12-06 05:44:20 +00:00
Align ML-KEM code with ML-DSA (#25964)
This will facilitate maintainance and code sharing between primitives.
This commit is contained in:
parent
73f863a6fb
commit
4ea4728084
1 changed files with 338 additions and 196 deletions
|
|
@ -105,19 +105,20 @@ const crypto = std.crypto;
|
|||
const errors = std.crypto.errors;
|
||||
const math = std.math;
|
||||
const mem = std.mem;
|
||||
const RndGen = std.Random.DefaultPrng;
|
||||
const sha3 = crypto.hash.sha3;
|
||||
|
||||
// Q is the parameter q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1.
|
||||
const RndGen = std.Random.DefaultPrng;
|
||||
|
||||
// Q is the modulus q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1
|
||||
const Q: i16 = 3329;
|
||||
|
||||
// Montgomery R
|
||||
// Montgomery R = 2^16 mod Q (for Montgomery multiplication)
|
||||
const R: i32 = 1 << 16;
|
||||
|
||||
// Parameter n, degree of polynomials.
|
||||
// N is the degree of polynomials (polynomial ring dimension)
|
||||
const N: usize = 256;
|
||||
|
||||
// Size of "small" vectors used in encryption blinds.
|
||||
// eta2 is the size of "small" vectors used in encryption blinds
|
||||
const eta2: u8 = 2;
|
||||
|
||||
const Params = struct {
|
||||
|
|
@ -215,7 +216,7 @@ fn Kyber(comptime p: Params) type {
|
|||
pub const ciphertext_length = Poly.compressedSize(p.du) * p.k + Poly.compressedSize(p.dv);
|
||||
|
||||
const Self = @This();
|
||||
const V = Vec(p.k);
|
||||
const V = PolyVec(p.k);
|
||||
const M = Mat(p.k);
|
||||
|
||||
/// Length (in bytes) of a shared secret.
|
||||
|
|
@ -241,7 +242,7 @@ fn Kyber(comptime p: Params) type {
|
|||
hpk: [h_length]u8, // H(pk)
|
||||
|
||||
/// Size of a serialized representation of the key, in bytes.
|
||||
pub const bytes_length = InnerPk.bytes_length;
|
||||
pub const encoded_length = InnerPk.encoded_length;
|
||||
|
||||
/// Generates a shared secret, and encapsulates it for the public key.
|
||||
/// If `seed` is `null`, a random seed is used. This is recommended.
|
||||
|
|
@ -289,14 +290,14 @@ fn Kyber(comptime p: Params) type {
|
|||
}
|
||||
|
||||
/// Serializes the key into a byte array.
|
||||
pub fn toBytes(pk: PublicKey) [bytes_length]u8 {
|
||||
pub fn toBytes(pk: PublicKey) [encoded_length]u8 {
|
||||
return pk.pk.toBytes();
|
||||
}
|
||||
|
||||
/// Deserializes the key from a byte array.
|
||||
pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!PublicKey {
|
||||
pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!PublicKey {
|
||||
var ret: PublicKey = undefined;
|
||||
ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.bytes_length]);
|
||||
ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.encoded_length]);
|
||||
sha3.Sha3_256.hash(buf, &ret.hpk, .{});
|
||||
return ret;
|
||||
}
|
||||
|
|
@ -310,8 +311,8 @@ fn Kyber(comptime p: Params) type {
|
|||
z: [shared_length]u8,
|
||||
|
||||
/// Size of a serialized representation of the key, in bytes.
|
||||
pub const bytes_length: usize =
|
||||
InnerSk.bytes_length + InnerPk.bytes_length + h_length + shared_length;
|
||||
pub const encoded_length: usize =
|
||||
InnerSk.encoded_length + InnerPk.encoded_length + h_length + shared_length;
|
||||
|
||||
/// Decapsulates the shared secret within ct using the private key.
|
||||
pub fn decaps(sk: SecretKey, ct: *const [ciphertext_length]u8) ![shared_length]u8 {
|
||||
|
|
@ -346,18 +347,18 @@ fn Kyber(comptime p: Params) type {
|
|||
}
|
||||
|
||||
/// Serializes the key into a byte array.
|
||||
pub fn toBytes(sk: SecretKey) [bytes_length]u8 {
|
||||
pub fn toBytes(sk: SecretKey) [encoded_length]u8 {
|
||||
return sk.sk.toBytes() ++ sk.pk.toBytes() ++ sk.hpk ++ sk.z;
|
||||
}
|
||||
|
||||
/// Deserializes the key from a byte array.
|
||||
pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!SecretKey {
|
||||
pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!SecretKey {
|
||||
var ret: SecretKey = undefined;
|
||||
comptime var s: usize = 0;
|
||||
ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.bytes_length]);
|
||||
s += InnerSk.bytes_length;
|
||||
ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]);
|
||||
s += InnerPk.bytes_length;
|
||||
ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.encoded_length]);
|
||||
s += InnerSk.encoded_length;
|
||||
ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.encoded_length]);
|
||||
s += InnerPk.encoded_length;
|
||||
ret.hpk = buf[s..][0..h_length].*;
|
||||
s += h_length;
|
||||
ret.z = buf[s..][0..shared_length].*;
|
||||
|
|
@ -418,7 +419,7 @@ fn Kyber(comptime p: Params) type {
|
|||
// Cached values
|
||||
aT: M,
|
||||
|
||||
const bytes_length = V.bytes_length + 32;
|
||||
const encoded_length = V.encoded_length + 32;
|
||||
|
||||
fn encrypt(
|
||||
pk: InnerPk,
|
||||
|
|
@ -436,7 +437,7 @@ fn Kyber(comptime p: Params) type {
|
|||
// Note that coefficients of r are bounded by q and those of Aᵀ
|
||||
// are bounded by 4.5q and so their product is bounded by 2¹⁵q
|
||||
// as required for multiplication.
|
||||
u.ps[i] = pk.aT.vs[i].dotHat(rh);
|
||||
u.ps[i] = pk.aT.rows[i].dotHat(rh);
|
||||
}
|
||||
|
||||
// Aᵀ and r were not in Montgomery form, so the Montgomery
|
||||
|
|
@ -451,14 +452,14 @@ fn Kyber(comptime p: Params) type {
|
|||
return u.compress(p.du) ++ v.compress(p.dv);
|
||||
}
|
||||
|
||||
fn toBytes(pk: InnerPk) [bytes_length]u8 {
|
||||
fn toBytes(pk: InnerPk) [encoded_length]u8 {
|
||||
return pk.th.toBytes() ++ pk.rho;
|
||||
}
|
||||
|
||||
fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!InnerPk {
|
||||
fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!InnerPk {
|
||||
var ret: InnerPk = undefined;
|
||||
|
||||
const th_bytes = buf[0..V.bytes_length];
|
||||
const th_bytes = buf[0..V.encoded_length];
|
||||
ret.th = V.fromBytes(th_bytes).normalize();
|
||||
|
||||
if (p.ml_kem) {
|
||||
|
|
@ -468,7 +469,7 @@ fn Kyber(comptime p: Params) type {
|
|||
}
|
||||
}
|
||||
|
||||
ret.rho = buf[V.bytes_length..bytes_length].*;
|
||||
ret.rho = buf[V.encoded_length..encoded_length].*;
|
||||
ret.aT = M.uniform(ret.rho, true);
|
||||
return ret;
|
||||
}
|
||||
|
|
@ -477,7 +478,7 @@ fn Kyber(comptime p: Params) type {
|
|||
// Private key of the inner PKE
|
||||
const InnerSk = struct {
|
||||
sh: V, // NTT(s), normalized
|
||||
const bytes_length = V.bytes_length;
|
||||
const encoded_length = V.encoded_length;
|
||||
|
||||
fn decrypt(sk: InnerSk, ct: *const [ciphertext_length]u8) [inner_plaintext_length]u8 {
|
||||
const u = V.decompress(p.du, ct[0..comptime V.compressedSize(p.du)]);
|
||||
|
|
@ -491,11 +492,11 @@ fn Kyber(comptime p: Params) type {
|
|||
.normalize().compress(1);
|
||||
}
|
||||
|
||||
fn toBytes(sk: InnerSk) [bytes_length]u8 {
|
||||
fn toBytes(sk: InnerSk) [encoded_length]u8 {
|
||||
return sk.sh.toBytes();
|
||||
}
|
||||
|
||||
fn fromBytes(buf: *const [bytes_length]u8) InnerSk {
|
||||
fn fromBytes(buf: *const [encoded_length]u8) InnerSk {
|
||||
var ret: InnerSk = undefined;
|
||||
ret.sh = V.fromBytes(buf).normalize();
|
||||
return ret;
|
||||
|
|
@ -516,7 +517,7 @@ fn Kyber(comptime p: Params) type {
|
|||
// Sample secret vector s.
|
||||
sk.sh = V.noise(p.eta1, 0, sigma).ntt().normalize();
|
||||
|
||||
const eh = Vec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e.
|
||||
const eh = PolyVec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e.
|
||||
var th: V = undefined;
|
||||
|
||||
// Next, we compute t = A s + e.
|
||||
|
|
@ -528,7 +529,7 @@ fn Kyber(comptime p: Params) type {
|
|||
// multiplications in the inner product added a factor R⁻¹ which
|
||||
// we'll cancel out with toMont(). This will also ensure the
|
||||
// coefficients of th are bounded in absolute value by q.
|
||||
th.ps[i] = pk.aT.vs[i].dotHat(sk.sh).toMont();
|
||||
th.ps[i] = pk.aT.rows[i].dotHat(sk.sh).toMont();
|
||||
}
|
||||
|
||||
pk.th = th.add(eh).normalize(); // bounded by 8q
|
||||
|
|
@ -565,7 +566,6 @@ const zetas = computeZetas();
|
|||
// not enough, the other coefficient is reduced as well.
|
||||
//
|
||||
// This is actually optimal, as proven in https://eprint.iacr.org/2020/1377.pdf
|
||||
// TODO generate comptime?
|
||||
const inv_ntt_reductions = [_]i16{
|
||||
-1, // after layer 1
|
||||
-1, // after layer 2
|
||||
|
|
@ -634,31 +634,8 @@ test "invNTTReductions bounds" {
|
|||
}
|
||||
}
|
||||
|
||||
// Extended euclidean algorithm.
|
||||
//
|
||||
// For a, b finds x, y such that x a + y b = gcd(a, b). Used to compute
|
||||
// modular inverse.
|
||||
fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) {
|
||||
if (a == 0) {
|
||||
return .{ .gcd = b, .x = 0, .y = 1 };
|
||||
}
|
||||
const r = eea(@rem(b, a), a);
|
||||
return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x };
|
||||
}
|
||||
|
||||
fn EeaResult(comptime T: type) type {
|
||||
return struct { gcd: T, x: T, y: T };
|
||||
}
|
||||
|
||||
// Returns least common multiple of a and b.
|
||||
fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) {
|
||||
const r = eea(a, b);
|
||||
return a * b / r.gcd;
|
||||
}
|
||||
|
||||
// Invert modulo p.
|
||||
fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) {
|
||||
const r = eea(a, p);
|
||||
const r = extendedEuclidean(@TypeOf(a), a, p);
|
||||
assert(r.gcd == 1);
|
||||
return r.x;
|
||||
}
|
||||
|
|
@ -788,31 +765,12 @@ test "Test csubq" {
|
|||
}
|
||||
}
|
||||
|
||||
// Compute a^s mod p.
|
||||
fn mpow(a: anytype, s: @TypeOf(a), p: @TypeOf(a)) @TypeOf(a) {
|
||||
var ret: @TypeOf(a) = 1;
|
||||
var s2 = s;
|
||||
var a2 = a;
|
||||
|
||||
while (true) {
|
||||
if (s2 & 1 == 1) {
|
||||
ret = @mod(ret * a2, p);
|
||||
}
|
||||
s2 >>= 1;
|
||||
if (s2 == 0) {
|
||||
break;
|
||||
}
|
||||
a2 = @mod(a2 * a2, p);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Computes zetas table used by ntt and invNTT.
|
||||
fn computeZetas() [128]i16 {
|
||||
@setEvalBranchQuota(10000);
|
||||
var ret: [128]i16 = undefined;
|
||||
for (&ret, 0..) |*r, i| {
|
||||
const t = @as(i16, @intCast(mpow(@as(i32, zeta), @bitReverse(@as(u7, @intCast(i))), Q)));
|
||||
const t = @as(i16, @intCast(modularPow(i32, zeta, @bitReverse(@as(u7, @intCast(i))), Q)));
|
||||
r.* = csubq(feBarrettReduce(feToMont(t)));
|
||||
}
|
||||
return ret;
|
||||
|
|
@ -828,9 +786,10 @@ fn computeZetas() [128]i16 {
|
|||
const Poly = struct {
|
||||
cs: [N]i16,
|
||||
|
||||
const bytes_length = N / 2 * 3;
|
||||
const encoded_length = N / 2 * 3;
|
||||
const zero: Poly = .{ .cs = .{0} ** N };
|
||||
|
||||
// Add two polynomials (coefficients not normalized)
|
||||
fn add(a: Poly, b: Poly) Poly {
|
||||
var ret: Poly = undefined;
|
||||
for (0..N) |i| {
|
||||
|
|
@ -839,6 +798,7 @@ const Poly = struct {
|
|||
return ret;
|
||||
}
|
||||
|
||||
// Subtract two polynomials (coefficients not normalized)
|
||||
fn sub(a: Poly, b: Poly) Poly {
|
||||
var ret: Poly = undefined;
|
||||
for (0..N) |i| {
|
||||
|
|
@ -847,25 +807,6 @@ const Poly = struct {
|
|||
return ret;
|
||||
}
|
||||
|
||||
// For testing, generates a random polynomial with for each
|
||||
// coefficient |x| ≤ q.
|
||||
fn randAbsLeqQ(rnd: anytype) Poly {
|
||||
var ret: Poly = undefined;
|
||||
for (0..N) |i| {
|
||||
ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// For testing, generates a random normalized polynomial.
|
||||
fn randNormalized(rnd: anytype) Poly {
|
||||
var ret: Poly = undefined;
|
||||
for (0..N) |i| {
|
||||
ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Executes a forward "NTT" on p.
|
||||
//
|
||||
// Assumes the coefficients are in absolute value ≤q. The resulting
|
||||
|
|
@ -1054,7 +995,7 @@ const Poly = struct {
|
|||
var in_off: usize = 0;
|
||||
var out_off: usize = 0;
|
||||
|
||||
const batch_size: usize = comptime lcm(@as(i16, d), 8);
|
||||
const batch_size: usize = comptime math.lcm(d, 8);
|
||||
const in_batch_size: usize = comptime batch_size / d;
|
||||
const out_batch_size: usize = comptime batch_size / 8;
|
||||
|
||||
|
|
@ -1118,7 +1059,7 @@ const Poly = struct {
|
|||
var in_off: usize = 0;
|
||||
var out_off: usize = 0;
|
||||
|
||||
const batch_size: usize = comptime lcm(@as(i16, d), 8);
|
||||
const batch_size: usize = comptime math.lcm(d, 8);
|
||||
const in_batch_size: usize = comptime batch_size / 8;
|
||||
const out_batch_size: usize = comptime batch_size / d;
|
||||
|
||||
|
|
@ -1275,53 +1216,23 @@ const Poly = struct {
|
|||
return ret;
|
||||
}
|
||||
|
||||
// Sample p uniformly from the given seed and x and y coordinates.
|
||||
fn uniform(seed: [32]u8, x: u8, y: u8) Poly {
|
||||
var h = sha3.Shake128.init(.{});
|
||||
const suffix: [2]u8 = .{ x, y };
|
||||
h.update(&seed);
|
||||
h.update(&suffix);
|
||||
|
||||
const buf_len = sha3.Shake128.block_length; // rate SHAKE-128
|
||||
var buf: [buf_len]u8 = undefined;
|
||||
|
||||
var ret: Poly = undefined;
|
||||
var i: usize = 0; // index into ret.cs
|
||||
outer: while (true) {
|
||||
h.squeeze(&buf);
|
||||
|
||||
var j: usize = 0; // index into buf
|
||||
while (j < buf_len) : (j += 3) {
|
||||
const b0 = @as(u16, buf[j]);
|
||||
const b1 = @as(u16, buf[j + 1]);
|
||||
const b2 = @as(u16, buf[j + 2]);
|
||||
|
||||
const ts: [2]u16 = .{
|
||||
b0 | ((b1 & 0xf) << 8),
|
||||
(b1 >> 4) | (b2 << 4),
|
||||
};
|
||||
|
||||
inline for (ts) |t| {
|
||||
if (t < Q) {
|
||||
ret.cs[i] = @as(i16, @intCast(t));
|
||||
i += 1;
|
||||
|
||||
if (i == N) {
|
||||
break :outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
const domain_sep: [2]u8 = .{ x, y };
|
||||
return sampleUniformRejection(
|
||||
Poly,
|
||||
Q,
|
||||
12,
|
||||
N,
|
||||
&seed,
|
||||
&domain_sep,
|
||||
);
|
||||
}
|
||||
|
||||
// Packs p.
|
||||
//
|
||||
// Assumes p is normalized (and not just Barrett reduced).
|
||||
fn toBytes(p: Poly) [bytes_length]u8 {
|
||||
var ret: [bytes_length]u8 = undefined;
|
||||
fn toBytes(p: Poly) [encoded_length]u8 {
|
||||
var ret: [encoded_length]u8 = undefined;
|
||||
for (0..comptime N / 2) |i| {
|
||||
const t0 = @as(u16, @intCast(p.cs[2 * i]));
|
||||
const t1 = @as(u16, @intCast(p.cs[2 * i + 1]));
|
||||
|
|
@ -1335,7 +1246,7 @@ const Poly = struct {
|
|||
// Unpacks a Poly from buf.
|
||||
//
|
||||
// p will not be normalized; instead 0 ≤ p[i] < 4096.
|
||||
fn fromBytes(buf: *const [bytes_length]u8) Poly {
|
||||
fn fromBytes(buf: *const [encoded_length]u8) Poly {
|
||||
var ret: Poly = undefined;
|
||||
for (0..comptime N / 2) |i| {
|
||||
const b0 = @as(i16, buf[3 * i]);
|
||||
|
|
@ -1348,71 +1259,65 @@ const Poly = struct {
|
|||
}
|
||||
};
|
||||
|
||||
// A vector of K polynomials.
|
||||
fn Vec(comptime K: u8) type {
|
||||
// A vector of k polynomials.
|
||||
fn PolyVec(comptime k: u8) type {
|
||||
return struct {
|
||||
ps: [K]Poly,
|
||||
ps: [k]Poly,
|
||||
|
||||
const Self = @This();
|
||||
const bytes_length = K * Poly.bytes_length;
|
||||
const encoded_length = k * Poly.encoded_length;
|
||||
|
||||
fn compressedSize(comptime d: u8) usize {
|
||||
return Poly.compressedSize(d) * K;
|
||||
return Poly.compressedSize(d) * k;
|
||||
}
|
||||
|
||||
fn ntt(a: Self) Self {
|
||||
/// Apply unary operation to each polynomial
|
||||
fn map(v: Self, comptime op: fn (Poly) Poly) Self {
|
||||
var ret: Self = undefined;
|
||||
for (0..K) |i| {
|
||||
ret.ps[i] = a.ps[i].ntt();
|
||||
inline for (0..k) |i| {
|
||||
ret.ps[i] = op(v.ps[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
fn invNTT(a: Self) Self {
|
||||
/// Apply binary operation pairwise
|
||||
fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self {
|
||||
var ret: Self = undefined;
|
||||
for (0..K) |i| {
|
||||
ret.ps[i] = a.ps[i].invNTT();
|
||||
inline for (0..k) |i| {
|
||||
ret.ps[i] = op(a.ps[i], b.ps[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
fn normalize(a: Self) Self {
|
||||
var ret: Self = undefined;
|
||||
for (0..K) |i| {
|
||||
ret.ps[i] = a.ps[i].normalize();
|
||||
}
|
||||
return ret;
|
||||
fn ntt(v: Self) Self {
|
||||
return map(v, Poly.ntt);
|
||||
}
|
||||
|
||||
fn barrettReduce(a: Self) Self {
|
||||
var ret: Self = undefined;
|
||||
for (0..K) |i| {
|
||||
ret.ps[i] = a.ps[i].barrettReduce();
|
||||
}
|
||||
return ret;
|
||||
fn invNTT(v: Self) Self {
|
||||
return map(v, Poly.invNTT);
|
||||
}
|
||||
|
||||
fn normalize(v: Self) Self {
|
||||
return map(v, Poly.normalize);
|
||||
}
|
||||
|
||||
fn barrettReduce(v: Self) Self {
|
||||
return map(v, Poly.barrettReduce);
|
||||
}
|
||||
|
||||
fn add(a: Self, b: Self) Self {
|
||||
var ret: Self = undefined;
|
||||
for (0..K) |i| {
|
||||
ret.ps[i] = a.ps[i].add(b.ps[i]);
|
||||
}
|
||||
return ret;
|
||||
return mapBinary(a, b, Poly.add);
|
||||
}
|
||||
|
||||
fn sub(a: Self, b: Self) Self {
|
||||
var ret: Self = undefined;
|
||||
for (0..K) |i| {
|
||||
ret.ps[i] = a.ps[i].sub(b.ps[i]);
|
||||
}
|
||||
return ret;
|
||||
return mapBinary(a, b, Poly.sub);
|
||||
}
|
||||
|
||||
// Samples v[i] from centered binomial distribution with the given η,
|
||||
// seed and nonce+i.
|
||||
fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Self {
|
||||
var ret: Self = undefined;
|
||||
for (0..K) |i| {
|
||||
for (0..k) |i| {
|
||||
ret.ps[i] = Poly.noise(eta, nonce + @as(u8, @intCast(i)), seed);
|
||||
}
|
||||
return ret;
|
||||
|
|
@ -1428,7 +1333,7 @@ fn Vec(comptime K: u8) type {
|
|||
// of the Montgomery factor.
|
||||
fn dotHat(a: Self, b: Self) Poly {
|
||||
var ret: Poly = Poly.zero;
|
||||
for (0..K) |i| {
|
||||
for (0..k) |i| {
|
||||
ret = ret.add(a.ps[i].mulHat(b.ps[i]));
|
||||
}
|
||||
return ret;
|
||||
|
|
@ -1437,7 +1342,7 @@ fn Vec(comptime K: u8) type {
|
|||
fn compress(v: Self, comptime d: u8) [compressedSize(d)]u8 {
|
||||
const cs = comptime Poly.compressedSize(d);
|
||||
var ret: [compressedSize(d)]u8 = undefined;
|
||||
inline for (0..K) |i| {
|
||||
inline for (0..k) |i| {
|
||||
ret[i * cs .. (i + 1) * cs].* = v.ps[i].compress(d);
|
||||
}
|
||||
return ret;
|
||||
|
|
@ -1446,27 +1351,27 @@ fn Vec(comptime K: u8) type {
|
|||
fn decompress(comptime d: u8, buf: *const [compressedSize(d)]u8) Self {
|
||||
const cs = comptime Poly.compressedSize(d);
|
||||
var ret: Self = undefined;
|
||||
inline for (0..K) |i| {
|
||||
inline for (0..k) |i| {
|
||||
ret.ps[i] = Poly.decompress(d, buf[i * cs .. (i + 1) * cs]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Serializes the key into a byte array.
|
||||
fn toBytes(v: Self) [bytes_length]u8 {
|
||||
var ret: [bytes_length]u8 = undefined;
|
||||
inline for (0..K) |i| {
|
||||
ret[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length].* = v.ps[i].toBytes();
|
||||
fn toBytes(v: Self) [encoded_length]u8 {
|
||||
var ret: [encoded_length]u8 = undefined;
|
||||
inline for (0..k) |i| {
|
||||
ret[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length].* = v.ps[i].toBytes();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Deserializes the key from a byte array.
|
||||
fn fromBytes(buf: *const [bytes_length]u8) Self {
|
||||
fn fromBytes(buf: *const [encoded_length]u8) Self {
|
||||
var ret: Self = undefined;
|
||||
inline for (0..K) |i| {
|
||||
inline for (0..k) |i| {
|
||||
ret.ps[i] = Poly.fromBytes(
|
||||
buf[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length],
|
||||
buf[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length],
|
||||
);
|
||||
}
|
||||
return ret;
|
||||
|
|
@ -1474,19 +1379,19 @@ fn Vec(comptime K: u8) type {
|
|||
};
|
||||
}
|
||||
|
||||
// A matrix of K vectors
|
||||
fn Mat(comptime K: u8) type {
|
||||
// A matrix of k vectors
|
||||
fn Mat(comptime k: u8) type {
|
||||
return struct {
|
||||
const Self = @This();
|
||||
vs: [K]Vec(K),
|
||||
rows: [k]PolyVec(k),
|
||||
|
||||
fn uniform(seed: [32]u8, comptime transposed: bool) Self {
|
||||
var ret: Self = undefined;
|
||||
var i: u8 = 0;
|
||||
while (i < K) : (i += 1) {
|
||||
while (i < k) : (i += 1) {
|
||||
var j: u8 = 0;
|
||||
while (j < K) : (j += 1) {
|
||||
ret.vs[i].ps[j] = Poly.uniform(
|
||||
while (j < k) : (j += 1) {
|
||||
ret.rows[i].ps[j] = Poly.uniform(
|
||||
seed,
|
||||
if (transposed) i else j,
|
||||
if (transposed) j else i,
|
||||
|
|
@ -1499,9 +1404,9 @@ fn Mat(comptime K: u8) type {
|
|||
// Returns transpose of A
|
||||
fn transpose(m: Self) Self {
|
||||
var ret: Self = undefined;
|
||||
for (0..K) |i| {
|
||||
for (0..K) |j| {
|
||||
ret.vs[i].ps[j] = m.vs[j].ps[i];
|
||||
for (0..k) |i| {
|
||||
for (0..k) |j| {
|
||||
ret.rows[i].ps[j] = m.rows[j].ps[i];
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
|
|
@ -1522,12 +1427,30 @@ fn cmov(comptime len: usize, dst: *[len]u8, src: [len]u8, b: u1) void {
|
|||
}
|
||||
}
|
||||
|
||||
// Test helper: generates a random polynomial with each coefficient |x| ≤ q
|
||||
fn randPolyAbsLeqQ(rnd: anytype) Poly {
|
||||
var ret: Poly = undefined;
|
||||
for (0..N) |i| {
|
||||
ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Test helper: generates a random normalized polynomial
|
||||
fn randPolyNormalized(rnd: anytype) Poly {
|
||||
var ret: Poly = undefined;
|
||||
for (0..N) |i| {
|
||||
ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
test "MulHat" {
|
||||
var rnd = RndGen.init(0);
|
||||
|
||||
for (0..100) |_| {
|
||||
const a = Poly.randAbsLeqQ(&rnd);
|
||||
const b = Poly.randAbsLeqQ(&rnd);
|
||||
const a = randPolyAbsLeqQ(&rnd);
|
||||
const b = randPolyAbsLeqQ(&rnd);
|
||||
|
||||
const p2 = a.ntt().mulHat(b.ntt()).barrettReduce().invNTT().normalize();
|
||||
var p: Poly = undefined;
|
||||
|
|
@ -1557,7 +1480,7 @@ test "NTT" {
|
|||
var rnd = RndGen.init(0);
|
||||
|
||||
for (0..1000) |_| {
|
||||
var p = Poly.randAbsLeqQ(&rnd);
|
||||
var p = randPolyAbsLeqQ(&rnd);
|
||||
const q = p.toMont().normalize();
|
||||
p = p.ntt();
|
||||
|
||||
|
|
@ -1580,7 +1503,7 @@ test "Compression" {
|
|||
var rnd = RndGen.init(0);
|
||||
inline for (.{ 1, 4, 5, 10, 11 }) |d| {
|
||||
for (0..1000) |_| {
|
||||
const p = Poly.randNormalized(&rnd);
|
||||
const p = randPolyNormalized(&rnd);
|
||||
const pp = p.compress(d);
|
||||
const pq = Poly.decompress(d, &pp).compress(d);
|
||||
try testing.expectEqual(pp, pq);
|
||||
|
|
@ -1671,7 +1594,7 @@ test "Polynomial packing" {
|
|||
var rnd = RndGen.init(0);
|
||||
|
||||
for (0..1000) |_| {
|
||||
const p = Poly.randNormalized(&rnd);
|
||||
const p = randPolyNormalized(&rnd);
|
||||
try testing.expectEqual(Poly.fromBytes(&p.toBytes()), p);
|
||||
}
|
||||
}
|
||||
|
|
@ -1839,3 +1762,222 @@ const NistDRBG = struct {
|
|||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
/// Extended Euclidian Algorithm
|
||||
/// Only meant to be used on comptime values; correctness matters, performance doesn't.
|
||||
fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } {
|
||||
var a = a_;
|
||||
var b = b_;
|
||||
var x0: T = 1;
|
||||
var x1: T = 0;
|
||||
var y0: T = 0;
|
||||
var y1: T = 1;
|
||||
|
||||
while (b != 0) {
|
||||
const q = @divTrunc(a, b);
|
||||
const temp_a = a;
|
||||
a = b;
|
||||
b = temp_a - q * b;
|
||||
|
||||
const temp_x = x0;
|
||||
x0 = x1;
|
||||
x1 = temp_x - q * x1;
|
||||
|
||||
const temp_y = y0;
|
||||
y0 = y1;
|
||||
y1 = temp_y - q * y1;
|
||||
}
|
||||
|
||||
return .{ .gcd = a, .x = x0, .y = y0 };
|
||||
}
|
||||
|
||||
/// Modular inversion: computes a^(-1) mod p
|
||||
/// Requires gcd(a,p) = 1. The result is normalized to the range [0, p).
|
||||
fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T {
|
||||
// Use a signed type for EEA computation
|
||||
const type_info = @typeInfo(T);
|
||||
const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned)
|
||||
std.meta.Int(.signed, type_info.int.bits)
|
||||
else
|
||||
T;
|
||||
|
||||
const a_signed = @as(SignedT, @intCast(a));
|
||||
const p_signed = @as(SignedT, @intCast(p));
|
||||
|
||||
const r = extendedEuclidean(SignedT, a_signed, p_signed);
|
||||
assert(r.gcd == 1);
|
||||
|
||||
// Normalize result to [0, p)
|
||||
var result = r.x;
|
||||
while (result < 0) {
|
||||
result += p_signed;
|
||||
}
|
||||
|
||||
return @intCast(result);
|
||||
}
|
||||
|
||||
/// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm.
|
||||
fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T {
|
||||
const type_info = @typeInfo(T);
|
||||
const bits = type_info.int.bits;
|
||||
const WideT = std.meta.Int(.unsigned, bits * 2);
|
||||
|
||||
var ret: T = 1;
|
||||
var base: T = a;
|
||||
var exp = s;
|
||||
|
||||
while (exp > 0) {
|
||||
if (exp & 1 == 1) {
|
||||
ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p);
|
||||
}
|
||||
base = @intCast((@as(WideT, base) * @as(WideT, base)) % p);
|
||||
exp >>= 1;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Creates an all-ones or all-zeros mask from a single bit value.
|
||||
/// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0.
|
||||
fn bitMask(comptime T: type, bit: T) T {
|
||||
const type_info = @typeInfo(T);
|
||||
if (type_info != .int or type_info.int.signedness != .unsigned) {
|
||||
@compileError("bitMask requires an unsigned integer type");
|
||||
}
|
||||
return -%bit;
|
||||
}
|
||||
|
||||
/// Creates a mask from the sign bit of a signed integer.
|
||||
/// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0.
|
||||
fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) {
|
||||
const type_info = @typeInfo(T);
|
||||
if (type_info != .int) {
|
||||
@compileError("signMask requires an integer type");
|
||||
}
|
||||
|
||||
const bits = type_info.int.bits;
|
||||
const SignedT = std.meta.Int(.signed, bits);
|
||||
|
||||
// Convert to signed if needed, arithmetic right shift to propagate sign bit
|
||||
const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x);
|
||||
const shifted = x_signed >> (bits - 1);
|
||||
return @bitCast(shifted);
|
||||
}
|
||||
|
||||
test "bitMask and signMask helpers" {
|
||||
try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0));
|
||||
try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1));
|
||||
try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0));
|
||||
try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1));
|
||||
try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0));
|
||||
try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1));
|
||||
|
||||
try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1));
|
||||
try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100));
|
||||
try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0));
|
||||
try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1));
|
||||
try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100));
|
||||
|
||||
try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set
|
||||
try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear
|
||||
}
|
||||
|
||||
/// Montgomery reduction: for input x, returns y where y ≡ x*R^(-1) (mod q).
|
||||
/// This is a generic implementation parameterized by the modulus q, its inverse qInv,
|
||||
/// the Montgomery constant R, and the result bound.
|
||||
///
|
||||
/// For ML-DSA: R = 2^32, returns y < 2q
|
||||
/// For ML-KEM: R = 2^16, returns y in range (-q, q)
|
||||
fn montgomeryReduce(
|
||||
comptime InT: type,
|
||||
comptime OutT: type,
|
||||
comptime q: comptime_int,
|
||||
comptime qInv: comptime_int,
|
||||
comptime r_bits: comptime_int,
|
||||
x: InT,
|
||||
) OutT {
|
||||
const mask = (@as(InT, 1) << r_bits) - 1;
|
||||
const m_full = (x *% qInv) & mask;
|
||||
const m: OutT = @truncate(m_full);
|
||||
|
||||
const yR = x -% @as(InT, m) * @as(InT, q);
|
||||
const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits;
|
||||
return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted)));
|
||||
}
|
||||
|
||||
/// Uniform sampling using SHAKE-128 with rejection sampling.
|
||||
/// Samples polynomial coefficients uniformly from [0, q) using rejection sampling.
|
||||
///
|
||||
/// Parameters:
|
||||
/// - PolyType: The polynomial type to return
|
||||
/// - q: Modulus
|
||||
/// - bits_per_coef: Number of bits per coefficient (12 or 23)
|
||||
/// - n: Number of coefficients
|
||||
/// - seed: Random seed
|
||||
/// - domain_sep: Domain separation bytes (appended to seed)
|
||||
fn sampleUniformRejection(
|
||||
comptime PolyType: type,
|
||||
comptime q: comptime_int,
|
||||
comptime bits_per_coef: comptime_int,
|
||||
comptime n: comptime_int,
|
||||
seed: []const u8,
|
||||
domain_sep: []const u8,
|
||||
) PolyType {
|
||||
var h = sha3.Shake128.init(.{});
|
||||
h.update(seed);
|
||||
h.update(domain_sep);
|
||||
|
||||
const buf_len = sha3.Shake128.block_length; // 168 bytes
|
||||
var buf: [buf_len]u8 = undefined;
|
||||
|
||||
var ret: PolyType = undefined;
|
||||
var coef_idx: usize = 0;
|
||||
|
||||
if (bits_per_coef == 12) {
|
||||
// ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each)
|
||||
outer: while (true) {
|
||||
h.squeeze(&buf);
|
||||
|
||||
var j: usize = 0;
|
||||
while (j < buf_len) : (j += 3) {
|
||||
const b0 = @as(u16, buf[j]);
|
||||
const b1 = @as(u16, buf[j + 1]);
|
||||
const b2 = @as(u16, buf[j + 2]);
|
||||
|
||||
const ts: [2]u16 = .{
|
||||
b0 | ((b1 & 0xf) << 8),
|
||||
(b1 >> 4) | (b2 << 4),
|
||||
};
|
||||
|
||||
inline for (ts) |t| {
|
||||
if (t < q) {
|
||||
ret.cs[coef_idx] = @intCast(t);
|
||||
coef_idx += 1;
|
||||
if (coef_idx == n) break :outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (bits_per_coef == 23) {
|
||||
// ML-DSA path: 1 coefficient per 3 bytes (23 bits)
|
||||
while (coef_idx < n) {
|
||||
h.squeeze(&buf);
|
||||
|
||||
var j: usize = 0;
|
||||
while (j < buf_len and coef_idx < n) : (j += 3) {
|
||||
const t = (@as(u32, buf[j]) |
|
||||
(@as(u32, buf[j + 1]) << 8) |
|
||||
(@as(u32, buf[j + 2]) << 16)) & 0x7fffff;
|
||||
|
||||
if (t < q) {
|
||||
ret.cs[coef_idx] = @intCast(t);
|
||||
coef_idx += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@compileError("bits_per_coef must be 12 or 23");
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue