std.http: add TlsAlert descriptions so that they can at least be viewed in err return traces

This commit is contained in:
Nameless 2023-05-28 09:37:56 +02:00
parent 8136123aa7
commit 0e5e6cb10c
No known key found for this signature in database
GPG key ID: A477BC03CAFCCAF7
5 changed files with 141 additions and 68 deletions

View file

@ -138,6 +138,35 @@ pub const AlertLevel = enum(u8) {
}; };
pub const AlertDescription = enum(u8) { pub const AlertDescription = enum(u8) {
pub const Error = error{
TlsAlertUnexpectedMessage,
TlsAlertBadRecordMac,
TlsAlertRecordOverflow,
TlsAlertHandshakeFailure,
TlsAlertBadCertificate,
TlsAlertUnsupportedCertificate,
TlsAlertCertificateRevoked,
TlsAlertCertificateExpired,
TlsAlertCertificateUnknown,
TlsAlertIllegalParameter,
TlsAlertUnknownCa,
TlsAlertAccessDenied,
TlsAlertDecodeError,
TlsAlertDecryptError,
TlsAlertProtocolVersion,
TlsAlertInsufficientSecurity,
TlsAlertInternalError,
TlsAlertInappropriateFallback,
TlsAlertMissingExtension,
TlsAlertUnsupportedExtension,
TlsAlertUnrecognizedName,
TlsAlertBadCertificateStatusResponse,
TlsAlertUnknownPskIdentity,
TlsAlertCertificateRequired,
TlsAlertNoApplicationProtocol,
TlsAlertUnknown,
};
close_notify = 0, close_notify = 0,
unexpected_message = 10, unexpected_message = 10,
bad_record_mac = 20, bad_record_mac = 20,
@ -166,6 +195,39 @@ pub const AlertDescription = enum(u8) {
certificate_required = 116, certificate_required = 116,
no_application_protocol = 120, no_application_protocol = 120,
_, _,
pub fn toError(alert: AlertDescription) Error!void {
return switch (alert) {
.close_notify => {}, // not an error
.unexpected_message => error.TlsAlertUnexpectedMessage,
.bad_record_mac => error.TlsAlertBadRecordMac,
.record_overflow => error.TlsAlertRecordOverflow,
.handshake_failure => error.TlsAlertHandshakeFailure,
.bad_certificate => error.TlsAlertBadCertificate,
.unsupported_certificate => error.TlsAlertUnsupportedCertificate,
.certificate_revoked => error.TlsAlertCertificateRevoked,
.certificate_expired => error.TlsAlertCertificateExpired,
.certificate_unknown => error.TlsAlertCertificateUnknown,
.illegal_parameter => error.TlsAlertIllegalParameter,
.unknown_ca => error.TlsAlertUnknownCa,
.access_denied => error.TlsAlertAccessDenied,
.decode_error => error.TlsAlertDecodeError,
.decrypt_error => error.TlsAlertDecryptError,
.protocol_version => error.TlsAlertProtocolVersion,
.insufficient_security => error.TlsAlertInsufficientSecurity,
.internal_error => error.TlsAlertInternalError,
.inappropriate_fallback => error.TlsAlertInappropriateFallback,
.user_canceled => {}, // not an error
.missing_extension => error.TlsAlertMissingExtension,
.unsupported_extension => error.TlsAlertUnsupportedExtension,
.unrecognized_name => error.TlsAlertUnrecognizedName,
.bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse,
.unknown_psk_identity => error.TlsAlertUnknownPskIdentity,
.certificate_required => error.TlsAlertCertificateRequired,
.no_application_protocol => error.TlsAlertNoApplicationProtocol,
_ => error.TlsAlertUnknown,
};
}
}; };
pub const SignatureScheme = enum(u16) { pub const SignatureScheme = enum(u16) {

View file

@ -89,12 +89,11 @@ pub const StreamInterface = struct {
}; };
pub fn InitError(comptime Stream: type) type { pub fn InitError(comptime Stream: type) type {
return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error{ return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{
InsufficientEntropy, InsufficientEntropy,
DiskQuota, DiskQuota,
LockViolation, LockViolation,
NotOpenForWriting, NotOpenForWriting,
TlsAlert,
TlsUnexpectedMessage, TlsUnexpectedMessage,
TlsIllegalParameter, TlsIllegalParameter,
TlsDecryptFailure, TlsDecryptFailure,
@ -251,8 +250,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
const level = ptd.decode(tls.AlertLevel); const level = ptd.decode(tls.AlertLevel);
const desc = ptd.decode(tls.AlertDescription); const desc = ptd.decode(tls.AlertDescription);
_ = level; _ = level;
_ = desc;
return error.TlsAlert; // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake
try desc.toError();
// TODO: handle server-side closures
return error.TlsUnexpectedMessage;
}, },
.handshake => { .handshake => {
try ptd.ensure(4); try ptd.ensure(4);
@ -1071,8 +1073,10 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
const level = @intToEnum(tls.AlertLevel, frag[in]); const level = @intToEnum(tls.AlertLevel, frag[in]);
const desc = @intToEnum(tls.AlertDescription, frag[in + 1]); const desc = @intToEnum(tls.AlertDescription, frag[in + 1]);
_ = level; _ = level;
_ = desc;
return error.TlsAlert; try desc.toError();
// TODO: handle server-side closures
return error.TlsUnexpectedMessage;
}, },
.application_data => { .application_data => {
const cleartext = switch (c.application_cipher) { const cleartext = switch (c.application_cipher) {
@ -1112,7 +1116,10 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
return vp.total; return vp.total;
} }
_ = level; _ = level;
return error.TlsAlert;
try desc.toError();
// TODO: handle server-side closures
return error.TlsUnexpectedMessage;
}, },
.handshake => { .handshake => {
var ct_i: usize = 0; var ct_i: usize = 0;

View file

@ -168,19 +168,23 @@ pub const Connection = struct {
return switch (conn.protocol) { return switch (conn.protocol) {
.plain => conn.stream.readAtLeast(buffer, len), .plain => conn.stream.readAtLeast(buffer, len),
.tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
} catch |err| switch (err) { } catch |err| {
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, // TODO: https://github.com/ziglang/zig/issues/2473
error.TlsAlert => return error.TlsAlert, if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, switch (err) {
else => return error.UnexpectedReadFailure, error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
}
}; };
} }
pub fn fill(conn: *Connection) ReadError!void { pub fn fill(conn: *Connection) ReadError!void {
if (conn.read_end != conn.read_start) return; if (conn.read_end != conn.read_start) return;
const nread = try conn.conn.read(conn.read_buf[0..]); const nread = try conn.read(conn.read_buf[0..]);
if (nread == 0) return error.EndOfStream; if (nread == 0) return error.EndOfStream;
conn.read_start = 0; conn.read_start = 0;
conn.read_end = @intCast(u16, nread); conn.read_end = @intCast(u16, nread);
@ -204,8 +208,8 @@ pub const Connection = struct {
if (available_read > available_buffer) { // partially read buffered data if (available_read > available_buffer) { // partially read buffered data
@memcpy(buffer[out_index..], conn.read_buf[conn.read_start..][0..available_buffer]); @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..][0..available_buffer]);
out_index += available_buffer; out_index += @intCast(u16, available_buffer);
conn.read_start += available_buffer; conn.read_start += @intCast(u16, available_buffer);
break; break;
} else if (available_read > 0) { // fully read buffered data } else if (available_read > 0) { // fully read buffered data
@ -759,7 +763,7 @@ pub const Request = struct {
try req.connection.data.fill(); try req.connection.data.fill();
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek()); const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
req.connection.data.clear(@intCast(u16, nchecked)); req.connection.data.drop(@intCast(u16, nchecked));
} }
if (has_trail) { if (has_trail) {

View file

@ -118,7 +118,7 @@ pub const BufferedConnection = struct {
return bconn.read_buf[bconn.read_start..bconn.read_end]; return bconn.read_buf[bconn.read_start..bconn.read_end];
} }
pub fn clear(bconn: *BufferedConnection, num: u16) void { pub fn drop(bconn: *BufferedConnection, num: u16) void {
bconn.read_start += num; bconn.read_start += num;
} }
@ -545,7 +545,7 @@ pub const Response = struct {
try res.connection.fill(); try res.connection.fill();
const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek());
res.connection.clear(@intCast(u16, nchecked)); res.connection.drop(@intCast(u16, nchecked));
if (res.request.parser.state.isContent()) break; if (res.request.parser.state.isContent()) break;
} }
@ -612,7 +612,7 @@ pub const Response = struct {
try res.connection.fill(); try res.connection.fill();
const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek());
res.connection.clear(@intCast(u16, nchecked)); res.connection.drop(@intCast(u16, nchecked));
} }
if (has_trail) { if (has_trail) {

View file

@ -513,8 +513,8 @@ pub const HeadersParser = struct {
/// ///
/// If `skip` is true, the buffer will be unused and the body will be skipped. /// If `skip` is true, the buffer will be unused and the body will be skipped.
/// ///
/// See `std.http.Client.BufferedConnection for an example of `bconn`. /// See `std.http.Client.BufferedConnection for an example of `conn`.
pub fn read(r: *HeadersParser, bconn: anytype, buffer: []u8, skip: bool) !usize { pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize {
assert(r.state.isContent()); assert(r.state.isContent());
if (r.done) return 0; if (r.done) return 0;
@ -526,10 +526,10 @@ pub const HeadersParser = struct {
const data_avail = r.next_chunk_length; const data_avail = r.next_chunk_length;
if (skip) { if (skip) {
try bconn.fill(); try conn.fill();
const nread = @min(bconn.peek().len, data_avail); const nread = @min(conn.peek().len, data_avail);
bconn.clear(@intCast(u16, nread)); conn.drop(@intCast(u16, nread));
r.next_chunk_length -= nread; r.next_chunk_length -= nread;
if (r.next_chunk_length == 0) r.done = true; if (r.next_chunk_length == 0) r.done = true;
@ -539,7 +539,7 @@ pub const HeadersParser = struct {
const out_avail = buffer.len; const out_avail = buffer.len;
const can_read = @intCast(usize, @min(data_avail, out_avail)); const can_read = @intCast(usize, @min(data_avail, out_avail));
const nread = try bconn.read(buffer[0..can_read]); const nread = try conn.read(buffer[0..can_read]);
r.next_chunk_length -= nread; r.next_chunk_length -= nread;
if (r.next_chunk_length == 0) r.done = true; if (r.next_chunk_length == 0) r.done = true;
@ -548,15 +548,15 @@ pub const HeadersParser = struct {
} }
}, },
.chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => {
try bconn.fill(); try conn.fill();
const i = r.findChunkedLen(bconn.peek()); const i = r.findChunkedLen(conn.peek());
bconn.clear(@intCast(u16, i)); conn.drop(@intCast(u16, i));
switch (r.state) { switch (r.state) {
.invalid => return error.HttpChunkInvalid, .invalid => return error.HttpChunkInvalid,
.chunk_data => if (r.next_chunk_length == 0) { .chunk_data => if (r.next_chunk_length == 0) {
if (std.mem.eql(u8, bconn.peek(), "\r\n")) { if (std.mem.eql(u8, conn.peek(), "\r\n")) {
r.state = .finished; r.state = .finished;
} else { } else {
// The trailer section is formatted identically to the header section. // The trailer section is formatted identically to the header section.
@ -576,14 +576,14 @@ pub const HeadersParser = struct {
const out_avail = buffer.len - out_index; const out_avail = buffer.len - out_index;
if (skip) { if (skip) {
try bconn.fill(); try conn.fill();
const nread = @min(bconn.peek().len, data_avail); const nread = @min(conn.peek().len, data_avail);
bconn.clear(@intCast(u16, nread)); conn.drop(@intCast(u16, nread));
r.next_chunk_length -= nread; r.next_chunk_length -= nread;
} else { } else {
const can_read = @intCast(usize, @min(data_avail, out_avail)); const can_read = @intCast(usize, @min(data_avail, out_avail));
const nread = try bconn.read(buffer[out_index..][0..can_read]); const nread = try conn.read(buffer[out_index..][0..can_read]);
r.next_chunk_length -= nread; r.next_chunk_length -= nread;
out_index += nread; out_index += nread;
} }
@ -628,74 +628,74 @@ const MockBufferedConnection = struct {
start: u16 = 0, start: u16 = 0,
end: u16 = 0, end: u16 = 0,
pub fn fill(bconn: *MockBufferedConnection) ReadError!void { pub fn fill(conn: *MockBufferedConnection) ReadError!void {
if (bconn.end != bconn.start) return; if (conn.end != conn.start) return;
const nread = try bconn.conn.read(bconn.buf[0..]); const nread = try conn.conn.read(conn.buf[0..]);
if (nread == 0) return error.EndOfStream; if (nread == 0) return error.EndOfStream;
bconn.start = 0; conn.start = 0;
bconn.end = @truncate(u16, nread); conn.end = @truncate(u16, nread);
} }
pub fn peek(bconn: *MockBufferedConnection) []const u8 { pub fn peek(conn: *MockBufferedConnection) []const u8 {
return bconn.buf[bconn.start..bconn.end]; return conn.buf[conn.start..conn.end];
} }
pub fn drop(conn: *MockBufferedConnection, num: u16) void { pub fn drop(conn: *MockBufferedConnection, num: u16) void {
conn.start += num; conn.start += num;
} }
pub fn readAtLeast(bconn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize {
var out_index: u16 = 0; var out_index: u16 = 0;
while (out_index < len) { while (out_index < len) {
const available = bconn.end - bconn.start; const available = conn.end - conn.start;
const left = buffer.len - out_index; const left = buffer.len - out_index;
if (available > 0) { if (available > 0) {
const can_read = @truncate(u16, @min(available, left)); const can_read = @truncate(u16, @min(available, left));
@memcpy(buffer[out_index..][0..can_read], bconn.buf[bconn.start..][0..can_read]); @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]);
out_index += can_read; out_index += can_read;
bconn.start += can_read; conn.start += can_read;
continue; continue;
} }
if (left > bconn.buf.len) { if (left > conn.buf.len) {
// skip the buffer if the output is large enough // skip the buffer if the output is large enough
return bconn.conn.read(buffer[out_index..]); return conn.conn.read(buffer[out_index..]);
} }
try bconn.fill(); try conn.fill();
} }
return out_index; return out_index;
} }
pub fn read(bconn: *MockBufferedConnection, buffer: []u8) ReadError!usize { pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize {
return bconn.readAtLeast(buffer, 1); return conn.readAtLeast(buffer, 1);
} }
pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream};
pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read); pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read);
pub fn reader(bconn: *MockBufferedConnection) Reader { pub fn reader(conn: *MockBufferedConnection) Reader {
return Reader{ .context = bconn }; return Reader{ .context = conn };
} }
pub fn writeAll(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!void { pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void {
return bconn.conn.writeAll(buffer); return conn.conn.writeAll(buffer);
} }
pub fn write(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize {
return bconn.conn.write(buffer); return conn.conn.write(buffer);
} }
pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError;
pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write); pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write);
pub fn writer(bconn: *MockBufferedConnection) Writer { pub fn writer(conn: *MockBufferedConnection) Writer {
return Writer{ .context = bconn }; return Writer{ .context = conn };
} }
}; };
@ -753,12 +753,12 @@ test "HeadersParser.read length" {
const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello";
var fbs = std.io.fixedBufferStream(data); var fbs = std.io.fixedBufferStream(data);
var bconn = MockBufferedConnection{ var conn = MockBufferedConnection{
.conn = fbs, .conn = fbs,
}; };
while (true) { // read headers while (true) { // read headers
try bconn.fill(); try conn.fill();
const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
conn.drop(@intCast(u16, nchecked)); conn.drop(@intCast(u16, nchecked));
@ -769,7 +769,7 @@ test "HeadersParser.read length" {
var buf: [8]u8 = undefined; var buf: [8]u8 = undefined;
r.next_chunk_length = 5; r.next_chunk_length = 5;
const len = try r.read(&bconn, &buf, false); const len = try r.read(&conn, &buf, false);
try std.testing.expectEqual(@as(usize, 5), len); try std.testing.expectEqual(@as(usize, 5), len);
try std.testing.expectEqualStrings("Hello", buf[0..len]); try std.testing.expectEqualStrings("Hello", buf[0..len]);
@ -784,12 +784,12 @@ test "HeadersParser.read chunked" {
const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n";
var fbs = std.io.fixedBufferStream(data); var fbs = std.io.fixedBufferStream(data);
var bconn = MockBufferedConnection{ var conn = MockBufferedConnection{
.conn = fbs, .conn = fbs,
}; };
while (true) { // read headers while (true) { // read headers
try bconn.fill(); try conn.fill();
const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
conn.drop(@intCast(u16, nchecked)); conn.drop(@intCast(u16, nchecked));
@ -799,7 +799,7 @@ test "HeadersParser.read chunked" {
var buf: [8]u8 = undefined; var buf: [8]u8 = undefined;
r.state = .chunk_head_size; r.state = .chunk_head_size;
const len = try r.read(&bconn, &buf, false); const len = try r.read(&conn, &buf, false);
try std.testing.expectEqual(@as(usize, 5), len); try std.testing.expectEqual(@as(usize, 5), len);
try std.testing.expectEqualStrings("Hello", buf[0..len]); try std.testing.expectEqualStrings("Hello", buf[0..len]);
@ -814,12 +814,12 @@ test "HeadersParser.read chunked trailer" {
const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n";
var fbs = std.io.fixedBufferStream(data); var fbs = std.io.fixedBufferStream(data);
var bconn = MockBufferedConnection{ var conn = MockBufferedConnection{
.conn = fbs, .conn = fbs,
}; };
while (true) { // read headers while (true) { // read headers
try bconn.fill(); try conn.fill();
const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
conn.drop(@intCast(u16, nchecked)); conn.drop(@intCast(u16, nchecked));
@ -829,12 +829,12 @@ test "HeadersParser.read chunked trailer" {
var buf: [8]u8 = undefined; var buf: [8]u8 = undefined;
r.state = .chunk_head_size; r.state = .chunk_head_size;
const len = try r.read(&bconn, &buf, false); const len = try r.read(&conn, &buf, false);
try std.testing.expectEqual(@as(usize, 5), len); try std.testing.expectEqual(@as(usize, 5), len);
try std.testing.expectEqualStrings("Hello", buf[0..len]); try std.testing.expectEqualStrings("Hello", buf[0..len]);
while (true) { // read headers while (true) { // read headers
try bconn.fill(); try conn.fill();
const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
conn.drop(@intCast(u16, nchecked)); conn.drop(@intCast(u16, nchecked));