std.http.Server: expose arbitrary HTTP headers

Ultimate flexibility, just be sure to destroy the correct amount of
information when looking at them.
This commit is contained in:
Andrew Kelley 2024-02-23 02:58:02 -07:00
parent 5b34a1b718
commit 653d4158cd
6 changed files with 121 additions and 64 deletions

View file

@ -3,6 +3,7 @@ pub const Server = @import("http/Server.zig");
pub const protocol = @import("http/protocol.zig"); pub const protocol = @import("http/protocol.zig");
pub const HeadParser = @import("http/HeadParser.zig"); pub const HeadParser = @import("http/HeadParser.zig");
pub const ChunkParser = @import("http/ChunkParser.zig"); pub const ChunkParser = @import("http/ChunkParser.zig");
pub const HeaderIterator = @import("http/HeaderIterator.zig");
pub const Version = enum { pub const Version = enum {
@"HTTP/1.0", @"HTTP/1.0",

View file

@ -568,8 +568,8 @@ pub const Response = struct {
try expectEqual(@as(u10, 999), parseInt3("999")); try expectEqual(@as(u10, 999), parseInt3("999"));
} }
pub fn iterateHeaders(r: Response) proto.HeaderIterator { pub fn iterateHeaders(r: Response) http.HeaderIterator {
return proto.HeaderIterator.init(r.parser.get()); return http.HeaderIterator.init(r.parser.get());
} }
}; };

View file

@ -0,0 +1,62 @@
bytes: []const u8,
index: usize,
is_trailer: bool,
pub fn init(bytes: []const u8) HeaderIterator {
return .{
.bytes = bytes,
.index = std.mem.indexOfPosLinear(u8, bytes, 0, "\r\n").? + 2,
.is_trailer = false,
};
}
pub fn next(it: *HeaderIterator) ?std.http.Header {
const end = std.mem.indexOfPosLinear(u8, it.bytes, it.index, "\r\n").?;
var kv_it = std.mem.splitSequence(u8, it.bytes[it.index..end], ": ");
const name = kv_it.next().?;
const value = kv_it.rest();
if (value.len == 0) {
if (it.is_trailer) return null;
const next_end = std.mem.indexOfPosLinear(u8, it.bytes, end + 2, "\r\n") orelse
return null;
it.is_trailer = true;
it.index = next_end + 2;
kv_it = std.mem.splitSequence(u8, it.bytes[end + 2 .. next_end], ": ");
return .{
.name = kv_it.next().?,
.value = kv_it.rest(),
};
}
it.index = end + 2;
return .{
.name = name,
.value = value,
};
}
test next {
var it = HeaderIterator.init("200 OK\r\na: b\r\nc: d\r\n\r\ne: f\r\n\r\n");
try std.testing.expect(!it.is_trailer);
{
const header = it.next().?;
try std.testing.expect(!it.is_trailer);
try std.testing.expectEqualStrings("a", header.name);
try std.testing.expectEqualStrings("b", header.value);
}
{
const header = it.next().?;
try std.testing.expect(!it.is_trailer);
try std.testing.expectEqualStrings("c", header.name);
try std.testing.expectEqualStrings("d", header.value);
}
{
const header = it.next().?;
try std.testing.expect(it.is_trailer);
try std.testing.expectEqualStrings("e", header.name);
try std.testing.expectEqualStrings("f", header.value);
}
try std.testing.expectEqual(null, it.next());
}
const HeaderIterator = @This();
const std = @import("../std.zig");

View file

@ -273,6 +273,10 @@ pub const Request = struct {
} }
}; };
pub fn iterateHeaders(r: *Request) http.HeaderIterator {
return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]);
}
pub const RespondOptions = struct { pub const RespondOptions = struct {
version: http.Version = .@"HTTP/1.1", version: http.Version = .@"HTTP/1.1",
status: http.Status = .ok, status: http.Status = .ok,

View file

@ -250,68 +250,6 @@ pub const HeadersParser = struct {
} }
}; };
pub const HeaderIterator = struct {
bytes: []const u8,
index: usize,
is_trailer: bool,
pub fn init(bytes: []const u8) HeaderIterator {
return .{
.bytes = bytes,
.index = std.mem.indexOfPosLinear(u8, bytes, 0, "\r\n").? + 2,
.is_trailer = false,
};
}
pub fn next(it: *HeaderIterator) ?std.http.Header {
const end = std.mem.indexOfPosLinear(u8, it.bytes, it.index, "\r\n").?;
var kv_it = std.mem.splitSequence(u8, it.bytes[it.index..end], ": ");
const name = kv_it.next().?;
const value = kv_it.rest();
if (value.len == 0) {
if (it.is_trailer) return null;
const next_end = std.mem.indexOfPosLinear(u8, it.bytes, end + 2, "\r\n") orelse
return null;
it.is_trailer = true;
it.index = next_end + 2;
kv_it = std.mem.splitSequence(u8, it.bytes[end + 2 .. next_end], ": ");
return .{
.name = kv_it.next().?,
.value = kv_it.rest(),
};
}
it.index = end + 2;
return .{
.name = name,
.value = value,
};
}
test next {
var it = HeaderIterator.init("200 OK\r\na: b\r\nc: d\r\n\r\ne: f\r\n\r\n");
try std.testing.expect(!it.is_trailer);
{
const header = it.next().?;
try std.testing.expect(!it.is_trailer);
try std.testing.expectEqualStrings("a", header.name);
try std.testing.expectEqualStrings("b", header.value);
}
{
const header = it.next().?;
try std.testing.expect(!it.is_trailer);
try std.testing.expectEqualStrings("c", header.name);
try std.testing.expectEqualStrings("d", header.value);
}
{
const header = it.next().?;
try std.testing.expect(it.is_trailer);
try std.testing.expectEqualStrings("e", header.name);
try std.testing.expectEqualStrings("f", header.value);
}
try std.testing.expectEqual(null, it.next());
}
};
inline fn int16(array: *const [2]u8) u16 { inline fn int16(array: *const [2]u8) u16 {
return @as(u16, @bitCast(array.*)); return @as(u16, @bitCast(array.*));
} }

View file

@ -290,6 +290,58 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" {
try expectEqualStrings(expected_response.items, response); try expectEqualStrings(expected_response.items, response);
} }
test "receiving arbitrary http headers from the client" {
const test_server = try createTestServer(struct {
fn run(net_server: *std.net.Server) anyerror!void {
var read_buffer: [666]u8 = undefined;
var remaining: usize = 1;
while (remaining != 0) : (remaining -= 1) {
const conn = try net_server.accept();
defer conn.stream.close();
var server = http.Server.init(conn, &read_buffer);
try expectEqual(.ready, server.state);
var request = try server.receiveHead();
try expectEqualStrings("/bar", request.head.target);
var it = request.iterateHeaders();
{
const header = it.next().?;
try expectEqualStrings("CoNneCtIoN", header.name);
try expectEqualStrings("close", header.value);
try expect(!it.is_trailer);
}
{
const header = it.next().?;
try expectEqualStrings("aoeu", header.name);
try expectEqualStrings("asdf", header.value);
try expect(!it.is_trailer);
}
try request.respond("", .{});
}
}
});
defer test_server.destroy();
const request_bytes = "GET /bar HTTP/1.1\r\n" ++
"CoNneCtIoN: close\r\n" ++
"aoeu: asdf\r\n" ++
"\r\n";
const gpa = std.testing.allocator;
const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port());
defer stream.close();
try stream.writeAll(request_bytes);
const response = try stream.reader().readAllAlloc(gpa, 8192);
defer gpa.free(response);
var expected_response = std.ArrayList(u8).init(gpa);
defer expected_response.deinit();
try expected_response.appendSlice("HTTP/1.1 200 OK\r\n");
try expected_response.appendSlice("content-length: 0\r\n\r\n");
try expectEqualStrings(expected_response.items, response);
}
test "general client/server API coverage" { test "general client/server API coverage" {
if (builtin.os.tag == .windows) { if (builtin.os.tag == .windows) {
// This test was never passing on Windows. // This test was never passing on Windows.