std.http.Server: add safety for invalidated Head strings

and fix bad unit test API usage that it finds
This commit is contained in:
Andrew Kelley 2025-08-06 18:37:48 -07:00
parent 063d87132a
commit 995bfdf0ff
3 changed files with 72 additions and 40 deletions

View file

@ -444,8 +444,8 @@ pub const Connection = struct {
pub const Response = struct { pub const Response = struct {
request: *Request, request: *Request,
/// Pointers in this struct are invalidated with the next call to /// Pointers in this struct are invalidated when the response body stream
/// `receiveHead`. /// is initialized.
head: Head, head: Head,
pub const Head = struct { pub const Head = struct {
@ -671,6 +671,16 @@ pub const Response = struct {
try expectEqual(@as(u10, 418), parseInt3("418")); try expectEqual(@as(u10, 418), parseInt3("418"));
try expectEqual(@as(u10, 999), parseInt3("999")); try expectEqual(@as(u10, 999), parseInt3("999"));
} }
/// Help the programmer avoid bugs by calling this when the string
/// memory of `Head` becomes invalidated.
fn invalidateStrings(h: *Head) void {
h.bytes = undefined;
h.reason = undefined;
if (h.location) |*s| s.* = undefined;
if (h.content_type) |*s| s.* = undefined;
if (h.content_disposition) |*s| s.* = undefined;
}
}; };
/// If compressed body has been negotiated this will return compressed bytes. /// If compressed body has been negotiated this will return compressed bytes.
@ -682,7 +692,8 @@ pub const Response = struct {
/// ///
/// See also: /// See also:
/// * `readerDecompressing` /// * `readerDecompressing`
pub fn reader(response: *const Response, buffer: []u8) *Reader { pub fn reader(response: *Response, buffer: []u8) *Reader {
response.head.invalidateStrings();
const req = response.request; const req = response.request;
if (!req.method.responseHasBody()) return .ending; if (!req.method.responseHasBody()) return .ending;
const head = &response.head; const head = &response.head;
@ -703,6 +714,7 @@ pub const Response = struct {
decompressor: *http.Decompressor, decompressor: *http.Decompressor,
decompression_buffer: []u8, decompression_buffer: []u8,
) *Reader { ) *Reader {
response.head.invalidateStrings();
const head = &response.head; const head = &response.head;
return response.request.reader.bodyReaderDecompressing( return response.request.reader.bodyReaderDecompressing(
head.transfer_encoding, head.transfer_encoding,

View file

@ -55,8 +55,8 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request {
pub const Request = struct { pub const Request = struct {
server: *Server, server: *Server,
/// Pointers in this struct are invalidated with the next call to /// Pointers in this struct are invalidated when the request body stream is
/// `receiveHead`. /// initialized.
head: Head, head: Head,
head_buffer: []const u8, head_buffer: []const u8,
respond_err: ?RespondError = null, respond_err: ?RespondError = null,
@ -224,6 +224,14 @@ pub const Request = struct {
inline fn int64(array: *const [8]u8) u64 { inline fn int64(array: *const [8]u8) u64 {
return @bitCast(array.*); return @bitCast(array.*);
} }
/// Help the programmer avoid bugs by calling this when the string
/// memory of `Head` becomes invalidated.
fn invalidateStrings(h: *Head) void {
h.target = undefined;
if (h.expect) |*s| s.* = undefined;
if (h.content_type) |*s| s.* = undefined;
}
}; };
pub fn iterateHeaders(r: *const Request) http.HeaderIterator { pub fn iterateHeaders(r: *const Request) http.HeaderIterator {
@ -578,9 +586,12 @@ pub const Request = struct {
/// this function. /// this function.
/// ///
/// Asserts that this function is only called once. /// Asserts that this function is only called once.
///
/// Invalidates the string memory inside `Head`.
pub fn readerExpectNone(request: *Request, buffer: []u8) *Reader { pub fn readerExpectNone(request: *Request, buffer: []u8) *Reader {
assert(request.server.reader.state == .received_head); assert(request.server.reader.state == .received_head);
assert(request.head.expect == null); assert(request.head.expect == null);
request.head.invalidateStrings();
if (!request.head.method.requestHasBody()) return .ending; if (!request.head.method.requestHasBody()) return .ending;
return request.server.reader.bodyReader(buffer, request.head.transfer_encoding, request.head.content_length); return request.server.reader.bodyReader(buffer, request.head.transfer_encoding, request.head.content_length);
} }

View file

@ -65,23 +65,22 @@ test "trailers" {
try req.sendBodiless(); try req.sendBodiless();
var response = try req.receiveHead(&.{}); var response = try req.receiveHead(&.{});
{
var it = response.head.iterateHeaders();
const header = it.next().?;
try expectEqualStrings("transfer-encoding", header.name);
try expectEqualStrings("chunked", header.value);
try expectEqual(null, it.next());
}
const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body); defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body); try expectEqualStrings("Hello, World!\n", body);
{
var it = response.head.iterateHeaders();
const header = it.next().?;
try expect(!it.is_trailer);
try expectEqualStrings("transfer-encoding", header.name);
try expectEqualStrings("chunked", header.value);
try expectEqual(null, it.next());
}
{ {
var it = response.iterateTrailers(); var it = response.iterateTrailers();
const header = it.next().?; const header = it.next().?;
try expect(it.is_trailer);
try expectEqualStrings("X-Checksum", header.name); try expectEqualStrings("X-Checksum", header.name);
try expectEqualStrings("aaaa", header.value); try expectEqualStrings("aaaa", header.value);
try expectEqual(null, it.next()); try expectEqual(null, it.next());
@ -208,12 +207,14 @@ test "echo content server" {
// request.head.target, // request.head.target,
//}); //});
try expect(mem.startsWith(u8, request.head.target, "/echo-content"));
try expectEqualStrings("text/plain", request.head.content_type.?);
// head strings expire here
const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .unlimited); const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .unlimited);
defer std.testing.allocator.free(body); defer std.testing.allocator.free(body);
try expect(mem.startsWith(u8, request.head.target, "/echo-content"));
try expectEqualStrings("Hello, World!\n", body); try expectEqualStrings("Hello, World!\n", body);
try expectEqualStrings("text/plain", request.head.content_type.?);
var response = try request.respondStreaming(&.{}, .{ var response = try request.respondStreaming(&.{}, .{
.content_length = switch (request.head.transfer_encoding) { .content_length = switch (request.head.transfer_encoding) {
@ -410,17 +411,19 @@ test "general client/server API coverage" {
fn handleRequest(request: *http.Server.Request, listen_port: u16) !void { fn handleRequest(request: *http.Server.Request, listen_port: u16) !void {
const log = std.log.scoped(.server); const log = std.log.scoped(.server);
const gpa = std.testing.allocator;
log.info("{f} {t} {s}", .{ request.head.method, request.head.version, request.head.target }); log.info("{f} {t} {s}", .{ request.head.method, request.head.version, request.head.target });
const target = try gpa.dupe(u8, request.head.target);
defer gpa.free(target);
const gpa = std.testing.allocator;
const reader = (try request.readerExpectContinue(&.{})); const reader = (try request.readerExpectContinue(&.{}));
const body = try reader.allocRemaining(gpa, .unlimited); const body = try reader.allocRemaining(gpa, .unlimited);
defer gpa.free(body); defer gpa.free(body);
if (mem.startsWith(u8, request.head.target, "/get")) { if (mem.startsWith(u8, target, "/get")) {
var response = try request.respondStreaming(&.{}, .{ var response = try request.respondStreaming(&.{}, .{
.content_length = if (mem.indexOf(u8, request.head.target, "?chunked") == null) .content_length = if (mem.indexOf(u8, target, "?chunked") == null)
14 14
else else
null, null,
@ -435,7 +438,7 @@ test "general client/server API coverage" {
try w.writeAll("World!\n"); try w.writeAll("World!\n");
try response.end(); try response.end();
// Writing again would cause an assertion failure. // Writing again would cause an assertion failure.
} else if (mem.startsWith(u8, request.head.target, "/large")) { } else if (mem.startsWith(u8, target, "/large")) {
var response = try request.respondStreaming(&.{}, .{ var response = try request.respondStreaming(&.{}, .{
.content_length = 14 * 1024 + 14 * 10, .content_length = 14 * 1024 + 14 * 10,
}); });
@ -458,7 +461,7 @@ test "general client/server API coverage" {
} }
try response.end(); try response.end();
} else if (mem.eql(u8, request.head.target, "/redirect/1")) { } else if (mem.eql(u8, target, "/redirect/1")) {
var response = try request.respondStreaming(&.{}, .{ var response = try request.respondStreaming(&.{}, .{
.respond_options = .{ .respond_options = .{
.status = .found, .status = .found,
@ -472,14 +475,14 @@ test "general client/server API coverage" {
try w.writeAll("Hello, "); try w.writeAll("Hello, ");
try w.writeAll("Redirected!\n"); try w.writeAll("Redirected!\n");
try response.end(); try response.end();
} else if (mem.eql(u8, request.head.target, "/redirect/2")) { } else if (mem.eql(u8, target, "/redirect/2")) {
try request.respond("Hello, Redirected!\n", .{ try request.respond("Hello, Redirected!\n", .{
.status = .found, .status = .found,
.extra_headers = &.{ .extra_headers = &.{
.{ .name = "location", .value = "/redirect/1" }, .{ .name = "location", .value = "/redirect/1" },
}, },
}); });
} else if (mem.eql(u8, request.head.target, "/redirect/3")) { } else if (mem.eql(u8, target, "/redirect/3")) {
const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/2", .{ const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/2", .{
listen_port, listen_port,
}); });
@ -491,23 +494,23 @@ test "general client/server API coverage" {
.{ .name = "location", .value = location }, .{ .name = "location", .value = location },
}, },
}); });
} else if (mem.eql(u8, request.head.target, "/redirect/4")) { } else if (mem.eql(u8, target, "/redirect/4")) {
try request.respond("Hello, Redirected!\n", .{ try request.respond("Hello, Redirected!\n", .{
.status = .found, .status = .found,
.extra_headers = &.{ .extra_headers = &.{
.{ .name = "location", .value = "/redirect/3" }, .{ .name = "location", .value = "/redirect/3" },
}, },
}); });
} else if (mem.eql(u8, request.head.target, "/redirect/5")) { } else if (mem.eql(u8, target, "/redirect/5")) {
try request.respond("Hello, Redirected!\n", .{ try request.respond("Hello, Redirected!\n", .{
.status = .found, .status = .found,
.extra_headers = &.{ .extra_headers = &.{
.{ .name = "location", .value = "/%2525" }, .{ .name = "location", .value = "/%2525" },
}, },
}); });
} else if (mem.eql(u8, request.head.target, "/%2525")) { } else if (mem.eql(u8, target, "/%2525")) {
try request.respond("Encoded redirect successful!\n", .{}); try request.respond("Encoded redirect successful!\n", .{});
} else if (mem.eql(u8, request.head.target, "/redirect/invalid")) { } else if (mem.eql(u8, target, "/redirect/invalid")) {
const invalid_port = try getUnusedTcpPort(); const invalid_port = try getUnusedTcpPort();
const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}", .{invalid_port}); const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}", .{invalid_port});
defer gpa.free(location); defer gpa.free(location);
@ -518,7 +521,7 @@ test "general client/server API coverage" {
.{ .name = "location", .value = location }, .{ .name = "location", .value = location },
}, },
}); });
} else if (mem.eql(u8, request.head.target, "/empty")) { } else if (mem.eql(u8, target, "/empty")) {
try request.respond("", .{ try request.respond("", .{
.extra_headers = &.{ .extra_headers = &.{
.{ .name = "empty", .value = "" }, .{ .name = "empty", .value = "" },
@ -559,11 +562,12 @@ test "general client/server API coverage" {
try req.sendBodiless(); try req.sendBodiless();
var response = try req.receiveHead(&redirect_buffer); var response = try req.receiveHead(&redirect_buffer);
try expectEqualStrings("text/plain", response.head.content_type.?);
const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body); defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body); try expectEqualStrings("Hello, World!\n", body);
try expectEqualStrings("text/plain", response.head.content_type.?);
} }
// connection has been kept alive // connection has been kept alive
@ -604,12 +608,13 @@ test "general client/server API coverage" {
try req.sendBodiless(); try req.sendBodiless();
var response = try req.receiveHead(&redirect_buffer); var response = try req.receiveHead(&redirect_buffer);
try expectEqualStrings("text/plain", response.head.content_type.?);
try expectEqual(14, response.head.content_length.?);
const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body); defer gpa.free(body);
try expectEqualStrings("", body); try expectEqualStrings("", body);
try expectEqualStrings("text/plain", response.head.content_type.?);
try expectEqual(14, response.head.content_length.?);
} }
// connection has been kept alive // connection has been kept alive
@ -628,11 +633,12 @@ test "general client/server API coverage" {
try req.sendBodiless(); try req.sendBodiless();
var response = try req.receiveHead(&redirect_buffer); var response = try req.receiveHead(&redirect_buffer);
try expectEqualStrings("text/plain", response.head.content_type.?);
const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body); defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body); try expectEqualStrings("Hello, World!\n", body);
try expectEqualStrings("text/plain", response.head.content_type.?);
} }
// connection has been kept alive // connection has been kept alive
@ -651,12 +657,13 @@ test "general client/server API coverage" {
try req.sendBodiless(); try req.sendBodiless();
var response = try req.receiveHead(&redirect_buffer); var response = try req.receiveHead(&redirect_buffer);
try expectEqualStrings("text/plain", response.head.content_type.?);
try expect(response.head.transfer_encoding == .chunked);
const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body); defer gpa.free(body);
try expectEqualStrings("", body); try expectEqualStrings("", body);
try expectEqualStrings("text/plain", response.head.content_type.?);
try expect(response.head.transfer_encoding == .chunked);
} }
// connection has been kept alive // connection has been kept alive
@ -677,11 +684,12 @@ test "general client/server API coverage" {
try req.sendBodiless(); try req.sendBodiless();
var response = try req.receiveHead(&redirect_buffer); var response = try req.receiveHead(&redirect_buffer);
try expectEqualStrings("text/plain", response.head.content_type.?);
const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body); defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body); try expectEqualStrings("Hello, World!\n", body);
try expectEqualStrings("text/plain", response.head.content_type.?);
} }
// connection has been closed // connection has been closed
@ -706,11 +714,6 @@ test "general client/server API coverage" {
try std.testing.expectEqual(.ok, response.head.status); try std.testing.expectEqual(.ok, response.head.status);
const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("", body);
var it = response.head.iterateHeaders(); var it = response.head.iterateHeaders();
{ {
const header = it.next().?; const header = it.next().?;
@ -718,6 +721,12 @@ test "general client/server API coverage" {
try expectEqualStrings("content-length", header.name); try expectEqualStrings("content-length", header.name);
try expectEqualStrings("0", header.value); try expectEqualStrings("0", header.value);
} }
const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("", body);
{ {
const header = it.next().?; const header = it.next().?;
try expect(!it.is_trailer); try expect(!it.is_trailer);