zig/lib/std/http/protocol.zig
Nameless 08bdaf3bd6
std.http: add http server
* extract http protocol into protocol.zig, as it is shared between client and server
* coalesce Request and Response back into Client.zig, they don't contain
  any large chunks of code anymore
* http.Server is implemented as basic as possible, a simple example below:

```zig
fn handler(res: *Server.Response) !void {
    while (true) {
        defer res.reset();

        try res.waitForCompleteHead();
        res.headers.transfer_encoding = .{ .content_length = 14 };
        res.headers.connection = res.request.headers.connection;
        try res.sendResponseHead();
        _ = try res.write("Hello, World!\n");

        if (res.connection.closing) break;
    }
}

pub fn main() !void {
    var server = Server.init(std.heap.page_allocator, .{ .reuse_address = true });
    defer server.deinit();

    try server.listen(try net.Address.parseIp("127.0.0.1", 8080));

    while (true) {
        const res = try server.accept(.{ .dynamic = 8192 });

        const thread = try std.Thread.spawn(.{}, handler, .{res});
        thread.detach();
    }
}
```
2023-04-08 09:59:35 -05:00

714 lines
27 KiB
Zig

const std = @import("std");
const testing = std.testing;
const mem = std.mem;
const assert = std.debug.assert;
pub const State = enum {
/// Begin header parsing states.
invalid,
start,
seen_n,
seen_r,
seen_rn,
seen_rnr,
finished,
/// Begin transfer-encoding: chunked parsing states.
chunk_head_size,
chunk_head_ext,
chunk_head_r,
chunk_data,
chunk_data_suffix,
chunk_data_suffix_r,
pub fn isContent(self: State) bool {
return switch (self) {
.invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false,
.finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true,
};
}
};
const read_buffer_size = 0x4000;
const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size);
pub const HeadersParser = struct {
state: State = .start,
/// Wether or not `header_bytes` is allocated or was provided as a fixed buffer.
header_bytes_owned: bool,
/// Either a fixed buffer of len `max_header_bytes` or a dynamic buffer that can grow up to `max_header_bytes`.
/// Pointers into this buffer are not stable until after a message is complete.
header_bytes: std.ArrayListUnmanaged(u8),
/// The maximum allowed size of `header_bytes`.
max_header_bytes: usize,
next_chunk_length: u64 = 0,
/// Wether this parser is done parsing a complete message.
/// A message is only done when the entire payload has been read
done: bool = false,
read_buffer: [read_buffer_size]u8 = undefined,
read_buffer_start: ReadBufferIndex = 0,
read_buffer_len: ReadBufferIndex = 0,
pub fn initDynamic(max: usize) HeadersParser {
return .{
.header_bytes = .{},
.max_header_bytes = max,
.header_bytes_owned = true,
};
}
pub fn initStatic(buf: []u8) HeadersParser {
return .{
.header_bytes = .{ .items = buf[0..0], .capacity = buf.len },
.max_header_bytes = buf.len,
.header_bytes_owned = false,
};
}
pub fn reset(r: *HeadersParser) void {
r.header_bytes.clearRetainingCapacity();
r.* = .{
.header_bytes = r.header_bytes,
.max_header_bytes = r.max_header_bytes,
.header_bytes_owned = r.header_bytes_owned,
};
}
/// Returns how many bytes are part of HTTP headers. Always less than or
/// equal to bytes.len. If the amount returned is less than bytes.len, it
/// means the headers ended and the first byte after the double \r\n\r\n is
/// located at `bytes[result]`.
pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 {
const vector_len = 16;
const len = @truncate(u32, bytes.len);
var index: u32 = 0;
while (true) {
switch (r.state) {
.invalid => unreachable,
.finished => return index,
.start => switch (len - index) {
0 => return index,
1 => {
switch (bytes[index]) {
'\r' => r.state = .seen_r,
'\n' => r.state = .seen_n,
else => {},
}
return index + 1;
},
2 => {
const b16 = int16(bytes[index..][0..2]);
const b8 = intShift(u8, b16);
switch (b8) {
'\r' => r.state = .seen_r,
'\n' => r.state = .seen_n,
else => {},
}
switch (b16) {
int16("\r\n") => r.state = .seen_rn,
int16("\n\n") => r.state = .finished,
else => {},
}
return index + 2;
},
3 => {
const b24 = int24(bytes[index..][0..3]);
const b16 = intShift(u16, b24);
const b8 = intShift(u8, b24);
switch (b8) {
'\r' => r.state = .seen_r,
'\n' => r.state = .seen_n,
else => {},
}
switch (b16) {
int16("\r\n") => r.state = .seen_rn,
int16("\n\n") => r.state = .finished,
else => {},
}
switch (b24) {
int24("\r\n\r") => r.state = .seen_rnr,
else => {},
}
return index + 3;
},
4...vector_len - 1 => {
const b32 = int32(bytes[index..][0..4]);
const b24 = intShift(u24, b32);
const b16 = intShift(u16, b32);
const b8 = intShift(u8, b32);
switch (b8) {
'\r' => r.state = .seen_r,
'\n' => r.state = .seen_n,
else => {},
}
switch (b16) {
int16("\r\n") => r.state = .seen_rn,
int16("\n\n") => r.state = .finished,
else => {},
}
switch (b24) {
int24("\r\n\r") => r.state = .seen_rnr,
else => {},
}
switch (b32) {
int32("\r\n\r\n") => r.state = .finished,
else => {},
}
index += 4;
continue;
},
else => {
const Vector = @Vector(vector_len, u8);
// const BoolVector = @Vector(vector_len, bool);
const BitVector = @Vector(vector_len, u1);
const SizeVector = @Vector(vector_len, u8);
const chunk = bytes[index..][0..vector_len];
const v: Vector = chunk.*;
const matches_r = @bitCast(BitVector, v == @splat(vector_len, @as(u8, '\r')));
const matches_n = @bitCast(BitVector, v == @splat(vector_len, @as(u8, '\n')));
const matches_or: SizeVector = matches_r | matches_n;
const matches = @reduce(.Add, matches_or);
switch (matches) {
0 => {},
1 => switch (chunk[vector_len - 1]) {
'\r' => r.state = .seen_r,
'\n' => r.state = .seen_n,
else => {},
},
2 => {
const b16 = int16(chunk[vector_len - 2 ..][0..2]);
const b8 = intShift(u8, b16);
switch (b8) {
'\r' => r.state = .seen_r,
'\n' => r.state = .seen_n,
else => {},
}
switch (b16) {
int16("\r\n") => r.state = .seen_rn,
int16("\n\n") => r.state = .finished,
else => {},
}
},
3 => {
const b24 = int24(chunk[vector_len - 3 ..][0..3]);
const b16 = intShift(u16, b24);
const b8 = intShift(u8, b24);
switch (b8) {
'\r' => r.state = .seen_r,
'\n' => r.state = .seen_n,
else => {},
}
switch (b16) {
int16("\r\n") => r.state = .seen_rn,
int16("\n\n") => r.state = .finished,
else => {},
}
switch (b24) {
int24("\r\n\r") => r.state = .seen_rnr,
else => {},
}
},
4...vector_len - 1 => {
for (0..vector_len - 4) |i_usize| {
const i = @truncate(u32, i_usize);
const b32 = int32(chunk[i..][0..4]);
const b16 = intShift(u16, b32);
if (b32 == int32("\r\n\r\n")) {
r.state = .finished;
return index + i + 4;
} else if (b16 == int16("\n\n")) {
r.state = .finished;
return index + i + 2;
}
}
},
else => unreachable,
}
index += vector_len;
continue;
},
},
.seen_n => switch (len - index) {
0 => return index,
else => {
switch (bytes[index]) {
'\n' => r.state = .finished,
else => r.state = .start,
}
index += 1;
continue;
},
},
.seen_r => switch (len - index) {
0 => return index,
1 => {
switch (bytes[index]) {
'\n' => r.state = .seen_rn,
'\r' => r.state = .seen_r,
else => r.state = .start,
}
return index + 1;
},
2 => {
const b16 = int16(bytes[index..][0..2]);
const b8 = intShift(u8, b16);
switch (b8) {
'\r' => r.state = .seen_r,
'\n' => r.state = .seen_rn,
else => r.state = .start,
}
switch (b16) {
int16("\r\n") => r.state = .seen_rn,
int16("\n\n") => r.state = .finished,
else => {},
}
return index + 2;
},
else => {
const b24 = int24(bytes[index..][0..3]);
const b16 = intShift(u16, b24);
const b8 = intShift(u8, b24);
switch (b8) {
'\r' => r.state = .seen_r,
'\n' => r.state = .seen_n,
else => r.state = .start,
}
switch (b16) {
int16("\r\n") => r.state = .seen_rn,
int16("\n\n") => r.state = .finished,
else => {},
}
switch (b24) {
int24("\n\r\n") => r.state = .finished,
else => {},
}
index += 3;
continue;
},
},
.seen_rn => switch (len - index) {
0 => return index,
1 => {
switch (bytes[index]) {
'\r' => r.state = .seen_rnr,
'\n' => r.state = .seen_n,
else => r.state = .start,
}
return index + 1;
},
else => {
const b16 = int16(bytes[index..][0..2]);
const b8 = intShift(u8, b16);
switch (b8) {
'\r' => r.state = .seen_rnr,
'\n' => r.state = .seen_n,
else => r.state = .start,
}
switch (b16) {
int16("\r\n") => r.state = .finished,
int16("\n\n") => r.state = .finished,
else => {},
}
index += 2;
continue;
},
},
.seen_rnr => switch (len - index) {
0 => return index,
else => {
switch (bytes[index]) {
'\n' => r.state = .finished,
else => r.state = .start,
}
index += 1;
continue;
},
},
.chunk_head_size => unreachable,
.chunk_head_ext => unreachable,
.chunk_head_r => unreachable,
.chunk_data => unreachable,
.chunk_data_suffix => unreachable,
.chunk_data_suffix_r => unreachable,
}
return index;
}
}
pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 {
const len = @truncate(u32, bytes.len);
for (bytes[0..], 0..) |c, i| {
const index = @intCast(u32, i);
switch (r.state) {
.chunk_data_suffix => switch (c) {
'\r' => r.state = .chunk_data_suffix_r,
'\n' => r.state = .chunk_head_size,
else => {
r.state = .invalid;
return index;
},
},
.chunk_data_suffix_r => switch (c) {
'\n' => r.state = .chunk_head_size,
else => {
r.state = .invalid;
return index;
},
},
.chunk_head_size => {
const digit = switch (c) {
'0'...'9' => |b| b - '0',
'A'...'Z' => |b| b - 'A' + 10,
'a'...'z' => |b| b - 'a' + 10,
'\r' => {
r.state = .chunk_head_r;
continue;
},
'\n' => {
r.state = .chunk_data;
return index + 1;
},
else => {
r.state = .chunk_head_ext;
continue;
},
};
const new_len = r.next_chunk_length *% 16 +% digit;
if (new_len <= r.next_chunk_length and r.next_chunk_length != 0) {
r.state = .invalid;
return index;
}
r.next_chunk_length = new_len;
},
.chunk_head_ext => switch (c) {
'\r' => r.state = .chunk_head_r,
'\n' => {
r.state = .chunk_data;
return index + 1;
},
else => continue,
},
.chunk_head_r => switch (c) {
'\n' => {
r.state = .chunk_data;
return index + 1;
},
else => {
r.state = .invalid;
return index;
},
},
else => unreachable,
}
}
return len;
}
/// Returns whether or not the parser has finished parsing a complete message. A message is only complete after the
/// entire body has been read and any trailing headers have been parsed.
pub fn isComplete(r: *HeadersParser) bool {
return r.done and r.state == .finished;
}
pub const CheckCompleteHeadError = mem.Allocator.Error || error{HttpHeadersExceededSizeLimit};
/// Pumps `in` bytes into the parser. Returns the number of bytes consumed. This function will return 0 if the parser
/// is not in a state to parse more headers.
pub fn checkCompleteHead(r: *HeadersParser, allocator: std.mem.Allocator, in: []const u8) CheckCompleteHeadError!u32 {
if (r.state.isContent()) return 0;
const i = r.findHeadersEnd(in);
const data = in[0..i];
if (r.header_bytes.items.len + data.len > r.max_header_bytes) {
return error.HttpHeadersExceededSizeLimit;
} else {
if (r.header_bytes_owned) try r.header_bytes.ensureUnusedCapacity(allocator, data.len);
r.header_bytes.appendSliceAssumeCapacity(data);
}
return i;
}
/// Set of errors that `waitForCompleteHead` can throw except any errors inherited by `reader`
pub const WaitForCompleteHeadError = CheckCompleteHeadError || error{UnexpectedEndOfStream};
/// Waits for the complete head to be available. This function will continue trying to read until the head is complete
/// or an error occurs.
pub fn waitForCompleteHead(r: *HeadersParser, reader: anytype, allocator: std.mem.Allocator) !void {
if (r.state.isContent()) return;
while (true) {
if (r.read_buffer_start == r.read_buffer_len) {
const nread = try reader.read(r.read_buffer[0..]);
if (nread == 0) return error.UnexpectedEndOfStream;
r.read_buffer_start = 0;
r.read_buffer_len = @intCast(ReadBufferIndex, nread);
}
const amt = try r.checkCompleteHead(allocator, r.read_buffer[r.read_buffer_start..r.read_buffer_len]);
r.read_buffer_start += @intCast(ReadBufferIndex, amt);
if (amt != 0) return;
}
}
pub const ReadError = error{
UnexpectedEndOfStream,
HttpHeadersExceededSizeLimit,
HttpChunkInvalid,
};
/// Reads the body of the message into `buffer`. If `skip` is true, the buffer will be unused and the body will be
/// skipped. Returns the number of bytes placed in the buffer.
pub fn read(r: *HeadersParser, reader: anytype, buffer: []u8, skip: bool) !usize {
assert(r.state.isContent());
if (r.done) return 0;
if (r.read_buffer_start == r.read_buffer_len) {
const nread = try reader.read(r.read_buffer[0..]);
if (nread == 0) return error.UnexpectedEndOfStream;
r.read_buffer_start = 0;
r.read_buffer_len = @intCast(ReadBufferIndex, nread);
}
var out_index: usize = 0;
while (true) {
switch (r.state) {
.invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable,
.finished => {
const buf_avail = r.read_buffer_len - r.read_buffer_start;
const data_avail = r.next_chunk_length;
const out_avail = buffer.len;
// TODO https://github.com/ziglang/zig/issues/14039
const read_available = @intCast(usize, @min(buf_avail, data_avail));
if (skip) {
r.next_chunk_length -= read_available;
r.read_buffer_start += @intCast(ReadBufferIndex, read_available);
} else {
const can_read = @min(read_available, out_avail);
r.next_chunk_length -= can_read;
mem.copy(u8, buffer[out_index..], r.read_buffer[r.read_buffer_start..][0..can_read]);
r.read_buffer_start += @intCast(ReadBufferIndex, can_read);
out_index += can_read;
}
if (r.next_chunk_length == 0) r.done = true;
return out_index;
},
.chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => {
const i = r.findChunkedLen(r.read_buffer[r.read_buffer_start..r.read_buffer_len]);
r.read_buffer_start += @intCast(ReadBufferIndex, i);
switch (r.state) {
.invalid => return error.HttpChunkInvalid,
.chunk_data => if (r.next_chunk_length == 0) {
// The trailer section is formatted identically to the header section.
r.state = .seen_rn;
r.done = true;
return out_index;
},
else => return out_index,
}
continue;
},
.chunk_data => {
const buf_avail = r.read_buffer_len - r.read_buffer_start;
const data_avail = r.next_chunk_length;
const out_avail = buffer.len;
// TODO https://github.com/ziglang/zig/issues/14039
const read_available = @intCast(usize, @min(buf_avail, data_avail));
if (skip) {
r.next_chunk_length -= read_available;
r.read_buffer_start += @intCast(ReadBufferIndex, read_available);
} else {
const can_read = @min(read_available, out_avail);
r.next_chunk_length -= can_read;
mem.copy(u8, buffer[out_index..], r.read_buffer[r.read_buffer_start..][0..can_read]);
r.read_buffer_start += @intCast(ReadBufferIndex, can_read);
out_index += can_read;
}
if (r.next_chunk_length == 0) {
r.state = .chunk_data_suffix;
continue;
}
return out_index;
},
}
}
}
};
inline fn int16(array: *const [2]u8) u16 {
return @bitCast(u16, array.*);
}
inline fn int24(array: *const [3]u8) u24 {
return @bitCast(u24, array.*);
}
inline fn int32(array: *const [4]u8) u32 {
return @bitCast(u32, array.*);
}
inline fn intShift(comptime T: type, x: anytype) T {
switch (@import("builtin").cpu.arch.endian()) {
.Little => return @truncate(T, x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T))),
.Big => return @truncate(T, x),
}
}
test "HeadersParser.findHeadersEnd" {
var r: HeadersParser = undefined;
const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\nHello";
for (0..36) |i| {
r = HeadersParser.initDynamic(0);
try std.testing.expectEqual(@intCast(u32, i), r.findHeadersEnd(data[0..i]));
try std.testing.expectEqual(@intCast(u32, 35 - i), r.findHeadersEnd(data[i..]));
}
}
test "HeadersParser.findChunkedLen" {
var r: HeadersParser = undefined;
const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n";
r = HeadersParser.initDynamic(0);
r.state = .chunk_head_size;
r.next_chunk_length = 0;
const first = r.findChunkedLen(data[0..]);
try testing.expectEqual(@as(u32, 4), first);
try testing.expectEqual(@as(u64, 0xff), r.next_chunk_length);
try testing.expectEqual(State.chunk_data, r.state);
r.state = .chunk_head_size;
r.next_chunk_length = 0;
const second = r.findChunkedLen(data[first..]);
try testing.expectEqual(@as(u32, 13), second);
try testing.expectEqual(@as(u64, 0xf0f000), r.next_chunk_length);
try testing.expectEqual(State.chunk_data, r.state);
r.state = .chunk_head_size;
r.next_chunk_length = 0;
const third = r.findChunkedLen(data[first + second ..]);
try testing.expectEqual(@as(u32, 3), third);
try testing.expectEqual(@as(u64, 0), r.next_chunk_length);
try testing.expectEqual(State.chunk_data, r.state);
r.state = .chunk_head_size;
r.next_chunk_length = 0;
const fourth = r.findChunkedLen(data[first + second + third ..]);
try testing.expectEqual(@as(u32, 16), fourth);
try testing.expectEqual(@as(u64, 0xffffffffffffffff), r.next_chunk_length);
try testing.expectEqual(State.invalid, r.state);
}
test "HeadersParser.read length" {
var r = HeadersParser.initDynamic(256);
defer r.header_bytes.deinit(std.testing.allocator);
const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello";
var fbs = std.io.fixedBufferStream(data);
try r.waitForCompleteHead(fbs.reader(), std.testing.allocator);
var buf: [8]u8 = undefined;
r.next_chunk_length = 5;
const len = try r.read(fbs.reader(), &buf, false);
try std.testing.expectEqual(@as(usize, 5), len);
try std.testing.expectEqualStrings("Hello", buf[0..len]);
try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.header_bytes.items);
}
test "HeadersParser.read chunked" {
var r = HeadersParser.initDynamic(256);
defer r.header_bytes.deinit(std.testing.allocator);
const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n";
var fbs = std.io.fixedBufferStream(data);
try r.waitForCompleteHead(fbs.reader(), std.testing.allocator);
var buf: [8]u8 = undefined;
r.state = .chunk_head_size;
const len = try r.read(fbs.reader(), &buf, false);
try std.testing.expectEqual(@as(usize, 5), len);
try std.testing.expectEqualStrings("Hello", buf[0..len]);
try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.header_bytes.items);
}
test "HeadersParser.read chunked trailer" {
var r = HeadersParser.initDynamic(256);
defer r.header_bytes.deinit(std.testing.allocator);
const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n";
var fbs = std.io.fixedBufferStream(data);
try r.waitForCompleteHead(fbs.reader(), std.testing.allocator);
var buf: [8]u8 = undefined;
r.state = .chunk_head_size;
const len = try r.read(fbs.reader(), &buf, false);
try std.testing.expectEqual(@as(usize, 5), len);
try std.testing.expectEqualStrings("Hello", buf[0..len]);
try r.waitForCompleteHead(fbs.reader(), std.testing.allocator);
try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.header_bytes.items);
}