diff --git a/lib/std/Io/Threaded.zig b/lib/std/Io/Threaded.zig index 016d864fcc..2d46ff373e 100644 --- a/lib/std/Io/Threaded.zig +++ b/lib/std/Io/Threaded.zig @@ -25,7 +25,8 @@ run_queue: std.SinglyLinkedList = .{}, join_requested: bool = false, threads: std.ArrayList(std.Thread), stack_size: usize, -cpu_count: std.Thread.CpuCountError!usize, +thread_capacity: std.atomic.Value(ThreadCapacity), +thread_capacity_error: ?std.Thread.CpuCountError, concurrent_count: usize, wsa: if (is_windows) Wsa else struct {} = .{}, @@ -34,6 +35,21 @@ have_signal_handler: bool, old_sig_io: if (have_sig_io) posix.Sigaction else void, old_sig_pipe: if (have_sig_pipe) posix.Sigaction else void, +pub const ThreadCapacity = enum(usize) { + unknown = 0, + _, + + pub fn init(n: usize) ThreadCapacity { + assert(n != 0); + return @enumFromInt(n); + } + + pub fn get(tc: ThreadCapacity) ?usize { + if (tc == .unknown) return null; + return @intFromEnum(tc); + } +}; + threadlocal var current_closure: ?*Closure = null; const max_iovecs_len = 8; @@ -104,18 +120,21 @@ pub fn init( /// here. gpa: Allocator, ) Threaded { + const cpu_count = std.Thread.getCpuCount(); + var t: Threaded = .{ .allocator = gpa, .threads = .empty, .stack_size = std.Thread.SpawnConfig.default_stack_size, - .cpu_count = std.Thread.getCpuCount(), + .thread_capacity = .init(if (cpu_count) |n| .init(n) else |_| .unknown), + .thread_capacity_error = if (cpu_count) |_| null else |e| e, .concurrent_count = 0, .old_sig_io = undefined, .old_sig_pipe = undefined, .have_signal_handler = false, }; - if (t.cpu_count) |n| { + if (cpu_count) |n| { t.threads.ensureTotalCapacityPrecise(gpa, n - 1) catch {}; } else |_| {} @@ -145,7 +164,8 @@ pub const init_single_threaded: Threaded = .{ .allocator = .failing, .threads = .empty, .stack_size = std.Thread.SpawnConfig.default_stack_size, - .cpu_count = 1, + .thread_capacity = .init(.init(1)), + .thread_capacity_error = null, .concurrent_count = 0, .old_sig_io = undefined, .old_sig_pipe = undefined, @@ -166,6 +186,18 @@ pub fn deinit(t: *Threaded) void { t.* = undefined; } +pub fn setThreadCapacity(t: *Threaded, n: usize) void { + t.thread_capacity.store(.init(n), .monotonic); +} + +pub fn getThreadCapacity(t: *Threaded) ?usize { + return t.thread_capacity.load(.monotonic).get(); +} + +pub fn getCurrentThreadId() usize { + @panic("TODO"); +} + fn join(t: *Threaded) void { if (builtin.single_threaded) return; { @@ -497,7 +529,7 @@ fn async( } const t: *Threaded = @ptrCast(@alignCast(userdata)); - const cpu_count = t.cpu_count catch { + const cpu_count = t.getThreadCapacity() orelse { return concurrent(userdata, result.len, result_alignment, context, context_alignment, start) catch { start(context.ptr, result.ptr); return null; @@ -556,7 +588,7 @@ fn concurrent( if (builtin.single_threaded) return error.ConcurrencyUnavailable; const t: *Threaded = @ptrCast(@alignCast(userdata)); - const cpu_count = t.cpu_count catch 1; + const cpu_count = t.getThreadCapacity() orelse 1; const gpa = t.allocator; const ac = AsyncClosure.init(gpa, .concurrent, result_len, result_alignment, context, context_alignment, start) catch { @@ -685,7 +717,7 @@ fn groupAsync( if (builtin.single_threaded) return start(group, context.ptr); const t: *Threaded = @ptrCast(@alignCast(userdata)); - const cpu_count = t.cpu_count catch 1; + const cpu_count = t.getThreadCapacity() orelse 1; const gpa = t.allocator; const gc = GroupClosure.init(gpa, t, group, context, context_alignment, start) catch { diff --git a/lib/std/Io/Threaded/test.zig b/lib/std/Io/Threaded/test.zig index 7e6e687cf2..03582d4d95 100644 --- a/lib/std/Io/Threaded/test.zig +++ b/lib/std/Io/Threaded/test.zig @@ -10,7 +10,7 @@ test "concurrent vs main prevents deadlock via oversubscription" { defer threaded.deinit(); const io = threaded.io(); - threaded.cpu_count = 1; + threaded.setThreadCapacity(1); var queue: Io.Queue(u8) = .init(&.{}); @@ -38,7 +38,7 @@ test "concurrent vs concurrent prevents deadlock via oversubscription" { defer threaded.deinit(); const io = threaded.io(); - threaded.cpu_count = 1; + threaded.setThreadCapacity(1); var queue: Io.Queue(u8) = .init(&.{});