std.http: rework connection pool into its own type

This commit is contained in:
Nameless 2023-03-08 11:27:13 -06:00
parent 634e715504
commit 524e0cd987
No known key found for this signature in database
GPG key ID: A477BC03CAFCCAF7
4 changed files with 134 additions and 87 deletions

View file

@ -16,6 +16,9 @@ const testing = std.testing;
pub const Request = @import("Client/Request.zig");
pub const Response = @import("Client/Response.zig");
pub const default_connection_pool_size = 32;
const connection_pool_size = std.options.http_connection_pool_size;
/// Used for tcpConnectToHost and storing HTTP headers when an externally
/// managed buffer is not provided.
allocator: Allocator,
@ -24,39 +27,115 @@ ca_bundle: std.crypto.Certificate.Bundle = .{},
/// it will first rescan the system for root certificates.
next_https_rescan_certs: bool = true,
connection_mutex: std.Thread.Mutex = .{},
connection_pool: ConnectionPool = .{},
connection_used: ConnectionPool = .{},
pub const ConnectionPool = std.TailQueue(Connection);
pub const ConnectionNode = ConnectionPool.Node;
pub const ConnectionPool = struct {
pub const Criteria = struct {
host: []const u8,
port: u16,
is_tls: bool,
};
/// Acquires an existing connection from the connection pool. This function is threadsafe.
/// If the caller already holds the connection mutex, it should pass `true` for `held`.
pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void {
if (!held) client.connection_mutex.lock();
defer if (!held) client.connection_mutex.unlock();
const Queue = std.TailQueue(Connection);
pub const Node = Queue.Node;
client.connection_pool.remove(node);
client.connection_used.append(node);
}
mutex: std.Thread.Mutex = .{},
used: Queue = .{},
free: Queue = .{},
free_len: usize = 0,
free_size: usize = default_connection_pool_size,
/// Tries to release a connection back to the connection pool. This function is threadsafe.
/// If the connection is marked as closing, it will be closed instead.
pub fn release(client: *Client, node: *ConnectionNode) void {
client.connection_mutex.lock();
defer client.connection_mutex.unlock();
/// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
/// If no connection is found, null is returned.
pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node {
pool.mutex.lock();
defer pool.mutex.unlock();
client.connection_used.remove(node);
var next = pool.free.last;
while (next) |node| : (next = node.prev) {
if ((node.data.protocol == .tls) != criteria.is_tls) continue;
if (node.data.port != criteria.port) continue;
if (std.mem.eql(u8, node.data.host, criteria.host)) continue;
if (node.data.closing) {
node.data.close(client);
pool.acquireUnsafe(node);
return node;
}
return client.allocator.destroy(node);
return null;
}
client.connection_pool.append(node);
}
/// Acquires an existing connection from the connection pool. This function is not threadsafe.
pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void {
pool.free.remove(node);
pool.free_len -= 1;
pool.used.append(node);
}
/// Acquires an existing connection from the connection pool. This function is threadsafe.
pub fn acquire(pool: *ConnectionPool, node: *Node) void {
pool.mutex.lock();
defer pool.mutex.unlock();
return pool.acquireUnsafe(node);
}
/// Tries to release a connection back to the connection pool. This function is threadsafe.
/// If the connection is marked as closing, it will be closed instead.
pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void {
pool.mutex.lock();
defer pool.mutex.unlock();
pool.used.remove(node);
if (node.data.closing) {
node.data.close(client);
return client.allocator.destroy(node);
}
if (pool.free_len + 1 >= pool.free_size) {
const popped = pool.free.popFirst() orelse unreachable;
popped.data.close(client);
return client.allocator.destroy(popped);
}
pool.free.append(node);
pool.free_len += 1;
}
/// Adds a newly created node to the pool of used connections. This function is threadsafe.
pub fn addUsed(pool: *ConnectionPool, node: *Node) void {
pool.mutex.lock();
defer pool.mutex.unlock();
pool.used.append(node);
}
pub fn deinit(pool: *ConnectionPool, client: *Client) void {
pool.mutex.lock();
var next = pool.free.first;
while (next) |node| {
defer client.allocator.destroy(node);
next = node.next;
node.data.close(client);
}
next = pool.used.first;
while (next) |node| {
defer client.allocator.destroy(node);
next = node.next;
node.data.close(client);
}
pool.* = undefined;
}
};
pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw);
pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw);
@ -142,25 +221,7 @@ pub const Connection = struct {
};
pub fn deinit(client: *Client) void {
client.connection_mutex.lock();
var next = client.connection_pool.first;
while (next) |node| {
next = node.next;
node.data.close(client);
client.allocator.destroy(node);
}
next = client.connection_used.first;
while (next) |node| {
next = node.next;
node.data.close(client);
client.allocator.destroy(node);
}
client.connection_pool.deinit(client);
client.ca_bundle.deinit(client.allocator);
client.* = undefined;
@ -168,36 +229,25 @@ pub fn deinit(client: *Client) void {
pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode {
{ // Search through the connection pool for a potential connection.
client.connection_mutex.lock();
defer client.connection_mutex.unlock();
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
if (client.connection_pool.findConnection(.{
.host = host,
.port = port,
.is_tls = protocol == .tls,
})) |node|
return node;
var potential = client.connection_pool.last;
while (potential) |node| {
const same_host = mem.eql(u8, node.data.host, host);
const same_port = node.data.port == port;
const same_protocol = node.data.protocol == protocol;
if (same_host and same_port and same_protocol) {
client.acquire(node, true);
return node;
}
potential = node.prev;
}
}
const conn = try client.allocator.create(ConnectionNode);
const conn = try client.allocator.create(ConnectionPool.Node);
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
conn.* = .{ .data = .{
conn.data = .{
.stream = try net.tcpConnectToHost(client.allocator, host, port),
.tls_client = undefined,
.protocol = protocol,
.host = try client.allocator.dupe(u8, host),
.port = port,
} };
};
switch (protocol) {
.plain => {},
@ -210,12 +260,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
},
}
{
client.connection_mutex.lock();
defer client.connection_mutex.unlock();
client.connection_used.append(conn);
}
client.connection_pool.addUsed(conn);
return conn;
}
@ -247,8 +292,8 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
const host = uri.host orelse return error.UriMissingHost;
if (client.next_https_rescan_certs and protocol == .tls) {
client.connection_mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
defer client.connection_mutex.unlock();
client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
defer client.connection_pool.mutex.unlock();
if (client.next_https_rescan_certs) {
try client.ca_bundle.rescan(client.allocator);

View file

@ -6,7 +6,7 @@ const assert = std.debug.assert;
const Client = @import("../Client.zig");
const Connection = Client.Connection;
const ConnectionNode = Client.ConnectionNode;
const ConnectionNode = Client.ConnectionPool.Node;
const Response = @import("Response.zig");
const Request = @This();
@ -85,7 +85,7 @@ pub fn deinit(req: *Request) void {
if (!req.response.done) {
// If the response wasn't fully read, then we need to close the connection.
req.connection.data.closing = true;
req.client.release(req.connection);
req.client.connection_pool.release(req.client, req.connection);
}
req.arena.deinit();
@ -135,7 +135,7 @@ fn checkForCompleteHead(req: *Request, buffer: []u8) !usize {
if (req.response.state == .finished) {
req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
if (req.response.upgrade) |_| {
if (req.response.headers.upgrade) |_| {
req.connection.data.closing = false;
req.response.done = true;
return i;
@ -226,7 +226,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
req.response.next_chunk_length -= can_read;
if (req.response.next_chunk_length == 0) {
req.client.release(req.connection);
req.client.connection_pool.release(req.client, req.connection);
req.connection = undefined;
req.response.done = true;
}
@ -241,7 +241,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
req.read_buffer_start += @intCast(ReadBufferIndex, can_read);
if (req.response.next_chunk_length == 0) {
req.client.release(req.connection);
req.client.connection_pool.release(req.client, req.connection);
req.connection = undefined;
req.response.done = true;
}
@ -293,7 +293,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
.chunk_data => {
if (req.response.next_chunk_length == 0) {
req.response.done = true;
req.client.release(req.connection);
req.client.connection_pool.release(req.client, req.connection);
req.connection = undefined;
return out_index;
@ -317,7 +317,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
req.response.next_chunk_length -= can_read;
if (req.response.next_chunk_length == 0) {
req.client.release(req.connection);
req.client.connection_pool.release(req.client, req.connection);
req.connection = undefined;
req.response.done = true;
continue;
@ -345,13 +345,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
}
}
pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{
BadHeader,
InvalidCompression,
StreamTooLong,
InvalidWindowSize,
CompressionNotSupported
};
pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize, CompressionNotSupported };
pub const Reader = std.io.Reader(*Request, ReadError, read);

View file

@ -32,6 +32,7 @@ pub const Headers = struct {
transfer_encoding: ?http.TransferEncoding = null,
transfer_compression: ?http.ContentEncoding = null,
connection: http.Connection = .close,
upgrade: ?[]const u8 = null,
number_of_headers: usize = 0,
@ -93,7 +94,7 @@ pub const Headers = struct {
if (iter.next()) |second| {
if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
const trimmed = std.mem.trim(u8, second, " ");
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
@ -122,6 +123,8 @@ pub const Headers = struct {
} else {
return error.HttpConnectionHeaderUnsupported;
}
} else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) {
headers.upgrade = header_value;
}
}

View file

@ -185,6 +185,11 @@ pub const options = struct {
options_override.keep_sigpipe
else
false;
pub const http_connection_pool_size = if (@hasDecl(options_override, "http_connection_pool_size"))
options_override.http_connection_pool_size
else
http.Client.default_connection_pool_size;
};
// This forces the start.zig file to be imported, and the comptime logic inside that