mirror of
https://codeberg.org/ziglang/zig.git
synced 2025-12-06 05:44:20 +00:00
Merge pull request #24874 from ziglang/tls-client
std: more reliable HTTP and TLS networking
This commit is contained in:
commit
399bace2f2
7 changed files with 52 additions and 48 deletions
|
|
@ -717,7 +717,12 @@ pub fn Poller(comptime StreamEnum: type) type {
|
|||
const unused = r.buffer[r.end..];
|
||||
if (unused.len >= min_len) return unused;
|
||||
}
|
||||
if (r.seek > 0) r.rebase(r.buffer.len) catch unreachable;
|
||||
if (r.seek > 0) {
|
||||
const data = r.buffer[r.seek..r.end];
|
||||
@memmove(r.buffer[0..data.len], data);
|
||||
r.seek = 0;
|
||||
r.end = data.len;
|
||||
}
|
||||
{
|
||||
var list: std.ArrayListUnmanaged(u8) = .{
|
||||
.items = r.buffer[0..r.end],
|
||||
|
|
|
|||
|
|
@ -86,12 +86,12 @@ pub const VTable = struct {
|
|||
/// `Reader.buffer`, whichever is bigger.
|
||||
readVec: *const fn (r: *Reader, data: [][]u8) Error!usize = defaultReadVec,
|
||||
|
||||
/// Ensures `capacity` more data can be buffered without rebasing.
|
||||
/// Ensures `capacity` data can be buffered without rebasing.
|
||||
///
|
||||
/// Asserts `capacity` is within buffer capacity, or that the stream ends
|
||||
/// within `capacity` bytes.
|
||||
///
|
||||
/// Only called when `capacity` cannot fit into the unused capacity of
|
||||
/// Only called when `capacity` cannot be satisfied by unused capacity of
|
||||
/// `buffer`.
|
||||
///
|
||||
/// The default implementation moves buffered data to the start of
|
||||
|
|
@ -1035,7 +1035,7 @@ fn fillUnbuffered(r: *Reader, n: usize) Error!void {
|
|||
///
|
||||
/// Asserts buffer capacity is at least 1.
|
||||
pub fn fillMore(r: *Reader) Error!void {
|
||||
try rebase(r, 1);
|
||||
try rebase(r, r.end - r.seek + 1);
|
||||
var bufs: [1][]u8 = .{""};
|
||||
_ = try r.vtable.readVec(r, &bufs);
|
||||
}
|
||||
|
|
@ -1203,24 +1203,6 @@ pub fn takeLeb128(r: *Reader, comptime Result: type) TakeLeb128Error!Result {
|
|||
} }))) orelse error.Overflow;
|
||||
}
|
||||
|
||||
pub fn expandTotalCapacity(r: *Reader, allocator: Allocator, n: usize) Allocator.Error!void {
|
||||
if (n <= r.buffer.len) return;
|
||||
if (r.seek > 0) rebase(r, r.buffer.len);
|
||||
var list: ArrayList(u8) = .{
|
||||
.items = r.buffer[0..r.end],
|
||||
.capacity = r.buffer.len,
|
||||
};
|
||||
defer r.buffer = list.allocatedSlice();
|
||||
try list.ensureTotalCapacity(allocator, n);
|
||||
}
|
||||
|
||||
pub const FillAllocError = Error || Allocator.Error;
|
||||
|
||||
pub fn fillAlloc(r: *Reader, allocator: Allocator, n: usize) FillAllocError!void {
|
||||
try expandTotalCapacity(r, allocator, n);
|
||||
return fill(r, n);
|
||||
}
|
||||
|
||||
fn takeMultipleOf7Leb128(r: *Reader, comptime Result: type) TakeLeb128Error!Result {
|
||||
const result_info = @typeInfo(Result).int;
|
||||
comptime assert(result_info.bits % 7 == 0);
|
||||
|
|
@ -1251,9 +1233,9 @@ fn takeMultipleOf7Leb128(r: *Reader, comptime Result: type) TakeLeb128Error!Resu
|
|||
}
|
||||
}
|
||||
|
||||
/// Ensures `capacity` more data can be buffered without rebasing.
|
||||
/// Ensures `capacity` data can be buffered without rebasing.
|
||||
pub fn rebase(r: *Reader, capacity: usize) RebaseError!void {
|
||||
if (r.end + capacity <= r.buffer.len) {
|
||||
if (r.buffer.len - r.seek >= capacity) {
|
||||
@branchHint(.likely);
|
||||
return;
|
||||
}
|
||||
|
|
@ -1261,11 +1243,12 @@ pub fn rebase(r: *Reader, capacity: usize) RebaseError!void {
|
|||
}
|
||||
|
||||
pub fn defaultRebase(r: *Reader, capacity: usize) RebaseError!void {
|
||||
if (r.end <= r.buffer.len - capacity) return;
|
||||
assert(r.buffer.len - r.seek < capacity);
|
||||
const data = r.buffer[r.seek..r.end];
|
||||
@memmove(r.buffer[0..data.len], data);
|
||||
r.seek = 0;
|
||||
r.end = data.len;
|
||||
assert(r.buffer.len - r.seek >= capacity);
|
||||
}
|
||||
|
||||
test fixed {
|
||||
|
|
|
|||
|
|
@ -183,7 +183,6 @@ const InitError = error{
|
|||
/// `input` is asserted to have buffer capacity at least `min_buffer_len`.
|
||||
pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client {
|
||||
assert(input.buffer.len >= min_buffer_len);
|
||||
assert(output.buffer.len >= min_buffer_len);
|
||||
const host = switch (options.host) {
|
||||
.no_verification => "",
|
||||
.explicit => |host| host,
|
||||
|
|
@ -1124,12 +1123,6 @@ fn readIndirect(c: *Client) Reader.Error!usize {
|
|||
if (record_end > input.buffered().len) return 0;
|
||||
}
|
||||
|
||||
if (r.seek == r.end) {
|
||||
r.seek = 0;
|
||||
r.end = 0;
|
||||
}
|
||||
const cleartext_buffer = r.buffer[r.end..];
|
||||
|
||||
const cleartext_len, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
|
||||
inline else => |*p| switch (c.tls_version) {
|
||||
.tls_1_3 => {
|
||||
|
|
@ -1145,7 +1138,8 @@ fn readIndirect(c: *Client) Reader.Error!usize {
|
|||
const operand: V = pad ++ mem.toBytes(big(c.read_seq));
|
||||
break :nonce @as(V, pv.server_iv) ^ operand;
|
||||
};
|
||||
const cleartext = cleartext_buffer[0..ciphertext.len];
|
||||
rebase(r, ciphertext.len);
|
||||
const cleartext = r.buffer[r.end..][0..ciphertext.len];
|
||||
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
|
||||
return failRead(c, error.TlsBadRecordMac);
|
||||
// TODO use scalar, non-slice version
|
||||
|
|
@ -1171,7 +1165,8 @@ fn readIndirect(c: *Client) Reader.Error!usize {
|
|||
};
|
||||
const ciphertext = input.take(message_len) catch unreachable; // already peeked
|
||||
const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked
|
||||
const cleartext = cleartext_buffer[0..ciphertext.len];
|
||||
rebase(r, ciphertext.len);
|
||||
const cleartext = r.buffer[r.end..][0..ciphertext.len];
|
||||
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
|
||||
return failRead(c, error.TlsBadRecordMac);
|
||||
break :cleartext .{ cleartext.len, ct };
|
||||
|
|
@ -1179,7 +1174,7 @@ fn readIndirect(c: *Client) Reader.Error!usize {
|
|||
else => unreachable,
|
||||
},
|
||||
};
|
||||
const cleartext = cleartext_buffer[0..cleartext_len];
|
||||
const cleartext = r.buffer[r.end..][0..cleartext_len];
|
||||
c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow);
|
||||
switch (inner_ct) {
|
||||
.alert => {
|
||||
|
|
@ -1275,6 +1270,15 @@ fn readIndirect(c: *Client) Reader.Error!usize {
|
|||
}
|
||||
}
|
||||
|
||||
fn rebase(r: *Reader, capacity: usize) void {
|
||||
if (r.buffer.len - r.end >= capacity) return;
|
||||
const data = r.buffer[r.seek..r.end];
|
||||
@memmove(r.buffer[0..data.len], data);
|
||||
r.seek = 0;
|
||||
r.end = data.len;
|
||||
assert(r.buffer.len - r.end >= capacity);
|
||||
}
|
||||
|
||||
fn failRead(c: *Client, err: ReadError) error{ReadFailed} {
|
||||
c.read_err = err;
|
||||
return error.ReadFailed;
|
||||
|
|
|
|||
|
|
@ -329,6 +329,7 @@ pub const Reader = struct {
|
|||
/// read from `in`.
|
||||
trailers: []const u8 = &.{},
|
||||
body_err: ?BodyError = null,
|
||||
max_head_len: usize,
|
||||
|
||||
pub const RemainingChunkLen = enum(u64) {
|
||||
head = 0,
|
||||
|
|
@ -387,10 +388,11 @@ pub const Reader = struct {
|
|||
pub fn receiveHead(reader: *Reader) HeadError![]const u8 {
|
||||
reader.trailers = &.{};
|
||||
const in = reader.in;
|
||||
const max_head_len = reader.max_head_len;
|
||||
var hp: HeadParser = .{};
|
||||
var head_len: usize = 0;
|
||||
while (true) {
|
||||
if (in.buffer.len - head_len == 0) return error.HttpHeadersOversize;
|
||||
if (head_len >= max_head_len) return error.HttpHeadersOversize;
|
||||
const remaining = in.buffered()[head_len..];
|
||||
if (remaining.len == 0) {
|
||||
in.fillMore() catch |err| switch (err) {
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ connection_pool: ConnectionPool = .{},
|
|||
///
|
||||
/// If the entire HTTP header cannot fit in this amount of bytes,
|
||||
/// `error.HttpHeadersOversize` will be returned from `Request.wait`.
|
||||
read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len,
|
||||
read_buffer_size: usize = 8192,
|
||||
/// Each `Connection` allocates this amount for the writer buffer.
|
||||
write_buffer_size: usize = 1024,
|
||||
|
||||
|
|
@ -302,18 +302,22 @@ pub const Connection = struct {
|
|||
const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len);
|
||||
errdefer gpa.free(base);
|
||||
const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
|
||||
const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size];
|
||||
// The TLS client wants enough buffer for the max encrypted frame
|
||||
// size, and the HTTP body reader wants enough buffer for the
|
||||
// entire HTTP header. This means we need a combined upper bound.
|
||||
const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size;
|
||||
const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..tls_read_buffer_len];
|
||||
const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size];
|
||||
const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
|
||||
const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size];
|
||||
assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len);
|
||||
const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
|
||||
const socket_read_buffer = socket_write_buffer.ptr[socket_write_buffer.len..][0..client.tls_buffer_size];
|
||||
assert(base.ptr + alloc_len == socket_read_buffer.ptr + socket_read_buffer.len);
|
||||
@memcpy(host_buffer, remote_host);
|
||||
const tls: *Tls = @ptrCast(base);
|
||||
tls.* = .{
|
||||
.connection = .{
|
||||
.client = client,
|
||||
.stream_writer = stream.writer(tls_write_buffer),
|
||||
.stream_reader = stream.reader(tls_read_buffer),
|
||||
.stream_reader = stream.reader(socket_read_buffer),
|
||||
.pool_node = .{},
|
||||
.port = port,
|
||||
.host_len = @intCast(remote_host.len),
|
||||
|
|
@ -329,8 +333,8 @@ pub const Connection = struct {
|
|||
.host = .{ .explicit = remote_host },
|
||||
.ca = .{ .bundle = client.ca_bundle },
|
||||
.ssl_key_log = client.ssl_key_log,
|
||||
.read_buffer = read_buffer,
|
||||
.write_buffer = write_buffer,
|
||||
.read_buffer = tls_read_buffer,
|
||||
.write_buffer = socket_write_buffer,
|
||||
// This is appropriate for HTTPS because the HTTP headers contain
|
||||
// the content length which is used to detect truncation attacks.
|
||||
.allow_truncation_attacks = true,
|
||||
|
|
@ -348,8 +352,9 @@ pub const Connection = struct {
|
|||
}
|
||||
|
||||
fn allocLen(client: *Client, host_len: usize) usize {
|
||||
return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size +
|
||||
client.write_buffer_size + client.read_buffer_size;
|
||||
const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size;
|
||||
return @sizeOf(Tls) + host_len + tls_read_buffer_len + client.tls_buffer_size +
|
||||
client.write_buffer_size + client.tls_buffer_size;
|
||||
}
|
||||
|
||||
fn host(tls: *Tls) []u8 {
|
||||
|
|
@ -1214,6 +1219,7 @@ pub const Request = struct {
|
|||
.state = .ready,
|
||||
// Populated when `http.Reader.bodyReader` is called.
|
||||
.interface = undefined,
|
||||
.max_head_len = r.client.read_buffer_size,
|
||||
};
|
||||
r.redirect_behavior.subtractOne();
|
||||
}
|
||||
|
|
@ -1679,6 +1685,7 @@ pub fn request(
|
|||
.state = .ready,
|
||||
// Populated when `http.Reader.bodyReader` is called.
|
||||
.interface = undefined,
|
||||
.max_head_len = client.read_buffer_size,
|
||||
},
|
||||
.keep_alive = options.keep_alive,
|
||||
.method = method,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ pub fn init(in: *Reader, out: *Writer) Server {
|
|||
.state = .ready,
|
||||
// Populated when `http.Reader.bodyReader` is called.
|
||||
.interface = undefined,
|
||||
.max_head_len = in.buffer.len,
|
||||
},
|
||||
.out = out,
|
||||
};
|
||||
|
|
@ -251,6 +252,7 @@ pub const Request = struct {
|
|||
.in = undefined,
|
||||
.state = .received_head,
|
||||
.interface = undefined,
|
||||
.max_head_len = 4096,
|
||||
},
|
||||
.out = undefined,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1212,10 +1212,11 @@ fn unpackResource(
|
|||
return try unpackTarball(f, tmp_directory.handle, &adapter.new_interface);
|
||||
},
|
||||
.@"tar.zst" => {
|
||||
const window_size = std.compress.zstd.default_window_len;
|
||||
const window_buffer = try f.arena.allocator().create([window_size]u8);
|
||||
const window_len = std.compress.zstd.default_window_len;
|
||||
const window_buffer = try f.arena.allocator().alloc(u8, window_len + std.compress.zstd.block_size_max);
|
||||
var decompress: std.compress.zstd.Decompress = .init(resource.reader(), window_buffer, .{
|
||||
.verify_checksum = false,
|
||||
.window_len = window_len,
|
||||
});
|
||||
return try unpackTarball(f, tmp_directory.handle, &decompress.reader);
|
||||
},
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue