diff --git a/lib/std/http.zig b/lib/std/http.zig index 00b2c56663..af966d89e7 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -3,6 +3,7 @@ pub const Server = @import("http/Server.zig"); pub const protocol = @import("http/protocol.zig"); pub const HeadParser = @import("http/HeadParser.zig"); pub const ChunkParser = @import("http/ChunkParser.zig"); +pub const HeaderIterator = @import("http/HeaderIterator.zig"); pub const Version = enum { @"HTTP/1.0", diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 1affbf9b5c..5f580bd53e 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -568,8 +568,8 @@ pub const Response = struct { try expectEqual(@as(u10, 999), parseInt3("999")); } - pub fn iterateHeaders(r: Response) proto.HeaderIterator { - return proto.HeaderIterator.init(r.parser.get()); + pub fn iterateHeaders(r: Response) http.HeaderIterator { + return http.HeaderIterator.init(r.parser.get()); } }; diff --git a/lib/std/http/HeaderIterator.zig b/lib/std/http/HeaderIterator.zig new file mode 100644 index 0000000000..8d36374f8c --- /dev/null +++ b/lib/std/http/HeaderIterator.zig @@ -0,0 +1,62 @@ +bytes: []const u8, +index: usize, +is_trailer: bool, + +pub fn init(bytes: []const u8) HeaderIterator { + return .{ + .bytes = bytes, + .index = std.mem.indexOfPosLinear(u8, bytes, 0, "\r\n").? + 2, + .is_trailer = false, + }; +} + +pub fn next(it: *HeaderIterator) ?std.http.Header { + const end = std.mem.indexOfPosLinear(u8, it.bytes, it.index, "\r\n").?; + var kv_it = std.mem.splitSequence(u8, it.bytes[it.index..end], ": "); + const name = kv_it.next().?; + const value = kv_it.rest(); + if (value.len == 0) { + if (it.is_trailer) return null; + const next_end = std.mem.indexOfPosLinear(u8, it.bytes, end + 2, "\r\n") orelse + return null; + it.is_trailer = true; + it.index = next_end + 2; + kv_it = std.mem.splitSequence(u8, it.bytes[end + 2 .. next_end], ": "); + return .{ + .name = kv_it.next().?, + .value = kv_it.rest(), + }; + } + it.index = end + 2; + return .{ + .name = name, + .value = value, + }; +} + +test next { + var it = HeaderIterator.init("200 OK\r\na: b\r\nc: d\r\n\r\ne: f\r\n\r\n"); + try std.testing.expect(!it.is_trailer); + { + const header = it.next().?; + try std.testing.expect(!it.is_trailer); + try std.testing.expectEqualStrings("a", header.name); + try std.testing.expectEqualStrings("b", header.value); + } + { + const header = it.next().?; + try std.testing.expect(!it.is_trailer); + try std.testing.expectEqualStrings("c", header.name); + try std.testing.expectEqualStrings("d", header.value); + } + { + const header = it.next().?; + try std.testing.expect(it.is_trailer); + try std.testing.expectEqualStrings("e", header.name); + try std.testing.expectEqualStrings("f", header.value); + } + try std.testing.expectEqual(null, it.next()); +} + +const HeaderIterator = @This(); +const std = @import("../std.zig"); diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 226e13fc32..2d360d40a4 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -273,6 +273,10 @@ pub const Request = struct { } }; + pub fn iterateHeaders(r: *Request) http.HeaderIterator { + return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); + } + pub const RespondOptions = struct { version: http.Version = .@"HTTP/1.1", status: http.Status = .ok, diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 64c87b9287..78511f435d 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -250,68 +250,6 @@ pub const HeadersParser = struct { } }; -pub const HeaderIterator = struct { - bytes: []const u8, - index: usize, - is_trailer: bool, - - pub fn init(bytes: []const u8) HeaderIterator { - return .{ - .bytes = bytes, - .index = std.mem.indexOfPosLinear(u8, bytes, 0, "\r\n").? + 2, - .is_trailer = false, - }; - } - - pub fn next(it: *HeaderIterator) ?std.http.Header { - const end = std.mem.indexOfPosLinear(u8, it.bytes, it.index, "\r\n").?; - var kv_it = std.mem.splitSequence(u8, it.bytes[it.index..end], ": "); - const name = kv_it.next().?; - const value = kv_it.rest(); - if (value.len == 0) { - if (it.is_trailer) return null; - const next_end = std.mem.indexOfPosLinear(u8, it.bytes, end + 2, "\r\n") orelse - return null; - it.is_trailer = true; - it.index = next_end + 2; - kv_it = std.mem.splitSequence(u8, it.bytes[end + 2 .. next_end], ": "); - return .{ - .name = kv_it.next().?, - .value = kv_it.rest(), - }; - } - it.index = end + 2; - return .{ - .name = name, - .value = value, - }; - } - - test next { - var it = HeaderIterator.init("200 OK\r\na: b\r\nc: d\r\n\r\ne: f\r\n\r\n"); - try std.testing.expect(!it.is_trailer); - { - const header = it.next().?; - try std.testing.expect(!it.is_trailer); - try std.testing.expectEqualStrings("a", header.name); - try std.testing.expectEqualStrings("b", header.value); - } - { - const header = it.next().?; - try std.testing.expect(!it.is_trailer); - try std.testing.expectEqualStrings("c", header.name); - try std.testing.expectEqualStrings("d", header.value); - } - { - const header = it.next().?; - try std.testing.expect(it.is_trailer); - try std.testing.expectEqualStrings("e", header.name); - try std.testing.expectEqualStrings("f", header.value); - } - try std.testing.expectEqual(null, it.next()); - } -}; - inline fn int16(array: *const [2]u8) u16 { return @as(u16, @bitCast(array.*)); } diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 61bd00a6e7..e36b0cdf28 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -290,6 +290,58 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { try expectEqualStrings(expected_response.items, response); } +test "receiving arbitrary http headers from the client" { + const test_server = try createTestServer(struct { + fn run(net_server: *std.net.Server) anyerror!void { + var read_buffer: [666]u8 = undefined; + var remaining: usize = 1; + while (remaining != 0) : (remaining -= 1) { + const conn = try net_server.accept(); + defer conn.stream.close(); + + var server = http.Server.init(conn, &read_buffer); + try expectEqual(.ready, server.state); + var request = try server.receiveHead(); + try expectEqualStrings("/bar", request.head.target); + var it = request.iterateHeaders(); + { + const header = it.next().?; + try expectEqualStrings("CoNneCtIoN", header.name); + try expectEqualStrings("close", header.value); + try expect(!it.is_trailer); + } + { + const header = it.next().?; + try expectEqualStrings("aoeu", header.name); + try expectEqualStrings("asdf", header.value); + try expect(!it.is_trailer); + } + try request.respond("", .{}); + } + } + }); + defer test_server.destroy(); + + const request_bytes = "GET /bar HTTP/1.1\r\n" ++ + "CoNneCtIoN: close\r\n" ++ + "aoeu: asdf\r\n" ++ + "\r\n"; + const gpa = std.testing.allocator; + const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + defer stream.close(); + try stream.writeAll(request_bytes); + + const response = try stream.reader().readAllAlloc(gpa, 8192); + defer gpa.free(response); + + var expected_response = std.ArrayList(u8).init(gpa); + defer expected_response.deinit(); + + try expected_response.appendSlice("HTTP/1.1 200 OK\r\n"); + try expected_response.appendSlice("content-length: 0\r\n\r\n"); + try expectEqualStrings(expected_response.items, response); +} + test "general client/server API coverage" { if (builtin.os.tag == .windows) { // This test was never passing on Windows.