From 23ccff9cce2a5264fc84998bd2c897682ac266ea Mon Sep 17 00:00:00 2001 From: Nameless Date: Sun, 28 May 2023 09:50:51 +0200 Subject: [PATCH] std.http.Server: collapse BufferedConnection into Connection --- lib/std/http/Client.zig | 8 +- lib/std/http/Server.zig | 220 +++++++++++++++------------------------ test/standalone/http.zig | 1 - 3 files changed, 88 insertions(+), 141 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 053aa2a59f..91b688a25c 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -184,7 +184,7 @@ pub const Connection = struct { pub fn fill(conn: *Connection) ReadError!void { if (conn.read_end != conn.read_start) return; - const nread = try conn.read(conn.read_buf[0..]); + const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); if (nread == 0) return error.EndOfStream; conn.read_start = 0; conn.read_end = @intCast(u16, nread); @@ -207,13 +207,13 @@ pub const Connection = struct { const available_buffer = buffer.len - out_index; if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..][0..available_buffer]); + @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); out_index += @intCast(u16, available_buffer); conn.read_start += @intCast(u16, available_buffer); break; } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..]); + @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); out_index += available_read; conn.read_start += available_read; @@ -608,6 +608,8 @@ pub const Request = struct { try w.print("{}", .{req.headers}); try w.writeAll("\r\n"); + + try buffered.flush(); } pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 7d84cd3b58..67641eab00 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -16,39 +16,92 @@ socket: net.StreamServer, /// An interface to either a plain or TLS connection. pub const Connection = struct { + pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + pub const Protocol = enum { plain }; + stream: net.Stream, protocol: Protocol, closing: bool = true, - pub const Protocol = enum { plain }; + read_buf: [buffer_size]u8 = undefined, + read_start: u16 = 0, + read_end: u16 = 0, - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { return switch (conn.protocol) { - .plain => conn.stream.read(buffer), - // .tls => return conn.tls_client.read(conn.stream, buffer), - } catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, + .plain => conn.stream.readAtLeast(buffer, len), + // .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), + } catch |err| { + switch (err) { + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + } }; } + pub fn fill(conn: *Connection) ReadError!void { + if (conn.read_end != conn.read_start) return; + + const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); + if (nread == 0) return error.EndOfStream; + conn.read_start = 0; + conn.read_end = @intCast(u16, nread); + } + + pub fn peek(conn: *Connection) []const u8 { + return conn.read_buf[conn.read_start..conn.read_end]; + } + + pub fn drop(conn: *Connection, num: u16) void { + conn.read_start += num; + } + pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - return switch (conn.protocol) { - .plain => conn.stream.readAtLeast(buffer, len), - // .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len), - } catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - }; + assert(len <= buffer.len); + + var out_index: u16 = 0; + while (out_index < len) { + const available_read = conn.read_end - conn.read_start; + const available_buffer = buffer.len - out_index; + + if (available_read > available_buffer) { // partially read buffered data + @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); + out_index += @intCast(u16, available_buffer); + conn.read_start += @intCast(u16, available_buffer); + + break; + } else if (available_read > 0) { // fully read buffered data + @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); + out_index += available_read; + conn.read_start += available_read; + + if (out_index >= len) break; + } + + const leftover_buffer = available_buffer - available_read; + const leftover_len = len - out_index; + + if (leftover_buffer > conn.read_buf.len) { + // skip the buffer if the output is large enough + return conn.rawReadAtLeast(buffer[out_index..], leftover_len); + } + + try conn.fill(); + } + + return out_index; + } + + pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + return conn.readAtLeast(buffer, 1); } pub const ReadError = error{ ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure, + EndOfStream, }; pub const Reader = std.io.Reader(*Connection, ReadError, read); @@ -93,112 +146,6 @@ pub const Connection = struct { } }; -/// A buffered (and peekable) Connection. -pub const BufferedConnection = struct { - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - - conn: Connection, - read_buf: [buffer_size]u8 = undefined, - read_start: u16 = 0, - read_end: u16 = 0, - - write_buf: [buffer_size]u8 = undefined, - write_end: u16 = 0, - - pub fn fill(bconn: *BufferedConnection) ReadError!void { - if (bconn.read_end != bconn.read_start) return; - - const nread = try bconn.conn.read(bconn.read_buf[0..]); - if (nread == 0) return error.EndOfStream; - bconn.read_start = 0; - bconn.read_end = @intCast(u16, nread); - } - - pub fn peek(bconn: *BufferedConnection) []const u8 { - return bconn.read_buf[bconn.read_start..bconn.read_end]; - } - - pub fn drop(bconn: *BufferedConnection, num: u16) void { - bconn.read_start += num; - } - - pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize { - var out_index: u16 = 0; - while (out_index < len) { - const available = bconn.read_end - bconn.read_start; - const left = buffer.len - out_index; - - if (available > 0) { - const can_read = @intCast(u16, @min(available, left)); - - @memcpy(buffer[out_index..][0..can_read], bconn.read_buf[bconn.read_start..][0..can_read]); - out_index += can_read; - bconn.read_start += can_read; - - continue; - } - - if (left > bconn.read_buf.len) { - // skip the buffer if the output is large enough - return bconn.conn.read(buffer[out_index..]); - } - - try bconn.fill(); - } - - return out_index; - } - - pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize { - return bconn.readAtLeast(buffer, 1); - } - - pub const ReadError = Connection.ReadError || error{EndOfStream}; - pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read); - - pub fn reader(bconn: *BufferedConnection) Reader { - return Reader{ .context = bconn }; - } - - pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { - if (bconn.write_buf.len - bconn.write_end >= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); - bconn.write_end += @intCast(u16, buffer.len); - } else { - try bconn.flush(); - try bconn.conn.writeAll(buffer); - } - } - - pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { - if (bconn.write_buf.len - bconn.write_end >= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); - bconn.write_end += @intCast(u16, buffer.len); - - return buffer.len; - } else { - try bconn.flush(); - return try bconn.conn.write(buffer); - } - } - - pub fn flush(bconn: *BufferedConnection) WriteError!void { - defer bconn.write_end = 0; - return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - } - - pub const WriteError = Connection.WriteError; - pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write); - - pub fn writer(bconn: *BufferedConnection) Writer { - return Writer{ .context = bconn }; - } - - pub fn close(bconn: *BufferedConnection) void { - bconn.conn.close(); - } -}; - /// The mode of transport for responses. pub const ResponseTransfer = union(enum) { content_length: u64, @@ -351,7 +298,7 @@ pub const Response = struct { allocator: Allocator, address: net.Address, - connection: BufferedConnection, + connection: Connection, headers: http.Headers, request: Request, @@ -388,7 +335,7 @@ pub const Response = struct { if (!res.request.parser.done) { // If the response wasn't fully read, then we need to close the connection. - res.connection.conn.closing = true; + res.connection.closing = true; return .closing; } @@ -402,9 +349,9 @@ pub const Response = struct { const req_connection = res.request.headers.getFirstValue("connection"); const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); if (req_keepalive and (res_keepalive or res_connection == null)) { - res.connection.conn.closing = false; + res.connection.closing = false; } else { - res.connection.conn.closing = true; + res.connection.closing = true; } switch (res.request.compression) { @@ -434,14 +381,14 @@ pub const Response = struct { .parser = res.request.parser, }; - if (res.connection.conn.closing) { + if (res.connection.closing) { return .closing; } else { return .reset; } } - pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; + pub const DoError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; /// Send the response headers. pub fn do(res: *Response) !void { @@ -450,7 +397,8 @@ pub const Response = struct { .first, .start, .responded, .finished => unreachable, } - const w = res.connection.writer(); + var buffered = std.io.bufferedWriter(res.connection.writer()); + const w = buffered.writer(); try w.writeAll(@tagName(res.version)); try w.writeByte(' '); @@ -508,10 +456,10 @@ pub const Response = struct { try w.writeAll("\r\n"); - try res.connection.flush(); + try buffered.flush(); } - pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError; + pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead); @@ -532,7 +480,7 @@ pub const Response = struct { return index; } - pub const WaitError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported }; + pub const WaitError = Connection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported }; /// Wait for the client to send a complete request head. pub fn wait(res: *Response) WaitError!void { @@ -637,7 +585,7 @@ pub const Response = struct { return index; } - pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong }; + pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; pub const Writer = std.io.Writer(*Response, WriteError, write); @@ -692,8 +640,6 @@ pub const Response = struct { .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } - - try res.connection.flush(); } }; @@ -742,10 +688,10 @@ pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response { return Response{ .allocator = options.allocator, .address = in.address, - .connection = .{ .conn = .{ + .connection = .{ .stream = in.stream, .protocol = .plain, - } }, + }, .headers = .{ .allocator = options.allocator }, .request = .{ .version = undefined, diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 13dc278b6d..ffb7a59276 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -86,7 +86,6 @@ fn handleRequest(res: *Server.Response) !void { try res.writeAll("World!\n"); // try res.finish(); try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); - try res.connection.flush(); } else if (mem.eql(u8, res.request.target, "/redirect/1")) { res.transfer_encoding = .chunked;