mirror of
https://codeberg.org/ziglang/zig.git
synced 2025-12-06 13:54:21 +00:00
std.http: rework connection pool into its own type
This commit is contained in:
parent
634e715504
commit
524e0cd987
4 changed files with 134 additions and 87 deletions
|
|
@ -16,6 +16,9 @@ const testing = std.testing;
|
|||
pub const Request = @import("Client/Request.zig");
|
||||
pub const Response = @import("Client/Response.zig");
|
||||
|
||||
pub const default_connection_pool_size = 32;
|
||||
const connection_pool_size = std.options.http_connection_pool_size;
|
||||
|
||||
/// Used for tcpConnectToHost and storing HTTP headers when an externally
|
||||
/// managed buffer is not provided.
|
||||
allocator: Allocator,
|
||||
|
|
@ -24,39 +27,115 @@ ca_bundle: std.crypto.Certificate.Bundle = .{},
|
|||
/// it will first rescan the system for root certificates.
|
||||
next_https_rescan_certs: bool = true,
|
||||
|
||||
connection_mutex: std.Thread.Mutex = .{},
|
||||
connection_pool: ConnectionPool = .{},
|
||||
connection_used: ConnectionPool = .{},
|
||||
|
||||
pub const ConnectionPool = std.TailQueue(Connection);
|
||||
pub const ConnectionNode = ConnectionPool.Node;
|
||||
pub const ConnectionPool = struct {
|
||||
pub const Criteria = struct {
|
||||
host: []const u8,
|
||||
port: u16,
|
||||
is_tls: bool,
|
||||
};
|
||||
|
||||
/// Acquires an existing connection from the connection pool. This function is threadsafe.
|
||||
/// If the caller already holds the connection mutex, it should pass `true` for `held`.
|
||||
pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void {
|
||||
if (!held) client.connection_mutex.lock();
|
||||
defer if (!held) client.connection_mutex.unlock();
|
||||
const Queue = std.TailQueue(Connection);
|
||||
pub const Node = Queue.Node;
|
||||
|
||||
client.connection_pool.remove(node);
|
||||
client.connection_used.append(node);
|
||||
}
|
||||
mutex: std.Thread.Mutex = .{},
|
||||
used: Queue = .{},
|
||||
free: Queue = .{},
|
||||
free_len: usize = 0,
|
||||
free_size: usize = default_connection_pool_size,
|
||||
|
||||
/// Tries to release a connection back to the connection pool. This function is threadsafe.
|
||||
/// If the connection is marked as closing, it will be closed instead.
|
||||
pub fn release(client: *Client, node: *ConnectionNode) void {
|
||||
client.connection_mutex.lock();
|
||||
defer client.connection_mutex.unlock();
|
||||
/// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
|
||||
/// If no connection is found, null is returned.
|
||||
pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node {
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
|
||||
client.connection_used.remove(node);
|
||||
var next = pool.free.last;
|
||||
while (next) |node| : (next = node.prev) {
|
||||
if ((node.data.protocol == .tls) != criteria.is_tls) continue;
|
||||
if (node.data.port != criteria.port) continue;
|
||||
if (std.mem.eql(u8, node.data.host, criteria.host)) continue;
|
||||
|
||||
if (node.data.closing) {
|
||||
node.data.close(client);
|
||||
pool.acquireUnsafe(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
return client.allocator.destroy(node);
|
||||
return null;
|
||||
}
|
||||
|
||||
client.connection_pool.append(node);
|
||||
}
|
||||
/// Acquires an existing connection from the connection pool. This function is not threadsafe.
|
||||
pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void {
|
||||
pool.free.remove(node);
|
||||
pool.free_len -= 1;
|
||||
|
||||
pool.used.append(node);
|
||||
}
|
||||
|
||||
/// Acquires an existing connection from the connection pool. This function is threadsafe.
|
||||
pub fn acquire(pool: *ConnectionPool, node: *Node) void {
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
|
||||
return pool.acquireUnsafe(node);
|
||||
}
|
||||
|
||||
/// Tries to release a connection back to the connection pool. This function is threadsafe.
|
||||
/// If the connection is marked as closing, it will be closed instead.
|
||||
pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void {
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
|
||||
pool.used.remove(node);
|
||||
|
||||
if (node.data.closing) {
|
||||
node.data.close(client);
|
||||
|
||||
return client.allocator.destroy(node);
|
||||
}
|
||||
|
||||
if (pool.free_len + 1 >= pool.free_size) {
|
||||
const popped = pool.free.popFirst() orelse unreachable;
|
||||
|
||||
popped.data.close(client);
|
||||
|
||||
return client.allocator.destroy(popped);
|
||||
}
|
||||
|
||||
pool.free.append(node);
|
||||
pool.free_len += 1;
|
||||
}
|
||||
|
||||
/// Adds a newly created node to the pool of used connections. This function is threadsafe.
|
||||
pub fn addUsed(pool: *ConnectionPool, node: *Node) void {
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
|
||||
pool.used.append(node);
|
||||
}
|
||||
|
||||
pub fn deinit(pool: *ConnectionPool, client: *Client) void {
|
||||
pool.mutex.lock();
|
||||
|
||||
var next = pool.free.first;
|
||||
while (next) |node| {
|
||||
defer client.allocator.destroy(node);
|
||||
next = node.next;
|
||||
|
||||
node.data.close(client);
|
||||
}
|
||||
|
||||
next = pool.used.first;
|
||||
while (next) |node| {
|
||||
defer client.allocator.destroy(node);
|
||||
next = node.next;
|
||||
|
||||
node.data.close(client);
|
||||
}
|
||||
|
||||
pool.* = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw);
|
||||
pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw);
|
||||
|
|
@ -142,25 +221,7 @@ pub const Connection = struct {
|
|||
};
|
||||
|
||||
pub fn deinit(client: *Client) void {
|
||||
client.connection_mutex.lock();
|
||||
|
||||
var next = client.connection_pool.first;
|
||||
while (next) |node| {
|
||||
next = node.next;
|
||||
|
||||
node.data.close(client);
|
||||
|
||||
client.allocator.destroy(node);
|
||||
}
|
||||
|
||||
next = client.connection_used.first;
|
||||
while (next) |node| {
|
||||
next = node.next;
|
||||
|
||||
node.data.close(client);
|
||||
|
||||
client.allocator.destroy(node);
|
||||
}
|
||||
client.connection_pool.deinit(client);
|
||||
|
||||
client.ca_bundle.deinit(client.allocator);
|
||||
client.* = undefined;
|
||||
|
|
@ -168,36 +229,25 @@ pub fn deinit(client: *Client) void {
|
|||
|
||||
pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
|
||||
|
||||
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode {
|
||||
{ // Search through the connection pool for a potential connection.
|
||||
client.connection_mutex.lock();
|
||||
defer client.connection_mutex.unlock();
|
||||
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
|
||||
if (client.connection_pool.findConnection(.{
|
||||
.host = host,
|
||||
.port = port,
|
||||
.is_tls = protocol == .tls,
|
||||
})) |node|
|
||||
return node;
|
||||
|
||||
var potential = client.connection_pool.last;
|
||||
while (potential) |node| {
|
||||
const same_host = mem.eql(u8, node.data.host, host);
|
||||
const same_port = node.data.port == port;
|
||||
const same_protocol = node.data.protocol == protocol;
|
||||
|
||||
if (same_host and same_port and same_protocol) {
|
||||
client.acquire(node, true);
|
||||
return node;
|
||||
}
|
||||
|
||||
potential = node.prev;
|
||||
}
|
||||
}
|
||||
|
||||
const conn = try client.allocator.create(ConnectionNode);
|
||||
const conn = try client.allocator.create(ConnectionPool.Node);
|
||||
errdefer client.allocator.destroy(conn);
|
||||
conn.* = .{ .data = undefined };
|
||||
|
||||
conn.* = .{ .data = .{
|
||||
conn.data = .{
|
||||
.stream = try net.tcpConnectToHost(client.allocator, host, port),
|
||||
.tls_client = undefined,
|
||||
.protocol = protocol,
|
||||
.host = try client.allocator.dupe(u8, host),
|
||||
.port = port,
|
||||
} };
|
||||
};
|
||||
|
||||
switch (protocol) {
|
||||
.plain => {},
|
||||
|
|
@ -210,12 +260,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
|
|||
},
|
||||
}
|
||||
|
||||
{
|
||||
client.connection_mutex.lock();
|
||||
defer client.connection_mutex.unlock();
|
||||
|
||||
client.connection_used.append(conn);
|
||||
}
|
||||
client.connection_pool.addUsed(conn);
|
||||
|
||||
return conn;
|
||||
}
|
||||
|
|
@ -247,8 +292,8 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
|
|||
const host = uri.host orelse return error.UriMissingHost;
|
||||
|
||||
if (client.next_https_rescan_certs and protocol == .tls) {
|
||||
client.connection_mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
|
||||
defer client.connection_mutex.unlock();
|
||||
client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
|
||||
defer client.connection_pool.mutex.unlock();
|
||||
|
||||
if (client.next_https_rescan_certs) {
|
||||
try client.ca_bundle.rescan(client.allocator);
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ const assert = std.debug.assert;
|
|||
|
||||
const Client = @import("../Client.zig");
|
||||
const Connection = Client.Connection;
|
||||
const ConnectionNode = Client.ConnectionNode;
|
||||
const ConnectionNode = Client.ConnectionPool.Node;
|
||||
const Response = @import("Response.zig");
|
||||
|
||||
const Request = @This();
|
||||
|
|
@ -85,7 +85,7 @@ pub fn deinit(req: *Request) void {
|
|||
if (!req.response.done) {
|
||||
// If the response wasn't fully read, then we need to close the connection.
|
||||
req.connection.data.closing = true;
|
||||
req.client.release(req.connection);
|
||||
req.client.connection_pool.release(req.client, req.connection);
|
||||
}
|
||||
|
||||
req.arena.deinit();
|
||||
|
|
@ -135,7 +135,7 @@ fn checkForCompleteHead(req: *Request, buffer: []u8) !usize {
|
|||
if (req.response.state == .finished) {
|
||||
req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
|
||||
|
||||
if (req.response.upgrade) |_| {
|
||||
if (req.response.headers.upgrade) |_| {
|
||||
req.connection.data.closing = false;
|
||||
req.response.done = true;
|
||||
return i;
|
||||
|
|
@ -226,7 +226,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
|
|||
req.response.next_chunk_length -= can_read;
|
||||
|
||||
if (req.response.next_chunk_length == 0) {
|
||||
req.client.release(req.connection);
|
||||
req.client.connection_pool.release(req.client, req.connection);
|
||||
req.connection = undefined;
|
||||
req.response.done = true;
|
||||
}
|
||||
|
|
@ -241,7 +241,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
|
|||
req.read_buffer_start += @intCast(ReadBufferIndex, can_read);
|
||||
|
||||
if (req.response.next_chunk_length == 0) {
|
||||
req.client.release(req.connection);
|
||||
req.client.connection_pool.release(req.client, req.connection);
|
||||
req.connection = undefined;
|
||||
req.response.done = true;
|
||||
}
|
||||
|
|
@ -293,7 +293,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
|
|||
.chunk_data => {
|
||||
if (req.response.next_chunk_length == 0) {
|
||||
req.response.done = true;
|
||||
req.client.release(req.connection);
|
||||
req.client.connection_pool.release(req.client, req.connection);
|
||||
req.connection = undefined;
|
||||
|
||||
return out_index;
|
||||
|
|
@ -317,7 +317,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
|
|||
req.response.next_chunk_length -= can_read;
|
||||
|
||||
if (req.response.next_chunk_length == 0) {
|
||||
req.client.release(req.connection);
|
||||
req.client.connection_pool.release(req.client, req.connection);
|
||||
req.connection = undefined;
|
||||
req.response.done = true;
|
||||
continue;
|
||||
|
|
@ -345,13 +345,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
|
|||
}
|
||||
}
|
||||
|
||||
pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{
|
||||
BadHeader,
|
||||
InvalidCompression,
|
||||
StreamTooLong,
|
||||
InvalidWindowSize,
|
||||
CompressionNotSupported
|
||||
};
|
||||
pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize, CompressionNotSupported };
|
||||
|
||||
pub const Reader = std.io.Reader(*Request, ReadError, read);
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ pub const Headers = struct {
|
|||
transfer_encoding: ?http.TransferEncoding = null,
|
||||
transfer_compression: ?http.ContentEncoding = null,
|
||||
connection: http.Connection = .close,
|
||||
upgrade: ?[]const u8 = null,
|
||||
|
||||
number_of_headers: usize = 0,
|
||||
|
||||
|
|
@ -93,7 +94,7 @@ pub const Headers = struct {
|
|||
|
||||
if (iter.next()) |second| {
|
||||
if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
|
||||
|
||||
|
||||
const trimmed = std.mem.trim(u8, second, " ");
|
||||
|
||||
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
|
||||
|
|
@ -122,6 +123,8 @@ pub const Headers = struct {
|
|||
} else {
|
||||
return error.HttpConnectionHeaderUnsupported;
|
||||
}
|
||||
} else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) {
|
||||
headers.upgrade = header_value;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -185,6 +185,11 @@ pub const options = struct {
|
|||
options_override.keep_sigpipe
|
||||
else
|
||||
false;
|
||||
|
||||
pub const http_connection_pool_size = if (@hasDecl(options_override, "http_connection_pool_size"))
|
||||
options_override.http_connection_pool_size
|
||||
else
|
||||
http.Client.default_connection_pool_size;
|
||||
};
|
||||
|
||||
// This forces the start.zig file to be imported, and the comptime logic inside that
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue