Merge pull request #15299 from truemedian/std-http

std.http: curated error sets and custom Headers
This commit is contained in:
Andrew Kelley 2023-04-18 19:56:24 -07:00 committed by GitHub
commit 0eebc25880
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1150 additions and 581 deletions

View file

@ -27,6 +27,18 @@ pub fn escapeQuery(allocator: std.mem.Allocator, input: []const u8) error{OutOfM
return escapeStringWithFn(allocator, input, isQueryChar); return escapeStringWithFn(allocator, input, isQueryChar);
} }
pub fn writeEscapedString(writer: anytype, input: []const u8) !void {
return writeEscapedStringWithFn(writer, input, isUnreserved);
}
pub fn writeEscapedPath(writer: anytype, input: []const u8) !void {
return writeEscapedStringWithFn(writer, input, isPathChar);
}
pub fn writeEscapedQuery(writer: anytype, input: []const u8) !void {
return writeEscapedStringWithFn(writer, input, isQueryChar);
}
pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) std.mem.Allocator.Error![]const u8 { pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) std.mem.Allocator.Error![]const u8 {
var outsize: usize = 0; var outsize: usize = 0;
for (input) |c| { for (input) |c| {
@ -52,6 +64,16 @@ pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, compt
return output; return output;
} }
pub fn writeEscapedStringWithFn(writer: anytype, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) @TypeOf(writer).Error!void {
for (input) |c| {
if (keepUnescaped(c)) {
try writer.writeByte(c);
} else {
try writer.print("%{X:0>2}", .{c});
}
}
}
/// Parses a URI string and unescapes all %XX where XX is a valid hex number. Otherwise, verbatim copies /// Parses a URI string and unescapes all %XX where XX is a valid hex number. Otherwise, verbatim copies
/// them to the output. /// them to the output.
pub fn unescapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 { pub fn unescapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
@ -184,6 +206,60 @@ pub fn parseWithoutScheme(text: []const u8) ParseError!Uri {
return uri; return uri;
} }
pub fn format(
uri: Uri,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) @TypeOf(writer).Error!void {
_ = options;
const needs_absolute = comptime std.mem.indexOf(u8, fmt, "+") != null;
const needs_path = comptime std.mem.indexOf(u8, fmt, "/") != null or fmt.len == 0;
if (needs_absolute) {
try writer.writeAll(uri.scheme);
try writer.writeAll(":");
if (uri.host) |host| {
try writer.writeAll("//");
if (uri.user) |user| {
try writer.writeAll(user);
if (uri.password) |password| {
try writer.writeAll(":");
try writer.writeAll(password);
}
try writer.writeAll("@");
}
try writer.writeAll(host);
if (uri.port) |port| {
try writer.writeAll(":");
try std.fmt.formatInt(port, 10, .lower, .{}, writer);
}
}
}
if (needs_path) {
if (uri.path.len == 0) {
try writer.writeAll("/");
} else {
try Uri.writeEscapedPath(writer, uri.path);
}
if (uri.query) |q| {
try writer.writeAll("?");
try Uri.writeEscapedQuery(writer, q);
}
if (uri.fragment) |f| {
try writer.writeAll("#");
try Uri.writeEscapedQuery(writer, f);
}
}
}
/// Parses the URI or returns an error. /// Parses the URI or returns an error.
/// The return value will contain unescaped strings pointing into the /// The return value will contain unescaped strings pointing into the
/// original `text`. Each component that is provided, will be non-`null`. /// original `text`. Each component that is provided, will be non-`null`.

View file

@ -371,7 +371,9 @@ test "Parsed.checkHostName" {
try expectEqual(false, Parsed.checkHostName("lang.org", "zig*.org")); try expectEqual(false, Parsed.checkHostName("lang.org", "zig*.org"));
} }
pub fn parse(cert: Certificate) !Parsed { pub const ParseError = der.Element.ParseElementError || ParseVersionError || ParseTimeError || ParseEnumError || ParseBitStringError;
pub fn parse(cert: Certificate) ParseError!Parsed {
const cert_bytes = cert.buffer; const cert_bytes = cert.buffer;
const certificate = try der.Element.parse(cert_bytes, cert.index); const certificate = try der.Element.parse(cert_bytes, cert.index);
const tbs_certificate = try der.Element.parse(cert_bytes, certificate.slice.start); const tbs_certificate = try der.Element.parse(cert_bytes, certificate.slice.start);
@ -514,14 +516,18 @@ pub fn contents(cert: Certificate, elem: der.Element) []const u8 {
return cert.buffer[elem.slice.start..elem.slice.end]; return cert.buffer[elem.slice.start..elem.slice.end];
} }
pub const ParseBitStringError = error{ CertificateFieldHasWrongDataType, CertificateHasInvalidBitString };
pub fn parseBitString(cert: Certificate, elem: der.Element) !der.Element.Slice { pub fn parseBitString(cert: Certificate, elem: der.Element) !der.Element.Slice {
if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType; if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType;
if (cert.buffer[elem.slice.start] != 0) return error.CertificateHasInvalidBitString; if (cert.buffer[elem.slice.start] != 0) return error.CertificateHasInvalidBitString;
return .{ .start = elem.slice.start + 1, .end = elem.slice.end }; return .{ .start = elem.slice.start + 1, .end = elem.slice.end };
} }
pub const ParseTimeError = error{ CertificateTimeInvalid, CertificateFieldHasWrongDataType };
/// Returns number of seconds since epoch. /// Returns number of seconds since epoch.
pub fn parseTime(cert: Certificate, elem: der.Element) !u64 { pub fn parseTime(cert: Certificate, elem: der.Element) ParseTimeError!u64 {
const bytes = cert.contents(elem); const bytes = cert.contents(elem);
switch (elem.identifier.tag) { switch (elem.identifier.tag) {
.utc_time => { .utc_time => {
@ -647,34 +653,38 @@ test parseYear4 {
try expectError(error.CertificateTimeInvalid, parseYear4("crap")); try expectError(error.CertificateTimeInvalid, parseYear4("crap"));
} }
pub fn parseAlgorithm(bytes: []const u8, element: der.Element) !Algorithm { pub fn parseAlgorithm(bytes: []const u8, element: der.Element) ParseEnumError!Algorithm {
return parseEnum(Algorithm, bytes, element); return parseEnum(Algorithm, bytes, element);
} }
pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) !AlgorithmCategory { pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) ParseEnumError!AlgorithmCategory {
return parseEnum(AlgorithmCategory, bytes, element); return parseEnum(AlgorithmCategory, bytes, element);
} }
pub fn parseAttribute(bytes: []const u8, element: der.Element) !Attribute { pub fn parseAttribute(bytes: []const u8, element: der.Element) ParseEnumError!Attribute {
return parseEnum(Attribute, bytes, element); return parseEnum(Attribute, bytes, element);
} }
pub fn parseNamedCurve(bytes: []const u8, element: der.Element) !NamedCurve { pub fn parseNamedCurve(bytes: []const u8, element: der.Element) ParseEnumError!NamedCurve {
return parseEnum(NamedCurve, bytes, element); return parseEnum(NamedCurve, bytes, element);
} }
pub fn parseExtensionId(bytes: []const u8, element: der.Element) !ExtensionId { pub fn parseExtensionId(bytes: []const u8, element: der.Element) ParseEnumError!ExtensionId {
return parseEnum(ExtensionId, bytes, element); return parseEnum(ExtensionId, bytes, element);
} }
fn parseEnum(comptime E: type, bytes: []const u8, element: der.Element) !E { pub const ParseEnumError = error{ CertificateFieldHasWrongDataType, CertificateHasUnrecognizedObjectId };
fn parseEnum(comptime E: type, bytes: []const u8, element: der.Element) ParseEnumError!E {
if (element.identifier.tag != .object_identifier) if (element.identifier.tag != .object_identifier)
return error.CertificateFieldHasWrongDataType; return error.CertificateFieldHasWrongDataType;
const oid_bytes = bytes[element.slice.start..element.slice.end]; const oid_bytes = bytes[element.slice.start..element.slice.end];
return E.map.get(oid_bytes) orelse return error.CertificateHasUnrecognizedObjectId; return E.map.get(oid_bytes) orelse return error.CertificateHasUnrecognizedObjectId;
} }
pub fn parseVersion(bytes: []const u8, version_elem: der.Element) !Version { pub const ParseVersionError = error{ UnsupportedCertificateVersion, CertificateFieldHasInvalidLength };
pub fn parseVersion(bytes: []const u8, version_elem: der.Element) ParseVersionError!Version {
if (@bitCast(u8, version_elem.identifier) != 0xa0) if (@bitCast(u8, version_elem.identifier) != 0xa0)
return .v1; return .v1;
@ -861,9 +871,9 @@ pub const der = struct {
pub const empty: Slice = .{ .start = 0, .end = 0 }; pub const empty: Slice = .{ .start = 0, .end = 0 };
}; };
pub const ParseError = error{CertificateFieldHasInvalidLength}; pub const ParseElementError = error{CertificateFieldHasInvalidLength};
pub fn parse(bytes: []const u8, index: u32) ParseError!Element { pub fn parse(bytes: []const u8, index: u32) ParseElementError!Element {
var i = index; var i = index;
const identifier = @bitCast(Identifier, bytes[i]); const identifier = @bitCast(Identifier, bytes[i]);
i += 1; i += 1;

View file

@ -50,11 +50,13 @@ pub fn deinit(cb: *Bundle, gpa: Allocator) void {
cb.* = undefined; cb.* = undefined;
} }
pub const RescanError = RescanLinuxError || RescanMacError || RescanWindowsError;
/// Clears the set of certificates and then scans the host operating system /// Clears the set of certificates and then scans the host operating system
/// file system standard locations for certificates. /// file system standard locations for certificates.
/// For operating systems that do not have standard CA installations to be /// For operating systems that do not have standard CA installations to be
/// found, this function clears the set of certificates. /// found, this function clears the set of certificates.
pub fn rescan(cb: *Bundle, gpa: Allocator) !void { pub fn rescan(cb: *Bundle, gpa: Allocator) RescanError!void {
switch (builtin.os.tag) { switch (builtin.os.tag) {
.linux => return rescanLinux(cb, gpa), .linux => return rescanLinux(cb, gpa),
.macos => return rescanMac(cb, gpa), .macos => return rescanMac(cb, gpa),
@ -64,8 +66,11 @@ pub fn rescan(cb: *Bundle, gpa: Allocator) !void {
} }
pub const rescanMac = @import("Bundle/macos.zig").rescanMac; pub const rescanMac = @import("Bundle/macos.zig").rescanMac;
pub const RescanMacError = @import("Bundle/macos.zig").RescanMacError;
pub fn rescanLinux(cb: *Bundle, gpa: Allocator) !void { pub const RescanLinuxError = AddCertsFromFilePathError || AddCertsFromDirPathError;
pub fn rescanLinux(cb: *Bundle, gpa: Allocator) RescanLinuxError!void {
// Possible certificate files; stop after finding one. // Possible certificate files; stop after finding one.
const cert_file_paths = [_][]const u8{ const cert_file_paths = [_][]const u8{
"/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu/Gentoo etc. "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu/Gentoo etc.
@ -107,7 +112,9 @@ pub fn rescanLinux(cb: *Bundle, gpa: Allocator) !void {
cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len); cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len);
} }
pub fn rescanWindows(cb: *Bundle, gpa: Allocator) !void { pub const RescanWindowsError = Allocator.Error || ParseCertError || std.os.UnexpectedError || error{FileNotFound};
pub fn rescanWindows(cb: *Bundle, gpa: Allocator) RescanWindowsError!void {
cb.bytes.clearRetainingCapacity(); cb.bytes.clearRetainingCapacity();
cb.map.clearRetainingCapacity(); cb.map.clearRetainingCapacity();
@ -132,12 +139,14 @@ pub fn rescanWindows(cb: *Bundle, gpa: Allocator) !void {
cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len); cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len);
} }
pub const AddCertsFromDirPathError = fs.File.OpenError || AddCertsFromDirError;
pub fn addCertsFromDirPath( pub fn addCertsFromDirPath(
cb: *Bundle, cb: *Bundle,
gpa: Allocator, gpa: Allocator,
dir: fs.Dir, dir: fs.Dir,
sub_dir_path: []const u8, sub_dir_path: []const u8,
) !void { ) AddCertsFromDirPathError!void {
var iterable_dir = try dir.openIterableDir(sub_dir_path, .{}); var iterable_dir = try dir.openIterableDir(sub_dir_path, .{});
defer iterable_dir.close(); defer iterable_dir.close();
return addCertsFromDir(cb, gpa, iterable_dir); return addCertsFromDir(cb, gpa, iterable_dir);
@ -147,14 +156,16 @@ pub fn addCertsFromDirPathAbsolute(
cb: *Bundle, cb: *Bundle,
gpa: Allocator, gpa: Allocator,
abs_dir_path: []const u8, abs_dir_path: []const u8,
) !void { ) AddCertsFromDirPathError!void {
assert(fs.path.isAbsolute(abs_dir_path)); assert(fs.path.isAbsolute(abs_dir_path));
var iterable_dir = try fs.openIterableDirAbsolute(abs_dir_path, .{}); var iterable_dir = try fs.openIterableDirAbsolute(abs_dir_path, .{});
defer iterable_dir.close(); defer iterable_dir.close();
return addCertsFromDir(cb, gpa, iterable_dir); return addCertsFromDir(cb, gpa, iterable_dir);
} }
pub fn addCertsFromDir(cb: *Bundle, gpa: Allocator, iterable_dir: fs.IterableDir) !void { pub const AddCertsFromDirError = AddCertsFromFilePathError;
pub fn addCertsFromDir(cb: *Bundle, gpa: Allocator, iterable_dir: fs.IterableDir) AddCertsFromDirError!void {
var it = iterable_dir.iterate(); var it = iterable_dir.iterate();
while (try it.next()) |entry| { while (try it.next()) |entry| {
switch (entry.kind) { switch (entry.kind) {
@ -166,11 +177,13 @@ pub fn addCertsFromDir(cb: *Bundle, gpa: Allocator, iterable_dir: fs.IterableDir
} }
} }
pub const AddCertsFromFilePathError = fs.File.OpenError || AddCertsFromFileError;
pub fn addCertsFromFilePathAbsolute( pub fn addCertsFromFilePathAbsolute(
cb: *Bundle, cb: *Bundle,
gpa: Allocator, gpa: Allocator,
abs_file_path: []const u8, abs_file_path: []const u8,
) !void { ) AddCertsFromFilePathError!void {
assert(fs.path.isAbsolute(abs_file_path)); assert(fs.path.isAbsolute(abs_file_path));
var file = try fs.openFileAbsolute(abs_file_path, .{}); var file = try fs.openFileAbsolute(abs_file_path, .{});
defer file.close(); defer file.close();
@ -182,13 +195,15 @@ pub fn addCertsFromFilePath(
gpa: Allocator, gpa: Allocator,
dir: fs.Dir, dir: fs.Dir,
sub_file_path: []const u8, sub_file_path: []const u8,
) !void { ) AddCertsFromFilePathError!void {
var file = try dir.openFile(sub_file_path, .{}); var file = try dir.openFile(sub_file_path, .{});
defer file.close(); defer file.close();
return addCertsFromFile(cb, gpa, file); return addCertsFromFile(cb, gpa, file);
} }
pub fn addCertsFromFile(cb: *Bundle, gpa: Allocator, file: fs.File) !void { pub const AddCertsFromFileError = Allocator.Error || fs.File.GetSeekPosError || fs.File.ReadError || ParseCertError || std.base64.Error || error{ CertificateAuthorityBundleTooBig, MissingEndCertificateMarker };
pub fn addCertsFromFile(cb: *Bundle, gpa: Allocator, file: fs.File) AddCertsFromFileError!void {
const size = try file.getEndPos(); const size = try file.getEndPos();
// We borrow `bytes` as a temporary buffer for the base64-encoded data. // We borrow `bytes` as a temporary buffer for the base64-encoded data.
@ -222,7 +237,9 @@ pub fn addCertsFromFile(cb: *Bundle, gpa: Allocator, file: fs.File) !void {
} }
} }
pub fn parseCert(cb: *Bundle, gpa: Allocator, decoded_start: u32, now_sec: i64) !void { pub const ParseCertError = Allocator.Error || Certificate.ParseError;
pub fn parseCert(cb: *Bundle, gpa: Allocator, decoded_start: u32, now_sec: i64) ParseCertError!void {
// Even though we could only partially parse the certificate to find // Even though we could only partially parse the certificate to find
// the subject name, we pre-parse all of them to make sure and only // the subject name, we pre-parse all of them to make sure and only
// include in the bundle ones that we know will parse. This way we can // include in the bundle ones that we know will parse. This way we can

View file

@ -5,7 +5,9 @@ const mem = std.mem;
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const Bundle = @import("../Bundle.zig"); const Bundle = @import("../Bundle.zig");
pub fn rescanMac(cb: *Bundle, gpa: Allocator) !void { pub const RescanMacError = Allocator.Error || fs.File.OpenError || fs.File.ReadError || fs.File.SeekError || Bundle.ParseCertError || error{EndOfStream};
pub fn rescanMac(cb: *Bundle, gpa: Allocator) RescanMacError!void {
cb.bytes.clearRetainingCapacity(); cb.bytes.clearRetainingCapacity();
cb.map.clearRetainingCapacity(); cb.map.clearRetainingCapacity();

View file

@ -1,6 +1,10 @@
pub const Client = @import("http/Client.zig"); pub const Client = @import("http/Client.zig");
pub const Server = @import("http/Server.zig"); pub const Server = @import("http/Server.zig");
pub const protocol = @import("http/protocol.zig"); pub const protocol = @import("http/protocol.zig");
const headers = @import("http/Headers.zig");
pub const Headers = headers.Headers;
pub const Field = headers.Field;
pub const Version = enum { pub const Version = enum {
@"HTTP/1.0", @"HTTP/1.0",
@ -265,11 +269,6 @@ pub const Connection = enum {
close, close,
}; };
pub const CustomHeader = struct {
name: []const u8,
value: []const u8,
};
const std = @import("std.zig"); const std = @import("std.zig");
test { test {

File diff suppressed because it is too large Load diff

386
lib/std/http/Headers.zig Normal file
View file

@ -0,0 +1,386 @@
const std = @import("../std.zig");
const Allocator = std.mem.Allocator;
const testing = std.testing;
const ascii = std.ascii;
const assert = std.debug.assert;
pub const HeaderList = std.ArrayListUnmanaged(Field);
pub const HeaderIndexList = std.ArrayListUnmanaged(usize);
pub const HeaderIndex = std.HashMapUnmanaged([]const u8, HeaderIndexList, CaseInsensitiveStringContext, std.hash_map.default_max_load_percentage);
pub const CaseInsensitiveStringContext = struct {
pub fn hash(self: @This(), s: []const u8) u64 {
_ = self;
var buf: [64]u8 = undefined;
var i: u8 = 0;
var h = std.hash.Wyhash.init(0);
while (i < s.len) : (i += 64) {
const left = @min(64, s.len - i);
const ret = ascii.lowerString(buf[0..], s[i..][0..left]);
h.update(ret);
}
return h.final();
}
pub fn eql(self: @This(), a: []const u8, b: []const u8) bool {
_ = self;
return ascii.eqlIgnoreCase(a, b);
}
};
pub const Field = struct {
name: []const u8,
value: []const u8,
pub fn modify(entry: *Field, allocator: Allocator, new_value: []const u8) !void {
if (entry.value.len <= new_value.len) {
std.mem.copy(u8, @constCast(entry.value), new_value);
} else {
allocator.free(entry.value);
entry.value = try allocator.dupe(u8, new_value);
}
}
fn lessThan(ctx: void, a: Field, b: Field) bool {
_ = ctx;
if (a.name.ptr == b.name.ptr) return false;
return ascii.lessThanIgnoreCase(a.name, b.name);
}
};
pub const Headers = struct {
allocator: Allocator,
list: HeaderList = .{},
index: HeaderIndex = .{},
/// When this is false, names and values will not be duplicated.
/// Use with caution.
owned: bool = true,
pub fn init(allocator: Allocator) Headers {
return .{ .allocator = allocator };
}
pub fn deinit(headers: *Headers) void {
var it = headers.index.iterator();
while (it.next()) |entry| {
entry.value_ptr.deinit(headers.allocator);
if (headers.owned) headers.allocator.free(entry.key_ptr.*);
}
for (headers.list.items) |entry| {
if (headers.owned) headers.allocator.free(entry.value);
}
headers.index.deinit(headers.allocator);
headers.list.deinit(headers.allocator);
headers.* = undefined;
}
/// Appends a header to the list. Both name and value are copied.
pub fn append(headers: *Headers, name: []const u8, value: []const u8) !void {
const n = headers.list.items.len;
const value_duped = if (headers.owned) try headers.allocator.dupe(u8, value) else value;
errdefer if (headers.owned) headers.allocator.free(value_duped);
var entry = Field{ .name = undefined, .value = value_duped };
if (headers.index.getEntry(name)) |kv| {
entry.name = kv.key_ptr.*;
try kv.value_ptr.append(headers.allocator, n);
} else {
const name_duped = if (headers.owned) try headers.allocator.dupe(u8, name) else name;
errdefer if (headers.owned) headers.allocator.free(name_duped);
entry.name = name_duped;
var new_index = try HeaderIndexList.initCapacity(headers.allocator, 1);
errdefer new_index.deinit(headers.allocator);
new_index.appendAssumeCapacity(n);
try headers.index.put(headers.allocator, name_duped, new_index);
}
try headers.list.append(headers.allocator, entry);
}
pub fn contains(headers: Headers, name: []const u8) bool {
return headers.index.contains(name);
}
pub fn delete(headers: *Headers, name: []const u8) bool {
if (headers.index.fetchRemove(name)) |kv| {
var index = kv.value;
// iterate backwards
var i = index.items.len;
while (i > 0) {
i -= 1;
const data_index = index.items[i];
const removed = headers.list.orderedRemove(data_index);
assert(ascii.eqlIgnoreCase(removed.name, name)); // ensure the index hasn't been corrupted
if (headers.owned) headers.allocator.free(removed.value);
}
if (headers.owned) headers.allocator.free(kv.key);
index.deinit(headers.allocator);
headers.rebuildIndex();
return true;
} else {
return false;
}
}
/// Returns the index of the first occurrence of a header with the given name.
pub fn firstIndexOf(headers: Headers, name: []const u8) ?usize {
const index = headers.index.get(name) orelse return null;
return index.items[0];
}
/// Returns a list of indices containing headers with the given name.
pub fn getIndices(headers: Headers, name: []const u8) ?[]const usize {
const index = headers.index.get(name) orelse return null;
return index.items;
}
/// Returns the entry of the first occurrence of a header with the given name.
pub fn getFirstEntry(headers: Headers, name: []const u8) ?Field {
const first_index = headers.firstIndexOf(name) orelse return null;
return headers.list.items[first_index];
}
/// Returns a slice containing each header with the given name.
/// The caller owns the returned slice, but NOT the values in the slice.
pub fn getEntries(headers: Headers, allocator: Allocator, name: []const u8) !?[]const Field {
const indices = headers.getIndices(name) orelse return null;
const buf = try allocator.alloc(Field, indices.len);
for (indices, 0..) |idx, n| {
buf[n] = headers.list.items[idx];
}
return buf;
}
/// Returns the value in the entry of the first occurrence of a header with the given name.
pub fn getFirstValue(headers: Headers, name: []const u8) ?[]const u8 {
const first_index = headers.firstIndexOf(name) orelse return null;
return headers.list.items[first_index].value;
}
/// Returns a slice containing the value of each header with the given name.
/// The caller owns the returned slice, but NOT the values in the slice.
pub fn getValues(headers: Headers, allocator: Allocator, name: []const u8) !?[]const []const u8 {
const indices = headers.getIndices(name) orelse return null;
const buf = try allocator.alloc([]const u8, indices.len);
for (indices, 0..) |idx, n| {
buf[n] = headers.list.items[idx].value;
}
return buf;
}
fn rebuildIndex(headers: *Headers) void {
// clear out the indexes
var it = headers.index.iterator();
while (it.next()) |entry| {
entry.value_ptr.shrinkRetainingCapacity(0);
}
// fill up indexes again; we know capacity is fine from before
for (headers.list.items, 0..) |entry, i| {
headers.index.getEntry(entry.name).?.value_ptr.appendAssumeCapacity(i);
}
}
/// Sorts the headers in lexicographical order.
pub fn sort(headers: *Headers) void {
std.sort.sort(Field, headers.list.items, {}, Field.lessThan);
headers.rebuildIndex();
}
/// Writes the headers to the given stream.
pub fn format(
headers: Headers,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
out_stream: anytype,
) !void {
_ = fmt;
_ = options;
for (headers.list.items) |entry| {
if (entry.value.len == 0) continue;
try out_stream.writeAll(entry.name);
try out_stream.writeAll(": ");
try out_stream.writeAll(entry.value);
try out_stream.writeAll("\r\n");
}
}
/// Writes all of the headers with the given name to the given stream, separated by commas.
///
/// This is useful for headers like `Set-Cookie` which can have multiple values. RFC 9110, Section 5.2
pub fn formatCommaSeparated(
headers: Headers,
name: []const u8,
out_stream: anytype,
) !void {
const indices = headers.getIndices(name) orelse return;
try out_stream.writeAll(name);
try out_stream.writeAll(": ");
for (indices, 0..) |idx, n| {
if (n != 0) try out_stream.writeAll(", ");
try out_stream.writeAll(headers.list.items[idx].value);
}
try out_stream.writeAll("\r\n");
}
};
test "Headers.append" {
var h = Headers{ .allocator = std.testing.allocator };
defer h.deinit();
try h.append("foo", "bar");
try h.append("hello", "world");
try testing.expect(h.contains("Foo"));
try testing.expect(!h.contains("Bar"));
}
test "Headers.delete" {
var h = Headers{ .allocator = std.testing.allocator };
defer h.deinit();
try h.append("foo", "bar");
try h.append("hello", "world");
try testing.expect(h.contains("Foo"));
_ = h.delete("Foo");
try testing.expect(!h.contains("foo"));
}
test "Headers consistency" {
var h = Headers{ .allocator = std.testing.allocator };
defer h.deinit();
try h.append("foo", "bar");
try h.append("hello", "world");
_ = h.delete("Foo");
try h.append("foo", "bar");
try h.append("bar", "world");
try h.append("foo", "baz");
try h.append("baz", "hello");
try testing.expectEqual(@as(?usize, 0), h.firstIndexOf("hello"));
try testing.expectEqual(@as(?usize, 1), h.firstIndexOf("foo"));
try testing.expectEqual(@as(?usize, 2), h.firstIndexOf("bar"));
try testing.expectEqual(@as(?usize, 4), h.firstIndexOf("baz"));
try testing.expectEqual(@as(?usize, null), h.firstIndexOf("pog"));
try testing.expectEqualSlices(usize, &[_]usize{0}, h.getIndices("hello").?);
try testing.expectEqualSlices(usize, &[_]usize{ 1, 3 }, h.getIndices("foo").?);
try testing.expectEqualSlices(usize, &[_]usize{2}, h.getIndices("bar").?);
try testing.expectEqualSlices(usize, &[_]usize{4}, h.getIndices("baz").?);
try testing.expectEqual(@as(?[]const usize, null), h.getIndices("pog"));
try testing.expectEqualStrings("world", h.getFirstEntry("hello").?.value);
try testing.expectEqualStrings("bar", h.getFirstEntry("foo").?.value);
try testing.expectEqualStrings("world", h.getFirstEntry("bar").?.value);
try testing.expectEqualStrings("hello", h.getFirstEntry("baz").?.value);
const hello_entries = (try h.getEntries(testing.allocator, "hello")).?;
defer testing.allocator.free(hello_entries);
try testing.expectEqualDeep(@as([]const Field, &[_]Field{
.{ .name = "hello", .value = "world" },
}), hello_entries);
const foo_entries = (try h.getEntries(testing.allocator, "foo")).?;
defer testing.allocator.free(foo_entries);
try testing.expectEqualDeep(@as([]const Field, &[_]Field{
.{ .name = "foo", .value = "bar" },
.{ .name = "foo", .value = "baz" },
}), foo_entries);
const bar_entries = (try h.getEntries(testing.allocator, "bar")).?;
defer testing.allocator.free(bar_entries);
try testing.expectEqualDeep(@as([]const Field, &[_]Field{
.{ .name = "bar", .value = "world" },
}), bar_entries);
const baz_entries = (try h.getEntries(testing.allocator, "baz")).?;
defer testing.allocator.free(baz_entries);
try testing.expectEqualDeep(@as([]const Field, &[_]Field{
.{ .name = "baz", .value = "hello" },
}), baz_entries);
const pog_entries = (try h.getEntries(testing.allocator, "pog"));
try testing.expectEqual(@as(?[]const Field, null), pog_entries);
try testing.expectEqualStrings("world", h.getFirstValue("hello").?);
try testing.expectEqualStrings("bar", h.getFirstValue("foo").?);
try testing.expectEqualStrings("world", h.getFirstValue("bar").?);
try testing.expectEqualStrings("hello", h.getFirstValue("baz").?);
try testing.expectEqual(@as(?[]const u8, null), h.getFirstValue("pog"));
const hello_values = (try h.getValues(testing.allocator, "hello")).?;
defer testing.allocator.free(hello_values);
try testing.expectEqualDeep(@as([]const []const u8, &[_][]const u8{"world"}), hello_values);
const foo_values = (try h.getValues(testing.allocator, "foo")).?;
defer testing.allocator.free(foo_values);
try testing.expectEqualDeep(@as([]const []const u8, &[_][]const u8{ "bar", "baz" }), foo_values);
const bar_values = (try h.getValues(testing.allocator, "bar")).?;
defer testing.allocator.free(bar_values);
try testing.expectEqualDeep(@as([]const []const u8, &[_][]const u8{"world"}), bar_values);
const baz_values = (try h.getValues(testing.allocator, "baz")).?;
defer testing.allocator.free(baz_values);
try testing.expectEqualDeep(@as([]const []const u8, &[_][]const u8{"hello"}), baz_values);
const pog_values = (try h.getValues(testing.allocator, "pog"));
try testing.expectEqual(@as(?[]const []const u8, null), pog_values);
h.sort();
try testing.expectEqualSlices(usize, &[_]usize{0}, h.getIndices("bar").?);
try testing.expectEqualSlices(usize, &[_]usize{1}, h.getIndices("baz").?);
try testing.expectEqualSlices(usize, &[_]usize{ 2, 3 }, h.getIndices("foo").?);
try testing.expectEqualSlices(usize, &[_]usize{4}, h.getIndices("hello").?);
const formatted_values = try std.fmt.allocPrint(testing.allocator, "{}", .{h});
defer testing.allocator.free(formatted_values);
try testing.expectEqualStrings("bar: world\r\nbaz: hello\r\nfoo: bar\r\nfoo: baz\r\nhello: world\r\n", formatted_values);
var buf: [128]u8 = undefined;
var fbs = std.io.fixedBufferStream(&buf);
const writer = fbs.writer();
try h.formatCommaSeparated("foo", writer);
try testing.expectEqualStrings("foo: bar, baz\r\n", fbs.getWritten());
}

View file

@ -23,21 +23,33 @@ pub const Connection = struct {
pub const Protocol = enum { plain }; pub const Protocol = enum { plain };
pub fn read(conn: *Connection, buffer: []u8) !usize { pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
switch (conn.protocol) { return switch (conn.protocol) {
.plain => return conn.stream.read(buffer), .plain => conn.stream.read(buffer),
// .tls => return conn.tls_client.read(conn.stream, buffer), // .tls => return conn.tls_client.read(conn.stream, buffer),
} } catch |err| switch (err) {
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
};
} }
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize { pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
switch (conn.protocol) { return switch (conn.protocol) {
.plain => return conn.stream.readAtLeast(buffer, len), .plain => conn.stream.readAtLeast(buffer, len),
// .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len), // .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
} } catch |err| switch (err) {
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
};
} }
pub const ReadError = net.Stream.ReadError; pub const ReadError = error{
ConnectionTimedOut,
ConnectionResetByPeer,
UnexpectedReadFailure,
};
pub const Reader = std.io.Reader(*Connection, ReadError, read); pub const Reader = std.io.Reader(*Connection, ReadError, read);
@ -45,21 +57,31 @@ pub const Connection = struct {
return Reader{ .context = conn }; return Reader{ .context = conn };
} }
pub fn writeAll(conn: *Connection, buffer: []const u8) !void { pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void {
switch (conn.protocol) { return switch (conn.protocol) {
.plain => return conn.stream.writeAll(buffer), .plain => conn.stream.writeAll(buffer),
// .tls => return conn.tls_client.writeAll(conn.stream, buffer), // .tls => return conn.tls_client.writeAll(conn.stream, buffer),
} } catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
} }
pub fn write(conn: *Connection, buffer: []const u8) !usize { pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
switch (conn.protocol) { return switch (conn.protocol) {
.plain => return conn.stream.write(buffer), .plain => conn.stream.write(buffer),
// .tls => return conn.tls_client.write(conn.stream, buffer), // .tls => return conn.tls_client.write(conn.stream, buffer),
} } catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
} }
pub const WriteError = net.Stream.WriteError || error{}; pub const WriteError = error{
ConnectionResetByPeer,
UnexpectedWriteFailure,
};
pub const Writer = std.io.Writer(*Connection, WriteError, write); pub const Writer = std.io.Writer(*Connection, WriteError, write);
pub fn writer(conn: *Connection) Writer { pub fn writer(conn: *Connection) Writer {
@ -155,136 +177,142 @@ pub const BufferedConnection = struct {
} }
}; };
/// The mode of transport for responses.
pub const ResponseTransfer = union(enum) {
content_length: u64,
chunked: void,
none: void,
};
/// The decompressor for request messages.
pub const Compression = union(enum) {
pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader);
pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader);
pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});
deflate: DeflateDecompressor,
gzip: GzipDecompressor,
zstd: ZstdDecompressor,
none: void,
};
/// A HTTP request originating from a client. /// A HTTP request originating from a client.
pub const Request = struct { pub const Request = struct {
pub const Headers = struct { pub const ParseError = Allocator.Error || error{
method: http.Method, ShortHttpStatusLine,
target: []const u8, BadHttpVersion,
version: http.Version, UnknownHttpMethod,
content_length: ?u64 = null, HttpHeadersInvalid,
transfer_encoding: ?http.TransferEncoding = null, HttpHeaderContinuationsUnsupported,
transfer_compression: ?http.ContentEncoding = null, HttpTransferEncodingUnsupported,
connection: http.Connection = .close, HttpConnectionHeaderUnsupported,
host: ?[]const u8 = null, InvalidContentLength,
CompressionNotSupported,
};
pub const ParseError = error{ pub fn parse(req: *Request, bytes: []const u8) ParseError!void {
ShortHttpStatusLine, var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");
BadHttpVersion,
UnknownHttpMethod, const first_line = it.next() orelse return error.HttpHeadersInvalid;
HttpHeadersInvalid, if (first_line.len < 10)
HttpHeaderContinuationsUnsupported, return error.ShortHttpStatusLine;
HttpTransferEncodingUnsupported,
HttpConnectionHeaderUnsupported, const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid;
InvalidCharacter, const method_str = first_line[0..method_end];
const method = std.meta.stringToEnum(http.Method, method_str) orelse return error.UnknownHttpMethod;
const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid;
if (version_start == method_end) return error.HttpHeadersInvalid;
const version_str = first_line[version_start + 1 ..];
if (version_str.len != 8) return error.HttpHeadersInvalid;
const version: http.Version = switch (int64(version_str[0..8])) {
int64("HTTP/1.0") => .@"HTTP/1.0",
int64("HTTP/1.1") => .@"HTTP/1.1",
else => return error.BadHttpVersion,
}; };
pub fn parse(bytes: []const u8) !Headers { const target = first_line[method_end + 1 .. version_start];
var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");
const first_line = it.next() orelse return error.HttpHeadersInvalid; req.method = method;
if (first_line.len < 10) req.target = target;
return error.ShortHttpStatusLine; req.version = version;
const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; while (it.next()) |line| {
const method_str = first_line[0..method_end]; if (line.len == 0) return error.HttpHeadersInvalid;
const method = std.meta.stringToEnum(http.Method, method_str) orelse return error.UnknownHttpMethod; switch (line[0]) {
' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
else => {},
}
const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; var line_it = mem.tokenize(u8, line, ": ");
if (version_start == method_end) return error.HttpHeadersInvalid; const header_name = line_it.next() orelse return error.HttpHeadersInvalid;
const header_value = line_it.rest();
const version_str = first_line[version_start + 1 ..]; try req.headers.append(header_name, header_value);
if (version_str.len != 8) return error.HttpHeadersInvalid;
const version: http.Version = switch (int64(version_str[0..8])) {
int64("HTTP/1.0") => .@"HTTP/1.0",
int64("HTTP/1.1") => .@"HTTP/1.1",
else => return error.BadHttpVersion,
};
const target = first_line[method_end + 1 .. version_start]; if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
if (req.content_length != null) return error.HttpHeadersInvalid;
req.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
} else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
// Transfer-Encoding: second, first
// Transfer-Encoding: deflate, chunked
var iter = mem.splitBackwards(u8, header_value, ",");
var headers: Headers = .{ if (iter.next()) |first| {
.method = method, const trimmed = mem.trim(u8, first, " ");
.target = target,
.version = version,
};
while (it.next()) |line| { if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| {
if (line.len == 0) return error.HttpHeadersInvalid; if (req.transfer_encoding != null) return error.HttpHeadersInvalid;
switch (line[0]) { req.transfer_encoding = te;
' ', '\t' => return error.HttpHeaderContinuationsUnsupported, } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
else => {}, if (req.transfer_compression != null) return error.HttpHeadersInvalid;
} req.transfer_compression = ce;
var line_it = mem.tokenize(u8, line, ": ");
const header_name = line_it.next() orelse return error.HttpHeadersInvalid;
const header_value = line_it.rest();
if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
if (headers.content_length != null) return error.HttpHeadersInvalid;
headers.content_length = try std.fmt.parseInt(u64, header_value, 10);
} else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
// Transfer-Encoding: second, first
// Transfer-Encoding: deflate, chunked
var iter = mem.splitBackwards(u8, header_value, ",");
if (iter.next()) |first| {
const trimmed = mem.trim(u8, first, " ");
if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| {
if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
headers.transfer_encoding = te;
} else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
headers.transfer_compression = ce;
} else {
return error.HttpTransferEncodingUnsupported;
}
}
if (iter.next()) |second| {
if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
const trimmed = mem.trim(u8, second, " ");
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
headers.transfer_compression = ce;
} else {
return error.HttpTransferEncodingUnsupported;
}
}
if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
} else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
const trimmed = mem.trim(u8, header_value, " ");
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
headers.transfer_compression = ce;
} else { } else {
return error.HttpTransferEncodingUnsupported; return error.HttpTransferEncodingUnsupported;
} }
} else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { }
if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
headers.connection = .keep_alive; if (iter.next()) |second| {
} else if (std.ascii.eqlIgnoreCase(header_value, "close")) { if (req.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
headers.connection = .close;
const trimmed = mem.trim(u8, second, " ");
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
req.transfer_compression = ce;
} else { } else {
return error.HttpConnectionHeaderUnsupported; return error.HttpTransferEncodingUnsupported;
} }
} else if (std.ascii.eqlIgnoreCase(header_name, "host")) { }
headers.host = header_value;
if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
} else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
if (req.transfer_compression != null) return error.HttpHeadersInvalid;
const trimmed = mem.trim(u8, header_value, " ");
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
req.transfer_compression = ce;
} else {
return error.HttpTransferEncodingUnsupported;
} }
} }
return headers;
} }
}
inline fn int64(array: *const [8]u8) u64 { inline fn int64(array: *const [8]u8) u64 {
return @bitCast(u64, array.*); return @bitCast(u64, array.*);
} }
};
headers: Headers = undefined, method: http.Method,
target: []const u8,
version: http.Version,
content_length: ?u64 = null,
transfer_encoding: ?http.TransferEncoding = null,
transfer_compression: ?http.ContentEncoding = null,
headers: http.Headers = undefined,
parser: proto.HeadersParser, parser: proto.HeadersParser,
compression: Compression = .none, compression: Compression = .none,
}; };
@ -295,23 +323,17 @@ pub const Request = struct {
/// Order of operations: accept -> wait -> do [ -> write -> finish][ -> reset /] /// Order of operations: accept -> wait -> do [ -> write -> finish][ -> reset /]
/// \ -> read / /// \ -> read /
pub const Response = struct { pub const Response = struct {
pub const Headers = struct { version: http.Version = .@"HTTP/1.1",
version: http.Version = .@"HTTP/1.1", status: http.Status = .ok,
status: http.Status = .ok, reason: ?[]const u8 = null,
reason: ?[]const u8 = null,
server: ?[]const u8 = "zig (std.http)", transfer_encoding: ResponseTransfer = .none,
connection: http.Connection = .keep_alive,
transfer_encoding: RequestTransfer = .none,
custom: []const http.CustomHeader = &[_]http.CustomHeader{},
};
server: *Server, server: *Server,
address: net.Address, address: net.Address,
connection: BufferedConnection, connection: BufferedConnection,
headers: Headers = .{}, headers: http.Headers,
request: Request, request: Request,
/// Reset this response to its initial state. This must be called before handling a second request on the same connection. /// Reset this response to its initial state. This must be called before handling a second request on the same connection.
@ -341,46 +363,61 @@ pub const Response = struct {
} }
} }
pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength };
/// Send the response headers. /// Send the response headers.
pub fn do(res: *Response) !void { pub fn do(res: *Response) !void {
var buffered = std.io.bufferedWriter(res.connection.writer()); var buffered = std.io.bufferedWriter(res.connection.writer());
const w = buffered.writer(); const w = buffered.writer();
try w.writeAll(@tagName(res.headers.version)); try w.writeAll(@tagName(res.version));
try w.writeByte(' '); try w.writeByte(' ');
try w.print("{d}", .{@enumToInt(res.headers.status)}); try w.print("{d}", .{@enumToInt(res.status)});
try w.writeByte(' '); try w.writeByte(' ');
if (res.headers.reason) |reason| { if (res.reason) |reason| {
try w.writeAll(reason); try w.writeAll(reason);
} else if (res.headers.status.phrase()) |phrase| { } else if (res.status.phrase()) |phrase| {
try w.writeAll(phrase); try w.writeAll(phrase);
} }
try w.writeAll("\r\n");
if (res.headers.server) |server| { if (!res.headers.contains("server")) {
try w.writeAll("\r\nServer: "); try w.writeAll("Server: zig (std.http)\r\n");
try w.writeAll(server);
} }
if (res.headers.connection == .close) { if (!res.headers.contains("connection")) {
try w.writeAll("\r\nConnection: close"); try w.writeAll("Connection: keep-alive\r\n");
}
const has_transfer_encoding = res.headers.contains("transfer-encoding");
const has_content_length = res.headers.contains("content-length");
if (!has_transfer_encoding and !has_content_length) {
switch (res.transfer_encoding) {
.chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"),
.content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}),
.none => {},
}
} else { } else {
try w.writeAll("\r\nConnection: keep-alive"); if (has_content_length) {
const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
res.transfer_encoding = .{ .content_length = content_length };
} else if (has_transfer_encoding) {
const transfer_encoding = res.headers.getFirstValue("content-length").?;
if (std.mem.eql(u8, transfer_encoding, "chunked")) {
res.transfer_encoding = .chunked;
} else {
return error.UnsupportedTransferEncoding;
}
} else {
res.transfer_encoding = .none;
}
} }
switch (res.headers.transfer_encoding) { try w.print("{}", .{res.headers});
.chunked => try w.writeAll("\r\nTransfer-Encoding: chunked"),
.content_length => |content_length| try w.print("\r\nContent-Length: {d}", .{content_length}),
.none => {},
}
for (res.headers.custom) |header| { try w.writeAll("\r\n");
try w.writeAll("\r\n");
try w.writeAll(header.name);
try w.writeAll(": ");
try w.writeAll(header.value);
}
try w.writeAll("\r\n\r\n");
try buffered.flush(); try buffered.flush();
} }
@ -393,23 +430,23 @@ pub const Response = struct {
return .{ .context = res }; return .{ .context = res };
} }
pub fn transferRead(res: *Response, buf: []u8) TransferReadError!usize { fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
if (res.request.parser.isComplete()) return 0; if (res.request.parser.done) return 0;
var index: usize = 0; var index: usize = 0;
while (index == 0) { while (index == 0) {
const amt = try res.request.parser.read(&res.connection, buf[index..], false); const amt = try res.request.parser.read(&res.connection, buf[index..], false);
if (amt == 0 and res.request.parser.isComplete()) break; if (amt == 0 and res.request.parser.done) break;
index += amt; index += amt;
} }
return index; return index;
} }
pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported}; pub const WaitError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported };
/// Wait for the client to send a complete request head. /// Wait for the client to send a complete request head.
pub fn wait(res: *Response) !void { pub fn wait(res: *Response) WaitError!void {
while (true) { while (true) {
try res.connection.fill(); try res.connection.fill();
@ -419,22 +456,28 @@ pub const Response = struct {
if (res.request.parser.state.isContent()) break; if (res.request.parser.state.isContent()) break;
} }
res.request.headers = try Request.Headers.parse(res.request.parser.header_bytes.items); res.request.headers = .{ .allocator = res.server.allocator, .owned = true };
try res.request.parse(res.request.parser.header_bytes.items);
if (res.headers.connection == .keep_alive and res.request.headers.connection == .keep_alive) { const res_connection = res.headers.getFirstValue("connection");
const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
const req_connection = res.request.headers.getFirstValue("connection");
const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
if (res_keepalive and req_keepalive) {
res.connection.conn.closing = false; res.connection.conn.closing = false;
} else { } else {
res.connection.conn.closing = true; res.connection.conn.closing = true;
} }
if (res.request.headers.transfer_encoding) |te| { if (res.request.transfer_encoding) |te| {
switch (te) { switch (te) {
.chunked => { .chunked => {
res.request.parser.next_chunk_length = 0; res.request.parser.next_chunk_length = 0;
res.request.parser.state = .chunk_head_size; res.request.parser.state = .chunk_head_size;
}, },
} }
} else if (res.request.headers.content_length) |cl| { } else if (res.request.content_length) |cl| {
res.request.parser.next_chunk_length = cl; res.request.parser.next_chunk_length = cl;
if (cl == 0) res.request.parser.done = true; if (cl == 0) res.request.parser.done = true;
@ -443,13 +486,13 @@ pub const Response = struct {
} }
if (!res.request.parser.done) { if (!res.request.parser.done) {
if (res.request.headers.transfer_compression) |tc| switch (tc) { if (res.request.transfer_compression) |tc| switch (tc) {
.compress => return error.CompressionNotSupported, .compress => return error.CompressionNotSupported,
.deflate => res.request.compression = .{ .deflate => res.request.compression = .{
.deflate = try std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()), .deflate = std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed,
}, },
.gzip => res.request.compression = .{ .gzip => res.request.compression = .{
.gzip = try std.compress.gzip.decompress(res.server.allocator, res.transferReader()), .gzip = std.compress.gzip.decompress(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed,
}, },
.zstd => res.request.compression = .{ .zstd => res.request.compression = .{
.zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()), .zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()),
@ -458,7 +501,7 @@ pub const Response = struct {
} }
} }
pub const ReadError = Compression.DeflateDecompressor.Error || Compression.GzipDecompressor.Error || Compression.ZstdDecompressor.Error || WaitForCompleteHeadError; pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{DecompressionFailure};
pub const Reader = std.io.Reader(*Response, ReadError, read); pub const Reader = std.io.Reader(*Response, ReadError, read);
@ -467,12 +510,33 @@ pub const Response = struct {
} }
pub fn read(res: *Response, buffer: []u8) ReadError!usize { pub fn read(res: *Response, buffer: []u8) ReadError!usize {
return switch (res.request.compression) { const out_index = switch (res.request.compression) {
.deflate => |*deflate| try deflate.read(buffer), .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
.gzip => |*gzip| try gzip.read(buffer), .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
.zstd => |*zstd| try zstd.read(buffer), .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
else => try res.transferRead(buffer), else => try res.transferRead(buffer),
}; };
if (out_index == 0) {
const has_trail = !res.request.parser.state.isContent();
while (!res.request.parser.state.isContent()) { // read trailing headers
try res.connection.fill();
const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek());
res.connection.clear(@intCast(u16, nchecked));
}
if (has_trail) {
res.request.headers = http.Headers{ .allocator = res.server.allocator, .owned = false };
// The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error.
// This will *only* fail for a malformed trailer.
res.request.parse(res.request.parser.header_bytes.items) catch return error.InvalidTrailers;
}
}
return out_index;
} }
pub fn readAll(res: *Response, buffer: []u8) !usize { pub fn readAll(res: *Response, buffer: []u8) !usize {
@ -495,7 +559,7 @@ pub const Response = struct {
/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
pub fn write(res: *Response, bytes: []const u8) WriteError!usize { pub fn write(res: *Response, bytes: []const u8) WriteError!usize {
switch (res.headers.transfer_encoding) { switch (res.transfer_encoding) {
.chunked => { .chunked => {
try res.connection.writer().print("{x}\r\n", .{bytes.len}); try res.connection.writer().print("{x}\r\n", .{bytes.len});
try res.connection.writeAll(bytes); try res.connection.writeAll(bytes);
@ -514,9 +578,18 @@ pub const Response = struct {
} }
} }
pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
var index: usize = 0;
while (index < bytes.len) {
index += try write(req, bytes[index..]);
}
}
pub const FinishError = WriteError || error{MessageNotCompleted};
/// Finish the body of a request. This notifies the server that you have no more data to send. /// Finish the body of a request. This notifies the server that you have no more data to send.
pub fn finish(res: *Response) !void { pub fn finish(res: *Response) FinishError!void {
switch (res.headers.transfer_encoding) { switch (res.transfer_encoding) {
.chunked => try res.connection.writeAll("0\r\n\r\n"), .chunked => try res.connection.writeAll("0\r\n\r\n"),
.content_length => |len| if (len != 0) return error.MessageNotCompleted, .content_length => |len| if (len != 0) return error.MessageNotCompleted,
.none => {}, .none => {},
@ -524,25 +597,6 @@ pub const Response = struct {
} }
}; };
/// The mode of transport for responses.
pub const RequestTransfer = union(enum) {
content_length: u64,
chunked: void,
none: void,
};
/// The decompressor for request messages.
pub const Compression = union(enum) {
pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader);
pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader);
pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});
deflate: DeflateDecompressor,
gzip: GzipDecompressor,
zstd: ZstdDecompressor,
none: void,
};
pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server { pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server {
return .{ return .{
.allocator = allocator, .allocator = allocator,
@ -588,7 +642,11 @@ pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!*Response {
.stream = in.stream, .stream = in.stream,
.protocol = .plain, .protocol = .plain,
} }, } },
.headers = .{ .allocator = server.allocator },
.request = .{ .request = .{
.version = undefined,
.method = undefined,
.target = undefined,
.parser = switch (options) { .parser = switch (options) {
.dynamic => |max| proto.HeadersParser.initDynamic(max), .dynamic => |max| proto.HeadersParser.initDynamic(max),
.static => |buf| proto.HeadersParser.initStatic(buf), .static => |buf| proto.HeadersParser.initStatic(buf),

View file

@ -1,4 +1,4 @@
const std = @import("std"); const std = @import("../std.zig");
const testing = std.testing; const testing = std.testing;
const mem = std.mem; const mem = std.mem;

View file

@ -479,9 +479,14 @@ fn fetchAndUnpack(
}; };
defer tmp_directory.closeAndFree(gpa); defer tmp_directory.closeAndFree(gpa);
var req = try http_client.request(uri, .{}, .{}); var h = std.http.Headers{ .allocator = gpa };
defer h.deinit();
var req = try http_client.request(.GET, uri, h, .{});
defer req.deinit(); defer req.deinit();
try req.start();
try req.do(); try req.do();
if (mem.endsWith(u8, uri.path, ".tar.gz")) { if (mem.endsWith(u8, uri.path, ".tar.gz")) {