From 79b41dbdbfd6c511b2e206397788b81bc720d266 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 29 Dec 2022 20:49:56 -0700 Subject: [PATCH] std.crypto.tls: avoid heap allocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The code we are borrowing from https://github.com/shiguredo/tls13-zig requires an Allocator for doing RSA certificate verification. As a stopgap measure, this commit uses a FixedBufferAllocator to avoid heap allocation for these functions. Thank you to @naoki9911 for providing this great resource which has been extremely helpful for me when working on this standard library TLS implementation. Until Zig has std.crypto.rsa officially, we will borrow this implementation of RSA. 🙏 --- lib/std/crypto/Certificate.zig | 15 ++++++++------- lib/std/crypto/tls/Client.zig | 7 +++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index cce0193cf0..c4fd66bbc9 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -511,6 +511,10 @@ fn verifyRsa( var msg_hashed: [Hash.digest_length]u8 = undefined; Hash.hash(message, &msg_hashed, .{}); + var rsa_mem_buf: [512 * 32]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); + const ally = fba.allocator(); + switch (modulus.len) { inline 128, 256, 512 => |modulus_len| { const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; @@ -521,11 +525,11 @@ fn verifyRsa( hash_der ++ msg_hashed; - const public_key = rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop) catch |err| switch (err) { - error.OutOfMemory => @panic("TODO don't heap allocate"), + const public_key = rsa.PublicKey.fromBytes(exponent, modulus, ally) catch |err| switch (err) { + error.OutOfMemory => unreachable, // rsa_mem_buf is big enough }; - const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop) catch |err| switch (err) { - error.OutOfMemory => @panic("TODO don't heap allocate"), + const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, ally) catch |err| switch (err) { + error.OutOfMemory => unreachable, // rsa_mem_buf is big enough error.MessageTooLong => unreachable, error.NegativeIntoUnsigned => @panic("TODO make RSA not emit this error"), @@ -977,7 +981,4 @@ pub const rsa = struct { return i; } - - // TODO: flush the toilet - pub const poop = std.heap.page_allocator; }; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 6260995685..c4206862dd 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -544,11 +544,14 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const components = try rsa.PublicKey.parseDer(main_cert_pub_key); const exponent = components.exponent; const modulus = components.modulus; + var rsa_mem_buf: [512 * 32]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); + const ally = fba.allocator(); switch (modulus.len) { inline 128, 256, 512 => |modulus_len| { - const key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); + const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally); const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, rsa.poop); + try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally); }, else => { return error.TlsBadRsaSignatureBitCount;