Merge pull request #13980 from ziglang/std.net

networking: delete std.x; add std.crypto.tls and std.http.Client
This commit is contained in:
Andrew Kelley 2023-01-03 02:43:50 -05:00 committed by GitHub
commit c9ef277fa7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
38 changed files with 3917 additions and 3707 deletions

98
lib/std/Url.zig Normal file
View file

@ -0,0 +1,98 @@
scheme: []const u8,
host: []const u8,
path: []const u8,
port: ?u16,
/// TODO: redo this implementation according to RFC 1738. This code is only a
/// placeholder for now.
pub fn parse(s: []const u8) !Url {
var scheme_end: usize = 0;
var host_start: usize = 0;
var host_end: usize = 0;
var path_start: usize = 0;
var port_start: usize = 0;
var port_end: usize = 0;
var state: enum {
scheme,
scheme_slash1,
scheme_slash2,
host,
port,
path,
} = .scheme;
for (s) |b, i| switch (state) {
.scheme => switch (b) {
':' => {
state = .scheme_slash1;
scheme_end = i;
},
else => {},
},
.scheme_slash1 => switch (b) {
'/' => {
state = .scheme_slash2;
},
else => return error.InvalidUrl,
},
.scheme_slash2 => switch (b) {
'/' => {
state = .host;
host_start = i + 1;
},
else => return error.InvalidUrl,
},
.host => switch (b) {
':' => {
state = .port;
host_end = i;
port_start = i + 1;
},
'/' => {
state = .path;
host_end = i;
path_start = i;
},
else => {},
},
.port => switch (b) {
'/' => {
port_end = i;
state = .path;
path_start = i;
},
else => {},
},
.path => {},
};
const port_slice = s[port_start..port_end];
const port = if (port_slice.len == 0) null else try std.fmt.parseInt(u16, port_slice, 10);
return .{
.scheme = s[0..scheme_end],
.host = s[host_start..host_end],
.path = s[path_start..],
.port = port,
};
}
const Url = @This();
const std = @import("std.zig");
const testing = std.testing;
test "basic" {
const parsed = try parse("https://ziglang.org/download");
try testing.expectEqualStrings("https", parsed.scheme);
try testing.expectEqualStrings("ziglang.org", parsed.host);
try testing.expectEqualStrings("/download", parsed.path);
try testing.expectEqual(@as(?u16, null), parsed.port);
}
test "with port" {
const parsed = try parse("http://example:1337/");
try testing.expectEqualStrings("http", parsed.scheme);
try testing.expectEqualStrings("example", parsed.host);
try testing.expectEqualStrings("/", parsed.path);
try testing.expectEqual(@as(?u16, 1337), parsed.port);
}

View file

@ -206,7 +206,7 @@ pub extern "c" fn sendto(
dest_addr: ?*const c.sockaddr, dest_addr: ?*const c.sockaddr,
addrlen: c.socklen_t, addrlen: c.socklen_t,
) isize; ) isize;
pub extern "c" fn sendmsg(sockfd: c.fd_t, msg: *const std.x.os.Socket.Message, flags: c_int) isize; pub extern "c" fn sendmsg(sockfd: c.fd_t, msg: *const c.msghdr_const, flags: u32) isize;
pub extern "c" fn recv(sockfd: c.fd_t, arg1: ?*anyopaque, arg2: usize, arg3: c_int) isize; pub extern "c" fn recv(sockfd: c.fd_t, arg1: ?*anyopaque, arg2: usize, arg3: c_int) isize;
pub extern "c" fn recvfrom( pub extern "c" fn recvfrom(
@ -217,7 +217,7 @@ pub extern "c" fn recvfrom(
noalias src_addr: ?*c.sockaddr, noalias src_addr: ?*c.sockaddr,
noalias addrlen: ?*c.socklen_t, noalias addrlen: ?*c.socklen_t,
) isize; ) isize;
pub extern "c" fn recvmsg(sockfd: c.fd_t, msg: *std.x.os.Socket.Message, flags: c_int) isize; pub extern "c" fn recvmsg(sockfd: c.fd_t, msg: *c.msghdr, flags: u32) isize;
pub extern "c" fn kill(pid: c.pid_t, sig: c_int) c_int; pub extern "c" fn kill(pid: c.pid_t, sig: c_int) c_int;
pub extern "c" fn getdirentries(fd: c.fd_t, buf_ptr: [*]u8, nbytes: usize, basep: *i64) isize; pub extern "c" fn getdirentries(fd: c.fd_t, buf_ptr: [*]u8, nbytes: usize, basep: *i64) isize;

View file

@ -1007,7 +1007,16 @@ pub const sockaddr = extern struct {
data: [14]u8, data: [14]u8,
pub const SS_MAXSIZE = 128; pub const SS_MAXSIZE = 128;
pub const storage = std.x.os.Socket.Address.Native.Storage; pub const storage = extern struct {
len: u8 align(8),
family: sa_family_t,
padding: [126]u8 = undefined,
comptime {
assert(@sizeOf(storage) == SS_MAXSIZE);
assert(@alignOf(storage) == 8);
}
};
pub const in = extern struct { pub const in = extern struct {
len: u8 = @sizeOf(in), len: u8 = @sizeOf(in),
family: sa_family_t = AF.INET, family: sa_family_t = AF.INET,

View file

@ -1,5 +1,6 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("../std.zig"); const std = @import("../std.zig");
const assert = std.debug.assert;
const maxInt = std.math.maxInt; const maxInt = std.math.maxInt;
const iovec = std.os.iovec; const iovec = std.os.iovec;
@ -478,11 +479,20 @@ pub const CLOCK = struct {
pub const sockaddr = extern struct { pub const sockaddr = extern struct {
len: u8, len: u8,
family: u8, family: sa_family_t,
data: [14]u8, data: [14]u8,
pub const SS_MAXSIZE = 128; pub const SS_MAXSIZE = 128;
pub const storage = std.x.os.Socket.Address.Native.Storage; pub const storage = extern struct {
len: u8 align(8),
family: sa_family_t,
padding: [126]u8 = undefined,
comptime {
assert(@sizeOf(storage) == SS_MAXSIZE);
assert(@alignOf(storage) == 8);
}
};
pub const in = extern struct { pub const in = extern struct {
len: u8 = @sizeOf(in), len: u8 = @sizeOf(in),

View file

@ -1,4 +1,5 @@
const std = @import("../std.zig"); const std = @import("../std.zig");
const assert = std.debug.assert;
const builtin = @import("builtin"); const builtin = @import("builtin");
const maxInt = std.math.maxInt; const maxInt = std.math.maxInt;
const iovec = std.os.iovec; const iovec = std.os.iovec;
@ -404,7 +405,16 @@ pub const sockaddr = extern struct {
data: [14]u8, data: [14]u8,
pub const SS_MAXSIZE = 128; pub const SS_MAXSIZE = 128;
pub const storage = std.x.os.Socket.Address.Native.Storage; pub const storage = extern struct {
len: u8 align(8),
family: sa_family_t,
padding: [126]u8 = undefined,
comptime {
assert(@sizeOf(storage) == SS_MAXSIZE);
assert(@alignOf(storage) == 8);
}
};
pub const in = extern struct { pub const in = extern struct {
len: u8 = @sizeOf(in), len: u8 = @sizeOf(in),

View file

@ -1,4 +1,5 @@
const std = @import("../std.zig"); const std = @import("../std.zig");
const assert = std.debug.assert;
const builtin = @import("builtin"); const builtin = @import("builtin");
const maxInt = std.math.maxInt; const maxInt = std.math.maxInt;
const iovec = std.os.iovec; const iovec = std.os.iovec;
@ -339,7 +340,16 @@ pub const sockaddr = extern struct {
data: [14]u8, data: [14]u8,
pub const SS_MAXSIZE = 128; pub const SS_MAXSIZE = 128;
pub const storage = std.x.os.Socket.Address.Native.Storage; pub const storage = extern struct {
len: u8 align(8),
family: sa_family_t,
padding: [126]u8 = undefined,
comptime {
assert(@sizeOf(storage) == SS_MAXSIZE);
assert(@alignOf(storage) == 8);
}
};
pub const in = extern struct { pub const in = extern struct {
len: u8 = @sizeOf(in), len: u8 = @sizeOf(in),

View file

@ -1,4 +1,5 @@
const std = @import("../std.zig"); const std = @import("../std.zig");
const assert = std.debug.assert;
const builtin = @import("builtin"); const builtin = @import("builtin");
const maxInt = std.math.maxInt; const maxInt = std.math.maxInt;
const iovec = std.os.iovec; const iovec = std.os.iovec;
@ -481,7 +482,16 @@ pub const sockaddr = extern struct {
data: [14]u8, data: [14]u8,
pub const SS_MAXSIZE = 128; pub const SS_MAXSIZE = 128;
pub const storage = std.x.os.Socket.Address.Native.Storage; pub const storage = extern struct {
len: u8 align(8),
family: sa_family_t,
padding: [126]u8 = undefined,
comptime {
assert(@sizeOf(storage) == SS_MAXSIZE);
assert(@alignOf(storage) == 8);
}
};
pub const in = extern struct { pub const in = extern struct {
len: u8 = @sizeOf(in), len: u8 = @sizeOf(in),

View file

@ -1,4 +1,5 @@
const std = @import("../std.zig"); const std = @import("../std.zig");
const assert = std.debug.assert;
const maxInt = std.math.maxInt; const maxInt = std.math.maxInt;
const builtin = @import("builtin"); const builtin = @import("builtin");
const iovec = std.os.iovec; const iovec = std.os.iovec;
@ -372,7 +373,16 @@ pub const sockaddr = extern struct {
data: [14]u8, data: [14]u8,
pub const SS_MAXSIZE = 256; pub const SS_MAXSIZE = 256;
pub const storage = std.x.os.Socket.Address.Native.Storage; pub const storage = extern struct {
len: u8 align(8),
family: sa_family_t,
padding: [254]u8 = undefined,
comptime {
assert(@sizeOf(storage) == SS_MAXSIZE);
assert(@alignOf(storage) == 8);
}
};
pub const in = extern struct { pub const in = extern struct {
len: u8 = @sizeOf(in), len: u8 = @sizeOf(in),

View file

@ -1,4 +1,5 @@
const std = @import("../std.zig"); const std = @import("../std.zig");
const assert = std.debug.assert;
const builtin = @import("builtin"); const builtin = @import("builtin");
const maxInt = std.math.maxInt; const maxInt = std.math.maxInt;
const iovec = std.os.iovec; const iovec = std.os.iovec;
@ -435,7 +436,15 @@ pub const sockaddr = extern struct {
data: [14]u8, data: [14]u8,
pub const SS_MAXSIZE = 256; pub const SS_MAXSIZE = 256;
pub const storage = std.x.os.Socket.Address.Native.Storage; pub const storage = extern struct {
family: sa_family_t align(8),
padding: [254]u8 = undefined,
comptime {
assert(@sizeOf(storage) == SS_MAXSIZE);
assert(@alignOf(storage) == 8);
}
};
pub const in = extern struct { pub const in = extern struct {
family: sa_family_t = AF.INET, family: sa_family_t = AF.INET,

View file

@ -176,6 +176,9 @@ const std = @import("std.zig");
pub const errors = @import("crypto/errors.zig"); pub const errors = @import("crypto/errors.zig");
pub const tls = @import("crypto/tls.zig");
pub const Certificate = @import("crypto/Certificate.zig");
test { test {
_ = aead.aegis.Aegis128L; _ = aead.aegis.Aegis128L;
_ = aead.aegis.Aegis256; _ = aead.aegis.Aegis256;
@ -264,6 +267,8 @@ test {
_ = utils; _ = utils;
_ = random; _ = random;
_ = errors; _ = errors;
_ = tls;
_ = Certificate;
} }
test "CSPRNG" { test "CSPRNG" {

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,189 @@
//! A set of certificates. Typically pre-installed on every operating system,
//! these are "Certificate Authorities" used to validate SSL certificates.
//! This data structure stores certificates in DER-encoded form, all of them
//! concatenated together in the `bytes` array. The `map` field contains an
//! index from the DER-encoded subject name to the index of the containing
//! certificate within `bytes`.
/// The key is the contents slice of the subject.
map: std.HashMapUnmanaged(der.Element.Slice, u32, MapContext, std.hash_map.default_max_load_percentage) = .{},
bytes: std.ArrayListUnmanaged(u8) = .{},
pub const VerifyError = Certificate.Parsed.VerifyError || error{
CertificateIssuerNotFound,
};
pub fn verify(cb: Bundle, subject: Certificate.Parsed, now_sec: i64) VerifyError!void {
const bytes_index = cb.find(subject.issuer()) orelse return error.CertificateIssuerNotFound;
const issuer_cert: Certificate = .{
.buffer = cb.bytes.items,
.index = bytes_index,
};
// Every certificate in the bundle is pre-parsed before adding it, ensuring
// that parsing will succeed here.
const issuer = issuer_cert.parse() catch unreachable;
try subject.verify(issuer, now_sec);
}
/// The returned bytes become invalid after calling any of the rescan functions
/// or add functions.
pub fn find(cb: Bundle, subject_name: []const u8) ?u32 {
const Adapter = struct {
cb: Bundle,
pub fn hash(ctx: @This(), k: []const u8) u64 {
_ = ctx;
return std.hash_map.hashString(k);
}
pub fn eql(ctx: @This(), a: []const u8, b_key: der.Element.Slice) bool {
const b = ctx.cb.bytes.items[b_key.start..b_key.end];
return mem.eql(u8, a, b);
}
};
return cb.map.getAdapted(subject_name, Adapter{ .cb = cb });
}
pub fn deinit(cb: *Bundle, gpa: Allocator) void {
cb.map.deinit(gpa);
cb.bytes.deinit(gpa);
cb.* = undefined;
}
/// Clears the set of certificates and then scans the host operating system
/// file system standard locations for certificates.
/// For operating systems that do not have standard CA installations to be
/// found, this function clears the set of certificates.
pub fn rescan(cb: *Bundle, gpa: Allocator) !void {
switch (builtin.os.tag) {
.linux => return rescanLinux(cb, gpa),
.windows => {
// TODO
},
.macos => {
// TODO
},
else => {},
}
}
pub fn rescanLinux(cb: *Bundle, gpa: Allocator) !void {
var dir = fs.openIterableDirAbsolute("/etc/ssl/certs", .{}) catch |err| switch (err) {
error.FileNotFound => return,
else => |e| return e,
};
defer dir.close();
cb.bytes.clearRetainingCapacity();
cb.map.clearRetainingCapacity();
var it = dir.iterate();
while (try it.next()) |entry| {
switch (entry.kind) {
.File, .SymLink => {},
else => continue,
}
try addCertsFromFile(cb, gpa, dir.dir, entry.name);
}
cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len);
}
pub fn addCertsFromFile(
cb: *Bundle,
gpa: Allocator,
dir: fs.Dir,
sub_file_path: []const u8,
) !void {
var file = try dir.openFile(sub_file_path, .{});
defer file.close();
const size = try file.getEndPos();
// We borrow `bytes` as a temporary buffer for the base64-encoded data.
// This is possible by computing the decoded length and reserving the space
// for the decoded bytes first.
const decoded_size_upper_bound = size / 4 * 3;
const needed_capacity = std.math.cast(u32, decoded_size_upper_bound + size) orelse
return error.CertificateAuthorityBundleTooBig;
try cb.bytes.ensureUnusedCapacity(gpa, needed_capacity);
const end_reserved = @intCast(u32, cb.bytes.items.len + decoded_size_upper_bound);
const buffer = cb.bytes.allocatedSlice()[end_reserved..];
const end_index = try file.readAll(buffer);
const encoded_bytes = buffer[0..end_index];
const begin_marker = "-----BEGIN CERTIFICATE-----";
const end_marker = "-----END CERTIFICATE-----";
const now_sec = std.time.timestamp();
var start_index: usize = 0;
while (mem.indexOfPos(u8, encoded_bytes, start_index, begin_marker)) |begin_marker_start| {
const cert_start = begin_marker_start + begin_marker.len;
const cert_end = mem.indexOfPos(u8, encoded_bytes, cert_start, end_marker) orelse
return error.MissingEndCertificateMarker;
start_index = cert_end + end_marker.len;
const encoded_cert = mem.trim(u8, encoded_bytes[cert_start..cert_end], " \t\r\n");
const decoded_start = @intCast(u32, cb.bytes.items.len);
const dest_buf = cb.bytes.allocatedSlice()[decoded_start..];
cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert);
// Even though we could only partially parse the certificate to find
// the subject name, we pre-parse all of them to make sure and only
// include in the bundle ones that we know will parse. This way we can
// use `catch unreachable` later.
const parsed_cert = try Certificate.parse(.{
.buffer = cb.bytes.items,
.index = decoded_start,
});
if (now_sec > parsed_cert.validity.not_after) {
// Ignore expired cert.
cb.bytes.items.len = decoded_start;
continue;
}
const gop = try cb.map.getOrPutContext(gpa, parsed_cert.subject_slice, .{ .cb = cb });
if (gop.found_existing) {
cb.bytes.items.len = decoded_start;
} else {
gop.value_ptr.* = decoded_start;
}
}
}
const builtin = @import("builtin");
const std = @import("../../std.zig");
const fs = std.fs;
const mem = std.mem;
const crypto = std.crypto;
const Allocator = std.mem.Allocator;
const Certificate = std.crypto.Certificate;
const der = Certificate.der;
const Bundle = @This();
const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n");
const MapContext = struct {
cb: *const Bundle,
pub fn hash(ctx: MapContext, k: der.Element.Slice) u64 {
return std.hash_map.hashString(ctx.cb.bytes.items[k.start..k.end]);
}
pub fn eql(ctx: MapContext, a: der.Element.Slice, b: der.Element.Slice) bool {
const bytes = ctx.cb.bytes.items;
return mem.eql(
u8,
bytes[a.start..a.end],
bytes[b.start..b.end],
);
}
};
test "scan for OS-provided certificates" {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
var bundle: Bundle = .{};
defer bundle.deinit(std.testing.allocator);
try bundle.rescan(std.testing.allocator);
}

View file

@ -174,7 +174,7 @@ pub const Aegis128L = struct {
acc |= (computed_tag[j] ^ tag[j]); acc |= (computed_tag[j] ^ tag[j]);
} }
if (acc != 0) { if (acc != 0) {
mem.set(u8, m, 0xaa); @memset(m.ptr, undefined, m.len);
return error.AuthenticationFailed; return error.AuthenticationFailed;
} }
} }
@ -343,7 +343,7 @@ pub const Aegis256 = struct {
acc |= (computed_tag[j] ^ tag[j]); acc |= (computed_tag[j] ^ tag[j]);
} }
if (acc != 0) { if (acc != 0) {
mem.set(u8, m, 0xaa); @memset(m.ptr, undefined, m.len);
return error.AuthenticationFailed; return error.AuthenticationFailed;
} }
} }

View file

@ -91,7 +91,7 @@ fn AesGcm(comptime Aes: anytype) type {
acc |= (computed_tag[p] ^ tag[p]); acc |= (computed_tag[p] ^ tag[p]);
} }
if (acc != 0) { if (acc != 0) {
mem.set(u8, m, 0xaa); @memset(m.ptr, undefined, m.len);
return error.AuthenticationFailed; return error.AuthenticationFailed;
} }

View file

@ -142,6 +142,11 @@ fn Sha2x32(comptime params: Sha2Params32) type {
d.total_len += b.len; d.total_len += b.len;
} }
pub fn peek(d: Self) [digest_length]u8 {
var copy = d;
return copy.finalResult();
}
pub fn final(d: *Self, out: *[digest_length]u8) void { pub fn final(d: *Self, out: *[digest_length]u8) void {
// The buffer here will never be completely full. // The buffer here will never be completely full.
mem.set(u8, d.buf[d.buf_len..], 0); mem.set(u8, d.buf[d.buf_len..], 0);
@ -175,6 +180,12 @@ fn Sha2x32(comptime params: Sha2Params32) type {
} }
} }
pub fn finalResult(d: *Self) [digest_length]u8 {
var result: [digest_length]u8 = undefined;
d.final(&result);
return result;
}
const W = [64]u32{ const W = [64]u32{
0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5, 0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5,
0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174, 0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174,
@ -621,6 +632,11 @@ fn Sha2x64(comptime params: Sha2Params64) type {
d.total_len += b.len; d.total_len += b.len;
} }
pub fn peek(d: Self) [digest_length]u8 {
var copy = d;
return copy.finalResult();
}
pub fn final(d: *Self, out: *[digest_length]u8) void { pub fn final(d: *Self, out: *[digest_length]u8) void {
// The buffer here will never be completely full. // The buffer here will never be completely full.
mem.set(u8, d.buf[d.buf_len..], 0); mem.set(u8, d.buf[d.buf_len..], 0);
@ -654,6 +670,12 @@ fn Sha2x64(comptime params: Sha2Params64) type {
} }
} }
pub fn finalResult(d: *Self) [digest_length]u8 {
var result: [digest_length]u8 = undefined;
d.final(&result);
return result;
}
fn round(d: *Self, b: *const [128]u8) void { fn round(d: *Self, b: *const [128]u8) void {
var s: [80]u64 = undefined; var s: [80]u64 = undefined;

494
lib/std/crypto/tls.zig Normal file
View file

@ -0,0 +1,494 @@
//! Plaintext:
//! * type: ContentType
//! * legacy_record_version: u16 = 0x0303,
//! * length: u16,
//! - The length (in bytes) of the following TLSPlaintext.fragment. The
//! length MUST NOT exceed 2^14 bytes.
//! * fragment: opaque
//! - the data being transmitted
//!
//! Ciphertext
//! * ContentType opaque_type = application_data; /* 23 */
//! * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */
//! * uint16 length;
//! * opaque encrypted_record[TLSCiphertext.length];
//!
//! Handshake:
//! * type: HandshakeType
//! * length: u24
//! * data: opaque
//!
//! ServerHello:
//! * ProtocolVersion legacy_version = 0x0303;
//! * Random random;
//! * opaque legacy_session_id_echo<0..32>;
//! * CipherSuite cipher_suite;
//! * uint8 legacy_compression_method = 0;
//! * Extension extensions<6..2^16-1>;
//!
//! Extension:
//! * ExtensionType extension_type;
//! * opaque extension_data<0..2^16-1>;
const std = @import("../std.zig");
const Tls = @This();
const net = std.net;
const mem = std.mem;
const crypto = std.crypto;
const assert = std.debug.assert;
pub const Client = @import("tls/Client.zig");
pub const record_header_len = 5;
pub const max_ciphertext_len = (1 << 14) + 256;
pub const max_ciphertext_record_len = max_ciphertext_len + record_header_len;
pub const hello_retry_request_sequence = [32]u8{
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
};
pub const close_notify_alert = [_]u8{
@enumToInt(AlertLevel.warning),
@enumToInt(AlertDescription.close_notify),
};
pub const ProtocolVersion = enum(u16) {
tls_1_2 = 0x0303,
tls_1_3 = 0x0304,
_,
};
pub const ContentType = enum(u8) {
invalid = 0,
change_cipher_spec = 20,
alert = 21,
handshake = 22,
application_data = 23,
_,
};
pub const HandshakeType = enum(u8) {
client_hello = 1,
server_hello = 2,
new_session_ticket = 4,
end_of_early_data = 5,
encrypted_extensions = 8,
certificate = 11,
certificate_request = 13,
certificate_verify = 15,
finished = 20,
key_update = 24,
message_hash = 254,
_,
};
pub const ExtensionType = enum(u16) {
/// RFC 6066
server_name = 0,
/// RFC 6066
max_fragment_length = 1,
/// RFC 6066
status_request = 5,
/// RFC 8422, 7919
supported_groups = 10,
/// RFC 8446
signature_algorithms = 13,
/// RFC 5764
use_srtp = 14,
/// RFC 6520
heartbeat = 15,
/// RFC 7301
application_layer_protocol_negotiation = 16,
/// RFC 6962
signed_certificate_timestamp = 18,
/// RFC 7250
client_certificate_type = 19,
/// RFC 7250
server_certificate_type = 20,
/// RFC 7685
padding = 21,
/// RFC 8446
pre_shared_key = 41,
/// RFC 8446
early_data = 42,
/// RFC 8446
supported_versions = 43,
/// RFC 8446
cookie = 44,
/// RFC 8446
psk_key_exchange_modes = 45,
/// RFC 8446
certificate_authorities = 47,
/// RFC 8446
oid_filters = 48,
/// RFC 8446
post_handshake_auth = 49,
/// RFC 8446
signature_algorithms_cert = 50,
/// RFC 8446
key_share = 51,
_,
};
pub const AlertLevel = enum(u8) {
warning = 1,
fatal = 2,
_,
};
pub const AlertDescription = enum(u8) {
close_notify = 0,
unexpected_message = 10,
bad_record_mac = 20,
record_overflow = 22,
handshake_failure = 40,
bad_certificate = 42,
unsupported_certificate = 43,
certificate_revoked = 44,
certificate_expired = 45,
certificate_unknown = 46,
illegal_parameter = 47,
unknown_ca = 48,
access_denied = 49,
decode_error = 50,
decrypt_error = 51,
protocol_version = 70,
insufficient_security = 71,
internal_error = 80,
inappropriate_fallback = 86,
user_canceled = 90,
missing_extension = 109,
unsupported_extension = 110,
unrecognized_name = 112,
bad_certificate_status_response = 113,
unknown_psk_identity = 115,
certificate_required = 116,
no_application_protocol = 120,
_,
};
pub const SignatureScheme = enum(u16) {
// RSASSA-PKCS1-v1_5 algorithms
rsa_pkcs1_sha256 = 0x0401,
rsa_pkcs1_sha384 = 0x0501,
rsa_pkcs1_sha512 = 0x0601,
// ECDSA algorithms
ecdsa_secp256r1_sha256 = 0x0403,
ecdsa_secp384r1_sha384 = 0x0503,
ecdsa_secp521r1_sha512 = 0x0603,
// RSASSA-PSS algorithms with public key OID rsaEncryption
rsa_pss_rsae_sha256 = 0x0804,
rsa_pss_rsae_sha384 = 0x0805,
rsa_pss_rsae_sha512 = 0x0806,
// EdDSA algorithms
ed25519 = 0x0807,
ed448 = 0x0808,
// RSASSA-PSS algorithms with public key OID RSASSA-PSS
rsa_pss_pss_sha256 = 0x0809,
rsa_pss_pss_sha384 = 0x080a,
rsa_pss_pss_sha512 = 0x080b,
// Legacy algorithms
rsa_pkcs1_sha1 = 0x0201,
ecdsa_sha1 = 0x0203,
_,
};
pub const NamedGroup = enum(u16) {
// Elliptic Curve Groups (ECDHE)
secp256r1 = 0x0017,
secp384r1 = 0x0018,
secp521r1 = 0x0019,
x25519 = 0x001D,
x448 = 0x001E,
// Finite Field Groups (DHE)
ffdhe2048 = 0x0100,
ffdhe3072 = 0x0101,
ffdhe4096 = 0x0102,
ffdhe6144 = 0x0103,
ffdhe8192 = 0x0104,
_,
};
pub const CipherSuite = enum(u16) {
AES_128_GCM_SHA256 = 0x1301,
AES_256_GCM_SHA384 = 0x1302,
CHACHA20_POLY1305_SHA256 = 0x1303,
AES_128_CCM_SHA256 = 0x1304,
AES_128_CCM_8_SHA256 = 0x1305,
AEGIS_256_SHA384 = 0x1306,
AEGIS_128L_SHA256 = 0x1307,
_,
};
pub const CertificateType = enum(u8) {
X509 = 0,
RawPublicKey = 2,
_,
};
pub const KeyUpdateRequest = enum(u8) {
update_not_requested = 0,
update_requested = 1,
_,
};
pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type {
return struct {
pub const AEAD = AeadType;
pub const Hash = HashType;
pub const Hmac = crypto.auth.hmac.Hmac(Hash);
pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
handshake_secret: [Hkdf.prk_length]u8,
master_secret: [Hkdf.prk_length]u8,
client_handshake_key: [AEAD.key_length]u8,
server_handshake_key: [AEAD.key_length]u8,
client_finished_key: [Hmac.key_length]u8,
server_finished_key: [Hmac.key_length]u8,
client_handshake_iv: [AEAD.nonce_length]u8,
server_handshake_iv: [AEAD.nonce_length]u8,
transcript_hash: Hash,
};
}
pub const HandshakeCipher = union(enum) {
AES_128_GCM_SHA256: HandshakeCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256),
AES_256_GCM_SHA384: HandshakeCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384),
CHACHA20_POLY1305_SHA256: HandshakeCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256),
AEGIS_256_SHA384: HandshakeCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384),
AEGIS_128L_SHA256: HandshakeCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256),
};
pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type {
return struct {
pub const AEAD = AeadType;
pub const Hash = HashType;
pub const Hmac = crypto.auth.hmac.Hmac(Hash);
pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
client_secret: [Hash.digest_length]u8,
server_secret: [Hash.digest_length]u8,
client_key: [AEAD.key_length]u8,
server_key: [AEAD.key_length]u8,
client_iv: [AEAD.nonce_length]u8,
server_iv: [AEAD.nonce_length]u8,
};
}
/// Encryption parameters for application traffic.
pub const ApplicationCipher = union(enum) {
AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256),
AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384),
CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256),
AEGIS_256_SHA384: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384),
AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256),
};
pub fn hkdfExpandLabel(
comptime Hkdf: type,
key: [Hkdf.prk_length]u8,
label: []const u8,
context: []const u8,
comptime len: usize,
) [len]u8 {
const max_label_len = 255;
const max_context_len = 255;
const tls13 = "tls13 ";
var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined;
mem.writeIntBig(u16, buf[0..2], len);
buf[2] = @intCast(u8, tls13.len + label.len);
buf[3..][0..tls13.len].* = tls13.*;
var i: usize = 3 + tls13.len;
mem.copy(u8, buf[i..], label);
i += label.len;
buf[i] = @intCast(u8, context.len);
i += 1;
mem.copy(u8, buf[i..], context);
i += context.len;
var result: [len]u8 = undefined;
Hkdf.expand(&result, buf[0..i], key);
return result;
}
pub fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 {
var result: [Hash.digest_length]u8 = undefined;
Hash.hash(&.{}, &result, .{});
return result;
}
pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) [Hmac.mac_length]u8 {
var result: [Hmac.mac_length]u8 = undefined;
Hmac.create(&result, message, &key);
return result;
}
pub inline fn extension(comptime et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 {
return int2(@enumToInt(et)) ++ array(1, bytes);
}
pub inline fn array(comptime elem_size: comptime_int, bytes: anytype) [2 + bytes.len]u8 {
comptime assert(bytes.len % elem_size == 0);
return int2(bytes.len) ++ bytes;
}
pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeOf(E) * tags.len]u8 {
assert(@sizeOf(E) == 2);
var result: [tags.len * 2]u8 = undefined;
for (tags) |elem, i| {
result[i * 2] = @truncate(u8, @enumToInt(elem) >> 8);
result[i * 2 + 1] = @truncate(u8, @enumToInt(elem));
}
return array(2, result);
}
pub inline fn int2(x: u16) [2]u8 {
return .{
@truncate(u8, x >> 8),
@truncate(u8, x),
};
}
pub inline fn int3(x: u24) [3]u8 {
return .{
@truncate(u8, x >> 16),
@truncate(u8, x >> 8),
@truncate(u8, x),
};
}
/// An abstraction to ensure that protocol-parsing code does not perform an
/// out-of-bounds read.
pub const Decoder = struct {
buf: []u8,
/// Points to the next byte in buffer that will be decoded.
idx: usize = 0,
/// Up to this point in `buf` we have already checked that `cap` is greater than it.
our_end: usize = 0,
/// Beyond this point in `buf` is extra tag-along bytes beyond the amount we
/// requested with `readAtLeast`.
their_end: usize = 0,
/// Points to the end within buffer that has been filled. Beyond this point
/// in buf is undefined bytes.
cap: usize = 0,
/// Debug helper to prevent illegal calls to read functions.
disable_reads: bool = false,
pub fn fromTheirSlice(buf: []u8) Decoder {
return .{
.buf = buf,
.their_end = buf.len,
.cap = buf.len,
.disable_reads = true,
};
}
/// Use this function to increase `their_end`.
pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void {
assert(!d.disable_reads);
const existing_amt = d.cap - d.idx;
d.their_end = d.idx + their_amt;
if (their_amt <= existing_amt) return;
const request_amt = their_amt - existing_amt;
const dest = d.buf[d.cap..];
if (request_amt > dest.len) return error.TlsRecordOverflow;
const actual_amt = try stream.readAtLeast(dest, request_amt);
if (actual_amt < request_amt) return error.TlsConnectionTruncated;
d.cap += actual_amt;
}
/// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`.
/// Use when `our_amt` is calculated by us, not by them.
pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void {
assert(!d.disable_reads);
try readAtLeast(d, stream, our_amt);
d.our_end = d.idx + our_amt;
}
/// Use this function to increase `our_end`.
/// This should always be called with an amount provided by us, not them.
pub fn ensure(d: *Decoder, amt: usize) !void {
d.our_end = @max(d.idx + amt, d.our_end);
if (d.our_end > d.their_end) return error.TlsDecodeError;
}
/// Use this function to increase `idx`.
pub fn decode(d: *Decoder, comptime T: type) T {
switch (@typeInfo(T)) {
.Int => |info| switch (info.bits) {
8 => {
skip(d, 1);
return d.buf[d.idx - 1];
},
16 => {
skip(d, 2);
const b0: u16 = d.buf[d.idx - 2];
const b1: u16 = d.buf[d.idx - 1];
return (b0 << 8) | b1;
},
24 => {
skip(d, 3);
const b0: u24 = d.buf[d.idx - 3];
const b1: u24 = d.buf[d.idx - 2];
const b2: u24 = d.buf[d.idx - 1];
return (b0 << 16) | (b1 << 8) | b2;
},
else => @compileError("unsupported int type: " ++ @typeName(T)),
},
.Enum => |info| {
const int = d.decode(info.tag_type);
if (info.is_exhaustive) @compileError("exhaustive enum cannot be used");
return @intToEnum(T, int);
},
else => @compileError("unsupported type: " ++ @typeName(T)),
}
}
/// Use this function to increase `idx`.
pub fn array(d: *Decoder, comptime len: usize) *[len]u8 {
skip(d, len);
return d.buf[d.idx - len ..][0..len];
}
/// Use this function to increase `idx`.
pub fn slice(d: *Decoder, len: usize) []u8 {
skip(d, len);
return d.buf[d.idx - len ..][0..len];
}
/// Use this function to increase `idx`.
pub fn skip(d: *Decoder, amt: usize) void {
d.idx += amt;
assert(d.idx <= d.our_end); // insufficient ensured bytes
}
pub fn eof(d: Decoder) bool {
assert(d.our_end <= d.their_end);
assert(d.idx <= d.our_end);
return d.idx == d.their_end;
}
/// Provide the length they claim, and receive a sub-decoder specific to that slice.
/// The parent decoder is advanced to the end.
pub fn sub(d: *Decoder, their_len: usize) !Decoder {
const end = d.idx + their_len;
if (end > d.their_end) return error.TlsDecodeError;
const sub_buf = d.buf[d.idx..end];
d.idx = end;
d.our_end = end;
return fromTheirSlice(sub_buf);
}
pub fn rest(d: Decoder) []u8 {
return d.buf[d.idx..d.cap];
}
};

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,301 @@
pub const Client = @import("http/Client.zig");
/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton
/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH
pub const Method = enum {
GET,
HEAD,
POST,
PUT,
DELETE,
CONNECT,
OPTIONS,
TRACE,
PATCH,
/// Returns true if a request of this method is allowed to have a body
/// Actual behavior from servers may vary and should still be checked
pub fn requestHasBody(self: Method) bool {
return switch (self) {
.POST, .PUT, .PATCH => true,
.GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false,
};
}
/// Returns true if a response to this method is allowed to have a body
/// Actual behavior from clients may vary and should still be checked
pub fn responseHasBody(self: Method) bool {
return switch (self) {
.GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true,
.HEAD, .PUT, .TRACE => false,
};
}
/// An HTTP method is safe if it doesn't alter the state of the server.
/// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1
pub fn safe(self: Method) bool {
return switch (self) {
.GET, .HEAD, .OPTIONS, .TRACE => true,
.POST, .PUT, .DELETE, .CONNECT, .PATCH => false,
};
}
/// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state.
/// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2
pub fn idempotent(self: Method) bool {
return switch (self) {
.GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true,
.CONNECT, .POST, .PATCH => false,
};
}
/// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server.
/// https://developer.mozilla.org/en-US/docs/Glossary/cacheable
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3
pub fn cacheable(self: Method) bool {
return switch (self) {
.GET, .HEAD => true,
.POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false,
};
}
};
/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Status
pub const Status = enum(u10) {
@"continue" = 100, // RFC7231, Section 6.2.1
switching_protocols = 101, // RFC7231, Section 6.2.2
processing = 102, // RFC2518
early_hints = 103, // RFC8297
ok = 200, // RFC7231, Section 6.3.1
created = 201, // RFC7231, Section 6.3.2
accepted = 202, // RFC7231, Section 6.3.3
non_authoritative_info = 203, // RFC7231, Section 6.3.4
no_content = 204, // RFC7231, Section 6.3.5
reset_content = 205, // RFC7231, Section 6.3.6
partial_content = 206, // RFC7233, Section 4.1
multi_status = 207, // RFC4918
already_reported = 208, // RFC5842
im_used = 226, // RFC3229
multiple_choice = 300, // RFC7231, Section 6.4.1
moved_permanently = 301, // RFC7231, Section 6.4.2
found = 302, // RFC7231, Section 6.4.3
see_other = 303, // RFC7231, Section 6.4.4
not_modified = 304, // RFC7232, Section 4.1
use_proxy = 305, // RFC7231, Section 6.4.5
temporary_redirect = 307, // RFC7231, Section 6.4.7
permanent_redirect = 308, // RFC7538
bad_request = 400, // RFC7231, Section 6.5.1
unauthorized = 401, // RFC7235, Section 3.1
payment_required = 402, // RFC7231, Section 6.5.2
forbidden = 403, // RFC7231, Section 6.5.3
not_found = 404, // RFC7231, Section 6.5.4
method_not_allowed = 405, // RFC7231, Section 6.5.5
not_acceptable = 406, // RFC7231, Section 6.5.6
proxy_auth_required = 407, // RFC7235, Section 3.2
request_timeout = 408, // RFC7231, Section 6.5.7
conflict = 409, // RFC7231, Section 6.5.8
gone = 410, // RFC7231, Section 6.5.9
length_required = 411, // RFC7231, Section 6.5.10
precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2
payload_too_large = 413, // RFC7231, Section 6.5.11
uri_too_long = 414, // RFC7231, Section 6.5.12
unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3
range_not_satisfiable = 416, // RFC7233, Section 4.4
expectation_failed = 417, // RFC7231, Section 6.5.14
teapot = 418, // RFC 7168, 2.3.3
misdirected_request = 421, // RFC7540, Section 9.1.2
unprocessable_entity = 422, // RFC4918
locked = 423, // RFC4918
failed_dependency = 424, // RFC4918
too_early = 425, // RFC8470
upgrade_required = 426, // RFC7231, Section 6.5.15
precondition_required = 428, // RFC6585
too_many_requests = 429, // RFC6585
header_fields_too_large = 431, // RFC6585
unavailable_for_legal_reasons = 451, // RFC7725
internal_server_error = 500, // RFC7231, Section 6.6.1
not_implemented = 501, // RFC7231, Section 6.6.2
bad_gateway = 502, // RFC7231, Section 6.6.3
service_unavailable = 503, // RFC7231, Section 6.6.4
gateway_timeout = 504, // RFC7231, Section 6.6.5
http_version_not_supported = 505, // RFC7231, Section 6.6.6
variant_also_negotiates = 506, // RFC2295
insufficient_storage = 507, // RFC4918
loop_detected = 508, // RFC5842
not_extended = 510, // RFC2774
network_authentication_required = 511, // RFC6585
_,
pub fn phrase(self: Status) ?[]const u8 {
return switch (self) {
// 1xx statuses
.@"continue" => "Continue",
.switching_protocols => "Switching Protocols",
.processing => "Processing",
.early_hints => "Early Hints",
// 2xx statuses
.ok => "OK",
.created => "Created",
.accepted => "Accepted",
.non_authoritative_info => "Non-Authoritative Information",
.no_content => "No Content",
.reset_content => "Reset Content",
.partial_content => "Partial Content",
.multi_status => "Multi-Status",
.already_reported => "Already Reported",
.im_used => "IM Used",
// 3xx statuses
.multiple_choice => "Multiple Choice",
.moved_permanently => "Moved Permanently",
.found => "Found",
.see_other => "See Other",
.not_modified => "Not Modified",
.use_proxy => "Use Proxy",
.temporary_redirect => "Temporary Redirect",
.permanent_redirect => "Permanent Redirect",
// 4xx statuses
.bad_request => "Bad Request",
.unauthorized => "Unauthorized",
.payment_required => "Payment Required",
.forbidden => "Forbidden",
.not_found => "Not Found",
.method_not_allowed => "Method Not Allowed",
.not_acceptable => "Not Acceptable",
.proxy_auth_required => "Proxy Authentication Required",
.request_timeout => "Request Timeout",
.conflict => "Conflict",
.gone => "Gone",
.length_required => "Length Required",
.precondition_failed => "Precondition Failed",
.payload_too_large => "Payload Too Large",
.uri_too_long => "URI Too Long",
.unsupported_media_type => "Unsupported Media Type",
.range_not_satisfiable => "Range Not Satisfiable",
.expectation_failed => "Expectation Failed",
.teapot => "I'm a teapot",
.misdirected_request => "Misdirected Request",
.unprocessable_entity => "Unprocessable Entity",
.locked => "Locked",
.failed_dependency => "Failed Dependency",
.too_early => "Too Early",
.upgrade_required => "Upgrade Required",
.precondition_required => "Precondition Required",
.too_many_requests => "Too Many Requests",
.header_fields_too_large => "Request Header Fields Too Large",
.unavailable_for_legal_reasons => "Unavailable For Legal Reasons",
// 5xx statuses
.internal_server_error => "Internal Server Error",
.not_implemented => "Not Implemented",
.bad_gateway => "Bad Gateway",
.service_unavailable => "Service Unavailable",
.gateway_timeout => "Gateway Timeout",
.http_version_not_supported => "HTTP Version Not Supported",
.variant_also_negotiates => "Variant Also Negotiates",
.insufficient_storage => "Insufficient Storage",
.loop_detected => "Loop Detected",
.not_extended => "Not Extended",
.network_authentication_required => "Network Authentication Required",
else => return null,
};
}
pub const Class = enum {
informational,
success,
redirect,
client_error,
server_error,
};
pub fn class(self: Status) ?Class {
return switch (@enumToInt(self)) {
100...199 => .informational,
200...299 => .success,
300...399 => .redirect,
400...499 => .client_error,
500...599 => .server_error,
else => null,
};
}
test {
try std.testing.expectEqualStrings("OK", Status.ok.phrase().?);
try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?);
}
test {
try std.testing.expectEqual(@as(?Status.Class, Status.Class.success), Status.ok.class());
try std.testing.expectEqual(@as(?Status.Class, Status.Class.client_error), Status.not_found.class());
}
};
pub const Headers = struct {
state: State = .start,
invalid_index: u32 = undefined,
pub const State = enum { invalid, start, line, nl_r, nl_n, nl2_r, finished };
/// Returns how many bytes are processed into headers. Always less than or
/// equal to bytes.len. If the amount returned is less than bytes.len, it
/// means the headers ended and the first byte after the double \r\n\r\n is
/// located at `bytes[result]`.
pub fn feed(h: *Headers, bytes: []const u8) usize {
for (bytes) |b, i| {
switch (h.state) {
.start => switch (b) {
'\r' => h.state = .nl_r,
'\n' => return invalid(h, i),
else => {},
},
.nl_r => switch (b) {
'\n' => h.state = .nl_n,
else => return invalid(h, i),
},
.nl_n => switch (b) {
'\r' => h.state = .nl2_r,
else => h.state = .line,
},
.nl2_r => switch (b) {
'\n' => h.state = .finished,
else => return invalid(h, i),
},
.line => switch (b) {
'\r' => h.state = .nl_r,
'\n' => return invalid(h, i),
else => {},
},
.invalid => return i,
.finished => return i,
}
}
return bytes.len;
}
fn invalid(h: *Headers, i: usize) usize {
h.invalid_index = @intCast(u32, i);
h.state = .invalid;
return i;
}
};
const std = @import("std.zig"); const std = @import("std.zig");
pub const Method = @import("http/method.zig").Method;
pub const Status = @import("http/status.zig").Status;
test { test {
std.testing.refAllDecls(@This()); _ = Client;
_ = Method;
_ = Status;
_ = Headers;
} }

181
lib/std/http/Client.zig Normal file
View file

@ -0,0 +1,181 @@
//! This API is a barely-touched, barely-functional http client, just the
//! absolute minimum thing I needed in order to test `std.crypto.tls`. Bear
//! with me and I promise the API will become useful and streamlined.
const std = @import("../std.zig");
const assert = std.debug.assert;
const http = std.http;
const net = std.net;
const Client = @This();
const Url = std.Url;
allocator: std.mem.Allocator,
headers: std.ArrayListUnmanaged(u8) = .{},
active_requests: usize = 0,
ca_bundle: std.crypto.Certificate.Bundle = .{},
/// TODO: emit error.UnexpectedEndOfStream or something like that when the read
/// data does not match the content length. This is necessary since HTTPS disables
/// close_notify protection on underlying TLS streams.
pub const Request = struct {
client: *Client,
stream: net.Stream,
headers: std.ArrayListUnmanaged(u8) = .{},
tls_client: std.crypto.tls.Client,
protocol: Protocol,
response_headers: http.Headers = .{},
pub const Protocol = enum { http, https };
pub const Options = struct {
method: http.Method = .GET,
};
pub fn deinit(req: *Request) void {
req.client.active_requests -= 1;
req.headers.deinit(req.client.allocator);
req.* = undefined;
}
pub fn addHeader(req: *Request, name: []const u8, value: []const u8) !void {
const gpa = req.client.allocator;
// Ensure an extra +2 for the \r\n in end()
try req.headers.ensureUnusedCapacity(gpa, name.len + value.len + 6);
req.headers.appendSliceAssumeCapacity(name);
req.headers.appendSliceAssumeCapacity(": ");
req.headers.appendSliceAssumeCapacity(value);
req.headers.appendSliceAssumeCapacity("\r\n");
}
pub fn end(req: *Request) !void {
req.headers.appendSliceAssumeCapacity("\r\n");
switch (req.protocol) {
.http => {
try req.stream.writeAll(req.headers.items);
},
.https => {
try req.tls_client.writeAll(req.stream, req.headers.items);
},
}
}
pub fn readAll(req: *Request, buffer: []u8) !usize {
return readAtLeast(req, buffer, buffer.len);
}
pub fn read(req: *Request, buffer: []u8) !usize {
return readAtLeast(req, buffer, 1);
}
pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize {
assert(len <= buffer.len);
var index: usize = 0;
while (index < len) {
const headers_finished = req.response_headers.state == .finished;
const amt = try readAdvanced(req, buffer[index..]);
if (amt == 0 and headers_finished) break;
index += amt;
}
return index;
}
/// This one can return 0 without meaning EOF.
/// TODO change to readvAdvanced
pub fn readAdvanced(req: *Request, buffer: []u8) !usize {
if (req.response_headers.state == .finished) return readRaw(req, buffer);
const amt = try readRaw(req, buffer);
const data = buffer[0..amt];
const i = req.response_headers.feed(data);
if (req.response_headers.state == .invalid) return error.InvalidHttpHeaders;
if (i < data.len) {
const rest = data[i..];
std.mem.copy(u8, buffer, rest);
return rest.len;
}
return 0;
}
/// Only abstracts over http/https.
fn readRaw(req: *Request, buffer: []u8) !usize {
switch (req.protocol) {
.http => return req.stream.read(buffer),
.https => return req.tls_client.read(req.stream, buffer),
}
}
/// Only abstracts over http/https.
fn readAtLeastRaw(req: *Request, buffer: []u8, len: usize) !usize {
switch (req.protocol) {
.http => return req.stream.readAtLeast(buffer, len),
.https => return req.tls_client.readAtLeast(req.stream, buffer, len),
}
}
};
pub fn deinit(client: *Client) void {
assert(client.active_requests == 0);
client.headers.deinit(client.allocator);
client.* = undefined;
}
pub fn request(client: *Client, url: Url, options: Request.Options) !Request {
const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse
return error.UnsupportedUrlScheme;
const port: u16 = url.port orelse switch (protocol) {
.http => 80,
.https => 443,
};
var req: Request = .{
.client = client,
.stream = try net.tcpConnectToHost(client.allocator, url.host, port),
.protocol = protocol,
.tls_client = undefined,
};
client.active_requests += 1;
errdefer req.deinit();
switch (protocol) {
.http => {},
.https => {
req.tls_client = try std.crypto.tls.Client.init(req.stream, client.ca_bundle, url.host);
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
req.tls_client.allow_truncation_attacks = true;
},
}
try req.headers.ensureUnusedCapacity(
client.allocator,
@tagName(options.method).len +
1 +
url.path.len +
" HTTP/1.1\r\nHost: ".len +
url.host.len +
"\r\nUpgrade-Insecure-Requests: 1\r\n".len +
client.headers.items.len +
2, // for the \r\n at the end of headers
);
req.headers.appendSliceAssumeCapacity(@tagName(options.method));
req.headers.appendSliceAssumeCapacity(" ");
req.headers.appendSliceAssumeCapacity(url.path);
req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: ");
req.headers.appendSliceAssumeCapacity(url.host);
switch (protocol) {
.https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"),
.http => req.headers.appendSliceAssumeCapacity("\r\n"),
}
req.headers.appendSliceAssumeCapacity(client.headers.items);
return req;
}
pub fn addHeader(client: *Client, name: []const u8, value: []const u8) !void {
const gpa = client.allocator;
try client.headers.ensureUnusedCapacity(gpa, name.len + value.len + 4);
client.headers.appendSliceAssumeCapacity(name);
client.headers.appendSliceAssumeCapacity(": ");
client.headers.appendSliceAssumeCapacity(value);
client.headers.appendSliceAssumeCapacity("\r\n");
}

View file

@ -1,65 +0,0 @@
//! HTTP Methods
//! https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
// Style guide is violated here so that @tagName can be used effectively
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton
/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH
pub const Method = enum {
GET,
HEAD,
POST,
PUT,
DELETE,
CONNECT,
OPTIONS,
TRACE,
PATCH,
/// Returns true if a request of this method is allowed to have a body
/// Actual behavior from servers may vary and should still be checked
pub fn requestHasBody(self: Method) bool {
return switch (self) {
.POST, .PUT, .PATCH => true,
.GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false,
};
}
/// Returns true if a response to this method is allowed to have a body
/// Actual behavior from clients may vary and should still be checked
pub fn responseHasBody(self: Method) bool {
return switch (self) {
.GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true,
.HEAD, .PUT, .TRACE => false,
};
}
/// An HTTP method is safe if it doesn't alter the state of the server.
/// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1
pub fn safe(self: Method) bool {
return switch (self) {
.GET, .HEAD, .OPTIONS, .TRACE => true,
.POST, .PUT, .DELETE, .CONNECT, .PATCH => false,
};
}
/// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state.
/// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2
pub fn idempotent(self: Method) bool {
return switch (self) {
.GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true,
.CONNECT, .POST, .PATCH => false,
};
}
/// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server.
/// https://developer.mozilla.org/en-US/docs/Glossary/cacheable
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3
pub fn cacheable(self: Method) bool {
return switch (self) {
.GET, .HEAD => true,
.POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false,
};
}
};

View file

@ -1,182 +0,0 @@
//! HTTP Status
//! https://developer.mozilla.org/en-US/docs/Web/HTTP/Status
const std = @import("../std.zig");
pub const Status = enum(u10) {
@"continue" = 100, // RFC7231, Section 6.2.1
switching_protocols = 101, // RFC7231, Section 6.2.2
processing = 102, // RFC2518
early_hints = 103, // RFC8297
ok = 200, // RFC7231, Section 6.3.1
created = 201, // RFC7231, Section 6.3.2
accepted = 202, // RFC7231, Section 6.3.3
non_authoritative_info = 203, // RFC7231, Section 6.3.4
no_content = 204, // RFC7231, Section 6.3.5
reset_content = 205, // RFC7231, Section 6.3.6
partial_content = 206, // RFC7233, Section 4.1
multi_status = 207, // RFC4918
already_reported = 208, // RFC5842
im_used = 226, // RFC3229
multiple_choice = 300, // RFC7231, Section 6.4.1
moved_permanently = 301, // RFC7231, Section 6.4.2
found = 302, // RFC7231, Section 6.4.3
see_other = 303, // RFC7231, Section 6.4.4
not_modified = 304, // RFC7232, Section 4.1
use_proxy = 305, // RFC7231, Section 6.4.5
temporary_redirect = 307, // RFC7231, Section 6.4.7
permanent_redirect = 308, // RFC7538
bad_request = 400, // RFC7231, Section 6.5.1
unauthorized = 401, // RFC7235, Section 3.1
payment_required = 402, // RFC7231, Section 6.5.2
forbidden = 403, // RFC7231, Section 6.5.3
not_found = 404, // RFC7231, Section 6.5.4
method_not_allowed = 405, // RFC7231, Section 6.5.5
not_acceptable = 406, // RFC7231, Section 6.5.6
proxy_auth_required = 407, // RFC7235, Section 3.2
request_timeout = 408, // RFC7231, Section 6.5.7
conflict = 409, // RFC7231, Section 6.5.8
gone = 410, // RFC7231, Section 6.5.9
length_required = 411, // RFC7231, Section 6.5.10
precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2
payload_too_large = 413, // RFC7231, Section 6.5.11
uri_too_long = 414, // RFC7231, Section 6.5.12
unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3
range_not_satisfiable = 416, // RFC7233, Section 4.4
expectation_failed = 417, // RFC7231, Section 6.5.14
teapot = 418, // RFC 7168, 2.3.3
misdirected_request = 421, // RFC7540, Section 9.1.2
unprocessable_entity = 422, // RFC4918
locked = 423, // RFC4918
failed_dependency = 424, // RFC4918
too_early = 425, // RFC8470
upgrade_required = 426, // RFC7231, Section 6.5.15
precondition_required = 428, // RFC6585
too_many_requests = 429, // RFC6585
header_fields_too_large = 431, // RFC6585
unavailable_for_legal_reasons = 451, // RFC7725
internal_server_error = 500, // RFC7231, Section 6.6.1
not_implemented = 501, // RFC7231, Section 6.6.2
bad_gateway = 502, // RFC7231, Section 6.6.3
service_unavailable = 503, // RFC7231, Section 6.6.4
gateway_timeout = 504, // RFC7231, Section 6.6.5
http_version_not_supported = 505, // RFC7231, Section 6.6.6
variant_also_negotiates = 506, // RFC2295
insufficient_storage = 507, // RFC4918
loop_detected = 508, // RFC5842
not_extended = 510, // RFC2774
network_authentication_required = 511, // RFC6585
_,
pub fn phrase(self: Status) ?[]const u8 {
return switch (self) {
// 1xx statuses
.@"continue" => "Continue",
.switching_protocols => "Switching Protocols",
.processing => "Processing",
.early_hints => "Early Hints",
// 2xx statuses
.ok => "OK",
.created => "Created",
.accepted => "Accepted",
.non_authoritative_info => "Non-Authoritative Information",
.no_content => "No Content",
.reset_content => "Reset Content",
.partial_content => "Partial Content",
.multi_status => "Multi-Status",
.already_reported => "Already Reported",
.im_used => "IM Used",
// 3xx statuses
.multiple_choice => "Multiple Choice",
.moved_permanently => "Moved Permanently",
.found => "Found",
.see_other => "See Other",
.not_modified => "Not Modified",
.use_proxy => "Use Proxy",
.temporary_redirect => "Temporary Redirect",
.permanent_redirect => "Permanent Redirect",
// 4xx statuses
.bad_request => "Bad Request",
.unauthorized => "Unauthorized",
.payment_required => "Payment Required",
.forbidden => "Forbidden",
.not_found => "Not Found",
.method_not_allowed => "Method Not Allowed",
.not_acceptable => "Not Acceptable",
.proxy_auth_required => "Proxy Authentication Required",
.request_timeout => "Request Timeout",
.conflict => "Conflict",
.gone => "Gone",
.length_required => "Length Required",
.precondition_failed => "Precondition Failed",
.payload_too_large => "Payload Too Large",
.uri_too_long => "URI Too Long",
.unsupported_media_type => "Unsupported Media Type",
.range_not_satisfiable => "Range Not Satisfiable",
.expectation_failed => "Expectation Failed",
.teapot => "I'm a teapot",
.misdirected_request => "Misdirected Request",
.unprocessable_entity => "Unprocessable Entity",
.locked => "Locked",
.failed_dependency => "Failed Dependency",
.too_early => "Too Early",
.upgrade_required => "Upgrade Required",
.precondition_required => "Precondition Required",
.too_many_requests => "Too Many Requests",
.header_fields_too_large => "Request Header Fields Too Large",
.unavailable_for_legal_reasons => "Unavailable For Legal Reasons",
// 5xx statuses
.internal_server_error => "Internal Server Error",
.not_implemented => "Not Implemented",
.bad_gateway => "Bad Gateway",
.service_unavailable => "Service Unavailable",
.gateway_timeout => "Gateway Timeout",
.http_version_not_supported => "HTTP Version Not Supported",
.variant_also_negotiates => "Variant Also Negotiates",
.insufficient_storage => "Insufficient Storage",
.loop_detected => "Loop Detected",
.not_extended => "Not Extended",
.network_authentication_required => "Network Authentication Required",
else => return null,
};
}
pub const Class = enum {
informational,
success,
redirect,
client_error,
server_error,
};
pub fn class(self: Status) ?Class {
return switch (@enumToInt(self)) {
100...199 => .informational,
200...299 => .success,
300...399 => .redirect,
400...499 => .client_error,
500...599 => .server_error,
else => null,
};
}
};
test {
try std.testing.expectEqualStrings("OK", Status.ok.phrase().?);
try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?);
}
test {
try std.testing.expectEqual(@as(?Status.Class, Status.Class.success), Status.ok.class());
try std.testing.expectEqual(@as(?Status.Class, Status.Class.client_error), Status.not_found.class());
}

View file

@ -810,21 +810,25 @@ test "std.meta.activeTag" {
const TagPayloadType = TagPayload; const TagPayloadType = TagPayload;
///Given a tagged union type, and an enum, return the type of the union pub fn TagPayloadByName(comptime U: type, comptime tag_name: []const u8) type {
/// field corresponding to the enum tag.
pub fn TagPayload(comptime U: type, comptime tag: Tag(U)) type {
comptime debug.assert(trait.is(.Union)(U)); comptime debug.assert(trait.is(.Union)(U));
const info = @typeInfo(U).Union; const info = @typeInfo(U).Union;
inline for (info.fields) |field_info| { inline for (info.fields) |field_info| {
if (comptime mem.eql(u8, field_info.name, @tagName(tag))) if (comptime mem.eql(u8, field_info.name, tag_name))
return field_info.type; return field_info.type;
} }
unreachable; unreachable;
} }
/// Given a tagged union type, and an enum, return the type of the union field
/// corresponding to the enum tag.
pub fn TagPayload(comptime U: type, comptime tag: Tag(U)) type {
return TagPayloadByName(U, @tagName(tag));
}
test "std.meta.TagPayload" { test "std.meta.TagPayload" {
const Event = union(enum) { const Event = union(enum) {
Moved: struct { Moved: struct {

View file

@ -1672,6 +1672,40 @@ pub const Stream = struct {
} }
} }
pub fn readv(s: Stream, iovecs: []const os.iovec) ReadError!usize {
if (builtin.os.tag == .windows) {
// TODO improve this to use ReadFileScatter
if (iovecs.len == 0) return @as(usize, 0);
const first = iovecs[0];
return os.windows.ReadFile(s.handle, first.iov_base[0..first.iov_len], null, io.default_mode);
}
return os.readv(s.handle, iovecs);
}
/// Returns the number of bytes read. If the number read is smaller than
/// `buffer.len`, it means the stream reached the end. Reaching the end of
/// a stream is not an error condition.
pub fn readAll(s: Stream, buffer: []u8) ReadError!usize {
return readAtLeast(s, buffer, buffer.len);
}
/// Returns the number of bytes read, calling the underlying read function
/// the minimal number of times until the buffer has at least `len` bytes
/// filled. If the number read is less than `len` it means the stream
/// reached the end. Reaching the end of the stream is not an error
/// condition.
pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize {
assert(len <= buffer.len);
var index: usize = 0;
while (index < len) {
const amt = try s.read(buffer[index..]);
if (amt == 0) break;
index += amt;
}
return index;
}
/// TODO in evented I/O mode, this implementation incorrectly uses the event loop's /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's
/// file system thread instead of non-blocking. It needs to be reworked to properly /// file system thread instead of non-blocking. It needs to be reworked to properly
/// use non-blocking I/O. /// use non-blocking I/O.
@ -1687,6 +1721,13 @@ pub const Stream = struct {
} }
} }
pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void {
var index: usize = 0;
while (index < bytes.len) {
index += try self.write(bytes[index..]);
}
}
/// See https://github.com/ziglang/zig/issues/7699 /// See https://github.com/ziglang/zig/issues/7699
/// See equivalent function: `std.fs.File.writev`. /// See equivalent function: `std.fs.File.writev`.
pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize { pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize {

View file

@ -767,6 +767,7 @@ pub fn readv(fd: fd_t, iov: []const iovec) ReadError!usize {
.ISDIR => return error.IsDir, .ISDIR => return error.IsDir,
.NOBUFS => return error.SystemResources, .NOBUFS => return error.SystemResources,
.NOMEM => return error.SystemResources, .NOMEM => return error.SystemResources,
.CONNRESET => return error.ConnectionResetByPeer,
else => |err| return unexpectedErrno(err), else => |err| return unexpectedErrno(err),
} }
} }
@ -5685,11 +5686,11 @@ pub fn sendmsg(
/// The file descriptor of the sending socket. /// The file descriptor of the sending socket.
sockfd: socket_t, sockfd: socket_t,
/// Message header and iovecs /// Message header and iovecs
msg: msghdr_const, msg: *const msghdr_const,
flags: u32, flags: u32,
) SendMsgError!usize { ) SendMsgError!usize {
while (true) { while (true) {
const rc = system.sendmsg(sockfd, @ptrCast(*const std.x.os.Socket.Message, &msg), @intCast(c_int, flags)); const rc = system.sendmsg(sockfd, msg, flags);
if (builtin.os.tag == .windows) { if (builtin.os.tag == .windows) {
if (rc == windows.ws2_32.SOCKET_ERROR) { if (rc == windows.ws2_32.SOCKET_ERROR) {
switch (windows.ws2_32.WSAGetLastError()) { switch (windows.ws2_32.WSAGetLastError()) {

View file

@ -1226,11 +1226,14 @@ pub fn getsockopt(fd: i32, level: u32, optname: u32, noalias optval: [*]u8, noal
return syscall5(.getsockopt, @bitCast(usize, @as(isize, fd)), level, optname, @ptrToInt(optval), @ptrToInt(optlen)); return syscall5(.getsockopt, @bitCast(usize, @as(isize, fd)), level, optname, @ptrToInt(optval), @ptrToInt(optlen));
} }
pub fn sendmsg(fd: i32, msg: *const std.x.os.Socket.Message, flags: c_int) usize { pub fn sendmsg(fd: i32, msg: *const msghdr_const, flags: u32) usize {
const fd_usize = @bitCast(usize, @as(isize, fd));
const msg_usize = @ptrToInt(msg);
if (native_arch == .x86) { if (native_arch == .x86) {
return socketcall(SC.sendmsg, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)) }); return socketcall(SC.sendmsg, &[3]usize{ fd_usize, msg_usize, flags });
} else {
return syscall3(.sendmsg, fd_usize, msg_usize, flags);
} }
return syscall3(.sendmsg, @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)));
} }
pub fn sendmmsg(fd: i32, msgvec: [*]mmsghdr_const, vlen: u32, flags: u32) usize { pub fn sendmmsg(fd: i32, msgvec: [*]mmsghdr_const, vlen: u32, flags: u32) usize {
@ -1274,24 +1277,42 @@ pub fn sendmmsg(fd: i32, msgvec: [*]mmsghdr_const, vlen: u32, flags: u32) usize
} }
pub fn connect(fd: i32, addr: *const anyopaque, len: socklen_t) usize { pub fn connect(fd: i32, addr: *const anyopaque, len: socklen_t) usize {
const fd_usize = @bitCast(usize, @as(isize, fd));
const addr_usize = @ptrToInt(addr);
if (native_arch == .x86) { if (native_arch == .x86) {
return socketcall(SC.connect, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(addr), len }); return socketcall(SC.connect, &[3]usize{ fd_usize, addr_usize, len });
} else {
return syscall3(.connect, fd_usize, addr_usize, len);
} }
return syscall3(.connect, @bitCast(usize, @as(isize, fd)), @ptrToInt(addr), len);
} }
pub fn recvmsg(fd: i32, msg: *std.x.os.Socket.Message, flags: c_int) usize { pub fn recvmsg(fd: i32, msg: *msghdr, flags: u32) usize {
const fd_usize = @bitCast(usize, @as(isize, fd));
const msg_usize = @ptrToInt(msg);
if (native_arch == .x86) { if (native_arch == .x86) {
return socketcall(SC.recvmsg, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)) }); return socketcall(SC.recvmsg, &[3]usize{ fd_usize, msg_usize, flags });
} else {
return syscall3(.recvmsg, fd_usize, msg_usize, flags);
} }
return syscall3(.recvmsg, @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)));
} }
pub fn recvfrom(fd: i32, noalias buf: [*]u8, len: usize, flags: u32, noalias addr: ?*sockaddr, noalias alen: ?*socklen_t) usize { pub fn recvfrom(
fd: i32,
noalias buf: [*]u8,
len: usize,
flags: u32,
noalias addr: ?*sockaddr,
noalias alen: ?*socklen_t,
) usize {
const fd_usize = @bitCast(usize, @as(isize, fd));
const buf_usize = @ptrToInt(buf);
const addr_usize = @ptrToInt(addr);
const alen_usize = @ptrToInt(alen);
if (native_arch == .x86) { if (native_arch == .x86) {
return socketcall(SC.recvfrom, &[6]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(buf), len, flags, @ptrToInt(addr), @ptrToInt(alen) }); return socketcall(SC.recvfrom, &[6]usize{ fd_usize, buf_usize, len, flags, addr_usize, alen_usize });
} else {
return syscall6(.recvfrom, fd_usize, buf_usize, len, flags, addr_usize, alen_usize);
} }
return syscall6(.recvfrom, @bitCast(usize, @as(isize, fd)), @ptrToInt(buf), len, flags, @ptrToInt(addr), @ptrToInt(alen));
} }
pub fn shutdown(fd: i32, how: i32) usize { pub fn shutdown(fd: i32, how: i32) usize {
@ -3219,7 +3240,15 @@ pub const sockaddr = extern struct {
data: [14]u8, data: [14]u8,
pub const SS_MAXSIZE = 128; pub const SS_MAXSIZE = 128;
pub const storage = std.x.os.Socket.Address.Native.Storage; pub const storage = extern struct {
family: sa_family_t align(8),
padding: [SS_MAXSIZE - @sizeOf(sa_family_t)]u8 = undefined,
comptime {
assert(@sizeOf(storage) == SS_MAXSIZE);
assert(@alignOf(storage) == 8);
}
};
/// IPv4 socket address /// IPv4 socket address
pub const in = extern struct { pub const in = extern struct {

View file

@ -6,16 +6,14 @@
//! isn't that useful for general-purpose applications, and so a mode that //! isn't that useful for general-purpose applications, and so a mode that
//! utilizes user-supplied filters mode was added. //! utilizes user-supplied filters mode was added.
//! //!
//! Seccomp filters are classic BPF programs, which means that all the //! Seccomp filters are classic BPF programs. Conceptually, a seccomp program
//! information under `std.x.net.bpf` applies here as well. Conceptually, a //! is attached to the kernel and is executed on each syscall. The "packet"
//! seccomp program is attached to the kernel and is executed on each syscall. //! being validated is the `data` structure, and the verdict is an action that
//! The "packet" being validated is the `data` structure, and the verdict is an //! the kernel performs on the calling process. The actions are variations on a
//! action that the kernel performs on the calling process. The actions are //! "pass" or "fail" result, where a pass allows the syscall to continue and a
//! variations on a "pass" or "fail" result, where a pass allows the syscall to //! fail blocks the syscall and returns some sort of error value. See the full
//! continue and a fail blocks the syscall and returns some sort of error value. //! list of actions under ::RET for more information. Finally, only word-sized,
//! See the full list of actions under ::RET for more information. Finally, only //! absolute loads (`ld [k]`) are supported to read from the `data` structure.
//! word-sized, absolute loads (`ld [k]`) are supported to read from the `data`
//! structure.
//! //!
//! There are some issues with the filter API that have traditionally made //! There are some issues with the filter API that have traditionally made
//! writing them a pain: //! writing them a pain:

View file

@ -1,4 +1,5 @@
const std = @import("../../std.zig"); const std = @import("../../std.zig");
const assert = std.debug.assert;
const windows = std.os.windows; const windows = std.os.windows;
const WINAPI = windows.WINAPI; const WINAPI = windows.WINAPI;
@ -1106,7 +1107,15 @@ pub const sockaddr = extern struct {
data: [14]u8, data: [14]u8,
pub const SS_MAXSIZE = 128; pub const SS_MAXSIZE = 128;
pub const storage = std.x.os.Socket.Address.Native.Storage; pub const storage = extern struct {
family: ADDRESS_FAMILY align(8),
padding: [SS_MAXSIZE - @sizeOf(ADDRESS_FAMILY)]u8 = undefined,
comptime {
assert(@sizeOf(storage) == SS_MAXSIZE);
assert(@alignOf(storage) == 8);
}
};
/// IPv4 socket address /// IPv4 socket address
pub const in = extern struct { pub const in = extern struct {
@ -1207,7 +1216,7 @@ pub const LPFN_GETACCEPTEXSOCKADDRS = *const fn (
pub const LPFN_WSASENDMSG = *const fn ( pub const LPFN_WSASENDMSG = *const fn (
s: SOCKET, s: SOCKET,
lpMsg: *const std.x.os.Socket.Message, lpMsg: *const WSAMSG_const,
dwFlags: u32, dwFlags: u32,
lpNumberOfBytesSent: ?*u32, lpNumberOfBytesSent: ?*u32,
lpOverlapped: ?*OVERLAPPED, lpOverlapped: ?*OVERLAPPED,
@ -1216,7 +1225,7 @@ pub const LPFN_WSASENDMSG = *const fn (
pub const LPFN_WSARECVMSG = *const fn ( pub const LPFN_WSARECVMSG = *const fn (
s: SOCKET, s: SOCKET,
lpMsg: *std.x.os.Socket.Message, lpMsg: *WSAMSG,
lpdwNumberOfBytesRecv: ?*u32, lpdwNumberOfBytesRecv: ?*u32,
lpOverlapped: ?*OVERLAPPED, lpOverlapped: ?*OVERLAPPED,
lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE, lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE,
@ -2090,7 +2099,7 @@ pub extern "ws2_32" fn WSASend(
pub extern "ws2_32" fn WSASendMsg( pub extern "ws2_32" fn WSASendMsg(
s: SOCKET, s: SOCKET,
lpMsg: *const std.x.os.Socket.Message, lpMsg: *WSAMSG_const,
dwFlags: u32, dwFlags: u32,
lpNumberOfBytesSent: ?*u32, lpNumberOfBytesSent: ?*u32,
lpOverlapped: ?*OVERLAPPED, lpOverlapped: ?*OVERLAPPED,
@ -2099,7 +2108,7 @@ pub extern "ws2_32" fn WSASendMsg(
pub extern "ws2_32" fn WSARecvMsg( pub extern "ws2_32" fn WSARecvMsg(
s: SOCKET, s: SOCKET,
lpMsg: *std.x.os.Socket.Message, lpMsg: *WSAMSG,
lpdwNumberOfBytesRecv: ?*u32, lpdwNumberOfBytesRecv: ?*u32,
lpOverlapped: ?*OVERLAPPED, lpOverlapped: ?*OVERLAPPED,
lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE, lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE,

View file

@ -42,6 +42,7 @@ pub const Target = @import("target.zig").Target;
pub const Thread = @import("Thread.zig"); pub const Thread = @import("Thread.zig");
pub const Treap = @import("treap.zig").Treap; pub const Treap = @import("treap.zig").Treap;
pub const Tz = tz.Tz; pub const Tz = tz.Tz;
pub const Url = @import("Url.zig");
pub const array_hash_map = @import("array_hash_map.zig"); pub const array_hash_map = @import("array_hash_map.zig");
pub const atomic = @import("atomic.zig"); pub const atomic = @import("atomic.zig");
@ -90,7 +91,6 @@ pub const tz = @import("tz.zig");
pub const unicode = @import("unicode.zig"); pub const unicode = @import("unicode.zig");
pub const valgrind = @import("valgrind.zig"); pub const valgrind = @import("valgrind.zig");
pub const wasm = @import("wasm.zig"); pub const wasm = @import("wasm.zig");
pub const x = @import("x.zig");
pub const zig = @import("zig.zig"); pub const zig = @import("zig.zig");
pub const start = @import("start.zig"); pub const start = @import("start.zig");

View file

@ -1,19 +0,0 @@
const std = @import("std.zig");
pub const os = struct {
pub const Socket = @import("x/os/socket.zig").Socket;
pub usingnamespace @import("x/os/io.zig");
pub usingnamespace @import("x/os/net.zig");
};
pub const net = struct {
pub const ip = @import("x/net/ip.zig");
pub const tcp = @import("x/net/tcp.zig");
pub const bpf = @import("x/net/bpf.zig");
};
test {
inline for (.{ os, net }) |module| {
std.testing.refAllDecls(module);
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,57 +0,0 @@
const std = @import("../../std.zig");
const fmt = std.fmt;
const IPv4 = std.x.os.IPv4;
const IPv6 = std.x.os.IPv6;
const Socket = std.x.os.Socket;
/// A generic IP abstraction.
const ip = @This();
/// A union of all eligible types of IP addresses.
pub const Address = union(enum) {
ipv4: IPv4.Address,
ipv6: IPv6.Address,
/// Instantiate a new address with a IPv4 host and port.
pub fn initIPv4(host: IPv4, port: u16) Address {
return .{ .ipv4 = .{ .host = host, .port = port } };
}
/// Instantiate a new address with a IPv6 host and port.
pub fn initIPv6(host: IPv6, port: u16) Address {
return .{ .ipv6 = .{ .host = host, .port = port } };
}
/// Re-interpret a generic socket address into an IP address.
pub fn from(address: Socket.Address) ip.Address {
return switch (address) {
.ipv4 => |ipv4_address| .{ .ipv4 = ipv4_address },
.ipv6 => |ipv6_address| .{ .ipv6 = ipv6_address },
};
}
/// Re-interpret an IP address into a generic socket address.
pub fn into(self: ip.Address) Socket.Address {
return switch (self) {
.ipv4 => |ipv4_address| .{ .ipv4 = ipv4_address },
.ipv6 => |ipv6_address| .{ .ipv6 = ipv6_address },
};
}
/// Implements the `std.fmt.format` API.
pub fn format(
self: ip.Address,
comptime layout: []const u8,
opts: fmt.FormatOptions,
writer: anytype,
) !void {
if (layout.len != 0) std.fmt.invalidFmtError(layout, self);
_ = opts;
switch (self) {
.ipv4 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }),
.ipv6 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }),
}
}
};

View file

@ -1,447 +0,0 @@
const std = @import("../../std.zig");
const builtin = @import("builtin");
const io = std.io;
const os = std.os;
const ip = std.x.net.ip;
const fmt = std.fmt;
const mem = std.mem;
const testing = std.testing;
const native_os = builtin.os;
const IPv4 = std.x.os.IPv4;
const IPv6 = std.x.os.IPv6;
const Socket = std.x.os.Socket;
const Buffer = std.x.os.Buffer;
/// A generic TCP socket abstraction.
const tcp = @This();
/// A TCP client-address pair.
pub const Connection = struct {
client: tcp.Client,
address: ip.Address,
/// Enclose a TCP client and address into a client-address pair.
pub fn from(conn: Socket.Connection) tcp.Connection {
return .{
.client = tcp.Client.from(conn.socket),
.address = ip.Address.from(conn.address),
};
}
/// Unravel a TCP client-address pair into a socket-address pair.
pub fn into(self: tcp.Connection) Socket.Connection {
return .{
.socket = self.client.socket,
.address = self.address.into(),
};
}
/// Closes the underlying client of the connection.
pub fn deinit(self: tcp.Connection) void {
self.client.deinit();
}
};
/// Possible domains that a TCP client/listener may operate over.
pub const Domain = enum(u16) {
ip = os.AF.INET,
ipv6 = os.AF.INET6,
};
/// A TCP client.
pub const Client = struct {
socket: Socket,
/// Implements `std.io.Reader`.
pub const Reader = struct {
client: Client,
flags: u32,
/// Implements `readFn` for `std.io.Reader`.
pub fn read(self: Client.Reader, buffer: []u8) !usize {
return self.client.read(buffer, self.flags);
}
};
/// Implements `std.io.Writer`.
pub const Writer = struct {
client: Client,
flags: u32,
/// Implements `writeFn` for `std.io.Writer`.
pub fn write(self: Client.Writer, buffer: []const u8) !usize {
return self.client.write(buffer, self.flags);
}
};
/// Opens a new client.
pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Client {
return Client{
.socket = try Socket.init(
@enumToInt(domain),
os.SOCK.STREAM,
os.IPPROTO.TCP,
flags,
),
};
}
/// Enclose a TCP client over an existing socket.
pub fn from(socket: Socket) Client {
return Client{ .socket = socket };
}
/// Closes the client.
pub fn deinit(self: Client) void {
self.socket.deinit();
}
/// Shutdown either the read side, write side, or all sides of the client's underlying socket.
pub fn shutdown(self: Client, how: os.ShutdownHow) !void {
return self.socket.shutdown(how);
}
/// Have the client attempt to the connect to an address.
pub fn connect(self: Client, address: ip.Address) !void {
return self.socket.connect(address.into());
}
/// Extracts the error set of a function.
/// TODO: remove after Socket.{read, write} error unions are well-defined across different platforms
fn ErrorSetOf(comptime Function: anytype) type {
return @typeInfo(@typeInfo(@TypeOf(Function)).Fn.return_type.?).ErrorUnion.error_set;
}
/// Wrap `tcp.Client` into `std.io.Reader`.
pub fn reader(self: Client, flags: u32) io.Reader(Client.Reader, ErrorSetOf(Client.Reader.read), Client.Reader.read) {
return .{ .context = .{ .client = self, .flags = flags } };
}
/// Wrap `tcp.Client` into `std.io.Writer`.
pub fn writer(self: Client, flags: u32) io.Writer(Client.Writer, ErrorSetOf(Client.Writer.write), Client.Writer.write) {
return .{ .context = .{ .client = self, .flags = flags } };
}
/// Read data from the socket into the buffer provided with a set of flags
/// specified. It returns the number of bytes read into the buffer provided.
pub fn read(self: Client, buf: []u8, flags: u32) !usize {
return self.socket.read(buf, flags);
}
/// Write a buffer of data provided to the socket with a set of flags specified.
/// It returns the number of bytes that are written to the socket.
pub fn write(self: Client, buf: []const u8, flags: u32) !usize {
return self.socket.write(buf, flags);
}
/// Writes multiple I/O vectors with a prepended message header to the socket
/// with a set of flags specified. It returns the number of bytes that are
/// written to the socket.
pub fn writeMessage(self: Client, msg: Socket.Message, flags: u32) !usize {
return self.socket.writeMessage(msg, flags);
}
/// Read multiple I/O vectors with a prepended message header from the socket
/// with a set of flags specified. It returns the number of bytes that were
/// read into the buffer provided.
pub fn readMessage(self: Client, msg: *Socket.Message, flags: u32) !usize {
return self.socket.readMessage(msg, flags);
}
/// Query and return the latest cached error on the client's underlying socket.
pub fn getError(self: Client) !void {
return self.socket.getError();
}
/// Query the read buffer size of the client's underlying socket.
pub fn getReadBufferSize(self: Client) !u32 {
return self.socket.getReadBufferSize();
}
/// Query the write buffer size of the client's underlying socket.
pub fn getWriteBufferSize(self: Client) !u32 {
return self.socket.getWriteBufferSize();
}
/// Query the address that the client's socket is locally bounded to.
pub fn getLocalAddress(self: Client) !ip.Address {
return ip.Address.from(try self.socket.getLocalAddress());
}
/// Query the address that the socket is connected to.
pub fn getRemoteAddress(self: Client) !ip.Address {
return ip.Address.from(try self.socket.getRemoteAddress());
}
/// Have close() or shutdown() syscalls block until all queued messages in the client have been successfully
/// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption`
/// if the host does not support the option for a socket to linger around up until a timeout specified in
/// seconds.
pub fn setLinger(self: Client, timeout_seconds: ?u16) !void {
return self.socket.setLinger(timeout_seconds);
}
/// Have keep-alive messages be sent periodically. The timing in which keep-alive messages are sent are
/// dependant on operating system settings. It returns `error.UnsupportedSocketOption` if the host does
/// not support periodically sending keep-alive messages on connection-oriented sockets.
pub fn setKeepAlive(self: Client, enabled: bool) !void {
return self.socket.setKeepAlive(enabled);
}
/// Disable Nagle's algorithm on a TCP socket. It returns `error.UnsupportedSocketOption` if
/// the host does not support sockets disabling Nagle's algorithm.
pub fn setNoDelay(self: Client, enabled: bool) !void {
if (@hasDecl(os.TCP, "NODELAY")) {
const bytes = mem.asBytes(&@as(usize, @boolToInt(enabled)));
return self.socket.setOption(os.IPPROTO.TCP, os.TCP.NODELAY, bytes);
}
return error.UnsupportedSocketOption;
}
/// Enables TCP Quick ACK on a TCP socket to immediately send rather than delay ACKs when necessary. It returns
/// `error.UnsupportedSocketOption` if the host does not support TCP Quick ACK.
pub fn setQuickACK(self: Client, enabled: bool) !void {
if (@hasDecl(os.TCP, "QUICKACK")) {
return self.socket.setOption(os.IPPROTO.TCP, os.TCP.QUICKACK, mem.asBytes(&@as(u32, @boolToInt(enabled))));
}
return error.UnsupportedSocketOption;
}
/// Set the write buffer size of the socket.
pub fn setWriteBufferSize(self: Client, size: u32) !void {
return self.socket.setWriteBufferSize(size);
}
/// Set the read buffer size of the socket.
pub fn setReadBufferSize(self: Client, size: u32) !void {
return self.socket.setReadBufferSize(size);
}
/// Set a timeout on the socket that is to occur if no messages are successfully written
/// to its bound destination after a specified number of milliseconds. A subsequent write
/// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded.
pub fn setWriteTimeout(self: Client, milliseconds: u32) !void {
return self.socket.setWriteTimeout(milliseconds);
}
/// Set a timeout on the socket that is to occur if no messages are successfully read
/// from its bound destination after a specified number of milliseconds. A subsequent
/// read from the socket will thereafter return `error.WouldBlock` should the timeout be
/// exceeded.
pub fn setReadTimeout(self: Client, milliseconds: u32) !void {
return self.socket.setReadTimeout(milliseconds);
}
};
/// A TCP listener.
pub const Listener = struct {
socket: Socket,
/// Opens a new listener.
pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Listener {
return Listener{
.socket = try Socket.init(
@enumToInt(domain),
os.SOCK.STREAM,
os.IPPROTO.TCP,
flags,
),
};
}
/// Closes the listener.
pub fn deinit(self: Listener) void {
self.socket.deinit();
}
/// Shuts down the underlying listener's socket. The next subsequent call, or
/// a current pending call to accept() after shutdown is called will return
/// an error.
pub fn shutdown(self: Listener) !void {
return self.socket.shutdown(.recv);
}
/// Binds the listener's socket to an address.
pub fn bind(self: Listener, address: ip.Address) !void {
return self.socket.bind(address.into());
}
/// Start listening for incoming connections.
pub fn listen(self: Listener, max_backlog_size: u31) !void {
return self.socket.listen(max_backlog_size);
}
/// Accept a pending incoming connection queued to the kernel backlog
/// of the listener's socket.
pub fn accept(self: Listener, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !tcp.Connection {
return tcp.Connection.from(try self.socket.accept(flags));
}
/// Query and return the latest cached error on the listener's underlying socket.
pub fn getError(self: Client) !void {
return self.socket.getError();
}
/// Query the address that the listener's socket is locally bounded to.
pub fn getLocalAddress(self: Listener) !ip.Address {
return ip.Address.from(try self.socket.getLocalAddress());
}
/// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if
/// the host does not support sockets listening the same address.
pub fn setReuseAddress(self: Listener, enabled: bool) !void {
return self.socket.setReuseAddress(enabled);
}
/// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if
/// the host does not supports sockets listening on the same port.
pub fn setReusePort(self: Listener, enabled: bool) !void {
return self.socket.setReusePort(enabled);
}
/// Enables TCP Fast Open (RFC 7413) on a TCP socket. It returns `error.UnsupportedSocketOption` if the host does not
/// support TCP Fast Open.
pub fn setFastOpen(self: Listener, enabled: bool) !void {
if (@hasDecl(os.TCP, "FASTOPEN")) {
return self.socket.setOption(os.IPPROTO.TCP, os.TCP.FASTOPEN, mem.asBytes(&@as(u32, @boolToInt(enabled))));
}
return error.UnsupportedSocketOption;
}
/// Set a timeout on the listener that is to occur if no new incoming connections come in
/// after a specified number of milliseconds. A subsequent accept call to the listener
/// will thereafter return `error.WouldBlock` should the timeout be exceeded.
pub fn setAcceptTimeout(self: Listener, milliseconds: usize) !void {
return self.socket.setReadTimeout(milliseconds);
}
};
test "tcp: create client/listener pair" {
if (native_os.tag == .wasi) return error.SkipZigTest;
const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
defer listener.deinit();
try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
try listener.listen(128);
var binded_address = try listener.getLocalAddress();
switch (binded_address) {
.ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
.ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
}
const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
defer client.deinit();
try client.connect(binded_address);
const conn = try listener.accept(.{ .close_on_exec = true });
defer conn.deinit();
}
test "tcp/client: 1ms read timeout" {
if (native_os.tag == .wasi) return error.SkipZigTest;
const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
defer listener.deinit();
try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
try listener.listen(128);
var binded_address = try listener.getLocalAddress();
switch (binded_address) {
.ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
.ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
}
const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
defer client.deinit();
try client.connect(binded_address);
try client.setReadTimeout(1);
const conn = try listener.accept(.{ .close_on_exec = true });
defer conn.deinit();
var buf: [1]u8 = undefined;
try testing.expectError(error.WouldBlock, client.reader(0).read(&buf));
}
test "tcp/client: read and write multiple vectors" {
if (native_os.tag == .wasi) return error.SkipZigTest;
if (builtin.os.tag == .windows) {
// https://github.com/ziglang/zig/issues/13893
return error.SkipZigTest;
}
const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
defer listener.deinit();
try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
try listener.listen(128);
var binded_address = try listener.getLocalAddress();
switch (binded_address) {
.ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
.ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
}
const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
defer client.deinit();
try client.connect(binded_address);
const conn = try listener.accept(.{ .close_on_exec = true });
defer conn.deinit();
const message = "hello world";
_ = try conn.client.writeMessage(Socket.Message.fromBuffers(&[_]Buffer{
Buffer.from(message[0 .. message.len / 2]),
Buffer.from(message[message.len / 2 ..]),
}), 0);
var buf: [message.len + 1]u8 = undefined;
var msg = Socket.Message.fromBuffers(&[_]Buffer{
Buffer.from(buf[0 .. message.len / 2]),
Buffer.from(buf[message.len / 2 ..]),
});
_ = try client.readMessage(&msg, 0);
try testing.expectEqualStrings(message, buf[0..message.len]);
}
test "tcp/listener: bind to unspecified ipv4 address" {
if (native_os.tag == .wasi) return error.SkipZigTest;
const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
defer listener.deinit();
try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
try listener.listen(128);
const address = try listener.getLocalAddress();
try testing.expect(address == .ipv4);
}
test "tcp/listener: bind to unspecified ipv6 address" {
if (native_os.tag == .wasi) return error.SkipZigTest;
if (builtin.os.tag == .windows) {
// https://github.com/ziglang/zig/issues/13893
return error.SkipZigTest;
}
const listener = try tcp.Listener.init(.ipv6, .{ .close_on_exec = true });
defer listener.deinit();
try listener.bind(ip.Address.initIPv6(IPv6.unspecified, 0));
try listener.listen(128);
const address = try listener.getLocalAddress();
try testing.expect(address == .ipv6);
}

View file

@ -1,224 +0,0 @@
const std = @import("../../std.zig");
const builtin = @import("builtin");
const os = std.os;
const mem = std.mem;
const testing = std.testing;
const native_os = builtin.os;
const linux = std.os.linux;
/// POSIX `iovec`, or Windows `WSABUF`. The difference between the two are the ordering
/// of fields, alongside the length being represented as either a ULONG or a size_t.
pub const Buffer = if (native_os.tag == .windows)
extern struct {
len: c_ulong,
ptr: usize,
pub fn from(slice: []const u8) Buffer {
return .{ .len = @intCast(c_ulong, slice.len), .ptr = @ptrToInt(slice.ptr) };
}
pub fn into(self: Buffer) []const u8 {
return @intToPtr([*]const u8, self.ptr)[0..self.len];
}
pub fn intoMutable(self: Buffer) []u8 {
return @intToPtr([*]u8, self.ptr)[0..self.len];
}
}
else
extern struct {
ptr: usize,
len: usize,
pub fn from(slice: []const u8) Buffer {
return .{ .ptr = @ptrToInt(slice.ptr), .len = slice.len };
}
pub fn into(self: Buffer) []const u8 {
return @intToPtr([*]const u8, self.ptr)[0..self.len];
}
pub fn intoMutable(self: Buffer) []u8 {
return @intToPtr([*]u8, self.ptr)[0..self.len];
}
};
pub const Reactor = struct {
pub const InitFlags = enum {
close_on_exec,
};
pub const Event = struct {
data: usize,
is_error: bool,
is_hup: bool,
is_readable: bool,
is_writable: bool,
};
pub const Interest = struct {
hup: bool = false,
oneshot: bool = false,
readable: bool = false,
writable: bool = false,
};
fd: os.fd_t,
pub fn init(flags: std.enums.EnumFieldStruct(Reactor.InitFlags, bool, false)) !Reactor {
var raw_flags: u32 = 0;
const set = std.EnumSet(Reactor.InitFlags).init(flags);
if (set.contains(.close_on_exec)) raw_flags |= linux.EPOLL.CLOEXEC;
return Reactor{ .fd = try os.epoll_create1(raw_flags) };
}
pub fn deinit(self: Reactor) void {
os.close(self.fd);
}
pub fn update(self: Reactor, fd: os.fd_t, identifier: usize, interest: Reactor.Interest) !void {
var flags: u32 = 0;
flags |= if (interest.oneshot) linux.EPOLL.ONESHOT else linux.EPOLL.ET;
if (interest.hup) flags |= linux.EPOLL.RDHUP;
if (interest.readable) flags |= linux.EPOLL.IN;
if (interest.writable) flags |= linux.EPOLL.OUT;
const event = &linux.epoll_event{
.events = flags,
.data = .{ .ptr = identifier },
};
os.epoll_ctl(self.fd, linux.EPOLL.CTL_MOD, fd, event) catch |err| switch (err) {
error.FileDescriptorNotRegistered => try os.epoll_ctl(self.fd, linux.EPOLL.CTL_ADD, fd, event),
else => return err,
};
}
pub fn remove(self: Reactor, fd: os.fd_t) !void {
// directly from man epoll_ctl BUGS section
// In kernel versions before 2.6.9, the EPOLL_CTL_DEL operation re
// quired a non-null pointer in event, even though this argument is
// ignored. Since Linux 2.6.9, event can be specified as NULL when
// using EPOLL_CTL_DEL. Applications that need to be portable to
// kernels before 2.6.9 should specify a non-null pointer in event.
var event = linux.epoll_event{
.events = 0,
.data = .{ .ptr = 0 },
};
return os.epoll_ctl(self.fd, linux.EPOLL.CTL_DEL, fd, &event);
}
pub fn poll(self: Reactor, comptime max_num_events: comptime_int, closure: anytype, timeout_milliseconds: ?u64) !void {
var events: [max_num_events]linux.epoll_event = undefined;
const num_events = os.epoll_wait(self.fd, &events, if (timeout_milliseconds) |ms| @intCast(i32, ms) else -1);
for (events[0..num_events]) |ev| {
const is_error = ev.events & linux.EPOLL.ERR != 0;
const is_hup = ev.events & (linux.EPOLL.HUP | linux.EPOLL.RDHUP) != 0;
const is_readable = ev.events & linux.EPOLL.IN != 0;
const is_writable = ev.events & linux.EPOLL.OUT != 0;
try closure.call(Reactor.Event{
.data = ev.data.ptr,
.is_error = is_error,
.is_hup = is_hup,
.is_readable = is_readable,
.is_writable = is_writable,
});
}
}
};
test "reactor/linux: drive async tcp client/listener pair" {
if (native_os.tag != .linux) return error.SkipZigTest;
const ip = std.x.net.ip;
const tcp = std.x.net.tcp;
const IPv4 = std.x.os.IPv4;
const IPv6 = std.x.os.IPv6;
const reactor = try Reactor.init(.{ .close_on_exec = true });
defer reactor.deinit();
const listener = try tcp.Listener.init(.ip, .{
.close_on_exec = true,
.nonblocking = true,
});
defer listener.deinit();
try reactor.update(listener.socket.fd, 0, .{ .readable = true });
try reactor.poll(1, struct {
fn call(event: Reactor.Event) !void {
try testing.expectEqual(Reactor.Event{
.data = 0,
.is_error = false,
.is_hup = true,
.is_readable = false,
.is_writable = false,
}, event);
}
}, null);
try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
try listener.listen(128);
var binded_address = try listener.getLocalAddress();
switch (binded_address) {
.ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
.ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
}
const client = try tcp.Client.init(.ip, .{
.close_on_exec = true,
.nonblocking = true,
});
defer client.deinit();
try reactor.update(client.socket.fd, 1, .{ .readable = true, .writable = true });
try reactor.poll(1, struct {
fn call(event: Reactor.Event) !void {
try testing.expectEqual(Reactor.Event{
.data = 1,
.is_error = false,
.is_hup = true,
.is_readable = false,
.is_writable = true,
}, event);
}
}, null);
client.connect(binded_address) catch |err| switch (err) {
error.WouldBlock => {},
else => return err,
};
try reactor.poll(1, struct {
fn call(event: Reactor.Event) !void {
try testing.expectEqual(Reactor.Event{
.data = 1,
.is_error = false,
.is_hup = false,
.is_readable = false,
.is_writable = true,
}, event);
}
}, null);
try reactor.poll(1, struct {
fn call(event: Reactor.Event) !void {
try testing.expectEqual(Reactor.Event{
.data = 0,
.is_error = false,
.is_hup = false,
.is_readable = true,
.is_writable = false,
}, event);
}
}, null);
try reactor.remove(client.socket.fd);
try reactor.remove(listener.socket.fd);
}

View file

@ -1,605 +0,0 @@
const std = @import("../../std.zig");
const builtin = @import("builtin");
const os = std.os;
const fmt = std.fmt;
const mem = std.mem;
const math = std.math;
const testing = std.testing;
const native_os = builtin.os;
const have_ifnamesize = @hasDecl(os.system, "IFNAMESIZE");
pub const ResolveScopeIdError = error{
NameTooLong,
PermissionDenied,
AddressFamilyNotSupported,
ProtocolFamilyNotAvailable,
ProcessFdQuotaExceeded,
SystemFdQuotaExceeded,
SystemResources,
ProtocolNotSupported,
SocketTypeNotSupported,
InterfaceNotFound,
FileSystem,
Unexpected,
};
/// Resolves a network interface name into a scope/zone ID. It returns
/// an error if either resolution fails, or if the interface name is
/// too long.
pub fn resolveScopeId(name: []const u8) ResolveScopeIdError!u32 {
if (have_ifnamesize) {
if (name.len >= os.IFNAMESIZE) return error.NameTooLong;
if (native_os.tag == .windows or comptime native_os.tag.isDarwin()) {
var interface_name: [os.IFNAMESIZE:0]u8 = undefined;
mem.copy(u8, &interface_name, name);
interface_name[name.len] = 0;
const rc = blk: {
if (native_os.tag == .windows) {
break :blk os.windows.ws2_32.if_nametoindex(@ptrCast([*:0]const u8, &interface_name));
} else {
const index = os.system.if_nametoindex(@ptrCast([*:0]const u8, &interface_name));
break :blk @bitCast(u32, index);
}
};
if (rc == 0) {
return error.InterfaceNotFound;
}
return rc;
}
if (native_os.tag == .linux) {
const fd = try os.socket(os.AF.INET, os.SOCK.DGRAM, 0);
defer os.closeSocket(fd);
var f: os.ifreq = undefined;
mem.copy(u8, &f.ifrn.name, name);
f.ifrn.name[name.len] = 0;
try os.ioctl_SIOCGIFINDEX(fd, &f);
return @bitCast(u32, f.ifru.ivalue);
}
}
return error.InterfaceNotFound;
}
/// An IPv4 address comprised of 4 bytes.
pub const IPv4 = extern struct {
/// A IPv4 host-port pair.
pub const Address = extern struct {
host: IPv4,
port: u16,
};
/// Octets of a IPv4 address designating the local host.
pub const localhost_octets = [_]u8{ 127, 0, 0, 1 };
/// The IPv4 address of the local host.
pub const localhost: IPv4 = .{ .octets = localhost_octets };
/// Octets of an unspecified IPv4 address.
pub const unspecified_octets = [_]u8{0} ** 4;
/// An unspecified IPv4 address.
pub const unspecified: IPv4 = .{ .octets = unspecified_octets };
/// Octets of a broadcast IPv4 address.
pub const broadcast_octets = [_]u8{255} ** 4;
/// An IPv4 broadcast address.
pub const broadcast: IPv4 = .{ .octets = broadcast_octets };
/// The prefix octet pattern of a link-local IPv4 address.
pub const link_local_prefix = [_]u8{ 169, 254 };
/// The prefix octet patterns of IPv4 addresses intended for
/// documentation.
pub const documentation_prefixes = [_][]const u8{
&[_]u8{ 192, 0, 2 },
&[_]u8{ 198, 51, 100 },
&[_]u8{ 203, 0, 113 },
};
octets: [4]u8,
/// Returns whether or not the two addresses are equal to, less than, or
/// greater than each other.
pub fn cmp(self: IPv4, other: IPv4) math.Order {
return mem.order(u8, &self.octets, &other.octets);
}
/// Returns true if both addresses are semantically equivalent.
pub fn eql(self: IPv4, other: IPv4) bool {
return mem.eql(u8, &self.octets, &other.octets);
}
/// Returns true if the address is a loopback address.
pub fn isLoopback(self: IPv4) bool {
return self.octets[0] == 127;
}
/// Returns true if the address is an unspecified IPv4 address.
pub fn isUnspecified(self: IPv4) bool {
return mem.eql(u8, &self.octets, &unspecified_octets);
}
/// Returns true if the address is a private IPv4 address.
pub fn isPrivate(self: IPv4) bool {
return self.octets[0] == 10 or
(self.octets[0] == 172 and self.octets[1] >= 16 and self.octets[1] <= 31) or
(self.octets[0] == 192 and self.octets[1] == 168);
}
/// Returns true if the address is a link-local IPv4 address.
pub fn isLinkLocal(self: IPv4) bool {
return mem.startsWith(u8, &self.octets, &link_local_prefix);
}
/// Returns true if the address is a multicast IPv4 address.
pub fn isMulticast(self: IPv4) bool {
return self.octets[0] >= 224 and self.octets[0] <= 239;
}
/// Returns true if the address is a IPv4 broadcast address.
pub fn isBroadcast(self: IPv4) bool {
return mem.eql(u8, &self.octets, &broadcast_octets);
}
/// Returns true if the address is in a range designated for documentation. Refer
/// to IETF RFC 5737 for more details.
pub fn isDocumentation(self: IPv4) bool {
inline for (documentation_prefixes) |prefix| {
if (mem.startsWith(u8, &self.octets, prefix)) {
return true;
}
}
return false;
}
/// Implements the `std.fmt.format` API.
pub fn format(
self: IPv4,
comptime layout: []const u8,
opts: fmt.FormatOptions,
writer: anytype,
) !void {
_ = opts;
if (layout.len != 0) std.fmt.invalidFmtError(layout, self);
try fmt.format(writer, "{}.{}.{}.{}", .{
self.octets[0],
self.octets[1],
self.octets[2],
self.octets[3],
});
}
/// Set of possible errors that may encountered when parsing an IPv4
/// address.
pub const ParseError = error{
UnexpectedEndOfOctet,
TooManyOctets,
OctetOverflow,
UnexpectedToken,
IncompleteAddress,
};
/// Parses an arbitrary IPv4 address.
pub fn parse(buf: []const u8) ParseError!IPv4 {
var octets: [4]u8 = undefined;
var octet: u8 = 0;
var index: u8 = 0;
var saw_any_digits: bool = false;
for (buf) |c| {
switch (c) {
'.' => {
if (!saw_any_digits) return error.UnexpectedEndOfOctet;
if (index == 3) return error.TooManyOctets;
octets[index] = octet;
index += 1;
octet = 0;
saw_any_digits = false;
},
'0'...'9' => {
saw_any_digits = true;
octet = math.mul(u8, octet, 10) catch return error.OctetOverflow;
octet = math.add(u8, octet, c - '0') catch return error.OctetOverflow;
},
else => return error.UnexpectedToken,
}
}
if (index == 3 and saw_any_digits) {
octets[index] = octet;
return IPv4{ .octets = octets };
}
return error.IncompleteAddress;
}
/// Maps the address to its IPv6 equivalent. In most cases, you would
/// want to map the address to its IPv6 equivalent rather than directly
/// re-interpreting the address.
pub fn mapToIPv6(self: IPv4) IPv6 {
var octets: [16]u8 = undefined;
mem.copy(u8, octets[0..12], &IPv6.v4_mapped_prefix);
mem.copy(u8, octets[12..], &self.octets);
return IPv6{ .octets = octets, .scope_id = IPv6.no_scope_id };
}
/// Directly re-interprets the address to its IPv6 equivalent. In most
/// cases, you would want to map the address to its IPv6 equivalent rather
/// than directly re-interpreting the address.
pub fn toIPv6(self: IPv4) IPv6 {
var octets: [16]u8 = undefined;
mem.set(u8, octets[0..12], 0);
mem.copy(u8, octets[12..], &self.octets);
return IPv6{ .octets = octets, .scope_id = IPv6.no_scope_id };
}
};
/// An IPv6 address comprised of 16 bytes for an address, and 4 bytes
/// for a scope ID; cumulatively summing to 20 bytes in total.
pub const IPv6 = extern struct {
/// A IPv6 host-port pair.
pub const Address = extern struct {
host: IPv6,
port: u16,
};
/// Octets of a IPv6 address designating the local host.
pub const localhost_octets = [_]u8{0} ** 15 ++ [_]u8{0x01};
/// The IPv6 address of the local host.
pub const localhost: IPv6 = .{
.octets = localhost_octets,
.scope_id = no_scope_id,
};
/// Octets of an unspecified IPv6 address.
pub const unspecified_octets = [_]u8{0} ** 16;
/// An unspecified IPv6 address.
pub const unspecified: IPv6 = .{
.octets = unspecified_octets,
.scope_id = no_scope_id,
};
/// The prefix of a IPv6 address that is mapped to a IPv4 address.
pub const v4_mapped_prefix = [_]u8{0} ** 10 ++ [_]u8{0xFF} ** 2;
/// A marker value used to designate an IPv6 address with no
/// associated scope ID.
pub const no_scope_id = math.maxInt(u32);
octets: [16]u8,
scope_id: u32,
/// Returns whether or not the two addresses are equal to, less than, or
/// greater than each other.
pub fn cmp(self: IPv6, other: IPv6) math.Order {
return switch (mem.order(u8, self.octets, other.octets)) {
.eq => math.order(self.scope_id, other.scope_id),
else => |order| order,
};
}
/// Returns true if both addresses are semantically equivalent.
pub fn eql(self: IPv6, other: IPv6) bool {
return self.scope_id == other.scope_id and mem.eql(u8, &self.octets, &other.octets);
}
/// Returns true if the address is an unspecified IPv6 address.
pub fn isUnspecified(self: IPv6) bool {
return mem.eql(u8, &self.octets, &unspecified_octets);
}
/// Returns true if the address is a loopback address.
pub fn isLoopback(self: IPv6) bool {
return mem.eql(u8, self.octets[0..3], &[_]u8{ 0, 0, 0 }) and
mem.eql(u8, self.octets[12..], &[_]u8{ 0, 0, 0, 1 });
}
/// Returns true if the address maps to an IPv4 address.
pub fn mapsToIPv4(self: IPv6) bool {
return mem.startsWith(u8, &self.octets, &v4_mapped_prefix);
}
/// Returns an IPv4 address representative of the address should
/// it the address be mapped to an IPv4 address. It returns null
/// otherwise.
pub fn toIPv4(self: IPv6) ?IPv4 {
if (!self.mapsToIPv4()) return null;
return IPv4{ .octets = self.octets[12..][0..4].* };
}
/// Returns true if the address is a multicast IPv6 address.
pub fn isMulticast(self: IPv6) bool {
return self.octets[0] == 0xFF;
}
/// Returns true if the address is a unicast link local IPv6 address.
pub fn isLinkLocal(self: IPv6) bool {
return self.octets[0] == 0xFE and self.octets[1] & 0xC0 == 0x80;
}
/// Returns true if the address is a deprecated unicast site local
/// IPv6 address. Refer to IETF RFC 3879 for more details as to
/// why they are deprecated.
pub fn isSiteLocal(self: IPv6) bool {
return self.octets[0] == 0xFE and self.octets[1] & 0xC0 == 0xC0;
}
/// IPv6 multicast address scopes.
pub const Scope = enum(u8) {
interface = 1,
link = 2,
realm = 3,
admin = 4,
site = 5,
organization = 8,
global = 14,
unknown = 0xFF,
};
/// Returns the multicast scope of the address.
pub fn scope(self: IPv6) Scope {
if (!self.isMulticast()) return .unknown;
return switch (self.octets[0] & 0x0F) {
1 => .interface,
2 => .link,
3 => .realm,
4 => .admin,
5 => .site,
8 => .organization,
14 => .global,
else => .unknown,
};
}
/// Implements the `std.fmt.format` API. Specifying 'x' or 's' formats the
/// address lower-cased octets, while specifying 'X' or 'S' formats the
/// address using upper-cased ASCII octets.
///
/// The default specifier is 'x'.
pub fn format(
self: IPv6,
comptime layout: []const u8,
opts: fmt.FormatOptions,
writer: anytype,
) !void {
_ = opts;
const specifier = comptime &[_]u8{if (layout.len == 0) 'x' else switch (layout[0]) {
'x', 'X' => |specifier| specifier,
's' => 'x',
'S' => 'X',
else => std.fmt.invalidFmtError(layout, self),
}};
if (mem.startsWith(u8, &self.octets, &v4_mapped_prefix)) {
return fmt.format(writer, "::{" ++ specifier ++ "}{" ++ specifier ++ "}:{}.{}.{}.{}", .{
0xFF,
0xFF,
self.octets[12],
self.octets[13],
self.octets[14],
self.octets[15],
});
}
const zero_span: struct { from: usize, to: usize } = span: {
var i: usize = 0;
while (i < self.octets.len) : (i += 2) {
if (self.octets[i] == 0 and self.octets[i + 1] == 0) break;
} else break :span .{ .from = 0, .to = 0 };
const from = i;
while (i < self.octets.len) : (i += 2) {
if (self.octets[i] != 0 or self.octets[i + 1] != 0) break;
}
break :span .{ .from = from, .to = i };
};
var i: usize = 0;
while (i != 16) : (i += 2) {
if (zero_span.from != zero_span.to and i == zero_span.from) {
try writer.writeAll("::");
} else if (i >= zero_span.from and i < zero_span.to) {} else {
if (i != 0 and i != zero_span.to) try writer.writeAll(":");
const val = @as(u16, self.octets[i]) << 8 | self.octets[i + 1];
try fmt.formatIntValue(val, specifier, .{}, writer);
}
}
if (self.scope_id != no_scope_id and self.scope_id != 0) {
try fmt.format(writer, "%{d}", .{self.scope_id});
}
}
/// Set of possible errors that may encountered when parsing an IPv6
/// address.
pub const ParseError = error{
MalformedV4Mapping,
InterfaceNotFound,
UnknownScopeId,
} || IPv4.ParseError;
/// Parses an arbitrary IPv6 address, including link-local addresses.
pub fn parse(buf: []const u8) ParseError!IPv6 {
if (mem.lastIndexOfScalar(u8, buf, '%')) |index| {
const ip_slice = buf[0..index];
const scope_id_slice = buf[index + 1 ..];
if (scope_id_slice.len == 0) return error.UnknownScopeId;
const scope_id: u32 = switch (scope_id_slice[0]) {
'0'...'9' => fmt.parseInt(u32, scope_id_slice, 10),
else => resolveScopeId(scope_id_slice) catch |err| switch (err) {
error.InterfaceNotFound => return error.InterfaceNotFound,
else => err,
},
} catch return error.UnknownScopeId;
return parseWithScopeID(ip_slice, scope_id);
}
return parseWithScopeID(buf, no_scope_id);
}
/// Parses an IPv6 address with a pre-specified scope ID. Presumes
/// that the address is not a link-local address.
pub fn parseWithScopeID(buf: []const u8, scope_id: u32) ParseError!IPv6 {
var octets: [16]u8 = undefined;
var octet: u16 = 0;
var tail: [16]u8 = undefined;
var out: []u8 = &octets;
var index: u8 = 0;
var saw_any_digits: bool = false;
var abbrv: bool = false;
for (buf) |c, i| {
switch (c) {
':' => {
if (!saw_any_digits) {
if (abbrv) return error.UnexpectedToken;
if (i != 0) abbrv = true;
mem.set(u8, out[index..], 0);
out = &tail;
index = 0;
continue;
}
if (index == 14) return error.TooManyOctets;
out[index] = @truncate(u8, octet >> 8);
index += 1;
out[index] = @truncate(u8, octet);
index += 1;
octet = 0;
saw_any_digits = false;
},
'.' => {
if (!abbrv or out[0] != 0xFF and out[1] != 0xFF) {
return error.MalformedV4Mapping;
}
const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1;
const v4 = try IPv4.parse(buf[start_index..]);
octets[10] = 0xFF;
octets[11] = 0xFF;
mem.copy(u8, octets[12..], &v4.octets);
return IPv6{ .octets = octets, .scope_id = scope_id };
},
else => {
saw_any_digits = true;
const digit = fmt.charToDigit(c, 16) catch return error.UnexpectedToken;
octet = math.mul(u16, octet, 16) catch return error.OctetOverflow;
octet = math.add(u16, octet, digit) catch return error.OctetOverflow;
},
}
}
if (!saw_any_digits and !abbrv) {
return error.IncompleteAddress;
}
if (index == 14) {
out[14] = @truncate(u8, octet >> 8);
out[15] = @truncate(u8, octet);
} else {
out[index] = @truncate(u8, octet >> 8);
index += 1;
out[index] = @truncate(u8, octet);
index += 1;
mem.copy(u8, octets[16 - index ..], out[0..index]);
}
return IPv6{ .octets = octets, .scope_id = scope_id };
}
};
test {
testing.refAllDecls(@This());
}
test "ip: convert to and from ipv6" {
try testing.expectFmt("::7f00:1", "{}", .{IPv4.localhost.toIPv6()});
try testing.expect(!IPv4.localhost.toIPv6().mapsToIPv4());
try testing.expectFmt("::ffff:127.0.0.1", "{}", .{IPv4.localhost.mapToIPv6()});
try testing.expect(IPv4.localhost.mapToIPv6().mapsToIPv4());
try testing.expect(IPv4.localhost.toIPv6().toIPv4() == null);
try testing.expectFmt("127.0.0.1", "{?}", .{IPv4.localhost.mapToIPv6().toIPv4()});
}
test "ipv4: parse & format" {
const cases = [_][]const u8{
"0.0.0.0",
"255.255.255.255",
"1.2.3.4",
"123.255.0.91",
"127.0.0.1",
};
for (cases) |case| {
try testing.expectFmt(case, "{}", .{try IPv4.parse(case)});
}
}
test "ipv6: parse & format" {
const inputs = [_][]const u8{
"FF01:0:0:0:0:0:0:FB",
"FF01::Fb",
"::1",
"::",
"2001:db8::",
"::1234:5678",
"2001:db8::1234:5678",
"::ffff:123.5.123.5",
};
const outputs = [_][]const u8{
"ff01::fb",
"ff01::fb",
"::1",
"::",
"2001:db8::",
"::1234:5678",
"2001:db8::1234:5678",
"::ffff:123.5.123.5",
};
for (inputs) |input, i| {
try testing.expectFmt(outputs[i], "{}", .{try IPv6.parse(input)});
}
}
test "ipv6: parse & format addresses with scope ids" {
if (!have_ifnamesize) return error.SkipZigTest;
const iface = if (native_os.tag == .linux)
"lo"
else
"lo0";
const input = "FF01::FB%" ++ iface;
const output = "ff01::fb%1";
const parsed = IPv6.parse(input) catch |err| switch (err) {
error.InterfaceNotFound => return,
else => return err,
};
try testing.expectFmt(output, "{}", .{parsed});
}

View file

@ -1,320 +0,0 @@
const std = @import("../../std.zig");
const builtin = @import("builtin");
const net = @import("net.zig");
const os = std.os;
const fmt = std.fmt;
const mem = std.mem;
const time = std.time;
const meta = std.meta;
const native_os = builtin.os;
const native_endian = builtin.cpu.arch.endian();
const Buffer = std.x.os.Buffer;
const assert = std.debug.assert;
/// A generic, cross-platform socket abstraction.
pub const Socket = struct {
/// A socket-address pair.
pub const Connection = struct {
socket: Socket,
address: Socket.Address,
/// Enclose a socket and address into a socket-address pair.
pub fn from(socket: Socket, address: Socket.Address) Socket.Connection {
return .{ .socket = socket, .address = address };
}
};
/// A generic socket address abstraction. It is safe to directly access and modify
/// the fields of a `Socket.Address`.
pub const Address = union(enum) {
pub const Native = struct {
pub const requires_prepended_length = native_os.getVersionRange() == .semver;
pub const Length = if (requires_prepended_length) u8 else [0]u8;
pub const Family = if (requires_prepended_length) u8 else c_ushort;
/// POSIX `sockaddr.storage`. The expected size and alignment is specified in IETF RFC 2553.
pub const Storage = extern struct {
pub const expected_size = os.sockaddr.SS_MAXSIZE;
pub const expected_alignment = 8;
pub const padding_size = expected_size -
mem.alignForward(@sizeOf(Address.Native.Length), expected_alignment) -
mem.alignForward(@sizeOf(Address.Native.Family), expected_alignment);
len: Address.Native.Length align(expected_alignment) = undefined,
family: Address.Native.Family align(expected_alignment) = undefined,
padding: [padding_size]u8 align(expected_alignment) = undefined,
comptime {
assert(@sizeOf(Storage) == Storage.expected_size);
assert(@alignOf(Storage) == Storage.expected_alignment);
}
};
};
ipv4: net.IPv4.Address,
ipv6: net.IPv6.Address,
/// Instantiate a new address with a IPv4 host and port.
pub fn initIPv4(host: net.IPv4, port: u16) Socket.Address {
return .{ .ipv4 = .{ .host = host, .port = port } };
}
/// Instantiate a new address with a IPv6 host and port.
pub fn initIPv6(host: net.IPv6, port: u16) Socket.Address {
return .{ .ipv6 = .{ .host = host, .port = port } };
}
/// Parses a `sockaddr` into a generic socket address.
pub fn fromNative(address: *align(4) const os.sockaddr) Socket.Address {
switch (address.family) {
os.AF.INET => {
const info = @ptrCast(*const os.sockaddr.in, address);
const host = net.IPv4{ .octets = @bitCast([4]u8, info.addr) };
const port = mem.bigToNative(u16, info.port);
return Socket.Address.initIPv4(host, port);
},
os.AF.INET6 => {
const info = @ptrCast(*const os.sockaddr.in6, address);
const host = net.IPv6{ .octets = info.addr, .scope_id = info.scope_id };
const port = mem.bigToNative(u16, info.port);
return Socket.Address.initIPv6(host, port);
},
else => unreachable,
}
}
/// Encodes a generic socket address into an extern union that may be reliably
/// casted into a `sockaddr` which may be passed into socket syscalls.
pub fn toNative(self: Socket.Address) extern union {
ipv4: os.sockaddr.in,
ipv6: os.sockaddr.in6,
} {
return switch (self) {
.ipv4 => |address| .{
.ipv4 = .{
.addr = @bitCast(u32, address.host.octets),
.port = mem.nativeToBig(u16, address.port),
},
},
.ipv6 => |address| .{
.ipv6 = .{
.addr = address.host.octets,
.port = mem.nativeToBig(u16, address.port),
.scope_id = address.host.scope_id,
.flowinfo = 0,
},
},
};
}
/// Returns the number of bytes that make up the `sockaddr` equivalent to the address.
pub fn getNativeSize(self: Socket.Address) u32 {
return switch (self) {
.ipv4 => @sizeOf(os.sockaddr.in),
.ipv6 => @sizeOf(os.sockaddr.in6),
};
}
/// Implements the `std.fmt.format` API.
pub fn format(
self: Socket.Address,
comptime layout: []const u8,
opts: fmt.FormatOptions,
writer: anytype,
) !void {
if (layout.len != 0) std.fmt.invalidFmtError(layout, self);
_ = opts;
switch (self) {
.ipv4 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }),
.ipv6 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }),
}
}
};
/// POSIX `msghdr`. Denotes a destination address, set of buffers, control data, and flags. Ported
/// directly from musl.
pub const Message = if (native_os.isAtLeast(.windows, .vista) != null and native_os.isAtLeast(.windows, .vista).?)
extern struct {
name: usize = @ptrToInt(@as(?[*]u8, null)),
name_len: c_int = 0,
buffers: usize = undefined,
buffers_len: c_ulong = undefined,
control: Buffer = .{
.ptr = @ptrToInt(@as(?[*]u8, null)),
.len = 0,
},
flags: c_ulong = 0,
pub usingnamespace MessageMixin(Message);
}
else if (native_os.tag == .windows)
extern struct {
name: usize = @ptrToInt(@as(?[*]u8, null)),
name_len: c_int = 0,
buffers: usize = undefined,
buffers_len: u32 = undefined,
control: Buffer = .{
.ptr = @ptrToInt(@as(?[*]u8, null)),
.len = 0,
},
flags: u32 = 0,
pub usingnamespace MessageMixin(Message);
}
else if (@sizeOf(usize) > 4 and native_endian == .Big)
extern struct {
name: usize = @ptrToInt(@as(?[*]u8, null)),
name_len: c_uint = 0,
buffers: usize = undefined,
_pad_1: c_int = 0,
buffers_len: c_int = undefined,
control: usize = @ptrToInt(@as(?[*]u8, null)),
_pad_2: c_int = 0,
control_len: c_uint = 0,
flags: c_int = 0,
pub usingnamespace MessageMixin(Message);
}
else if (@sizeOf(usize) > 4 and native_endian == .Little)
extern struct {
name: usize = @ptrToInt(@as(?[*]u8, null)),
name_len: c_uint = 0,
buffers: usize = undefined,
buffers_len: c_int = undefined,
_pad_1: c_int = 0,
control: usize = @ptrToInt(@as(?[*]u8, null)),
control_len: c_uint = 0,
_pad_2: c_int = 0,
flags: c_int = 0,
pub usingnamespace MessageMixin(Message);
}
else
extern struct {
name: usize = @ptrToInt(@as(?[*]u8, null)),
name_len: c_uint = 0,
buffers: usize = undefined,
buffers_len: c_int = undefined,
control: usize = @ptrToInt(@as(?[*]u8, null)),
control_len: c_uint = 0,
flags: c_int = 0,
pub usingnamespace MessageMixin(Message);
};
fn MessageMixin(comptime Self: type) type {
return struct {
pub fn fromBuffers(buffers: []const Buffer) Self {
var self: Self = .{};
self.setBuffers(buffers);
return self;
}
pub fn setName(self: *Self, name: []const u8) void {
self.name = @ptrToInt(name.ptr);
self.name_len = @intCast(meta.fieldInfo(Self, .name_len).type, name.len);
}
pub fn setBuffers(self: *Self, buffers: []const Buffer) void {
self.buffers = @ptrToInt(buffers.ptr);
self.buffers_len = @intCast(meta.fieldInfo(Self, .buffers_len).type, buffers.len);
}
pub fn setControl(self: *Self, control: []const u8) void {
if (native_os.tag == .windows) {
self.control = Buffer.from(control);
} else {
self.control = @ptrToInt(control.ptr);
self.control_len = @intCast(meta.fieldInfo(Self, .control_len).type, control.len);
}
}
pub fn setFlags(self: *Self, flags: u32) void {
self.flags = @intCast(meta.fieldInfo(Self, .flags).type, flags);
}
pub fn getName(self: Self) []const u8 {
return @intToPtr([*]const u8, self.name)[0..@intCast(usize, self.name_len)];
}
pub fn getBuffers(self: Self) []const Buffer {
return @intToPtr([*]const Buffer, self.buffers)[0..@intCast(usize, self.buffers_len)];
}
pub fn getControl(self: Self) []const u8 {
if (native_os.tag == .windows) {
return self.control.into();
} else {
return @intToPtr([*]const u8, self.control)[0..@intCast(usize, self.control_len)];
}
}
pub fn getFlags(self: Self) u32 {
return @intCast(u32, self.flags);
}
};
}
/// POSIX `linger`, denoting the linger settings of a socket.
///
/// Microsoft's documentation and glibc denote the fields to be unsigned
/// short's on Windows, whereas glibc and musl denote the fields to be
/// int's on every other platform.
pub const Linger = extern struct {
pub const Field = switch (native_os.tag) {
.windows => c_ushort,
else => c_int,
};
enabled: Field,
timeout_seconds: Field,
pub fn init(timeout_seconds: ?u16) Socket.Linger {
return .{
.enabled = @intCast(Socket.Linger.Field, @boolToInt(timeout_seconds != null)),
.timeout_seconds = if (timeout_seconds) |seconds| @intCast(Socket.Linger.Field, seconds) else 0,
};
}
};
/// Possible set of flags to initialize a socket with.
pub const InitFlags = enum {
// Initialize a socket to be non-blocking.
nonblocking,
// Have a socket close itself on exec syscalls.
close_on_exec,
};
/// The underlying handle of a socket.
fd: os.socket_t,
/// Enclose a socket abstraction over an existing socket file descriptor.
pub fn from(fd: os.socket_t) Socket {
return Socket{ .fd = fd };
}
/// Mix in socket syscalls depending on the platform we are compiling against.
pub usingnamespace switch (native_os.tag) {
.windows => @import("socket_windows.zig"),
else => @import("socket_posix.zig"),
}.Mixin(Socket);
};

View file

@ -1,275 +0,0 @@
const std = @import("../../std.zig");
const os = std.os;
const mem = std.mem;
const time = std.time;
pub fn Mixin(comptime Socket: type) type {
return struct {
/// Open a new socket.
pub fn init(domain: u32, socket_type: u32, protocol: u32, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket {
var raw_flags: u32 = socket_type;
const set = std.EnumSet(Socket.InitFlags).init(flags);
if (set.contains(.close_on_exec)) raw_flags |= os.SOCK.CLOEXEC;
if (set.contains(.nonblocking)) raw_flags |= os.SOCK.NONBLOCK;
return Socket{ .fd = try os.socket(domain, raw_flags, protocol) };
}
/// Closes the socket.
pub fn deinit(self: Socket) void {
os.closeSocket(self.fd);
}
/// Shutdown either the read side, write side, or all side of the socket.
pub fn shutdown(self: Socket, how: os.ShutdownHow) !void {
return os.shutdown(self.fd, how);
}
/// Binds the socket to an address.
pub fn bind(self: Socket, address: Socket.Address) !void {
return os.bind(self.fd, @ptrCast(*const os.sockaddr, &address.toNative()), address.getNativeSize());
}
/// Start listening for incoming connections on the socket.
pub fn listen(self: Socket, max_backlog_size: u31) !void {
return os.listen(self.fd, max_backlog_size);
}
/// Have the socket attempt to the connect to an address.
pub fn connect(self: Socket, address: Socket.Address) !void {
return os.connect(self.fd, @ptrCast(*const os.sockaddr, &address.toNative()), address.getNativeSize());
}
/// Accept a pending incoming connection queued to the kernel backlog
/// of the socket.
pub fn accept(self: Socket, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket.Connection {
var address: Socket.Address.Native.Storage = undefined;
var address_len: u32 = @sizeOf(Socket.Address.Native.Storage);
var raw_flags: u32 = 0;
const set = std.EnumSet(Socket.InitFlags).init(flags);
if (set.contains(.close_on_exec)) raw_flags |= os.SOCK.CLOEXEC;
if (set.contains(.nonblocking)) raw_flags |= os.SOCK.NONBLOCK;
const socket = Socket{ .fd = try os.accept(self.fd, @ptrCast(*os.sockaddr, &address), &address_len, raw_flags) };
const socket_address = Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address));
return Socket.Connection.from(socket, socket_address);
}
/// Read data from the socket into the buffer provided with a set of flags
/// specified. It returns the number of bytes read into the buffer provided.
pub fn read(self: Socket, buf: []u8, flags: u32) !usize {
return os.recv(self.fd, buf, flags);
}
/// Write a buffer of data provided to the socket with a set of flags specified.
/// It returns the number of bytes that are written to the socket.
pub fn write(self: Socket, buf: []const u8, flags: u32) !usize {
return os.send(self.fd, buf, flags);
}
/// Writes multiple I/O vectors with a prepended message header to the socket
/// with a set of flags specified. It returns the number of bytes that are
/// written to the socket.
pub fn writeMessage(self: Socket, msg: Socket.Message, flags: u32) !usize {
while (true) {
const rc = os.system.sendmsg(self.fd, &msg, @intCast(c_int, flags));
return switch (os.errno(rc)) {
.SUCCESS => return @intCast(usize, rc),
.ACCES => error.AccessDenied,
.AGAIN => error.WouldBlock,
.ALREADY => error.FastOpenAlreadyInProgress,
.BADF => unreachable, // always a race condition
.CONNRESET => error.ConnectionResetByPeer,
.DESTADDRREQ => unreachable, // The socket is not connection-mode, and no peer address is set.
.FAULT => unreachable, // An invalid user space address was specified for an argument.
.INTR => continue,
.INVAL => unreachable, // Invalid argument passed.
.ISCONN => unreachable, // connection-mode socket was connected already but a recipient was specified
.MSGSIZE => error.MessageTooBig,
.NOBUFS => error.SystemResources,
.NOMEM => error.SystemResources,
.NOTSOCK => unreachable, // The file descriptor sockfd does not refer to a socket.
.OPNOTSUPP => unreachable, // Some bit in the flags argument is inappropriate for the socket type.
.PIPE => error.BrokenPipe,
.AFNOSUPPORT => error.AddressFamilyNotSupported,
.LOOP => error.SymLinkLoop,
.NAMETOOLONG => error.NameTooLong,
.NOENT => error.FileNotFound,
.NOTDIR => error.NotDir,
.HOSTUNREACH => error.NetworkUnreachable,
.NETUNREACH => error.NetworkUnreachable,
.NOTCONN => error.SocketNotConnected,
.NETDOWN => error.NetworkSubsystemFailed,
else => |err| os.unexpectedErrno(err),
};
}
}
/// Read multiple I/O vectors with a prepended message header from the socket
/// with a set of flags specified. It returns the number of bytes that were
/// read into the buffer provided.
pub fn readMessage(self: Socket, msg: *Socket.Message, flags: u32) !usize {
while (true) {
const rc = os.system.recvmsg(self.fd, msg, @intCast(c_int, flags));
return switch (os.errno(rc)) {
.SUCCESS => @intCast(usize, rc),
.BADF => unreachable, // always a race condition
.FAULT => unreachable,
.INVAL => unreachable,
.NOTCONN => unreachable,
.NOTSOCK => unreachable,
.INTR => continue,
.AGAIN => error.WouldBlock,
.NOMEM => error.SystemResources,
.CONNREFUSED => error.ConnectionRefused,
.CONNRESET => error.ConnectionResetByPeer,
else => |err| os.unexpectedErrno(err),
};
}
}
/// Query the address that the socket is locally bounded to.
pub fn getLocalAddress(self: Socket) !Socket.Address {
var address: Socket.Address.Native.Storage = undefined;
var address_len: u32 = @sizeOf(Socket.Address.Native.Storage);
try os.getsockname(self.fd, @ptrCast(*os.sockaddr, &address), &address_len);
return Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address));
}
/// Query the address that the socket is connected to.
pub fn getRemoteAddress(self: Socket) !Socket.Address {
var address: Socket.Address.Native.Storage = undefined;
var address_len: u32 = @sizeOf(Socket.Address.Native.Storage);
try os.getpeername(self.fd, @ptrCast(*os.sockaddr, &address), &address_len);
return Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address));
}
/// Query and return the latest cached error on the socket.
pub fn getError(self: Socket) !void {
return os.getsockoptError(self.fd);
}
/// Query the read buffer size of the socket.
pub fn getReadBufferSize(self: Socket) !u32 {
var value: u32 = undefined;
var value_len: u32 = @sizeOf(u32);
const rc = os.system.getsockopt(self.fd, os.SOL.SOCKET, os.SO.RCVBUF, mem.asBytes(&value), &value_len);
return switch (os.errno(rc)) {
.SUCCESS => value,
.BADF => error.BadFileDescriptor,
.FAULT => error.InvalidAddressSpace,
.INVAL => error.InvalidSocketOption,
.NOPROTOOPT => error.UnknownSocketOption,
.NOTSOCK => error.NotASocket,
else => |err| os.unexpectedErrno(err),
};
}
/// Query the write buffer size of the socket.
pub fn getWriteBufferSize(self: Socket) !u32 {
var value: u32 = undefined;
var value_len: u32 = @sizeOf(u32);
const rc = os.system.getsockopt(self.fd, os.SOL.SOCKET, os.SO.SNDBUF, mem.asBytes(&value), &value_len);
return switch (os.errno(rc)) {
.SUCCESS => value,
.BADF => error.BadFileDescriptor,
.FAULT => error.InvalidAddressSpace,
.INVAL => error.InvalidSocketOption,
.NOPROTOOPT => error.UnknownSocketOption,
.NOTSOCK => error.NotASocket,
else => |err| os.unexpectedErrno(err),
};
}
/// Set a socket option.
pub fn setOption(self: Socket, level: u32, code: u32, value: []const u8) !void {
return os.setsockopt(self.fd, level, code, value);
}
/// Have close() or shutdown() syscalls block until all queued messages in the socket have been successfully
/// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption`
/// if the host does not support the option for a socket to linger around up until a timeout specified in
/// seconds.
pub fn setLinger(self: Socket, timeout_seconds: ?u16) !void {
if (@hasDecl(os.SO, "LINGER")) {
const settings = Socket.Linger.init(timeout_seconds);
return self.setOption(os.SOL.SOCKET, os.SO.LINGER, mem.asBytes(&settings));
}
return error.UnsupportedSocketOption;
}
/// On connection-oriented sockets, have keep-alive messages be sent periodically. The timing in which keep-alive
/// messages are sent are dependant on operating system settings. It returns `error.UnsupportedSocketOption` if
/// the host does not support periodically sending keep-alive messages on connection-oriented sockets.
pub fn setKeepAlive(self: Socket, enabled: bool) !void {
if (@hasDecl(os.SO, "KEEPALIVE")) {
return self.setOption(os.SOL.SOCKET, os.SO.KEEPALIVE, mem.asBytes(&@as(u32, @boolToInt(enabled))));
}
return error.UnsupportedSocketOption;
}
/// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if
/// the host does not support sockets listening the same address.
pub fn setReuseAddress(self: Socket, enabled: bool) !void {
if (@hasDecl(os.SO, "REUSEADDR")) {
return self.setOption(os.SOL.SOCKET, os.SO.REUSEADDR, mem.asBytes(&@as(u32, @boolToInt(enabled))));
}
return error.UnsupportedSocketOption;
}
/// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if
/// the host does not supports sockets listening on the same port.
pub fn setReusePort(self: Socket, enabled: bool) !void {
if (@hasDecl(os.SO, "REUSEPORT")) {
return self.setOption(os.SOL.SOCKET, os.SO.REUSEPORT, mem.asBytes(&@as(u32, @boolToInt(enabled))));
}
return error.UnsupportedSocketOption;
}
/// Set the write buffer size of the socket.
pub fn setWriteBufferSize(self: Socket, size: u32) !void {
return self.setOption(os.SOL.SOCKET, os.SO.SNDBUF, mem.asBytes(&size));
}
/// Set the read buffer size of the socket.
pub fn setReadBufferSize(self: Socket, size: u32) !void {
return self.setOption(os.SOL.SOCKET, os.SO.RCVBUF, mem.asBytes(&size));
}
/// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is
/// set on a non-blocking socket.
///
/// Set a timeout on the socket that is to occur if no messages are successfully written
/// to its bound destination after a specified number of milliseconds. A subsequent write
/// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded.
pub fn setWriteTimeout(self: Socket, milliseconds: usize) !void {
const timeout = os.timeval{
.tv_sec = @intCast(i32, milliseconds / time.ms_per_s),
.tv_usec = @intCast(i32, (milliseconds % time.ms_per_s) * time.us_per_ms),
};
return self.setOption(os.SOL.SOCKET, os.SO.SNDTIMEO, mem.asBytes(&timeout));
}
/// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is
/// set on a non-blocking socket.
///
/// Set a timeout on the socket that is to occur if no messages are successfully read
/// from its bound destination after a specified number of milliseconds. A subsequent
/// read from the socket will thereafter return `error.WouldBlock` should the timeout be
/// exceeded.
pub fn setReadTimeout(self: Socket, milliseconds: usize) !void {
const timeout = os.timeval{
.tv_sec = @intCast(i32, milliseconds / time.ms_per_s),
.tv_usec = @intCast(i32, (milliseconds % time.ms_per_s) * time.us_per_ms),
};
return self.setOption(os.SOL.SOCKET, os.SO.RCVTIMEO, mem.asBytes(&timeout));
}
};
}

View file

@ -1,458 +0,0 @@
const std = @import("../../std.zig");
const net = @import("net.zig");
const os = std.os;
const mem = std.mem;
const windows = std.os.windows;
const ws2_32 = windows.ws2_32;
pub fn Mixin(comptime Socket: type) type {
return struct {
/// Open a new socket.
pub fn init(domain: u32, socket_type: u32, protocol: u32, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket {
var raw_flags: u32 = ws2_32.WSA_FLAG_OVERLAPPED;
const set = std.EnumSet(Socket.InitFlags).init(flags);
if (set.contains(.close_on_exec)) raw_flags |= ws2_32.WSA_FLAG_NO_HANDLE_INHERIT;
const fd = ws2_32.WSASocketW(
@intCast(i32, domain),
@intCast(i32, socket_type),
@intCast(i32, protocol),
null,
0,
raw_flags,
);
if (fd == ws2_32.INVALID_SOCKET) {
return switch (ws2_32.WSAGetLastError()) {
.WSANOTINITIALISED => {
_ = try windows.WSAStartup(2, 2);
return init(domain, socket_type, protocol, flags);
},
.WSAEAFNOSUPPORT => error.AddressFamilyNotSupported,
.WSAEMFILE => error.ProcessFdQuotaExceeded,
.WSAENOBUFS => error.SystemResources,
.WSAEPROTONOSUPPORT => error.ProtocolNotSupported,
else => |err| windows.unexpectedWSAError(err),
};
}
if (set.contains(.nonblocking)) {
var enabled: c_ulong = 1;
const rc = ws2_32.ioctlsocket(fd, ws2_32.FIONBIO, &enabled);
if (rc == ws2_32.SOCKET_ERROR) {
return windows.unexpectedWSAError(ws2_32.WSAGetLastError());
}
}
return Socket{ .fd = fd };
}
/// Closes the socket.
pub fn deinit(self: Socket) void {
_ = ws2_32.closesocket(self.fd);
}
/// Shutdown either the read side, write side, or all side of the socket.
pub fn shutdown(self: Socket, how: os.ShutdownHow) !void {
const rc = ws2_32.shutdown(self.fd, switch (how) {
.recv => ws2_32.SD_RECEIVE,
.send => ws2_32.SD_SEND,
.both => ws2_32.SD_BOTH,
});
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSAECONNABORTED => return error.ConnectionAborted,
.WSAECONNRESET => return error.ConnectionResetByPeer,
.WSAEINPROGRESS => return error.BlockingOperationInProgress,
.WSAEINVAL => unreachable,
.WSAENETDOWN => return error.NetworkSubsystemFailed,
.WSAENOTCONN => return error.SocketNotConnected,
.WSAENOTSOCK => unreachable,
.WSANOTINITIALISED => unreachable,
else => |err| return windows.unexpectedWSAError(err),
};
}
}
/// Binds the socket to an address.
pub fn bind(self: Socket, address: Socket.Address) !void {
const rc = ws2_32.bind(self.fd, @ptrCast(*const ws2_32.sockaddr, &address.toNative()), @intCast(c_int, address.getNativeSize()));
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSAENETDOWN => error.NetworkSubsystemFailed,
.WSAEACCES => error.AccessDenied,
.WSAEADDRINUSE => error.AddressInUse,
.WSAEADDRNOTAVAIL => error.AddressNotAvailable,
.WSAEFAULT => error.BadAddress,
.WSAEINPROGRESS => error.WouldBlock,
.WSAEINVAL => error.AlreadyBound,
.WSAENOBUFS => error.NoEphemeralPortsAvailable,
.WSAENOTSOCK => error.NotASocket,
else => |err| windows.unexpectedWSAError(err),
};
}
}
/// Start listening for incoming connections on the socket.
pub fn listen(self: Socket, max_backlog_size: u31) !void {
const rc = ws2_32.listen(self.fd, max_backlog_size);
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSAENETDOWN => error.NetworkSubsystemFailed,
.WSAEADDRINUSE => error.AddressInUse,
.WSAEISCONN => error.AlreadyConnected,
.WSAEINVAL => error.SocketNotBound,
.WSAEMFILE, .WSAENOBUFS => error.SystemResources,
.WSAENOTSOCK => error.FileDescriptorNotASocket,
.WSAEOPNOTSUPP => error.OperationNotSupported,
.WSAEINPROGRESS => error.WouldBlock,
else => |err| windows.unexpectedWSAError(err),
};
}
}
/// Have the socket attempt to the connect to an address.
pub fn connect(self: Socket, address: Socket.Address) !void {
const rc = ws2_32.connect(self.fd, @ptrCast(*const ws2_32.sockaddr, &address.toNative()), @intCast(c_int, address.getNativeSize()));
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSAEADDRINUSE => error.AddressInUse,
.WSAEADDRNOTAVAIL => error.AddressNotAvailable,
.WSAECONNREFUSED => error.ConnectionRefused,
.WSAETIMEDOUT => error.ConnectionTimedOut,
.WSAEFAULT => error.BadAddress,
.WSAEINVAL => error.ListeningSocket,
.WSAEISCONN => error.AlreadyConnected,
.WSAENOTSOCK => error.NotASocket,
.WSAEACCES => error.BroadcastNotEnabled,
.WSAENOBUFS => error.SystemResources,
.WSAEAFNOSUPPORT => error.AddressFamilyNotSupported,
.WSAEINPROGRESS, .WSAEWOULDBLOCK => error.WouldBlock,
.WSAEHOSTUNREACH, .WSAENETUNREACH => error.NetworkUnreachable,
else => |err| windows.unexpectedWSAError(err),
};
}
}
/// Accept a pending incoming connection queued to the kernel backlog
/// of the socket.
pub fn accept(self: Socket, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket.Connection {
var address: Socket.Address.Native.Storage = undefined;
var address_len: c_int = @sizeOf(Socket.Address.Native.Storage);
const fd = ws2_32.accept(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len);
if (fd == ws2_32.INVALID_SOCKET) {
return switch (ws2_32.WSAGetLastError()) {
.WSANOTINITIALISED => unreachable,
.WSAECONNRESET => error.ConnectionResetByPeer,
.WSAEFAULT => unreachable,
.WSAEINVAL => error.SocketNotListening,
.WSAEMFILE => error.ProcessFdQuotaExceeded,
.WSAENETDOWN => error.NetworkSubsystemFailed,
.WSAENOBUFS => error.FileDescriptorNotASocket,
.WSAEOPNOTSUPP => error.OperationNotSupported,
.WSAEWOULDBLOCK => error.WouldBlock,
else => |err| windows.unexpectedWSAError(err),
};
}
const socket = Socket.from(fd);
errdefer socket.deinit();
const socket_address = Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address));
const set = std.EnumSet(Socket.InitFlags).init(flags);
if (set.contains(.nonblocking)) {
var enabled: c_ulong = 1;
const rc = ws2_32.ioctlsocket(fd, ws2_32.FIONBIO, &enabled);
if (rc == ws2_32.SOCKET_ERROR) {
return windows.unexpectedWSAError(ws2_32.WSAGetLastError());
}
}
return Socket.Connection.from(socket, socket_address);
}
/// Read data from the socket into the buffer provided with a set of flags
/// specified. It returns the number of bytes read into the buffer provided.
pub fn read(self: Socket, buf: []u8, flags: u32) !usize {
var bufs = &[_]ws2_32.WSABUF{.{ .len = @intCast(u32, buf.len), .buf = buf.ptr }};
var num_bytes: u32 = undefined;
var flags_ = flags;
const rc = ws2_32.WSARecv(self.fd, bufs, 1, &num_bytes, &flags_, null, null);
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSAECONNABORTED => error.ConnectionAborted,
.WSAECONNRESET => error.ConnectionResetByPeer,
.WSAEDISCON => error.ConnectionClosedByPeer,
.WSAEFAULT => error.BadBuffer,
.WSAEINPROGRESS,
.WSAEWOULDBLOCK,
.WSA_IO_PENDING,
.WSAETIMEDOUT,
=> error.WouldBlock,
.WSAEINTR => error.Cancelled,
.WSAEINVAL => error.SocketNotBound,
.WSAEMSGSIZE => error.MessageTooLarge,
.WSAENETDOWN => error.NetworkSubsystemFailed,
.WSAENETRESET => error.NetworkReset,
.WSAENOTCONN => error.SocketNotConnected,
.WSAENOTSOCK => error.FileDescriptorNotASocket,
.WSAEOPNOTSUPP => error.OperationNotSupported,
.WSAESHUTDOWN => error.AlreadyShutdown,
.WSA_OPERATION_ABORTED => error.OperationAborted,
else => |err| windows.unexpectedWSAError(err),
};
}
return @intCast(usize, num_bytes);
}
/// Write a buffer of data provided to the socket with a set of flags specified.
/// It returns the number of bytes that are written to the socket.
pub fn write(self: Socket, buf: []const u8, flags: u32) !usize {
var bufs = &[_]ws2_32.WSABUF{.{ .len = @intCast(u32, buf.len), .buf = @intToPtr([*]u8, @ptrToInt(buf.ptr)) }};
var num_bytes: u32 = undefined;
const rc = ws2_32.WSASend(self.fd, bufs, 1, &num_bytes, flags, null, null);
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSAECONNABORTED => error.ConnectionAborted,
.WSAECONNRESET => error.ConnectionResetByPeer,
.WSAEFAULT => error.BadBuffer,
.WSAEINPROGRESS,
.WSAEWOULDBLOCK,
.WSA_IO_PENDING,
.WSAETIMEDOUT,
=> error.WouldBlock,
.WSAEINTR => error.Cancelled,
.WSAEINVAL => error.SocketNotBound,
.WSAEMSGSIZE => error.MessageTooLarge,
.WSAENETDOWN => error.NetworkSubsystemFailed,
.WSAENETRESET => error.NetworkReset,
.WSAENOBUFS => error.BufferDeadlock,
.WSAENOTCONN => error.SocketNotConnected,
.WSAENOTSOCK => error.FileDescriptorNotASocket,
.WSAEOPNOTSUPP => error.OperationNotSupported,
.WSAESHUTDOWN => error.AlreadyShutdown,
.WSA_OPERATION_ABORTED => error.OperationAborted,
else => |err| windows.unexpectedWSAError(err),
};
}
return @intCast(usize, num_bytes);
}
/// Writes multiple I/O vectors with a prepended message header to the socket
/// with a set of flags specified. It returns the number of bytes that are
/// written to the socket.
pub fn writeMessage(self: Socket, msg: Socket.Message, flags: u32) !usize {
const call = try windows.loadWinsockExtensionFunction(ws2_32.LPFN_WSASENDMSG, self.fd, ws2_32.WSAID_WSASENDMSG);
var num_bytes: u32 = undefined;
const rc = call(self.fd, &msg, flags, &num_bytes, null, null);
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSAECONNABORTED => error.ConnectionAborted,
.WSAECONNRESET => error.ConnectionResetByPeer,
.WSAEFAULT => error.BadBuffer,
.WSAEINPROGRESS,
.WSAEWOULDBLOCK,
.WSA_IO_PENDING,
.WSAETIMEDOUT,
=> error.WouldBlock,
.WSAEINTR => error.Cancelled,
.WSAEINVAL => error.SocketNotBound,
.WSAEMSGSIZE => error.MessageTooLarge,
.WSAENETDOWN => error.NetworkSubsystemFailed,
.WSAENETRESET => error.NetworkReset,
.WSAENOBUFS => error.BufferDeadlock,
.WSAENOTCONN => error.SocketNotConnected,
.WSAENOTSOCK => error.FileDescriptorNotASocket,
.WSAEOPNOTSUPP => error.OperationNotSupported,
.WSAESHUTDOWN => error.AlreadyShutdown,
.WSA_OPERATION_ABORTED => error.OperationAborted,
else => |err| windows.unexpectedWSAError(err),
};
}
return @intCast(usize, num_bytes);
}
/// Read multiple I/O vectors with a prepended message header from the socket
/// with a set of flags specified. It returns the number of bytes that were
/// read into the buffer provided.
pub fn readMessage(self: Socket, msg: *Socket.Message, flags: u32) !usize {
_ = flags;
const call = try windows.loadWinsockExtensionFunction(ws2_32.LPFN_WSARECVMSG, self.fd, ws2_32.WSAID_WSARECVMSG);
var num_bytes: u32 = undefined;
const rc = call(self.fd, msg, &num_bytes, null, null);
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSAECONNABORTED => error.ConnectionAborted,
.WSAECONNRESET => error.ConnectionResetByPeer,
.WSAEDISCON => error.ConnectionClosedByPeer,
.WSAEFAULT => error.BadBuffer,
.WSAEINPROGRESS,
.WSAEWOULDBLOCK,
.WSA_IO_PENDING,
.WSAETIMEDOUT,
=> error.WouldBlock,
.WSAEINTR => error.Cancelled,
.WSAEINVAL => error.SocketNotBound,
.WSAEMSGSIZE => error.MessageTooLarge,
.WSAENETDOWN => error.NetworkSubsystemFailed,
.WSAENETRESET => error.NetworkReset,
.WSAENOTCONN => error.SocketNotConnected,
.WSAENOTSOCK => error.FileDescriptorNotASocket,
.WSAEOPNOTSUPP => error.OperationNotSupported,
.WSAESHUTDOWN => error.AlreadyShutdown,
.WSA_OPERATION_ABORTED => error.OperationAborted,
else => |err| windows.unexpectedWSAError(err),
};
}
return @intCast(usize, num_bytes);
}
/// Query the address that the socket is locally bounded to.
pub fn getLocalAddress(self: Socket) !Socket.Address {
var address: Socket.Address.Native.Storage = undefined;
var address_len: c_int = @sizeOf(Socket.Address.Native.Storage);
const rc = ws2_32.getsockname(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len);
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSANOTINITIALISED => unreachable,
.WSAEFAULT => unreachable,
.WSAENETDOWN => error.NetworkSubsystemFailed,
.WSAENOTSOCK => error.FileDescriptorNotASocket,
.WSAEINVAL => error.SocketNotBound,
else => |err| windows.unexpectedWSAError(err),
};
}
return Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address));
}
/// Query the address that the socket is connected to.
pub fn getRemoteAddress(self: Socket) !Socket.Address {
var address: Socket.Address.Native.Storage = undefined;
var address_len: c_int = @sizeOf(Socket.Address.Native.Storage);
const rc = ws2_32.getpeername(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len);
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSANOTINITIALISED => unreachable,
.WSAEFAULT => unreachable,
.WSAENETDOWN => error.NetworkSubsystemFailed,
.WSAENOTSOCK => error.FileDescriptorNotASocket,
.WSAEINVAL => error.SocketNotBound,
else => |err| windows.unexpectedWSAError(err),
};
}
return Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address));
}
/// Query and return the latest cached error on the socket.
pub fn getError(self: Socket) !void {
_ = self;
return {};
}
/// Query the read buffer size of the socket.
pub fn getReadBufferSize(self: Socket) !u32 {
_ = self;
return 0;
}
/// Query the write buffer size of the socket.
pub fn getWriteBufferSize(self: Socket) !u32 {
_ = self;
return 0;
}
/// Set a socket option.
pub fn setOption(self: Socket, level: u32, code: u32, value: []const u8) !void {
const rc = ws2_32.setsockopt(self.fd, @intCast(i32, level), @intCast(i32, code), value.ptr, @intCast(i32, value.len));
if (rc == ws2_32.SOCKET_ERROR) {
return switch (ws2_32.WSAGetLastError()) {
.WSANOTINITIALISED => unreachable,
.WSAENETDOWN => return error.NetworkSubsystemFailed,
.WSAEFAULT => unreachable,
.WSAENOTSOCK => return error.FileDescriptorNotASocket,
.WSAEINVAL => return error.SocketNotBound,
else => |err| windows.unexpectedWSAError(err),
};
}
}
/// Have close() or shutdown() syscalls block until all queued messages in the socket have been successfully
/// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption`
/// if the host does not support the option for a socket to linger around up until a timeout specified in
/// seconds.
pub fn setLinger(self: Socket, timeout_seconds: ?u16) !void {
const settings = Socket.Linger.init(timeout_seconds);
return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.LINGER, mem.asBytes(&settings));
}
/// On connection-oriented sockets, have keep-alive messages be sent periodically. The timing in which keep-alive
/// messages are sent are dependant on operating system settings. It returns `error.UnsupportedSocketOption` if
/// the host does not support periodically sending keep-alive messages on connection-oriented sockets.
pub fn setKeepAlive(self: Socket, enabled: bool) !void {
return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.KEEPALIVE, mem.asBytes(&@as(u32, @boolToInt(enabled))));
}
/// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if
/// the host does not support sockets listening the same address.
pub fn setReuseAddress(self: Socket, enabled: bool) !void {
return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.REUSEADDR, mem.asBytes(&@as(u32, @boolToInt(enabled))));
}
/// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if
/// the host does not supports sockets listening on the same port.
///
/// TODO: verify if this truly mimicks SO.REUSEPORT behavior, or if SO.REUSE_UNICASTPORT provides the correct behavior
pub fn setReusePort(self: Socket, enabled: bool) !void {
try self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.BROADCAST, mem.asBytes(&@as(u32, @boolToInt(enabled))));
try self.setReuseAddress(enabled);
}
/// Set the write buffer size of the socket.
pub fn setWriteBufferSize(self: Socket, size: u32) !void {
return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.SNDBUF, mem.asBytes(&size));
}
/// Set the read buffer size of the socket.
pub fn setReadBufferSize(self: Socket, size: u32) !void {
return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.RCVBUF, mem.asBytes(&size));
}
/// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is
/// set on a non-blocking socket.
///
/// Set a timeout on the socket that is to occur if no messages are successfully written
/// to its bound destination after a specified number of milliseconds. A subsequent write
/// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded.
pub fn setWriteTimeout(self: Socket, milliseconds: u32) !void {
return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.SNDTIMEO, mem.asBytes(&milliseconds));
}
/// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is
/// set on a non-blocking socket.
///
/// Set a timeout on the socket that is to occur if no messages are successfully read
/// from its bound destination after a specified number of milliseconds. A subsequent
/// read from the socket will thereafter return `error.WouldBlock` should the timeout be
/// exceeded.
pub fn setReadTimeout(self: Socket, milliseconds: u32) !void {
return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.RCVTIMEO, mem.asBytes(&milliseconds));
}
};
}

View file

@ -584,7 +584,17 @@ pub const AllErrors = struct {
Message.HashContext, Message.HashContext,
std.hash_map.default_max_load_percentage, std.hash_map.default_max_load_percentage,
).init(allocator); ).init(allocator);
const err_source = try module_err_msg.src_loc.file_scope.getSource(module.gpa); const err_source = module_err_msg.src_loc.file_scope.getSource(module.gpa) catch |err| {
const file_path = try module_err_msg.src_loc.file_scope.fullPath(allocator);
try errors.append(.{
.plain = .{
.msg = try std.fmt.allocPrint(allocator, "unable to load '{s}': {s}", .{
file_path, @errorName(err),
}),
},
});
return;
};
const err_span = try module_err_msg.src_loc.span(module.gpa); const err_span = try module_err_msg.src_loc.span(module.gpa);
const err_loc = std.zig.findLineColumn(err_source.bytes, err_span.main); const err_loc = std.zig.findLineColumn(err_source.bytes, err_span.main);