mirror of
https://codeberg.org/ziglang/zig.git
synced 2025-12-06 13:54:21 +00:00
1048 lines
38 KiB
Zig
1048 lines
38 KiB
Zig
//! Connecting and opening requests are threadsafe. Individual requests are not.
|
|
|
|
const std = @import("../std.zig");
|
|
const testing = std.testing;
|
|
const http = std.http;
|
|
const mem = std.mem;
|
|
const net = std.net;
|
|
const Uri = std.Uri;
|
|
const Allocator = mem.Allocator;
|
|
const assert = std.debug.assert;
|
|
|
|
const Client = @This();
|
|
const proto = @import("protocol.zig");
|
|
|
|
pub const default_connection_pool_size = 32;
|
|
pub const connection_pool_size = std.options.http_connection_pool_size;
|
|
|
|
allocator: Allocator,
|
|
ca_bundle: std.crypto.Certificate.Bundle = .{},
|
|
ca_bundle_mutex: std.Thread.Mutex = .{},
|
|
/// When this is `true`, the next time this client performs an HTTPS request,
|
|
/// it will first rescan the system for root certificates.
|
|
next_https_rescan_certs: bool = true,
|
|
|
|
/// The pool of connections that can be reused (and currently in use).
|
|
connection_pool: ConnectionPool = .{},
|
|
|
|
proxy: ?HttpProxy = null,
|
|
|
|
/// A set of linked lists of connections that can be reused.
|
|
pub const ConnectionPool = struct {
|
|
/// The criteria for a connection to be considered a match.
|
|
pub const Criteria = struct {
|
|
host: []const u8,
|
|
port: u16,
|
|
is_tls: bool,
|
|
};
|
|
|
|
pub const StoredConnection = struct {
|
|
buffered: BufferedConnection,
|
|
host: []u8,
|
|
port: u16,
|
|
|
|
proxied: bool = false,
|
|
closing: bool = false,
|
|
|
|
pub fn deinit(self: *StoredConnection, client: *Client) void {
|
|
self.buffered.close(client);
|
|
client.allocator.free(self.host);
|
|
}
|
|
};
|
|
|
|
const Queue = std.TailQueue(StoredConnection);
|
|
pub const Node = Queue.Node;
|
|
|
|
mutex: std.Thread.Mutex = .{},
|
|
/// Open connections that are currently in use.
|
|
used: Queue = .{},
|
|
/// Open connections that are not currently in use.
|
|
free: Queue = .{},
|
|
free_len: usize = 0,
|
|
free_size: usize = connection_pool_size,
|
|
|
|
/// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
|
|
/// If no connection is found, null is returned.
|
|
pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node {
|
|
pool.mutex.lock();
|
|
defer pool.mutex.unlock();
|
|
|
|
var next = pool.free.last;
|
|
while (next) |node| : (next = node.prev) {
|
|
if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue;
|
|
if (node.data.port != criteria.port) continue;
|
|
if (mem.eql(u8, node.data.host, criteria.host)) continue;
|
|
|
|
pool.acquireUnsafe(node);
|
|
return node;
|
|
}
|
|
|
|
return null;
|
|
}
|
|
|
|
/// Acquires an existing connection from the connection pool. This function is not threadsafe.
|
|
pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void {
|
|
pool.free.remove(node);
|
|
pool.free_len -= 1;
|
|
|
|
pool.used.append(node);
|
|
}
|
|
|
|
/// Acquires an existing connection from the connection pool. This function is threadsafe.
|
|
pub fn acquire(pool: *ConnectionPool, node: *Node) void {
|
|
pool.mutex.lock();
|
|
defer pool.mutex.unlock();
|
|
|
|
return pool.acquireUnsafe(node);
|
|
}
|
|
|
|
/// Tries to release a connection back to the connection pool. This function is threadsafe.
|
|
/// If the connection is marked as closing, it will be closed instead.
|
|
pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void {
|
|
pool.mutex.lock();
|
|
defer pool.mutex.unlock();
|
|
|
|
pool.used.remove(node);
|
|
|
|
if (node.data.closing) {
|
|
node.data.deinit(client);
|
|
|
|
return client.allocator.destroy(node);
|
|
}
|
|
|
|
if (pool.free_len + 1 >= pool.free_size) {
|
|
const popped = pool.free.popFirst() orelse unreachable;
|
|
|
|
popped.data.deinit(client);
|
|
|
|
return client.allocator.destroy(popped);
|
|
}
|
|
|
|
if (node.data.proxied) {
|
|
pool.free.prepend(node); // proxied connections go to the end of the queue, always try direct connections first
|
|
} else {
|
|
pool.free.append(node);
|
|
}
|
|
|
|
pool.free_len += 1;
|
|
}
|
|
|
|
/// Adds a newly created node to the pool of used connections. This function is threadsafe.
|
|
pub fn addUsed(pool: *ConnectionPool, node: *Node) void {
|
|
pool.mutex.lock();
|
|
defer pool.mutex.unlock();
|
|
|
|
pool.used.append(node);
|
|
}
|
|
|
|
pub fn deinit(pool: *ConnectionPool, client: *Client) void {
|
|
pool.mutex.lock();
|
|
|
|
var next = pool.free.first;
|
|
while (next) |node| {
|
|
defer client.allocator.destroy(node);
|
|
next = node.next;
|
|
|
|
node.data.deinit(client);
|
|
}
|
|
|
|
next = pool.used.first;
|
|
while (next) |node| {
|
|
defer client.allocator.destroy(node);
|
|
next = node.next;
|
|
|
|
node.data.deinit(client);
|
|
}
|
|
|
|
pool.* = undefined;
|
|
}
|
|
};
|
|
|
|
/// An interface to either a plain or TLS connection.
|
|
pub const Connection = struct {
|
|
stream: net.Stream,
|
|
/// undefined unless protocol is tls.
|
|
tls_client: *std.crypto.tls.Client,
|
|
protocol: Protocol,
|
|
|
|
pub const Protocol = enum { plain, tls };
|
|
|
|
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
|
|
return switch (conn.protocol) {
|
|
.plain => conn.stream.read(buffer),
|
|
.tls => conn.tls_client.read(conn.stream, buffer),
|
|
} catch |err| switch (err) {
|
|
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
|
|
error.TlsAlert => return error.TlsAlert,
|
|
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
|
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
|
else => return error.UnexpectedReadFailure,
|
|
};
|
|
}
|
|
|
|
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
|
|
return switch (conn.protocol) {
|
|
.plain => conn.stream.readAtLeast(buffer, len),
|
|
.tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
|
|
} catch |err| switch (err) {
|
|
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
|
|
error.TlsAlert => return error.TlsAlert,
|
|
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
|
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
|
else => return error.UnexpectedReadFailure,
|
|
};
|
|
}
|
|
|
|
pub const ReadError = error{
|
|
TlsFailure,
|
|
TlsAlert,
|
|
ConnectionTimedOut,
|
|
ConnectionResetByPeer,
|
|
UnexpectedReadFailure,
|
|
};
|
|
|
|
pub const Reader = std.io.Reader(*Connection, ReadError, read);
|
|
|
|
pub fn reader(conn: *Connection) Reader {
|
|
return Reader{ .context = conn };
|
|
}
|
|
|
|
pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
|
|
return switch (conn.protocol) {
|
|
.plain => conn.stream.writeAll(buffer),
|
|
.tls => conn.tls_client.writeAll(conn.stream, buffer),
|
|
} catch |err| switch (err) {
|
|
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
|
|
else => return error.UnexpectedWriteFailure,
|
|
};
|
|
}
|
|
|
|
pub fn write(conn: *Connection, buffer: []const u8) !usize {
|
|
return switch (conn.protocol) {
|
|
.plain => conn.stream.write(buffer),
|
|
.tls => conn.tls_client.write(conn.stream, buffer),
|
|
} catch |err| switch (err) {
|
|
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
|
|
else => return error.UnexpectedWriteFailure,
|
|
};
|
|
}
|
|
|
|
pub const WriteError = error{
|
|
ConnectionResetByPeer,
|
|
UnexpectedWriteFailure,
|
|
};
|
|
|
|
pub const Writer = std.io.Writer(*Connection, WriteError, write);
|
|
|
|
pub fn writer(conn: *Connection) Writer {
|
|
return Writer{ .context = conn };
|
|
}
|
|
|
|
pub fn close(conn: *Connection, client: *const Client) void {
|
|
if (conn.protocol == .tls) {
|
|
// try to cleanly close the TLS connection, for any server that cares.
|
|
_ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
|
|
client.allocator.destroy(conn.tls_client);
|
|
}
|
|
|
|
conn.stream.close();
|
|
}
|
|
};
|
|
|
|
/// A buffered (and peekable) Connection.
|
|
pub const BufferedConnection = struct {
|
|
pub const buffer_size = 0x2000;
|
|
|
|
conn: Connection,
|
|
buf: [buffer_size]u8 = undefined,
|
|
start: u16 = 0,
|
|
end: u16 = 0,
|
|
|
|
pub fn fill(bconn: *BufferedConnection) ReadError!void {
|
|
if (bconn.end != bconn.start) return;
|
|
|
|
const nread = try bconn.conn.read(bconn.buf[0..]);
|
|
if (nread == 0) return error.EndOfStream;
|
|
bconn.start = 0;
|
|
bconn.end = @truncate(u16, nread);
|
|
}
|
|
|
|
pub fn peek(bconn: *BufferedConnection) []const u8 {
|
|
return bconn.buf[bconn.start..bconn.end];
|
|
}
|
|
|
|
pub fn clear(bconn: *BufferedConnection, num: u16) void {
|
|
bconn.start += num;
|
|
}
|
|
|
|
pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
|
|
var out_index: u16 = 0;
|
|
while (out_index < len) {
|
|
const available = bconn.end - bconn.start;
|
|
const left = buffer.len - out_index;
|
|
|
|
if (available > 0) {
|
|
const can_read = @truncate(u16, @min(available, left));
|
|
|
|
std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]);
|
|
out_index += can_read;
|
|
bconn.start += can_read;
|
|
|
|
continue;
|
|
}
|
|
|
|
if (left > bconn.buf.len) {
|
|
// skip the buffer if the output is large enough
|
|
return bconn.conn.read(buffer[out_index..]);
|
|
}
|
|
|
|
try bconn.fill();
|
|
}
|
|
|
|
return out_index;
|
|
}
|
|
|
|
pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
|
|
return bconn.readAtLeast(buffer, 1);
|
|
}
|
|
|
|
pub const ReadError = Connection.ReadError || error{EndOfStream};
|
|
pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);
|
|
|
|
pub fn reader(bconn: *BufferedConnection) Reader {
|
|
return Reader{ .context = bconn };
|
|
}
|
|
|
|
pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
|
|
return bconn.conn.writeAll(buffer);
|
|
}
|
|
|
|
pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
|
|
return bconn.conn.write(buffer);
|
|
}
|
|
|
|
pub const WriteError = Connection.WriteError;
|
|
pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);
|
|
|
|
pub fn writer(bconn: *BufferedConnection) Writer {
|
|
return Writer{ .context = bconn };
|
|
}
|
|
|
|
pub fn close(bconn: *BufferedConnection, client: *const Client) void {
|
|
bconn.conn.close(client);
|
|
}
|
|
};
|
|
|
|
/// The mode of transport for requests.
|
|
pub const RequestTransfer = union(enum) {
|
|
content_length: u64,
|
|
chunked: void,
|
|
none: void,
|
|
};
|
|
|
|
/// The decompressor for response messages.
|
|
pub const Compression = union(enum) {
|
|
pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.TransferReader);
|
|
pub const GzipDecompressor = std.compress.gzip.Decompress(Request.TransferReader);
|
|
pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{});
|
|
|
|
deflate: DeflateDecompressor,
|
|
gzip: GzipDecompressor,
|
|
zstd: ZstdDecompressor,
|
|
none: void,
|
|
};
|
|
|
|
/// A HTTP response originating from a server.
|
|
pub const Response = struct {
|
|
pub const ParseError = Allocator.Error || error{
|
|
ShortHttpStatusLine,
|
|
BadHttpVersion,
|
|
HttpHeadersInvalid,
|
|
HttpHeaderContinuationsUnsupported,
|
|
HttpTransferEncodingUnsupported,
|
|
HttpConnectionHeaderUnsupported,
|
|
InvalidContentLength,
|
|
CompressionNotSupported,
|
|
};
|
|
|
|
pub fn parse(res: *Response, bytes: []const u8) ParseError!void {
|
|
var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");
|
|
|
|
const first_line = it.next() orelse return error.HttpHeadersInvalid;
|
|
if (first_line.len < 12)
|
|
return error.ShortHttpStatusLine;
|
|
|
|
const version: http.Version = switch (int64(first_line[0..8])) {
|
|
int64("HTTP/1.0") => .@"HTTP/1.0",
|
|
int64("HTTP/1.1") => .@"HTTP/1.1",
|
|
else => return error.BadHttpVersion,
|
|
};
|
|
if (first_line[8] != ' ') return error.HttpHeadersInvalid;
|
|
const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*));
|
|
const reason = mem.trimLeft(u8, first_line[12..], " ");
|
|
|
|
res.version = version;
|
|
res.status = status;
|
|
res.reason = reason;
|
|
|
|
while (it.next()) |line| {
|
|
if (line.len == 0) return error.HttpHeadersInvalid;
|
|
switch (line[0]) {
|
|
' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
|
|
else => {},
|
|
}
|
|
|
|
var line_it = mem.tokenize(u8, line, ": ");
|
|
const header_name = line_it.next() orelse return error.HttpHeadersInvalid;
|
|
const header_value = line_it.rest();
|
|
|
|
try res.headers.append(header_name, header_value);
|
|
|
|
if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
|
|
if (res.content_length != null) return error.HttpHeadersInvalid;
|
|
res.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
|
|
} else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
|
|
// Transfer-Encoding: second, first
|
|
// Transfer-Encoding: deflate, chunked
|
|
var iter = mem.splitBackwards(u8, header_value, ",");
|
|
|
|
if (iter.next()) |first| {
|
|
const trimmed = mem.trim(u8, first, " ");
|
|
|
|
if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| {
|
|
if (res.transfer_encoding != null) return error.HttpHeadersInvalid;
|
|
res.transfer_encoding = te;
|
|
} else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
|
|
if (res.transfer_compression != null) return error.HttpHeadersInvalid;
|
|
res.transfer_compression = ce;
|
|
} else {
|
|
return error.HttpTransferEncodingUnsupported;
|
|
}
|
|
}
|
|
|
|
if (iter.next()) |second| {
|
|
if (res.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
|
|
|
|
const trimmed = mem.trim(u8, second, " ");
|
|
|
|
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
|
|
res.transfer_compression = ce;
|
|
} else {
|
|
return error.HttpTransferEncodingUnsupported;
|
|
}
|
|
}
|
|
|
|
if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
|
|
} else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
|
|
if (res.transfer_compression != null) return error.HttpHeadersInvalid;
|
|
|
|
const trimmed = mem.trim(u8, header_value, " ");
|
|
|
|
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
|
|
res.transfer_compression = ce;
|
|
} else {
|
|
return error.HttpTransferEncodingUnsupported;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
inline fn int64(array: *const [8]u8) u64 {
|
|
return @bitCast(u64, array.*);
|
|
}
|
|
|
|
fn parseInt3(nnn: @Vector(3, u8)) u10 {
|
|
const zero: @Vector(3, u8) = .{ '0', '0', '0' };
|
|
const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
|
|
return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm);
|
|
}
|
|
|
|
test parseInt3 {
|
|
const expectEqual = testing.expectEqual;
|
|
try expectEqual(@as(u10, 0), parseInt3("000".*));
|
|
try expectEqual(@as(u10, 418), parseInt3("418".*));
|
|
try expectEqual(@as(u10, 999), parseInt3("999".*));
|
|
}
|
|
|
|
version: http.Version,
|
|
status: http.Status,
|
|
reason: []const u8,
|
|
|
|
content_length: ?u64 = null,
|
|
transfer_encoding: ?http.TransferEncoding = null,
|
|
transfer_compression: ?http.ContentEncoding = null,
|
|
|
|
headers: http.Headers,
|
|
parser: proto.HeadersParser,
|
|
compression: Compression = .none,
|
|
skip: bool = false,
|
|
};
|
|
|
|
/// A HTTP request that has been sent.
|
|
///
|
|
/// Order of operations: request[ -> write -> finish] -> do -> read
|
|
pub const Request = struct {
|
|
uri: Uri,
|
|
client: *Client,
|
|
connection: *ConnectionPool.Node,
|
|
|
|
method: http.Method,
|
|
version: http.Version = .@"HTTP/1.1",
|
|
headers: http.Headers,
|
|
transfer_encoding: RequestTransfer = .none,
|
|
|
|
redirects_left: u32,
|
|
handle_redirects: bool,
|
|
|
|
response: Response,
|
|
|
|
/// Used as a allocator for resolving redirects locations.
|
|
arena: std.heap.ArenaAllocator,
|
|
|
|
/// Frees all resources associated with the request.
|
|
pub fn deinit(req: *Request) void {
|
|
switch (req.response.compression) {
|
|
.none => {},
|
|
.deflate => |*deflate| deflate.deinit(),
|
|
.gzip => |*gzip| gzip.deinit(),
|
|
.zstd => |*zstd| zstd.deinit(),
|
|
}
|
|
|
|
if (req.response.parser.header_bytes_owned) {
|
|
req.response.headers.deinit();
|
|
req.response.parser.header_bytes.deinit(req.client.allocator);
|
|
}
|
|
|
|
if (!req.response.parser.done) {
|
|
// If the response wasn't fully read, then we need to close the connection.
|
|
req.connection.data.closing = true;
|
|
}
|
|
|
|
req.client.connection_pool.release(req.client, req.connection);
|
|
|
|
req.arena.deinit();
|
|
req.* = undefined;
|
|
}
|
|
|
|
pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
|
|
|
|
/// Send the request to the server.
|
|
pub fn start(req: *Request) StartError!void {
|
|
var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer());
|
|
const w = buffered.writer();
|
|
|
|
try w.writeAll(@tagName(req.method));
|
|
try w.writeByte(' ');
|
|
|
|
if (req.method == .CONNECT) {
|
|
try w.writeAll(req.uri.host.?);
|
|
try w.writeByte(':');
|
|
try w.print("{}", .{req.uri.port.?});
|
|
} else if (req.connection.data.proxied) {
|
|
// proxied connections require the full uri
|
|
try w.print("{+/}", .{req.uri});
|
|
} else {
|
|
try w.print("{/}", .{req.uri});
|
|
}
|
|
|
|
try w.writeByte(' ');
|
|
try w.writeAll(@tagName(req.version));
|
|
try w.writeAll("\r\n");
|
|
|
|
if (!req.headers.contains("host")) {
|
|
try w.writeAll("Host: ");
|
|
try w.writeAll(req.uri.host.?);
|
|
try w.writeAll("\r\n");
|
|
}
|
|
|
|
if (!req.headers.contains("user-agent")) {
|
|
try w.writeAll("User-Agent: zig/");
|
|
try w.writeAll(@import("builtin").zig_version_string);
|
|
try w.writeAll(" (std.http)\r\n");
|
|
}
|
|
|
|
if (!req.headers.contains("connection")) {
|
|
try w.writeAll("Connection: keep-alive\r\n");
|
|
}
|
|
|
|
if (!req.headers.contains("accept-encoding")) {
|
|
try w.writeAll("Accept-Encoding: gzip, deflate, zstd\r\n");
|
|
}
|
|
|
|
if (!req.headers.contains("te")) {
|
|
try w.writeAll("TE: gzip, deflate, trailers\r\n");
|
|
}
|
|
|
|
const has_transfer_encoding = req.headers.contains("transfer-encoding");
|
|
const has_content_length = req.headers.contains("content-length");
|
|
|
|
if (!has_transfer_encoding and !has_content_length) {
|
|
switch (req.transfer_encoding) {
|
|
.chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"),
|
|
.content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}),
|
|
.none => {},
|
|
}
|
|
} else {
|
|
if (has_content_length) {
|
|
const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
|
|
|
|
req.transfer_encoding = .{ .content_length = content_length };
|
|
} else if (has_transfer_encoding) {
|
|
const transfer_encoding = req.headers.getFirstValue("content-length").?;
|
|
if (std.mem.eql(u8, transfer_encoding, "chunked")) {
|
|
req.transfer_encoding = .chunked;
|
|
} else {
|
|
return error.UnsupportedTransferEncoding;
|
|
}
|
|
} else {
|
|
req.transfer_encoding = .none;
|
|
}
|
|
}
|
|
|
|
try w.print("{}", .{req.headers});
|
|
|
|
try w.writeAll("\r\n");
|
|
|
|
try buffered.flush();
|
|
}
|
|
|
|
pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;
|
|
|
|
pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
|
|
|
|
pub fn transferReader(req: *Request) TransferReader {
|
|
return .{ .context = req };
|
|
}
|
|
|
|
pub fn transferRead(req: *Request, buf: []u8) TransferReadError!usize {
|
|
if (req.response.parser.done) return 0;
|
|
|
|
var index: usize = 0;
|
|
while (index == 0) {
|
|
const amt = try req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip);
|
|
if (amt == 0 and req.response.parser.done) break;
|
|
index += amt;
|
|
}
|
|
|
|
return index;
|
|
}
|
|
|
|
pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported };
|
|
|
|
/// Waits for a response from the server and parses any headers that are sent.
|
|
/// This function will block until the final response is received.
|
|
///
|
|
/// If `handle_redirects` is true, then this function will automatically follow
|
|
/// redirects.
|
|
pub fn do(req: *Request) DoError!void {
|
|
while (true) { // handle redirects
|
|
while (true) { // read headers
|
|
try req.connection.data.buffered.fill();
|
|
|
|
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
|
|
req.connection.data.buffered.clear(@intCast(u16, nchecked));
|
|
|
|
if (req.response.parser.state.isContent()) break;
|
|
}
|
|
|
|
req.response.headers = http.Headers{ .allocator = req.client.allocator, .owned = false };
|
|
try req.response.parse(req.response.parser.header_bytes.items);
|
|
|
|
if (req.response.status == .switching_protocols) {
|
|
req.connection.data.closing = false;
|
|
req.response.parser.done = true;
|
|
}
|
|
|
|
if (req.method == .CONNECT and req.response.status == .ok) {
|
|
req.connection.data.closing = false;
|
|
req.connection.data.proxied = true;
|
|
req.response.parser.done = true;
|
|
}
|
|
|
|
const req_connection = req.headers.getFirstValue("connection");
|
|
const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
|
|
|
|
const res_connection = req.response.headers.getFirstValue("connection");
|
|
const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
|
|
if (req_keepalive and res_keepalive) {
|
|
req.connection.data.closing = false;
|
|
} else {
|
|
req.connection.data.closing = true;
|
|
}
|
|
|
|
if (req.response.transfer_encoding) |te| {
|
|
switch (te) {
|
|
.chunked => {
|
|
req.response.parser.next_chunk_length = 0;
|
|
req.response.parser.state = .chunk_head_size;
|
|
},
|
|
}
|
|
} else if (req.response.content_length) |cl| {
|
|
req.response.parser.next_chunk_length = cl;
|
|
|
|
if (cl == 0) req.response.parser.done = true;
|
|
} else {
|
|
req.response.parser.done = true;
|
|
}
|
|
|
|
if (req.response.status.class() == .redirect and req.handle_redirects) {
|
|
req.response.skip = true;
|
|
|
|
const empty = @as([*]u8, undefined)[0..0];
|
|
assert(try req.transferRead(empty) == 0); // we're skipping, no buffer is necessary
|
|
|
|
if (req.redirects_left == 0) return error.TooManyHttpRedirects;
|
|
|
|
const location = req.response.headers.getFirstValue("location") orelse
|
|
return error.HttpRedirectMissingLocation;
|
|
const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location);
|
|
|
|
var new_arena = std.heap.ArenaAllocator.init(req.client.allocator);
|
|
const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator());
|
|
errdefer new_arena.deinit();
|
|
|
|
req.arena.deinit();
|
|
req.arena = new_arena;
|
|
|
|
const new_req = try req.client.request(req.method, resolved_url, req.headers, .{
|
|
.version = req.version,
|
|
.max_redirects = req.redirects_left - 1,
|
|
.header_strategy = if (req.response.parser.header_bytes_owned) .{
|
|
.dynamic = req.response.parser.max_header_bytes,
|
|
} else .{
|
|
.static = req.response.parser.header_bytes.items.ptr[0..req.response.parser.max_header_bytes],
|
|
},
|
|
});
|
|
req.deinit();
|
|
req.* = new_req;
|
|
} else {
|
|
req.response.skip = false;
|
|
if (!req.response.parser.done) {
|
|
if (req.response.transfer_compression) |tc| switch (tc) {
|
|
.compress => return error.CompressionNotSupported,
|
|
.deflate => req.response.compression = .{
|
|
.deflate = std.compress.zlib.zlibStream(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed,
|
|
},
|
|
.gzip => req.response.compression = .{
|
|
.gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed,
|
|
},
|
|
.zstd => req.response.compression = .{
|
|
.zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()),
|
|
},
|
|
};
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers };
|
|
|
|
pub const Reader = std.io.Reader(*Request, ReadError, read);
|
|
|
|
pub fn reader(req: *Request) Reader {
|
|
return .{ .context = req };
|
|
}
|
|
|
|
/// Reads data from the response body. Must be called after `do`.
|
|
pub fn read(req: *Request, buffer: []u8) ReadError!usize {
|
|
const out_index = switch (req.response.compression) {
|
|
.deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
|
|
.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
|
|
.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
|
|
else => try req.transferRead(buffer),
|
|
};
|
|
|
|
if (out_index == 0) {
|
|
const has_trail = !req.response.parser.state.isContent();
|
|
|
|
while (!req.response.parser.state.isContent()) { // read trailing headers
|
|
try req.connection.data.buffered.fill();
|
|
|
|
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
|
|
req.connection.data.buffered.clear(@intCast(u16, nchecked));
|
|
}
|
|
|
|
if (has_trail) {
|
|
req.response.headers = http.Headers{ .allocator = req.client.allocator, .owned = false };
|
|
|
|
// The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error.
|
|
// This will *only* fail for a malformed trailer.
|
|
req.response.parse(req.response.parser.header_bytes.items) catch return error.InvalidTrailers;
|
|
}
|
|
}
|
|
|
|
return out_index;
|
|
}
|
|
|
|
/// Reads data from the response body. Must be called after `do`.
|
|
pub fn readAll(req: *Request, buffer: []u8) !usize {
|
|
var index: usize = 0;
|
|
while (index < buffer.len) {
|
|
const amt = try read(req, buffer[index..]);
|
|
if (amt == 0) break;
|
|
index += amt;
|
|
}
|
|
return index;
|
|
}
|
|
|
|
pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };
|
|
|
|
pub const Writer = std.io.Writer(*Request, WriteError, write);
|
|
|
|
pub fn writer(req: *Request) Writer {
|
|
return .{ .context = req };
|
|
}
|
|
|
|
/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
|
|
pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
|
|
switch (req.headers.transfer_encoding) {
|
|
.chunked => {
|
|
try req.connection.data.conn.writer().print("{x}\r\n", .{bytes.len});
|
|
try req.connection.data.conn.writeAll(bytes);
|
|
try req.connection.data.conn.writeAll("\r\n");
|
|
|
|
return bytes.len;
|
|
},
|
|
.content_length => |*len| {
|
|
if (len.* < bytes.len) return error.MessageTooLong;
|
|
|
|
const amt = try req.connection.data.conn.write(bytes);
|
|
len.* -= amt;
|
|
return amt;
|
|
},
|
|
.none => return error.NotWriteable,
|
|
}
|
|
}
|
|
|
|
pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
|
|
var index: usize = 0;
|
|
while (index < bytes.len) {
|
|
index += try write(req, bytes[index..]);
|
|
}
|
|
}
|
|
|
|
pub const FinishError = WriteError || error{MessageNotCompleted};
|
|
|
|
/// Finish the body of a request. This notifies the server that you have no more data to send.
|
|
pub fn finish(req: *Request) FinishError!void {
|
|
switch (req.transfer_encoding) {
|
|
.chunked => try req.connection.data.conn.writeAll("0\r\n\r\n"),
|
|
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
|
|
.none => {},
|
|
}
|
|
}
|
|
};
|
|
|
|
pub const HttpProxy = struct {
|
|
pub const ProxyAuthentication = union(enum) {
|
|
basic: []const u8,
|
|
custom: []const u8,
|
|
};
|
|
|
|
protocol: Connection.Protocol,
|
|
host: []const u8,
|
|
port: ?u16 = null,
|
|
|
|
/// The value for the Proxy-Authorization header.
|
|
auth: ?ProxyAuthentication = null,
|
|
};
|
|
|
|
/// Release all associated resources with the client.
|
|
/// TODO: currently leaks all request allocated data
|
|
pub fn deinit(client: *Client) void {
|
|
client.connection_pool.deinit(client);
|
|
|
|
client.ca_bundle.deinit(client.allocator);
|
|
client.* = undefined;
|
|
}
|
|
|
|
pub const ConnectUnproxiedError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
|
|
|
|
/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
|
|
/// This function is threadsafe.
|
|
pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectUnproxiedError!*ConnectionPool.Node {
|
|
if (client.connection_pool.findConnection(.{
|
|
.host = host,
|
|
.port = port,
|
|
.is_tls = protocol == .tls,
|
|
})) |node|
|
|
return node;
|
|
|
|
const conn = try client.allocator.create(ConnectionPool.Node);
|
|
errdefer client.allocator.destroy(conn);
|
|
conn.* = .{ .data = undefined };
|
|
|
|
const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) {
|
|
error.ConnectionRefused => return error.ConnectionRefused,
|
|
error.NetworkUnreachable => return error.NetworkUnreachable,
|
|
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
|
error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
|
|
error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure,
|
|
error.NameServerFailure => return error.NameServerFailure,
|
|
error.UnknownHostName => return error.UnknownHostName,
|
|
error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses,
|
|
else => return error.UnexpectedConnectFailure,
|
|
};
|
|
errdefer stream.close();
|
|
|
|
conn.data = .{
|
|
.buffered = .{ .conn = .{
|
|
.stream = stream,
|
|
.tls_client = undefined,
|
|
.protocol = protocol,
|
|
} },
|
|
.host = try client.allocator.dupe(u8, host),
|
|
.port = port,
|
|
};
|
|
errdefer client.allocator.free(conn.data.host);
|
|
|
|
switch (protocol) {
|
|
.plain => {},
|
|
.tls => {
|
|
conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client);
|
|
errdefer client.allocator.destroy(conn.data.buffered.conn.tls_client);
|
|
|
|
conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
|
|
// This is appropriate for HTTPS because the HTTP headers contain
|
|
// the content length which is used to detect truncation attacks.
|
|
conn.data.buffered.conn.tls_client.allow_truncation_attacks = true;
|
|
},
|
|
}
|
|
|
|
client.connection_pool.addUsed(conn);
|
|
|
|
return conn;
|
|
}
|
|
|
|
// Prevents a dependency loop in request()
|
|
const ConnectErrorPartial = ConnectUnproxiedError || error{ UnsupportedUrlScheme, ConnectionRefused };
|
|
pub const ConnectError = ConnectErrorPartial || RequestError;
|
|
|
|
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
|
|
if (client.connection_pool.findConnection(.{
|
|
.host = host,
|
|
.port = port,
|
|
.is_tls = protocol == .tls,
|
|
})) |node|
|
|
return node;
|
|
|
|
if (client.proxy) |proxy| {
|
|
const proxy_port: u16 = proxy.port orelse switch (proxy.protocol) {
|
|
.plain => 80,
|
|
.tls => 443,
|
|
};
|
|
|
|
const conn = try client.connectUnproxied(proxy.host, proxy_port, proxy.protocol);
|
|
conn.data.proxied = true;
|
|
|
|
return conn;
|
|
} else {
|
|
return client.connectUnproxied(host, port, protocol);
|
|
}
|
|
}
|
|
|
|
pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || BufferedConnection.WriteError || error{
|
|
UnsupportedUrlScheme,
|
|
UriMissingHost,
|
|
|
|
CertificateBundleLoadFailure,
|
|
UnsupportedTransferEncoding,
|
|
};
|
|
|
|
pub const Options = struct {
|
|
version: http.Version = .@"HTTP/1.1",
|
|
|
|
handle_redirects: bool = true,
|
|
max_redirects: u32 = 3,
|
|
header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
|
|
|
|
/// Must be an already acquired connection.
|
|
connection: ?*ConnectionPool.Node = null,
|
|
|
|
pub const HeaderStrategy = union(enum) {
|
|
/// In this case, the client's Allocator will be used to store the
|
|
/// entire HTTP header. This value is the maximum total size of
|
|
/// HTTP headers allowed, otherwise
|
|
/// error.HttpHeadersExceededSizeLimit is returned from read().
|
|
dynamic: usize,
|
|
/// This is used to store the entire HTTP header. If the HTTP
|
|
/// header is too big to fit, `error.HttpHeadersExceededSizeLimit`
|
|
/// is returned from read(). When this is used, `error.OutOfMemory`
|
|
/// cannot be returned from `read()`.
|
|
static: []u8,
|
|
};
|
|
};
|
|
|
|
pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{
|
|
.{ "http", .plain },
|
|
.{ "ws", .plain },
|
|
.{ "https", .tls },
|
|
.{ "wss", .tls },
|
|
});
|
|
|
|
/// Form and send a http request to a server.
|
|
/// This function is threadsafe.
|
|
pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: Options) RequestError!Request {
|
|
const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
|
|
|
|
const port: u16 = uri.port orelse switch (protocol) {
|
|
.plain => 80,
|
|
.tls => 443,
|
|
};
|
|
|
|
const host = uri.host orelse return error.UriMissingHost;
|
|
|
|
if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .Acquire)) {
|
|
client.ca_bundle_mutex.lock();
|
|
defer client.ca_bundle_mutex.unlock();
|
|
|
|
if (client.next_https_rescan_certs) {
|
|
client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure;
|
|
@atomicStore(bool, &client.next_https_rescan_certs, false, .Release);
|
|
}
|
|
}
|
|
|
|
const conn = options.connection orelse try client.connect(host, port, protocol);
|
|
|
|
var req: Request = .{
|
|
.uri = uri,
|
|
.client = client,
|
|
.connection = conn,
|
|
.headers = headers,
|
|
.method = method,
|
|
.version = options.version,
|
|
.redirects_left = options.max_redirects,
|
|
.handle_redirects = options.handle_redirects,
|
|
.response = .{
|
|
.status = undefined,
|
|
.reason = undefined,
|
|
.version = undefined,
|
|
.headers = undefined,
|
|
.parser = switch (options.header_strategy) {
|
|
.dynamic => |max| proto.HeadersParser.initDynamic(max),
|
|
.static => |buf| proto.HeadersParser.initStatic(buf),
|
|
},
|
|
},
|
|
.arena = undefined,
|
|
};
|
|
errdefer req.deinit();
|
|
|
|
req.arena = std.heap.ArenaAllocator.init(client.allocator);
|
|
|
|
return req;
|
|
}
|
|
|
|
test {
|
|
const builtin = @import("builtin");
|
|
const native_endian = comptime builtin.cpu.arch.endian();
|
|
if (builtin.zig_backend == .stage2_llvm and native_endian == .Big) {
|
|
// https://github.com/ziglang/zig/issues/13782
|
|
return error.SkipZigTest;
|
|
}
|
|
|
|
if (builtin.os.tag == .wasi) return error.SkipZigTest;
|
|
|
|
std.testing.refAllDecls(@This());
|
|
}
|