zig/lib/std/crypto/aes_ocb.zig
Andrew Kelley e7b18a7ce6 std.crypto: remove inline from most functions
To quote the language reference,

It is generally better to let the compiler decide when to inline a
function, except for these scenarios:

* To change how many stack frames are in the call stack, for debugging
  purposes.
* To force comptime-ness of the arguments to propagate to the return
  value of the function, as in the above example.
* Real world performance measurements demand it. Don't guess!

Note that inline actually restricts what the compiler is allowed to do.
This can harm binary size, compilation speed, and even runtime
performance.

`zig run lib/std/crypto/benchmark.zig -OReleaseFast`
[-before-] vs {+after+}

              md5:        [-990-]        {+998+} MiB/s
             sha1:       [-1144-]       {+1140+} MiB/s
           sha256:       [-2267-]       {+2275+} MiB/s
           sha512:        [-762-]        {+767+} MiB/s
         sha3-256:        [-680-]        {+683+} MiB/s
         sha3-512:        [-362-]        {+363+} MiB/s
        shake-128:        [-835-]        {+839+} MiB/s
        shake-256:        [-680-]        {+681+} MiB/s
   turboshake-128:       [-1567-]       {+1570+} MiB/s
   turboshake-256:       [-1276-]       {+1282+} MiB/s
          blake2s:        [-778-]        {+789+} MiB/s
          blake2b:       [-1071-]       {+1086+} MiB/s
           blake3:       [-1148-]       {+1137+} MiB/s
            ghash:      [-10044-]      {+10033+} MiB/s
          polyval:       [-9726-]      {+10033+} MiB/s
         poly1305:       [-2486-]       {+2703+} MiB/s
         hmac-md5:        [-991-]        {+998+} MiB/s
        hmac-sha1:       [-1134-]       {+1137+} MiB/s
      hmac-sha256:       [-2265-]       {+2288+} MiB/s
      hmac-sha512:        [-765-]        {+764+} MiB/s
      siphash-2-4:       [-4410-]       {+4438+} MiB/s
      siphash-1-3:       [-7144-]       {+7225+} MiB/s
   siphash128-2-4:       [-4397-]       {+4449+} MiB/s
   siphash128-1-3:       [-7281-]       {+7374+} MiB/s
  aegis-128x4 mac:      [-73385-]      {+74523+} MiB/s
  aegis-256x4 mac:      [-30160-]      {+30539+} MiB/s
  aegis-128x2 mac:      [-66662-]      {+67267+} MiB/s
  aegis-256x2 mac:      [-16812-]      {+16806+} MiB/s
   aegis-128l mac:      [-33876-]      {+34055+} MiB/s
    aegis-256 mac:       [-8993-]       {+9087+} MiB/s
         aes-cmac:       2036 MiB/s
           x25519:      [-20670-]      {+16844+} exchanges/s
          ed25519:      [-29763-]      {+29576+} signatures/s
       ecdsa-p256:       [-4762-]       {+4900+} signatures/s
       ecdsa-p384:       [-1465-]       {+1500+} signatures/s
  ecdsa-secp256k1:       [-5643-]       {+5769+} signatures/s
          ed25519:      [-21926-]      {+21721+} verifications/s
          ed25519:      [-51200-]      {+50880+} verifications/s (batch)
 chacha20Poly1305:       [-1189-]       {+1109+} MiB/s
xchacha20Poly1305:       [-1196-]       {+1107+} MiB/s
 xchacha8Poly1305:       [-1466-]       {+1555+} MiB/s
 xsalsa20Poly1305:        [-660-]        {+620+} MiB/s
      aegis-128x4:      [-76389-]      {+78181+} MiB/s
      aegis-128x2:      [-53946-]      {+53495+} MiB/s
       aegis-128l:      [-27219-]      {+25621+} MiB/s
      aegis-256x4:      [-49351-]      {+49542+} MiB/s
      aegis-256x2:      [-32390-]      {+32366+} MiB/s
        aegis-256:       [-8881-]       {+8944+} MiB/s
       aes128-gcm:       [-6095-]       {+6205+} MiB/s
       aes256-gcm:       [-5306-]       {+5427+} MiB/s
       aes128-ocb:       [-8529-]      {+13974+} MiB/s
       aes256-ocb:       [-7241-]       {+9442+} MiB/s
        isapa128a:        [-204-]        {+214+} MiB/s
    aes128-single:  [-133857882-]  {+134170944+} ops/s
    aes256-single:   [-96306962-]   {+96408639+} ops/s
         aes128-8: [-1083210101-] {+1073727253+} ops/s
         aes256-8:  [-762042466-]  {+767091778+} ops/s
           bcrypt:      0.009 s/ops
           scrypt:      [-0.018-]      {+0.017+} s/ops
           argon2:      [-0.037-]      {+0.060+} s/ops
      kyber512d00:     [-206057-]     {+205779+} encaps/s
      kyber768d00:     [-156074-]     {+150711+} encaps/s
     kyber1024d00:     [-116626-]     {+115469+} encaps/s
      kyber512d00:     [-181149-]     {+182046+} decaps/s
      kyber768d00:     [-136965-]     {+135676+} decaps/s
     kyber1024d00:     [-101307-]     {+100643+} decaps/s
      kyber512d00:     [-123624-]     {+123375+} keygen/s
      kyber768d00:      [-69465-]      {+70828+} keygen/s
     kyber1024d00:      [-43117-]      {+43208+} keygen/s
2025-07-13 18:26:13 +02:00

350 lines
13 KiB
Zig

const std = @import("std");
const builtin = @import("builtin");
const crypto = std.crypto;
const aes = crypto.core.aes;
const assert = std.debug.assert;
const math = std.math;
const mem = std.mem;
const AuthenticationError = crypto.errors.AuthenticationError;
pub const Aes128Ocb = AesOcb(aes.Aes128);
pub const Aes256Ocb = AesOcb(aes.Aes256);
const Block = [16]u8;
/// AES-OCB (RFC 7253 - https://competitions.cr.yp.to/round3/ocbv11.pdf)
fn AesOcb(comptime Aes: anytype) type {
const EncryptCtx = aes.AesEncryptCtx(Aes);
const DecryptCtx = aes.AesDecryptCtx(Aes);
return struct {
pub const key_length = Aes.key_bits / 8;
pub const nonce_length: usize = 12;
pub const tag_length: usize = 16;
const Lx = struct {
star: Block align(16),
dol: Block align(16),
table: [56]Block align(16) = undefined,
upto: usize,
fn double(l: Block) Block {
const l_ = mem.readInt(u128, &l, .big);
const l_2 = (l_ << 1) ^ (0x87 & -%(l_ >> 127));
var l2: Block = undefined;
mem.writeInt(u128, &l2, l_2, .big);
return l2;
}
fn precomp(lx: *Lx, upto: usize) []const Block {
const table = &lx.table;
assert(upto < table.len);
var i = lx.upto;
while (i + 1 <= upto) : (i += 1) {
table[i + 1] = double(table[i]);
}
lx.upto = upto;
return lx.table[0 .. upto + 1];
}
fn init(aes_enc_ctx: EncryptCtx) Lx {
const zeros = [_]u8{0} ** 16;
var star: Block = undefined;
aes_enc_ctx.encrypt(&star, &zeros);
const dol = double(star);
var lx = Lx{ .star = star, .dol = dol, .upto = 0 };
lx.table[0] = double(dol);
return lx;
}
};
fn hash(aes_enc_ctx: EncryptCtx, lx: *Lx, a: []const u8) Block {
const full_blocks: usize = a.len / 16;
const x_max = if (full_blocks > 0) math.log2_int(usize, full_blocks) else 0;
const lt = lx.precomp(x_max);
var sum = [_]u8{0} ** 16;
var offset = [_]u8{0} ** 16;
var i: usize = 0;
while (i < full_blocks) : (i += 1) {
xorWith(&offset, lt[@ctz(i + 1)]);
var e = xorBlocks(offset, a[i * 16 ..][0..16].*);
aes_enc_ctx.encrypt(&e, &e);
xorWith(&sum, e);
}
const leftover = a.len % 16;
if (leftover > 0) {
xorWith(&offset, lx.star);
var padded = [_]u8{0} ** 16;
@memcpy(padded[0..leftover], a[i * 16 ..][0..leftover]);
padded[leftover] = 1;
var e = xorBlocks(offset, padded);
aes_enc_ctx.encrypt(&e, &e);
xorWith(&sum, e);
}
return sum;
}
fn getOffset(aes_enc_ctx: EncryptCtx, npub: [nonce_length]u8) Block {
var nx = [_]u8{0} ** 16;
nx[0] = @as(u8, @intCast(@as(u7, @truncate(tag_length * 8)) << 1));
nx[16 - nonce_length - 1] = 1;
nx[nx.len - nonce_length ..].* = npub;
const bottom: u6 = @truncate(nx[15]);
nx[15] &= 0xc0;
var ktop_: Block = undefined;
aes_enc_ctx.encrypt(&ktop_, &nx);
const ktop = mem.readInt(u128, &ktop_, .big);
const stretch = (@as(u192, ktop) << 64) | @as(u192, @as(u64, @truncate(ktop >> 64)) ^ @as(u64, @truncate(ktop >> 56)));
var offset: Block = undefined;
mem.writeInt(u128, &offset, @as(u128, @truncate(stretch >> (64 - @as(u7, bottom)))), .big);
return offset;
}
const has_aesni = builtin.cpu.has(.x86, .aes);
const has_armaes = builtin.cpu.has(.aarch64, .aes);
const wb: usize = if ((builtin.cpu.arch == .x86_64 and has_aesni) or (builtin.cpu.arch == .aarch64 and has_armaes)) 4 else 0;
/// c: ciphertext: output buffer should be of size m.len
/// tag: authentication tag: output MAC
/// m: message
/// ad: Associated Data
/// npub: public nonce
/// k: secret key
pub fn encrypt(c: []u8, tag: *[tag_length]u8, m: []const u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) void {
assert(c.len == m.len);
const aes_enc_ctx = Aes.initEnc(key);
const full_blocks: usize = m.len / 16;
const x_max = if (full_blocks > 0) math.log2_int(usize, full_blocks) else 0;
var lx = Lx.init(aes_enc_ctx);
const lt = lx.precomp(x_max);
var offset = getOffset(aes_enc_ctx, npub);
var sum = [_]u8{0} ** 16;
var i: usize = 0;
while (wb > 0 and i + wb <= full_blocks) : (i += wb) {
var offsets: [wb]Block align(16) = undefined;
var es: [16 * wb]u8 align(16) = undefined;
var j: usize = 0;
while (j < wb) : (j += 1) {
xorWith(&offset, lt[@ctz(i + 1 + j)]);
offsets[j] = offset;
const p = m[(i + j) * 16 ..][0..16].*;
es[j * 16 ..][0..16].* = xorBlocks(p, offsets[j]);
xorWith(&sum, p);
}
aes_enc_ctx.encryptWide(wb, &es, &es);
j = 0;
while (j < wb) : (j += 1) {
const e = es[j * 16 ..][0..16].*;
c[(i + j) * 16 ..][0..16].* = xorBlocks(e, offsets[j]);
}
}
while (i < full_blocks) : (i += 1) {
xorWith(&offset, lt[@ctz(i + 1)]);
const p = m[i * 16 ..][0..16].*;
var e = xorBlocks(p, offset);
aes_enc_ctx.encrypt(&e, &e);
c[i * 16 ..][0..16].* = xorBlocks(e, offset);
xorWith(&sum, p);
}
const leftover = m.len % 16;
if (leftover > 0) {
xorWith(&offset, lx.star);
var pad = offset;
aes_enc_ctx.encrypt(&pad, &pad);
for (m[i * 16 ..], 0..) |x, j| {
c[i * 16 + j] = pad[j] ^ x;
}
var e = [_]u8{0} ** 16;
@memcpy(e[0..leftover], m[i * 16 ..][0..leftover]);
e[leftover] = 0x80;
xorWith(&sum, e);
}
var e = xorBlocks(xorBlocks(sum, offset), lx.dol);
aes_enc_ctx.encrypt(&e, &e);
tag.* = xorBlocks(e, hash(aes_enc_ctx, &lx, ad));
}
/// `m`: Message
/// `c`: Ciphertext
/// `tag`: Authentication tag
/// `ad`: Associated data
/// `npub`: Public nonce
/// `k`: Private key
/// Asserts `c.len == m.len`.
///
/// Contents of `m` are undefined if an error is returned.
pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) AuthenticationError!void {
assert(c.len == m.len);
const aes_enc_ctx = Aes.initEnc(key);
const aes_dec_ctx = DecryptCtx.initFromEnc(aes_enc_ctx);
const full_blocks: usize = m.len / 16;
const x_max = if (full_blocks > 0) math.log2_int(usize, full_blocks) else 0;
var lx = Lx.init(aes_enc_ctx);
const lt = lx.precomp(x_max);
var offset = getOffset(aes_enc_ctx, npub);
var sum = [_]u8{0} ** 16;
var i: usize = 0;
while (wb > 0 and i + wb <= full_blocks) : (i += wb) {
var offsets: [wb]Block align(16) = undefined;
var es: [16 * wb]u8 align(16) = undefined;
var j: usize = 0;
while (j < wb) : (j += 1) {
xorWith(&offset, lt[@ctz(i + 1 + j)]);
offsets[j] = offset;
const q = c[(i + j) * 16 ..][0..16].*;
es[j * 16 ..][0..16].* = xorBlocks(q, offsets[j]);
}
aes_dec_ctx.decryptWide(wb, &es, &es);
j = 0;
while (j < wb) : (j += 1) {
const p = xorBlocks(es[j * 16 ..][0..16].*, offsets[j]);
m[(i + j) * 16 ..][0..16].* = p;
xorWith(&sum, p);
}
}
while (i < full_blocks) : (i += 1) {
xorWith(&offset, lt[@ctz(i + 1)]);
const q = c[i * 16 ..][0..16].*;
var e = xorBlocks(q, offset);
aes_dec_ctx.decrypt(&e, &e);
const p = xorBlocks(e, offset);
m[i * 16 ..][0..16].* = p;
xorWith(&sum, p);
}
const leftover = m.len % 16;
if (leftover > 0) {
xorWith(&offset, lx.star);
var pad = offset;
aes_enc_ctx.encrypt(&pad, &pad);
for (c[i * 16 ..], 0..) |x, j| {
m[i * 16 + j] = pad[j] ^ x;
}
var e = [_]u8{0} ** 16;
@memcpy(e[0..leftover], m[i * 16 ..][0..leftover]);
e[leftover] = 0x80;
xorWith(&sum, e);
}
var e = xorBlocks(xorBlocks(sum, offset), lx.dol);
aes_enc_ctx.encrypt(&e, &e);
var computed_tag = xorBlocks(e, hash(aes_enc_ctx, &lx, ad));
const verify = crypto.timing_safe.eql([tag_length]u8, computed_tag, tag);
if (!verify) {
crypto.secureZero(u8, &computed_tag);
@memset(m, undefined);
return error.AuthenticationFailed;
}
}
};
}
fn xorBlocks(x: Block, y: Block) Block {
var z: Block = x;
for (&z, 0..) |*v, i| {
v.* = x[i] ^ y[i];
}
return z;
}
fn xorWith(x: *Block, y: Block) void {
for (x, 0..) |*v, i| {
v.* ^= y[i];
}
}
const hexToBytes = std.fmt.hexToBytes;
test "AesOcb test vector 1" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
var k: [Aes128Ocb.key_length]u8 = undefined;
var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
var tag: [Aes128Ocb.tag_length]u8 = undefined;
_ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
_ = try hexToBytes(&nonce, "BBAA99887766554433221100");
var c: [0]u8 = undefined;
Aes128Ocb.encrypt(&c, &tag, "", "", nonce, k);
var expected_tag: [tag.len]u8 = undefined;
_ = try hexToBytes(&expected_tag, "785407BFFFC8AD9EDCC5520AC9111EE6");
var m: [0]u8 = undefined;
try Aes128Ocb.decrypt(&m, "", tag, "", nonce, k);
}
test "AesOcb test vector 2" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
var k: [Aes128Ocb.key_length]u8 = undefined;
var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
var tag: [Aes128Ocb.tag_length]u8 = undefined;
var ad: [40]u8 = undefined;
_ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
_ = try hexToBytes(&ad, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
_ = try hexToBytes(&nonce, "BBAA9988776655443322110E");
var c: [0]u8 = undefined;
Aes128Ocb.encrypt(&c, &tag, "", &ad, nonce, k);
var expected_tag: [tag.len]u8 = undefined;
_ = try hexToBytes(&expected_tag, "C5CD9D1850C141E358649994EE701B68");
var m: [0]u8 = undefined;
try Aes128Ocb.decrypt(&m, &c, tag, &ad, nonce, k);
}
test "AesOcb test vector 3" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
var k: [Aes128Ocb.key_length]u8 = undefined;
var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
var tag: [Aes128Ocb.tag_length]u8 = undefined;
var m: [40]u8 = undefined;
var c: [m.len]u8 = undefined;
_ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
_ = try hexToBytes(&m, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
_ = try hexToBytes(&nonce, "BBAA9988776655443322110F");
Aes128Ocb.encrypt(&c, &tag, &m, "", nonce, k);
var expected_c: [c.len]u8 = undefined;
var expected_tag: [tag.len]u8 = undefined;
_ = try hexToBytes(&expected_tag, "479AD363AC366B95A98CA5F3000B1479");
_ = try hexToBytes(&expected_c, "4412923493C57D5DE0D700F753CCE0D1D2D95060122E9F15A5DDBFC5787E50B5CC55EE507BCB084E");
var m2: [m.len]u8 = undefined;
try Aes128Ocb.decrypt(&m2, &c, tag, "", nonce, k);
assert(mem.eql(u8, &m, &m2));
}
test "AesOcb test vector 4" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
var k: [Aes128Ocb.key_length]u8 = undefined;
var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
var tag: [Aes128Ocb.tag_length]u8 = undefined;
var m: [40]u8 = undefined;
var ad = m;
var c: [m.len]u8 = undefined;
_ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
_ = try hexToBytes(&m, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
_ = try hexToBytes(&nonce, "BBAA99887766554433221104");
Aes128Ocb.encrypt(&c, &tag, &m, &ad, nonce, k);
var expected_c: [c.len]u8 = undefined;
var expected_tag: [tag.len]u8 = undefined;
_ = try hexToBytes(&expected_tag, "3AD7A4FF3835B8C5701C1CCEC8FC3358");
_ = try hexToBytes(&expected_c, "571D535B60B277188BE5147170A9A22C");
var m2: [m.len]u8 = undefined;
try Aes128Ocb.decrypt(&m2, &c, tag, &ad, nonce, k);
assert(mem.eql(u8, &m, &m2));
}