From 6129ecd4fe88e14531db98866c92c4a5660849ee Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 18 Feb 2024 20:22:09 -0700 Subject: [PATCH] std.net, std.http: simplify --- lib/std/http/Server.zig | 960 +++++++++++------------------ lib/std/http/Server/Connection.zig | 132 ++++ lib/std/http/test.zig | 92 ++- lib/std/net.zig | 245 +++----- lib/std/net/test.zig | 26 +- lib/std/os.zig | 8 - lib/std/os/linux/io_uring.zig | 31 +- lib/std/os/test.zig | 2 +- src/main.zig | 6 +- test/standalone/http.zig | 53 +- 10 files changed, 714 insertions(+), 841 deletions(-) create mode 100644 lib/std/http/Server/Connection.zig diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index f3ee7710a0..4176829f07 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,156 +1,55 @@ -//! HTTP Server implementation. -//! -//! This server assumes clients are well behaved and standard compliant; it -//! deadlocks if a client holds a connection open without sending a request. +version: http.Version, +status: http.Status, +reason: ?[]const u8, +transfer_encoding: ResponseTransfer, +keep_alive: bool, +connection: Connection, -const builtin = @import("builtin"); -const std = @import("../std.zig"); -const testing = std.testing; -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const Allocator = mem.Allocator; -const assert = std.debug.assert; +/// Externally-owned; must outlive the Server. +extra_headers: []const http.Header, -const Server = @This(); -const proto = @import("protocol.zig"); +/// The HTTP request that this response is responding to. +/// +/// This field is only valid after calling `wait`. +request: Request, -/// The underlying server socket. -socket: net.StreamServer, +state: State = .first, -/// An interface to a plain connection. -pub const Connection = struct { - stream: net.Stream, - protocol: Protocol, - - closing: bool = true, - - read_buf: [buffer_size]u8 = undefined, - read_start: u16 = 0, - read_end: u16 = 0, - - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - pub const Protocol = enum { plain }; - - pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - return switch (conn.protocol) { - .plain => conn.stream.readAtLeast(buffer, len), - // .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), - } catch |err| { - switch (err) { - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; - } - - pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; - - const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @intCast(nread); - } - - pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; - } - - pub fn drop(conn: *Connection, num: u16) void { - conn.read_start += num; - } - - pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - - var out_index: u16 = 0; - while (out_index < len) { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len - out_index; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - out_index += @as(u16, @intCast(available_buffer)); - conn.read_start += @as(u16, @intCast(available_buffer)); - - break; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - out_index += available_read; - conn.read_start += available_read; - - if (out_index >= len) break; - } - - const leftover_buffer = available_buffer - available_read; - const leftover_len = len - out_index; - - if (leftover_buffer > conn.read_buf.len) { - // skip the buffer if the output is large enough - return conn.rawReadAtLeast(buffer[out_index..], leftover_len); - } - - try conn.fill(); - } - - return out_index; - } - - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); - } - - pub const ReadError = error{ - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, +/// Initialize an HTTP server that can respond to multiple requests on the same +/// connection. +/// The returned `Server` is ready for `reset` or `wait` to be called. +pub fn init(connection: std.net.Server.Connection, options: Server.Request.InitOptions) Server { + return .{ + .transfer_encoding = .none, + .keep_alive = true, + .connection = .{ + .stream = connection.stream, + .protocol = .plain, + .closing = true, + .read_buf = undefined, + .read_start = 0, + .read_end = 0, + }, + .request = Server.Request.init(options), + .version = .@"HTTP/1.1", + .status = .ok, + .reason = null, + .extra_headers = &.{}, }; +} - pub const Reader = std.io.Reader(*Connection, ReadError, read); - - pub fn reader(conn: *Connection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void { - return switch (conn.protocol) { - .plain => conn.stream.writeAll(buffer), - // .tls => return conn.tls_client.writeAll(conn.stream, buffer), - } catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - return switch (conn.protocol) { - .plain => conn.stream.write(buffer), - // .tls => return conn.tls_client.write(conn.stream, buffer), - } catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, - }; - - pub const Writer = std.io.Writer(*Connection, WriteError, write); - - pub fn writer(conn: *Connection) Writer { - return Writer{ .context = conn }; - } - - pub fn close(conn: *Connection) void { - conn.stream.close(); - } +pub const State = enum { + first, + start, + waited, + responded, + finished, }; +pub const ResetState = enum { reset, closing }; + +pub const Connection = @import("Server/Connection.zig"); + /// The mode of transport for responses. pub const ResponseTransfer = union(enum) { content_length: u64, @@ -160,10 +59,10 @@ pub const ResponseTransfer = union(enum) { /// The decompressor for request messages. pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(Response.TransferReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(Response.TransferReader); + pub const DeflateDecompressor = std.compress.zlib.Decompressor(Server.TransferReader); + pub const GzipDecompressor = std.compress.gzip.Decompressor(Server.TransferReader); // https://github.com/ziglang/zig/issues/18937 - //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{}); + //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Server.TransferReader, .{}); deflate: DeflateDecompressor, gzip: GzipDecompressor, @@ -177,14 +76,37 @@ pub const Request = struct { method: http.Method, target: []const u8, version: http.Version, - expect: ?[]const u8 = null, - content_type: ?[]const u8 = null, - content_length: ?u64 = null, - transfer_encoding: http.TransferEncoding = .none, - transfer_compression: http.ContentEncoding = .identity, - keep_alive: bool = false, + expect: ?[]const u8, + content_type: ?[]const u8, + content_length: ?u64, + transfer_encoding: http.TransferEncoding, + transfer_compression: http.ContentEncoding, + keep_alive: bool, parser: proto.HeadersParser, - compression: Compression = .none, + compression: Compression, + + pub const InitOptions = struct { + /// Externally-owned memory used to store the client's entire HTTP header. + /// `error.HttpHeadersOversize` is returned from read() when a + /// client sends too many bytes of HTTP headers. + client_header_buffer: []u8, + }; + + pub fn init(options: InitOptions) Request { + return .{ + .method = undefined, + .target = undefined, + .version = undefined, + .expect = null, + .content_type = null, + .content_length = null, + .transfer_encoding = .none, + .transfer_compression = .identity, + .keep_alive = false, + .parser = proto.HeadersParser.init(options.client_header_buffer), + .compression = .none, + }; + } pub const ParseError = Allocator.Error || error{ UnknownHttpMethod, @@ -300,478 +222,316 @@ pub const Request = struct { } }; -/// A HTTP response waiting to be sent. -/// -/// Order of operations: -/// ``` -/// [/ <--------------------------------------- \] -/// accept -> wait -> send [ -> write -> finish][ -> reset /] -/// \ -> read / -/// ``` -pub const Response = struct { - version: http.Version = .@"HTTP/1.1", - status: http.Status = .ok, - reason: ?[]const u8 = null, - transfer_encoding: ResponseTransfer, - keep_alive: bool, - - /// The peer's address - address: net.Address, - - /// The underlying connection for this response. - connection: Connection, - - /// Externally-owned; must outlive the Response. - extra_headers: []const http.Header = &.{}, - - /// The HTTP request that this response is responding to. - /// - /// This field is only valid after calling `wait`. - request: Request, - - state: State = .first, - - pub const State = enum { - first, - start, - waited, - responded, - finished, - }; - - /// Free all resources associated with this response. - pub fn deinit(res: *Response) void { - res.connection.close(); - } - - pub const ResetState = enum { reset, closing }; - - /// Reset this response to its initial state. This must be called before - /// handling a second request on the same connection. - pub fn reset(res: *Response) ResetState { - if (res.state == .first) { - res.state = .start; - return .reset; - } - - if (!res.request.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - res.connection.closing = true; - return .closing; - } - - // A connection is only keep-alive if the Connection header is present - // and its value is not "close". The server and client must both agree. - // - // send() defaults to using keep-alive if the client requests it. - res.connection.closing = !res.keep_alive or !res.request.keep_alive; - +/// Reset this response to its initial state. This must be called before +/// handling a second request on the same connection. +pub fn reset(res: *Server) ResetState { + if (res.state == .first) { res.state = .start; - res.version = .@"HTTP/1.1"; - res.status = .ok; - res.reason = null; - - res.transfer_encoding = .none; - - res.request.parser.reset(); - - res.request = .{ - .version = undefined, - .method = undefined, - .target = undefined, - .parser = res.request.parser, - }; - - return if (res.connection.closing) .closing else .reset; + return .reset; } - pub const SendError = Connection.WriteError || error{ - UnsupportedTransferEncoding, - InvalidContentLength, + if (!res.request.parser.done) { + // If the response wasn't fully read, then we need to close the connection. + res.connection.closing = true; + return .closing; + } + + // A connection is only keep-alive if the Connection header is present + // and its value is not "close". The server and client must both agree. + // + // send() defaults to using keep-alive if the client requests it. + res.connection.closing = !res.keep_alive or !res.request.keep_alive; + + res.state = .start; + res.version = .@"HTTP/1.1"; + res.status = .ok; + res.reason = null; + + res.transfer_encoding = .none; + + res.request = Request.init(.{ + .client_header_buffer = res.request.parser.header_bytes_buffer, + }); + + return if (res.connection.closing) .closing else .reset; +} + +pub const SendError = Connection.WriteError || error{ + UnsupportedTransferEncoding, + InvalidContentLength, +}; + +/// Send the HTTP response headers to the client. +pub fn send(res: *Server) SendError!void { + switch (res.state) { + .waited => res.state = .responded, + .first, .start, .responded, .finished => unreachable, + } + + var buffered = std.io.bufferedWriter(res.connection.writer()); + const w = buffered.writer(); + + try w.writeAll(@tagName(res.version)); + try w.writeByte(' '); + try w.print("{d}", .{@intFromEnum(res.status)}); + try w.writeByte(' '); + if (res.reason) |reason| { + try w.writeAll(reason); + } else if (res.status.phrase()) |phrase| { + try w.writeAll(phrase); + } + try w.writeAll("\r\n"); + + if (res.status == .@"continue") { + res.state = .waited; // we still need to send another request after this + } else { + if (res.keep_alive and res.request.keep_alive) { + try w.writeAll("connection: keep-alive\r\n"); + } else { + try w.writeAll("connection: close\r\n"); + } + + switch (res.transfer_encoding) { + .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), + .content_length => |content_length| try w.print("content-length: {d}\r\n", .{content_length}), + .none => {}, + } + + for (res.extra_headers) |header| { + try w.print("{s}: {s}\r\n", .{ header.name, header.value }); + } + } + + if (res.request.method == .HEAD) { + res.transfer_encoding = .none; + } + + try w.writeAll("\r\n"); + + try buffered.flush(); +} + +const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + +const TransferReader = std.io.Reader(*Server, TransferReadError, transferRead); + +fn transferReader(res: *Server) TransferReader { + return .{ .context = res }; +} + +fn transferRead(res: *Server, buf: []u8) TransferReadError!usize { + if (res.request.parser.done) return 0; + + var index: usize = 0; + while (index == 0) { + const amt = try res.request.parser.read(&res.connection, buf[index..], false); + if (amt == 0 and res.request.parser.done) break; + index += amt; + } + + return index; +} + +pub const WaitError = Connection.ReadError || + proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || + error{CompressionUnsupported}; + +/// Wait for the client to send a complete request head. +/// +/// For correct behavior, the following rules must be followed: +/// +/// * If this returns any error in `Connection.ReadError`, you MUST +/// immediately close the connection by calling `deinit`. +/// * If this returns `error.HttpHeadersInvalid`, you MAY immediately close +/// the connection by calling `deinit`. +/// * If this returns `error.HttpHeadersOversize`, you MUST +/// respond with a 431 status code and then call `deinit`. +/// * If this returns any error in `Request.ParseError`, you MUST respond +/// with a 400 status code and then call `deinit`. +/// * If this returns any other error, you MUST respond with a 400 status +/// code and then call `deinit`. +/// * If the request has an Expect header containing 100-continue, you MUST either: +/// * Respond with a 100 status code, then call `wait` again. +/// * Respond with a 417 status code. +pub fn wait(res: *Server) WaitError!void { + switch (res.state) { + .first, .start => res.state = .waited, + .waited, .responded, .finished => unreachable, + } + + while (true) { + try res.connection.fill(); + + const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek()); + res.connection.drop(@intCast(nchecked)); + + if (res.request.parser.state.isContent()) break; + } + + try res.request.parse(res.request.parser.get()); + + switch (res.request.transfer_encoding) { + .none => { + if (res.request.content_length) |len| { + res.request.parser.next_chunk_length = len; + + if (len == 0) res.request.parser.done = true; + } else { + res.request.parser.done = true; + } + }, + .chunked => { + res.request.parser.next_chunk_length = 0; + res.request.parser.state = .chunk_head_size; + }, + } + + if (!res.request.parser.done) { + switch (res.request.transfer_compression) { + .identity => res.request.compression = .none, + .compress, .@"x-compress" => return error.CompressionUnsupported, + .deflate => res.request.compression = .{ + .deflate = std.compress.zlib.decompressor(res.transferReader()), + }, + .gzip, .@"x-gzip" => res.request.compression = .{ + .gzip = std.compress.gzip.decompressor(res.transferReader()), + }, + .zstd => { + // https://github.com/ziglang/zig/issues/18937 + return error.CompressionUnsupported; + }, + } + } +} + +pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers }; + +pub const Reader = std.io.Reader(*Server, ReadError, read); + +pub fn reader(res: *Server) Reader { + return .{ .context = res }; +} + +/// Reads data from the response body. Must be called after `wait`. +pub fn read(res: *Server, buffer: []u8) ReadError!usize { + switch (res.state) { + .waited, .responded, .finished => {}, + .first, .start => unreachable, + } + + const out_index = switch (res.request.compression) { + .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, + .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, + else => try res.transferRead(buffer), }; - /// Send the HTTP response headers to the client. - pub fn send(res: *Response) SendError!void { - switch (res.state) { - .waited => res.state = .responded, - .first, .start, .responded, .finished => unreachable, - } + if (out_index == 0) { + const has_trail = !res.request.parser.state.isContent(); - var buffered = std.io.bufferedWriter(res.connection.writer()); - const w = buffered.writer(); - - try w.writeAll(@tagName(res.version)); - try w.writeByte(' '); - try w.print("{d}", .{@intFromEnum(res.status)}); - try w.writeByte(' '); - if (res.reason) |reason| { - try w.writeAll(reason); - } else if (res.status.phrase()) |phrase| { - try w.writeAll(phrase); - } - try w.writeAll("\r\n"); - - if (res.status == .@"continue") { - res.state = .waited; // we still need to send another request after this - } else { - if (res.keep_alive and res.request.keep_alive) { - try w.writeAll("connection: keep-alive\r\n"); - } else { - try w.writeAll("connection: close\r\n"); - } - - switch (res.transfer_encoding) { - .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), - .content_length => |content_length| try w.print("content-length: {d}\r\n", .{content_length}), - .none => {}, - } - - for (res.extra_headers) |header| { - try w.print("{s}: {s}\r\n", .{ header.name, header.value }); - } - } - - if (res.request.method == .HEAD) { - res.transfer_encoding = .none; - } - - try w.writeAll("\r\n"); - - try buffered.flush(); - } - - const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - - const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead); - - fn transferReader(res: *Response) TransferReader { - return .{ .context = res }; - } - - fn transferRead(res: *Response, buf: []u8) TransferReadError!usize { - if (res.request.parser.done) return 0; - - var index: usize = 0; - while (index == 0) { - const amt = try res.request.parser.read(&res.connection, buf[index..], false); - if (amt == 0 and res.request.parser.done) break; - index += amt; - } - - return index; - } - - pub const WaitError = Connection.ReadError || - proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || - error{CompressionUnsupported}; - - /// Wait for the client to send a complete request head. - /// - /// For correct behavior, the following rules must be followed: - /// - /// * If this returns any error in `Connection.ReadError`, you MUST - /// immediately close the connection by calling `deinit`. - /// * If this returns `error.HttpHeadersInvalid`, you MAY immediately close - /// the connection by calling `deinit`. - /// * If this returns `error.HttpHeadersOversize`, you MUST - /// respond with a 431 status code and then call `deinit`. - /// * If this returns any error in `Request.ParseError`, you MUST respond - /// with a 400 status code and then call `deinit`. - /// * If this returns any other error, you MUST respond with a 400 status - /// code and then call `deinit`. - /// * If the request has an Expect header containing 100-continue, you MUST either: - /// * Respond with a 100 status code, then call `wait` again. - /// * Respond with a 417 status code. - pub fn wait(res: *Response) WaitError!void { - switch (res.state) { - .first, .start => res.state = .waited, - .waited, .responded, .finished => unreachable, - } - - while (true) { + while (!res.request.parser.state.isContent()) { // read trailing headers try res.connection.fill(); const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek()); res.connection.drop(@intCast(nchecked)); - - if (res.request.parser.state.isContent()) break; } - try res.request.parse(res.request.parser.get()); - - switch (res.request.transfer_encoding) { - .none => { - if (res.request.content_length) |len| { - res.request.parser.next_chunk_length = len; - - if (len == 0) res.request.parser.done = true; - } else { - res.request.parser.done = true; - } - }, - .chunked => { - res.request.parser.next_chunk_length = 0; - res.request.parser.state = .chunk_head_size; - }, - } - - if (!res.request.parser.done) { - switch (res.request.transfer_compression) { - .identity => res.request.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - .deflate => res.request.compression = .{ - .deflate = std.compress.zlib.decompressor(res.transferReader()), - }, - .gzip, .@"x-gzip" => res.request.compression = .{ - .gzip = std.compress.gzip.decompressor(res.transferReader()), - }, - .zstd => { - // https://github.com/ziglang/zig/issues/18937 - return error.CompressionUnsupported; - }, - } + if (has_trail) { + // The response headers before the trailers are already + // guaranteed to be valid, so they will always be parsed again + // and cannot return an error. + // This will *only* fail for a malformed trailer. + res.request.parse(res.request.parser.get()) catch return error.InvalidTrailers; } } - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers }; - - pub const Reader = std.io.Reader(*Response, ReadError, read); - - pub fn reader(res: *Response) Reader { - return .{ .context = res }; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn read(res: *Response, buffer: []u8) ReadError!usize { - switch (res.state) { - .waited, .responded, .finished => {}, - .first, .start => unreachable, - } - - const out_index = switch (res.request.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try res.transferRead(buffer), - }; - - if (out_index == 0) { - const has_trail = !res.request.parser.state.isContent(); - - while (!res.request.parser.state.isContent()) { // read trailing headers - try res.connection.fill(); - - const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek()); - res.connection.drop(@intCast(nchecked)); - } - - if (has_trail) { - // The response headers before the trailers are already - // guaranteed to be valid, so they will always be parsed again - // and cannot return an error. - // This will *only* fail for a malformed trailer. - res.request.parse(res.request.parser.get()) catch return error.InvalidTrailers; - } - } - - return out_index; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(res: *Response, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(res, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - - pub const Writer = std.io.Writer(*Response, WriteError, write); - - pub fn writer(res: *Response) Writer { - return .{ .context = res }; - } - - /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn write(res: *Response, bytes: []const u8) WriteError!usize { - switch (res.state) { - .responded => {}, - .first, .waited, .start, .finished => unreachable, - } - - switch (res.transfer_encoding) { - .chunked => { - if (bytes.len > 0) { - try res.connection.writer().print("{x}\r\n", .{bytes.len}); - try res.connection.writeAll(bytes); - try res.connection.writeAll("\r\n"); - } - - return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try res.connection.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } - } - - /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Response, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } - } - - pub const FinishError = Connection.WriteError || error{MessageNotCompleted}; - - /// Finish the body of a request. This notifies the server that you have no more data to send. - /// Must be called after `send`. - pub fn finish(res: *Response) FinishError!void { - switch (res.state) { - .responded => res.state = .finished, - .first, .waited, .start, .finished => unreachable, - } - - switch (res.transfer_encoding) { - .chunked => try res.connection.writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, - } - } -}; - -/// Create a new HTTP server. -pub fn init(options: net.StreamServer.Options) Server { - return .{ - .socket = net.StreamServer.init(options), - }; + return out_index; } -/// Free all resources associated with this server. -pub fn deinit(server: *Server) void { - server.socket.deinit(); +/// Reads data from the response body. Must be called after `wait`. +pub fn readAll(res: *Server, buffer: []u8) !usize { + var index: usize = 0; + while (index < buffer.len) { + const amt = try read(res, buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; } -pub const ListenError = std.os.SocketError || std.os.BindError || std.os.ListenError || std.os.SetSockOptError || std.os.GetSockNameError; +pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; -/// Start the HTTP server listening on the given address. -pub fn listen(server: *Server, address: net.Address) ListenError!void { - try server.socket.listen(address); +pub const Writer = std.io.Writer(*Server, WriteError, write); + +pub fn writer(res: *Server) Writer { + return .{ .context = res }; } -pub const AcceptError = net.StreamServer.AcceptError; +/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. +/// Must be called after `send` and before `finish`. +pub fn write(res: *Server, bytes: []const u8) WriteError!usize { + switch (res.state) { + .responded => {}, + .first, .waited, .start, .finished => unreachable, + } -pub const AcceptOptions = struct { - /// Externally-owned memory used to store the client's entire HTTP header. - /// `error.HttpHeadersOversize` is returned from read() when a - /// client sends too many bytes of HTTP headers. - client_header_buffer: []u8, -}; + switch (res.transfer_encoding) { + .chunked => { + if (bytes.len > 0) { + try res.connection.writer().print("{x}\r\n", .{bytes.len}); + try res.connection.writeAll(bytes); + try res.connection.writeAll("\r\n"); + } -pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response { - const in = try server.socket.accept(); - - return .{ - .transfer_encoding = .none, - .keep_alive = true, - .address = in.address, - .connection = .{ - .stream = in.stream, - .protocol = .plain, + return bytes.len; }, - .request = .{ - .version = undefined, - .method = undefined, - .target = undefined, - .parser = proto.HeadersParser.init(options.client_header_buffer), + .content_length => |*len| { + if (len.* < bytes.len) return error.MessageTooLong; + + const amt = try res.connection.write(bytes); + len.* -= amt; + return amt; }, - }; + .none => return error.NotWriteable, + } } -test "HTTP server handles a chunked transfer coding request" { - // This test requires spawning threads. - if (builtin.single_threaded) { - return error.SkipZigTest; +/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. +/// Must be called after `send` and before `finish`. +pub fn writeAll(req: *Server, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try write(req, bytes[index..]); } - - const native_endian = comptime builtin.cpu.arch.endian(); - if (builtin.zig_backend == .stage2_llvm and native_endian == .big) { - // https://github.com/ziglang/zig/issues/13782 - return error.SkipZigTest; - } - - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - const allocator = std.testing.allocator; - const expect = std.testing.expect; - - const max_header_size = 8192; - var server = std.http.Server.init(.{ .reuse_address = true }); - defer server.deinit(); - - const address = try std.net.Address.parseIp("127.0.0.1", 0); - try server.listen(address); - const server_port = server.socket.listen_address.in.getPort(); - - const server_thread = try std.Thread.spawn(.{}, (struct { - fn apply(s: *std.http.Server) !void { - var header_buffer: [max_header_size]u8 = undefined; - var res = try s.accept(.{ - .allocator = allocator, - .client_header_buffer = &header_buffer, - }); - defer res.deinit(); - defer _ = res.reset(); - try res.wait(); - - try expect(res.request.transfer_encoding == .chunked); - - const server_body: []const u8 = "message from server!\n"; - res.transfer_encoding = .{ .content_length = server_body.len }; - res.extra_headers = &.{ - .{ .name = "content-type", .value = "text/plain" }, - }; - res.keep_alive = false; - try res.send(); - - var buf: [128]u8 = undefined; - const n = try res.readAll(&buf); - try expect(std.mem.eql(u8, buf[0..n], "ABCD")); - _ = try res.writer().writeAll(server_body); - try res.finish(); - } - }).apply, .{&server}); - - const request_bytes = - "POST / HTTP/1.1\r\n" ++ - "Content-Type: text/plain\r\n" ++ - "Transfer-Encoding: chunked\r\n" ++ - "\r\n" ++ - "1\r\n" ++ - "A\r\n" ++ - "1\r\n" ++ - "B\r\n" ++ - "2\r\n" ++ - "CD\r\n" ++ - "0\r\n" ++ - "\r\n"; - - const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port); - defer stream.close(); - _ = try stream.writeAll(request_bytes[0..]); - - server_thread.join(); } + +pub const FinishError = Connection.WriteError || error{MessageNotCompleted}; + +/// Finish the body of a request. This notifies the server that you have no more data to send. +/// Must be called after `send`. +pub fn finish(res: *Server) FinishError!void { + switch (res.state) { + .responded => res.state = .finished, + .first, .waited, .start, .finished => unreachable, + } + + switch (res.transfer_encoding) { + .chunked => try res.connection.writeAll("0\r\n\r\n"), + .content_length => |len| if (len != 0) return error.MessageNotCompleted, + .none => {}, + } +} + +const builtin = @import("builtin"); +const std = @import("../std.zig"); +const testing = std.testing; +const http = std.http; +const mem = std.mem; +const net = std.net; +const Uri = std.Uri; +const Allocator = mem.Allocator; +const assert = std.debug.assert; + +const Server = @This(); +const proto = @import("protocol.zig"); diff --git a/lib/std/http/Server/Connection.zig b/lib/std/http/Server/Connection.zig new file mode 100644 index 0000000000..52b870992a --- /dev/null +++ b/lib/std/http/Server/Connection.zig @@ -0,0 +1,132 @@ +stream: std.net.Stream, +protocol: Protocol, + +closing: bool, + +read_buf: [buffer_size]u8, +read_start: u16, +read_end: u16, + +pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; +pub const Protocol = enum { plain }; + +pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { + return switch (conn.protocol) { + .plain => conn.stream.readAtLeast(buffer, len), + // .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), + } catch |err| { + switch (err) { + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + } + }; +} + +pub fn fill(conn: *Connection) ReadError!void { + if (conn.read_end != conn.read_start) return; + + const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); + if (nread == 0) return error.EndOfStream; + conn.read_start = 0; + conn.read_end = @intCast(nread); +} + +pub fn peek(conn: *Connection) []const u8 { + return conn.read_buf[conn.read_start..conn.read_end]; +} + +pub fn drop(conn: *Connection, num: u16) void { + conn.read_start += num; +} + +pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); + + var out_index: u16 = 0; + while (out_index < len) { + const available_read = conn.read_end - conn.read_start; + const available_buffer = buffer.len - out_index; + + if (available_read > available_buffer) { // partially read buffered data + @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); + out_index += @as(u16, @intCast(available_buffer)); + conn.read_start += @as(u16, @intCast(available_buffer)); + + break; + } else if (available_read > 0) { // fully read buffered data + @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); + out_index += available_read; + conn.read_start += available_read; + + if (out_index >= len) break; + } + + const leftover_buffer = available_buffer - available_read; + const leftover_len = len - out_index; + + if (leftover_buffer > conn.read_buf.len) { + // skip the buffer if the output is large enough + return conn.rawReadAtLeast(buffer[out_index..], leftover_len); + } + + try conn.fill(); + } + + return out_index; +} + +pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + return conn.readAtLeast(buffer, 1); +} + +pub const ReadError = error{ + ConnectionTimedOut, + ConnectionResetByPeer, + UnexpectedReadFailure, + EndOfStream, +}; + +pub const Reader = std.io.Reader(*Connection, ReadError, read); + +pub fn reader(conn: *Connection) Reader { + return .{ .context = conn }; +} + +pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void { + return switch (conn.protocol) { + .plain => conn.stream.writeAll(buffer), + // .tls => return conn.tls_client.writeAll(conn.stream, buffer), + } catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; +} + +pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { + return switch (conn.protocol) { + .plain => conn.stream.write(buffer), + // .tls => return conn.tls_client.write(conn.stream, buffer), + } catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; +} + +pub const WriteError = error{ + ConnectionResetByPeer, + UnexpectedWriteFailure, +}; + +pub const Writer = std.io.Writer(*Connection, WriteError, write); + +pub fn writer(conn: *Connection) Writer { + return .{ .context = conn }; +} + +pub fn close(conn: *Connection) void { + conn.stream.close(); +} + +const Connection = @This(); +const std = @import("../../std.zig"); +const assert = std.debug.assert; diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 0254e5cc2c..3441346baf 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -8,13 +8,12 @@ test "trailers" { const gpa = testing.allocator; - var http_server = std.http.Server.init(.{ + const address = try std.net.Address.parseIp("127.0.0.1", 0); + var http_server = try address.listen(.{ .reuse_address = true, }); - const address = try std.net.Address.parseIp("127.0.0.1", 0); - try http_server.listen(address); - const port = http_server.socket.listen_address.in.getPort(); + const port = http_server.listen_address.in.getPort(); const server_thread = try std.Thread.spawn(.{}, serverThread, .{&http_server}); defer server_thread.join(); @@ -67,17 +66,14 @@ test "trailers" { try testing.expect(client.connection_pool.free_len == 1); } -fn serverThread(http_server: *std.http.Server) anyerror!void { - const gpa = testing.allocator; - +fn serverThread(http_server: *std.net.Server) anyerror!void { var header_buffer: [1024]u8 = undefined; var remaining: usize = 1; accept: while (remaining != 0) : (remaining -= 1) { - var res = try http_server.accept(.{ - .allocator = gpa, - .client_header_buffer = &header_buffer, - }); - defer res.deinit(); + const conn = try http_server.accept(); + defer conn.stream.close(); + + var res = std.http.Server.init(conn, .{ .client_header_buffer = &header_buffer }); res.wait() catch |err| switch (err) { error.HttpHeadersInvalid => continue :accept, @@ -90,7 +86,7 @@ fn serverThread(http_server: *std.http.Server) anyerror!void { } } -fn serve(res: *std.http.Server.Response) !void { +fn serve(res: *std.http.Server) !void { try testing.expectEqualStrings(res.request.target, "/trailer"); res.transfer_encoding = .chunked; @@ -99,3 +95,73 @@ fn serve(res: *std.http.Server.Response) !void { try res.writeAll("World!\n"); try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); } + +test "HTTP server handles a chunked transfer coding request" { + // This test requires spawning threads. + if (builtin.single_threaded) { + return error.SkipZigTest; + } + + const native_endian = comptime builtin.cpu.arch.endian(); + if (builtin.zig_backend == .stage2_llvm and native_endian == .big) { + // https://github.com/ziglang/zig/issues/13782 + return error.SkipZigTest; + } + + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + const allocator = std.testing.allocator; + const expect = std.testing.expect; + + const max_header_size = 8192; + + const address = try std.net.Address.parseIp("127.0.0.1", 0); + var server = try address.listen(.{ .reuse_address = true }); + defer server.deinit(); + const server_port = server.listen_address.in.getPort(); + + const server_thread = try std.Thread.spawn(.{}, (struct { + fn apply(s: *std.net.Server) !void { + var header_buffer: [max_header_size]u8 = undefined; + const conn = try s.accept(); + defer conn.stream.close(); + var res = std.http.Server.init(conn, .{ .client_header_buffer = &header_buffer }); + try res.wait(); + + try expect(res.request.transfer_encoding == .chunked); + const server_body: []const u8 = "message from server!\n"; + res.transfer_encoding = .{ .content_length = server_body.len }; + res.extra_headers = &.{ + .{ .name = "content-type", .value = "text/plain" }, + }; + res.keep_alive = false; + try res.send(); + + var buf: [128]u8 = undefined; + const n = try res.readAll(&buf); + try expect(std.mem.eql(u8, buf[0..n], "ABCD")); + _ = try res.writer().writeAll(server_body); + try res.finish(); + } + }).apply, .{&server}); + + const request_bytes = + "POST / HTTP/1.1\r\n" ++ + "Content-Type: text/plain\r\n" ++ + "Transfer-Encoding: chunked\r\n" ++ + "\r\n" ++ + "1\r\n" ++ + "A\r\n" ++ + "1\r\n" ++ + "B\r\n" ++ + "2\r\n" ++ + "CD\r\n" ++ + "0\r\n" ++ + "\r\n"; + + const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port); + defer stream.close(); + _ = try stream.writeAll(request_bytes[0..]); + + server_thread.join(); +} diff --git a/lib/std/net.zig b/lib/std/net.zig index fdade0447f..f7e19850d3 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -4,15 +4,17 @@ const assert = std.debug.assert; const net = @This(); const mem = std.mem; const os = std.os; +const posix = std.posix; const fs = std.fs; const io = std.io; const native_endian = builtin.target.cpu.arch.endian(); // Windows 10 added support for unix sockets in build 17063, redstone 4 is the // first release to support them. -pub const has_unix_sockets = @hasDecl(os.sockaddr, "un") and - (builtin.target.os.tag != .windows or - builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false); +pub const has_unix_sockets = switch (builtin.os.tag) { + .windows => builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false, + else => true, +}; pub const IPParseError = error{ Overflow, @@ -206,6 +208,57 @@ pub const Address = extern union { else => unreachable, } } + + pub const ListenError = posix.SocketError || posix.BindError || posix.ListenError || + posix.SetSockOptError || posix.GetSockNameError; + + pub const ListenOptions = struct { + /// How many connections the kernel will accept on the application's behalf. + /// If more than this many connections pool in the kernel, clients will start + /// seeing "Connection refused". + kernel_backlog: u31 = 128, + reuse_address: bool = false, + reuse_port: bool = false, + force_nonblocking: bool = false, + }; + + /// The returned `Server` has an open `stream`. + pub fn listen(address: Address, options: ListenOptions) ListenError!Server { + const nonblock: u32 = if (options.force_nonblocking) posix.SOCK.NONBLOCK else 0; + const sock_flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | nonblock; + const proto: u32 = if (address.any.family == posix.AF.UNIX) 0 else posix.IPPROTO.TCP; + + const sockfd = try posix.socket(address.any.family, sock_flags, proto); + var s: Server = .{ + .listen_address = undefined, + .stream = .{ .handle = sockfd }, + }; + errdefer s.stream.close(); + + if (options.reuse_address) { + try posix.setsockopt( + sockfd, + posix.SOL.SOCKET, + posix.SO.REUSEADDR, + &mem.toBytes(@as(c_int, 1)), + ); + } + + if (options.reuse_port) { + try posix.setsockopt( + sockfd, + posix.SOL.SOCKET, + posix.SO.REUSEPORT, + &mem.toBytes(@as(c_int, 1)), + ); + } + + var socklen = address.getOsSockLen(); + try posix.bind(sockfd, &address.any, socklen); + try posix.listen(sockfd, options.kernel_backlog); + try posix.getsockname(sockfd, &s.listen_address.any, &socklen); + return s; + } }; pub const Ip4Address = extern struct { @@ -657,7 +710,7 @@ pub fn connectUnixSocket(path: []const u8) !Stream { os.SOCK.STREAM | os.SOCK.CLOEXEC | opt_non_block, 0, ); - errdefer os.closeSocket(sockfd); + errdefer Stream.close(.{ .handle = sockfd }); var addr = try std.net.Address.initUnix(path); try os.connect(sockfd, &addr.any, addr.getOsSockLen()); @@ -669,7 +722,7 @@ fn if_nametoindex(name: []const u8) IPv6InterfaceError!u32 { if (builtin.target.os.tag == .linux) { var ifr: os.ifreq = undefined; const sockfd = try os.socket(os.AF.UNIX, os.SOCK.DGRAM | os.SOCK.CLOEXEC, 0); - defer os.closeSocket(sockfd); + defer Stream.close(.{ .handle = sockfd }); @memcpy(ifr.ifrn.name[0..name.len], name); ifr.ifrn.name[name.len] = 0; @@ -738,7 +791,7 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { const sock_flags = os.SOCK.STREAM | nonblock | (if (builtin.target.os.tag == .windows) 0 else os.SOCK.CLOEXEC); const sockfd = try os.socket(address.any.family, sock_flags, os.IPPROTO.TCP); - errdefer os.closeSocket(sockfd); + errdefer Stream.close(.{ .handle = sockfd }); try os.connect(sockfd, &address.any, address.getOsSockLen()); @@ -1068,7 +1121,7 @@ fn linuxLookupName( var prefixlen: i32 = 0; const sock_flags = os.SOCK.DGRAM | os.SOCK.CLOEXEC; if (os.socket(addr.addr.any.family, sock_flags, os.IPPROTO.UDP)) |fd| syscalls: { - defer os.closeSocket(fd); + defer Stream.close(.{ .handle = fd }); os.connect(fd, da, dalen) catch break :syscalls; key |= DAS_USABLE; os.getsockname(fd, sa, &salen) catch break :syscalls; @@ -1553,7 +1606,7 @@ fn resMSendRc( }, else => |e| return e, }; - defer os.closeSocket(fd); + defer Stream.close(.{ .handle = fd }); // Past this point, there are no errors. Each individual query will // yield either no reply (indicated by zero length) or an answer @@ -1729,13 +1782,15 @@ fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8) } pub const Stream = struct { - // Underlying socket descriptor. - // Note that on some platforms this may not be interchangeable with a - // regular files descriptor. - handle: os.socket_t, + /// Underlying platform-defined type which may or may not be + /// interchangeable with a file system file descriptor. + handle: posix.socket_t, - pub fn close(self: Stream) void { - os.closeSocket(self.handle); + pub fn close(s: Stream) void { + switch (builtin.os.tag) { + .windows => std.os.windows.closesocket(s.handle) catch unreachable, + else => posix.close(s.handle), + } } pub const ReadError = os.ReadError; @@ -1839,156 +1894,38 @@ pub const Stream = struct { } }; -pub const StreamServer = struct { - /// Copied from `Options` on `init`. - kernel_backlog: u31, - reuse_address: bool, - reuse_port: bool, - force_nonblocking: bool, - - /// `undefined` until `listen` returns successfully. +pub const Server = struct { listen_address: Address, - - sockfd: ?os.socket_t, - - pub const Options = struct { - /// How many connections the kernel will accept on the application's behalf. - /// If more than this many connections pool in the kernel, clients will start - /// seeing "Connection refused". - kernel_backlog: u31 = 128, - - /// Enable SO.REUSEADDR on the socket. - reuse_address: bool = false, - - /// Enable SO.REUSEPORT on the socket. - reuse_port: bool = false, - - /// Force non-blocking mode. - force_nonblocking: bool = false, - }; - - /// After this call succeeds, resources have been acquired and must - /// be released with `deinit`. - pub fn init(options: Options) StreamServer { - return StreamServer{ - .sockfd = null, - .kernel_backlog = options.kernel_backlog, - .reuse_address = options.reuse_address, - .reuse_port = options.reuse_port, - .force_nonblocking = options.force_nonblocking, - .listen_address = undefined, - }; - } - - /// Release all resources. The `StreamServer` memory becomes `undefined`. - pub fn deinit(self: *StreamServer) void { - self.close(); - self.* = undefined; - } - - pub fn listen(self: *StreamServer, address: Address) !void { - const nonblock = 0; - const sock_flags = os.SOCK.STREAM | os.SOCK.CLOEXEC | nonblock; - var use_sock_flags: u32 = sock_flags; - if (self.force_nonblocking) use_sock_flags |= os.SOCK.NONBLOCK; - const proto = if (address.any.family == os.AF.UNIX) @as(u32, 0) else os.IPPROTO.TCP; - - const sockfd = try os.socket(address.any.family, use_sock_flags, proto); - self.sockfd = sockfd; - errdefer { - os.closeSocket(sockfd); - self.sockfd = null; - } - - if (self.reuse_address) { - try os.setsockopt( - sockfd, - os.SOL.SOCKET, - os.SO.REUSEADDR, - &mem.toBytes(@as(c_int, 1)), - ); - } - if (@hasDecl(os.SO, "REUSEPORT") and self.reuse_port) { - try os.setsockopt( - sockfd, - os.SOL.SOCKET, - os.SO.REUSEPORT, - &mem.toBytes(@as(c_int, 1)), - ); - } - - var socklen = address.getOsSockLen(); - try os.bind(sockfd, &address.any, socklen); - try os.listen(sockfd, self.kernel_backlog); - try os.getsockname(sockfd, &self.listen_address.any, &socklen); - } - - /// Stop listening. It is still necessary to call `deinit` after stopping listening. - /// Calling `deinit` will automatically call `close`. It is safe to call `close` when - /// not listening. - pub fn close(self: *StreamServer) void { - if (self.sockfd) |fd| { - os.closeSocket(fd); - self.sockfd = null; - self.listen_address = undefined; - } - } - - pub const AcceptError = error{ - ConnectionAborted, - - /// The per-process limit on the number of open file descriptors has been reached. - ProcessFdQuotaExceeded, - - /// The system-wide limit on the total number of open files has been reached. - SystemFdQuotaExceeded, - - /// Not enough free memory. This often means that the memory allocation - /// is limited by the socket buffer limits, not by the system memory. - SystemResources, - - /// Socket is not listening for new connections. - SocketNotListening, - - ProtocolFailure, - - /// Socket is in non-blocking mode and there is no connection to accept. - WouldBlock, - - /// Firewall rules forbid connection. - BlockedByFirewall, - - FileDescriptorNotASocket, - - ConnectionResetByPeer, - - NetworkSubsystemFailed, - - OperationNotSupported, - } || os.UnexpectedError; + stream: std.net.Stream, pub const Connection = struct { - stream: Stream, + stream: std.net.Stream, address: Address, }; - /// If this function succeeds, the returned `Connection` is a caller-managed resource. - pub fn accept(self: *StreamServer) AcceptError!Connection { - var accepted_addr: Address = undefined; - var adr_len: os.socklen_t = @sizeOf(Address); - const accept_result = os.accept(self.sockfd.?, &accepted_addr.any, &adr_len, os.SOCK.CLOEXEC); + pub fn deinit(s: *Server) void { + s.stream.close(); + s.* = undefined; + } - if (accept_result) |fd| { - return Connection{ - .stream = Stream{ .handle = fd }, - .address = accepted_addr, - }; - } else |err| { - return err; - } + pub const AcceptError = posix.AcceptError; + + /// Blocks until a client connects to the server. The returned `Connection` has + /// an open stream. + pub fn accept(s: *Server) AcceptError!Connection { + var accepted_addr: Address = undefined; + var addr_len: posix.socklen_t = @sizeOf(Address); + const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC); + return .{ + .stream = .{ .handle = fd }, + .address = accepted_addr, + }; } }; test { _ = @import("net/test.zig"); + _ = Server; + _ = Stream; + _ = Address; } diff --git a/lib/std/net/test.zig b/lib/std/net/test.zig index e359abb6d5..3e316c5456 100644 --- a/lib/std/net/test.zig +++ b/lib/std/net/test.zig @@ -181,11 +181,9 @@ test "listen on a port, send bytes, receive bytes" { // configured. const localhost = try net.Address.parseIp("127.0.0.1", 0); - var server = net.StreamServer.init(.{}); + var server = try localhost.listen(.{}); defer server.deinit(); - try server.listen(localhost); - const S = struct { fn clientFn(server_address: net.Address) !void { const socket = try net.tcpConnectToAddress(server_address); @@ -215,17 +213,11 @@ test "listen on an in use port" { const localhost = try net.Address.parseIp("127.0.0.1", 0); - var server1 = net.StreamServer.init(net.StreamServer.Options{ - .reuse_port = true, - }); + var server1 = try localhost.listen(.{ .reuse_port = true }); defer server1.deinit(); - try server1.listen(localhost); - var server2 = net.StreamServer.init(net.StreamServer.Options{ - .reuse_port = true, - }); + var server2 = try server1.listen_address.listen(.{ .reuse_port = true }); defer server2.deinit(); - try server2.listen(server1.listen_address); } fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void { @@ -252,7 +244,7 @@ fn testClient(addr: net.Address) anyerror!void { try testing.expect(mem.eql(u8, msg, "hello from server\n")); } -fn testServer(server: *net.StreamServer) anyerror!void { +fn testServer(server: *net.Server) anyerror!void { if (builtin.os.tag == .wasi) return error.SkipZigTest; var client = try server.accept(); @@ -274,15 +266,14 @@ test "listen on a unix socket, send bytes, receive bytes" { } } - var server = net.StreamServer.init(.{}); - defer server.deinit(); - const socket_path = try generateFileName("socket.unix"); defer testing.allocator.free(socket_path); const socket_addr = try net.Address.initUnix(socket_path); defer std.fs.cwd().deleteFile(socket_path) catch {}; - try server.listen(socket_addr); + + var server = try socket_addr.listen(.{}); + defer server.deinit(); const S = struct { fn clientFn(path: []const u8) !void { @@ -323,9 +314,8 @@ test "non-blocking tcp server" { } const localhost = try net.Address.parseIp("127.0.0.1", 0); - var server = net.StreamServer.init(.{ .force_nonblocking = true }); + var server = localhost.listen(.{ .force_nonblocking = true }); defer server.deinit(); - try server.listen(localhost); const accept_err = server.accept(); try testing.expectError(error.WouldBlock, accept_err); diff --git a/lib/std/os.zig b/lib/std/os.zig index 6880878c45..87402e49a3 100644 --- a/lib/std/os.zig +++ b/lib/std/os.zig @@ -3598,14 +3598,6 @@ pub fn shutdown(sock: socket_t, how: ShutdownHow) ShutdownError!void { } } -pub fn closeSocket(sock: socket_t) void { - if (builtin.os.tag == .windows) { - windows.closesocket(sock) catch unreachable; - } else { - close(sock); - } -} - pub const BindError = error{ /// The address is protected, and the user is not the superuser. /// For UNIX domain sockets: Search permission is denied on a component diff --git a/lib/std/os/linux/io_uring.zig b/lib/std/os/linux/io_uring.zig index dbde08c2c1..16c542714c 100644 --- a/lib/std/os/linux/io_uring.zig +++ b/lib/std/os/linux/io_uring.zig @@ -4,6 +4,7 @@ const assert = std.debug.assert; const mem = std.mem; const net = std.net; const os = std.os; +const posix = std.posix; const linux = os.linux; const testing = std.testing; @@ -3730,8 +3731,8 @@ const SocketTestHarness = struct { client: os.socket_t, fn close(self: SocketTestHarness) void { - os.closeSocket(self.client); - os.closeSocket(self.listener); + posix.close(self.client); + posix.close(self.listener); } }; @@ -3739,7 +3740,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness { // Create a TCP server socket var address = try net.Address.parseIp4("127.0.0.1", 0); const listener_socket = try createListenerSocket(&address); - errdefer os.closeSocket(listener_socket); + errdefer posix.close(listener_socket); // Submit 1 accept var accept_addr: os.sockaddr = undefined; @@ -3748,7 +3749,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness { // Create a TCP client socket const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); - errdefer os.closeSocket(client); + errdefer posix.close(client); _ = try ring.connect(0xcccccccc, client, &address.any, address.getOsSockLen()); try testing.expectEqual(@as(u32, 2), try ring.submit()); @@ -3788,7 +3789,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness { fn createListenerSocket(address: *net.Address) !os.socket_t { const kernel_backlog = 1; const listener_socket = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); - errdefer os.closeSocket(listener_socket); + errdefer posix.close(listener_socket); try os.setsockopt(listener_socket, os.SOL.SOCKET, os.SO.REUSEADDR, &mem.toBytes(@as(c_int, 1))); try os.bind(listener_socket, &address.any, address.getOsSockLen()); @@ -3813,7 +3814,7 @@ test "accept multishot" { var address = try net.Address.parseIp4("127.0.0.1", 0); const listener_socket = try createListenerSocket(&address); - defer os.closeSocket(listener_socket); + defer posix.close(listener_socket); // submit multishot accept operation var addr: os.sockaddr = undefined; @@ -3826,7 +3827,7 @@ test "accept multishot" { while (nr > 0) : (nr -= 1) { // connect client const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); - errdefer os.closeSocket(client); + errdefer posix.close(client); try os.connect(client, &address.any, address.getOsSockLen()); // test accept completion @@ -3836,7 +3837,7 @@ test "accept multishot" { try testing.expect(cqe.user_data == userdata); try testing.expect(cqe.flags & linux.IORING_CQE_F_MORE > 0); // more flag is set - os.closeSocket(client); + posix.close(client); } } @@ -3909,7 +3910,7 @@ test "accept_direct" { try ring.register_files(registered_fds[0..]); const listener_socket = try createListenerSocket(&address); - defer os.closeSocket(listener_socket); + defer posix.close(listener_socket); const accept_userdata: u64 = 0xaaaaaaaa; const read_userdata: u64 = 0xbbbbbbbb; @@ -3927,7 +3928,7 @@ test "accept_direct" { // connect const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); try os.connect(client, &address.any, address.getOsSockLen()); - defer os.closeSocket(client); + defer posix.close(client); // accept completion const cqe_accept = try ring.copy_cqe(); @@ -3961,7 +3962,7 @@ test "accept_direct" { // connect const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); try os.connect(client, &address.any, address.getOsSockLen()); - defer os.closeSocket(client); + defer posix.close(client); // completion with error const cqe_accept = try ring.copy_cqe(); try testing.expect(cqe_accept.user_data == accept_userdata); @@ -3989,7 +3990,7 @@ test "accept_multishot_direct" { try ring.register_files(registered_fds[0..]); const listener_socket = try createListenerSocket(&address); - defer os.closeSocket(listener_socket); + defer posix.close(listener_socket); const accept_userdata: u64 = 0xaaaaaaaa; @@ -4003,7 +4004,7 @@ test "accept_multishot_direct" { // connect const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); try os.connect(client, &address.any, address.getOsSockLen()); - defer os.closeSocket(client); + defer posix.close(client); // accept completion const cqe_accept = try ring.copy_cqe(); @@ -4018,7 +4019,7 @@ test "accept_multishot_direct" { // connect const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); try os.connect(client, &address.any, address.getOsSockLen()); - defer os.closeSocket(client); + defer posix.close(client); // completion with error const cqe_accept = try ring.copy_cqe(); try testing.expect(cqe_accept.user_data == accept_userdata); @@ -4092,7 +4093,7 @@ test "socket_direct/socket_direct_alloc/close_direct" { // use sockets from registered_fds in connect operation var address = try net.Address.parseIp4("127.0.0.1", 0); const listener_socket = try createListenerSocket(&address); - defer os.closeSocket(listener_socket); + defer posix.close(listener_socket); const accept_userdata: u64 = 0xaaaaaaaa; const connect_userdata: u64 = 0xbbbbbbbb; const close_userdata: u64 = 0xcccccccc; diff --git a/lib/std/os/test.zig b/lib/std/os/test.zig index 5fee5dcc7f..0d9255641c 100644 --- a/lib/std/os/test.zig +++ b/lib/std/os/test.zig @@ -817,7 +817,7 @@ test "shutdown socket" { error.SocketNotConnected => {}, else => |e| return e, }; - os.closeSocket(sock); + std.net.Stream.close(.{ .handle = sock }); } test "sigaction" { diff --git a/src/main.zig b/src/main.zig index e6521d58c8..584a34eeee 100644 --- a/src/main.zig +++ b/src/main.zig @@ -3322,13 +3322,13 @@ fn buildOutputType( .ip4 => |ip4_addr| { if (build_options.only_core_functionality) unreachable; - var server = std.net.StreamServer.init(.{ + const addr: std.net.Address = .{ .in = ip4_addr }; + + var server = try addr.listen(.{ .reuse_address = true, }); defer server.deinit(); - try server.listen(.{ .in = ip4_addr }); - const conn = try server.accept(); defer conn.stream.close(); diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 2b53ebbb81..7bf09f55c9 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -1,8 +1,6 @@ const std = @import("std"); const http = std.http; -const Server = http.Server; -const Client = http.Client; const mem = std.mem; const testing = std.testing; @@ -19,9 +17,7 @@ var gpa_client = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }) const salloc = gpa_server.allocator(); const calloc = gpa_client.allocator(); -var server: Server = undefined; - -fn handleRequest(res: *Server.Response) !void { +fn handleRequest(res: *http.Server, listen_port: u16) !void { const log = std.log.scoped(.server); log.info("{} {s} {s}", .{ res.request.method, @tagName(res.request.version), res.request.target }); @@ -125,7 +121,9 @@ fn handleRequest(res: *Server.Response) !void { } else if (mem.eql(u8, res.request.target, "/redirect/3")) { res.transfer_encoding = .chunked; - const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{server.socket.listen_address.getPort()}); + const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{ + listen_port, + }); defer salloc.free(location); res.status = .found; @@ -168,14 +166,15 @@ fn handleRequest(res: *Server.Response) !void { var handle_new_requests = true; -fn runServer(srv: *Server) !void { +fn runServer(server: *std.net.Server) !void { var client_header_buffer: [1024]u8 = undefined; outer: while (handle_new_requests) { - var res = try srv.accept(.{ - .allocator = salloc, + var connection = try server.accept(); + defer connection.stream.close(); + + var res = http.Server.init(connection, .{ .client_header_buffer = &client_header_buffer, }); - defer res.deinit(); while (res.reset() != .closing) { res.wait() catch |err| switch (err) { @@ -184,16 +183,15 @@ fn runServer(srv: *Server) !void { else => return err, }; - try handleRequest(&res); + try handleRequest(&res, server.listen_address.getPort()); } } } -fn serverThread(srv: *Server) void { - defer srv.deinit(); +fn serverThread(server: *std.net.Server) void { defer _ = gpa_server.deinit(); - runServer(srv) catch |err| { + runServer(server) catch |err| { std.debug.print("server error: {}\n", .{err}); if (@errorReturnTrace()) |trace| { @@ -205,18 +203,10 @@ fn serverThread(srv: *Server) void { }; } -fn killServer(addr: std.net.Address) void { - handle_new_requests = false; - - const conn = std.net.tcpConnectToAddress(addr) catch return; - conn.close(); -} - fn getUnusedTcpPort() !u16 { const addr = try std.net.Address.parseIp("127.0.0.1", 0); - var s = std.net.StreamServer.init(.{}); + var s = try addr.listen(.{}); defer s.deinit(); - try s.listen(addr); return s.listen_address.in.getPort(); } @@ -225,16 +215,15 @@ pub fn main() !void { defer _ = gpa_client.deinit(); - server = Server.init(.{ .reuse_address = true }); - const addr = std.net.Address.parseIp("127.0.0.1", 0) catch unreachable; - try server.listen(addr); + var server = try addr.listen(.{ .reuse_address = true }); + defer server.deinit(); - const port = server.socket.listen_address.getPort(); + const port = server.listen_address.getPort(); const server_thread = try std.Thread.spawn(.{}, serverThread, .{&server}); - var client = Client{ .allocator = calloc }; + var client: http.Client = .{ .allocator = calloc }; errdefer client.deinit(); // defer client.deinit(); handled below @@ -691,6 +680,12 @@ pub fn main() !void { client.deinit(); - killServer(server.socket.listen_address); + { + handle_new_requests = false; + + const conn = std.net.tcpConnectToAddress(server.listen_address) catch return; + conn.close(); + } + server_thread.join(); }