diff --git a/lib/std/crypto/kangarootwelve.zig b/lib/std/crypto/kangarootwelve.zig index afb1fc8c0d..2befdd42b9 100644 --- a/lib/std/crypto/kangarootwelve.zig +++ b/lib/std/crypto/kangarootwelve.zig @@ -819,7 +819,7 @@ fn FinalLeafContext(comptime Variant: type) type { }; } -/// Generic multi-threaded implementation +/// Generic multi-threaded implementation with bounded heap allocation. fn ktMultiThreaded( comptime Variant: type, allocator: Allocator, @@ -831,108 +831,86 @@ fn ktMultiThreaded( comptime std.debug.assert(bytes_per_batch % (optimal_vector_len * chunk_size) == 0); const cv_size = Variant.cv_size; - - // Calculate total leaves after the first chunk + const StateType = Variant.StateType; + const leaves_per_batch = bytes_per_batch / chunk_size; const remaining_bytes = total_len - chunk_size; const total_leaves = std.math.divCeil(usize, remaining_bytes, chunk_size) catch unreachable; - // Pre-compute suffix: right_encode(n) || terminator - const n_enc = rightEncode(total_leaves); - const terminator = [_]u8{ 0xFF, 0xFF }; + var final_state = StateType.init(.{}); + + var first_chunk_buffer: [chunk_size]u8 = undefined; + if (view.tryGetSlice(0, chunk_size)) |first_chunk| { + final_state.update(first_chunk); + } else { + view.copyRange(0, chunk_size, &first_chunk_buffer); + final_state.update(&first_chunk_buffer); + } + const padding = [_]u8{ 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; + final_state.update(&padding); - // Try to get first chunk as contiguous slice to avoid a copy and take advantage of MultiSliceView instead. - const first_chunk_direct = view.tryGetSlice(0, chunk_size); - - // Calculate buffer size - skip first chunk if we can reference it directly - const cvs_len = total_leaves * cv_size; - const suffix_len = n_enc.len + terminator.len; - const msg_buf_len = if (first_chunk_direct != null) - padding.len + cvs_len + suffix_len - else - chunk_size + padding.len + cvs_len + suffix_len; - - const msg_buf = try allocator.alignedAlloc(u8, std.mem.Alignment.of(u64), msg_buf_len); - defer allocator.free(msg_buf); - - // Set up buffer layout based on whether we're using zero-copy for first chunk - const cvs: []align(@alignOf(u64)) u8 = if (first_chunk_direct != null) blk: { - // Zero-copy layout: padding || cvs || suffix - @memcpy(msg_buf[0..padding.len], &padding); - @memcpy(msg_buf[padding.len + cvs_len ..][0..n_enc.len], n_enc.slice()); - @memcpy(msg_buf[padding.len + cvs_len + n_enc.len ..][0..terminator.len], &terminator); - break :blk @alignCast(msg_buf[padding.len..][0..cvs_len]); - } else blk: { - // Fallback layout: first_chunk || padding || cvs || suffix - view.copyRange(0, chunk_size, msg_buf[0..chunk_size]); - @memcpy(msg_buf[chunk_size..][0..padding.len], &padding); - @memcpy(msg_buf[chunk_size + padding.len + cvs_len ..][0..n_enc.len], n_enc.slice()); - @memcpy(msg_buf[chunk_size + padding.len + cvs_len + n_enc.len ..][0..terminator.len], &terminator); - break :blk @alignCast(msg_buf[chunk_size + padding.len ..][0..cvs_len]); - }; - - // Calculate how many full (complete chunk_size) leaves we have const full_leaves = remaining_bytes / chunk_size; - - // Check if there's a partial final leaf (less than chunk_size) const has_partial_leaf = (remaining_bytes % chunk_size) != 0; const partial_leaf_size = if (has_partial_leaf) remaining_bytes % chunk_size else 0; - // Number of leaves (chunks) per batch in multi-threaded mode - const leaves_per_batch = bytes_per_batch / chunk_size; + const max_concurrent_batches = 256; + const cvs_per_super_batch = max_concurrent_batches * leaves_per_batch * cv_size; - // Calculate number of full thread tasks based on complete leaves only - const num_full_tasks = full_leaves / leaves_per_batch; - const remaining_full_leaves = full_leaves % leaves_per_batch; + const cv_buf = try allocator.alignedAlloc(u8, std.mem.Alignment.of(u64), cvs_per_super_batch); + defer allocator.free(cv_buf); - var group: Io.Group = .init; + var leaves_processed: usize = 0; + while (leaves_processed < full_leaves) { + const leaves_in_super_batch = @min(max_concurrent_batches * leaves_per_batch, full_leaves - leaves_processed); + const num_batches = std.math.divCeil(usize, leaves_in_super_batch, leaves_per_batch) catch unreachable; - // Spawn tasks for full SIMD batches - for (0..num_full_tasks) |task_id| { - const start_offset = chunk_size + task_id * bytes_per_batch; - const cv_start = task_id * leaves_per_batch * cv_size; + var group: Io.Group = .init; - group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){ - .view = view, - .start_offset = start_offset, - .num_leaves = leaves_per_batch, - .output_cvs = @alignCast(cvs[cv_start..][0 .. leaves_per_batch * cv_size]), - }}); + for (0..num_batches) |batch_idx| { + const batch_start_leaf = leaves_processed + batch_idx * leaves_per_batch; + const batch_leaves = @min(leaves_per_batch, full_leaves - batch_start_leaf); + + if (batch_leaves == 0) break; + + const start_offset = chunk_size + batch_start_leaf * chunk_size; + const cv_start = batch_idx * leaves_per_batch * cv_size; + + group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){ + .view = view, + .start_offset = start_offset, + .num_leaves = batch_leaves, + .output_cvs = @alignCast(cv_buf[cv_start..][0 .. batch_leaves * cv_size]), + }}); + } + + group.wait(io); + + final_state.update(cv_buf[0 .. leaves_in_super_batch * cv_size]); + leaves_processed += leaves_in_super_batch; } - // Spawn task for remaining full leaves (if any) - if (remaining_full_leaves > 0) { - const start_offset = chunk_size + num_full_tasks * bytes_per_batch; - const cv_start = num_full_tasks * leaves_per_batch * cv_size; - - group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){ - .view = view, - .start_offset = start_offset, - .num_leaves = remaining_full_leaves, - .output_cvs = @alignCast(cvs[cv_start..][0 .. remaining_full_leaves * cv_size]), - }}); - } - - // Spawn task for the partial final leaf (if required) if (has_partial_leaf) { - const start_offset = chunk_size + full_leaves * chunk_size; - const cv_start = full_leaves * cv_size; + var cv_buffer: [64]u8 = undefined; + var leaf_buffer: [chunk_size]u8 = undefined; - group.async(io, FinalLeafContext(Variant).process, .{FinalLeafContext(Variant){ - .view = view, - .start_offset = start_offset, - .leaf_len = partial_leaf_size, - .output_cv = @alignCast(cvs[cv_start..][0..cv_size]), - }}); + const start_offset = chunk_size + full_leaves * chunk_size; + if (view.tryGetSlice(start_offset, start_offset + partial_leaf_size)) |leaf_data| { + const cv_slice = MultiSliceView.init(leaf_data, &[_]u8{}, &[_]u8{}); + Variant.turboShakeToBuffer(&cv_slice, 0x0B, cv_buffer[0..cv_size]); + } else { + view.copyRange(start_offset, start_offset + partial_leaf_size, leaf_buffer[0..partial_leaf_size]); + const cv_slice = MultiSliceView.init(leaf_buffer[0..partial_leaf_size], &[_]u8{}, &[_]u8{}); + Variant.turboShakeToBuffer(&cv_slice, 0x0B, cv_buffer[0..cv_size]); + } + final_state.update(cv_buffer[0..cv_size]); } - group.wait(io); + const n_enc = rightEncode(total_leaves); + final_state.update(n_enc.slice()); + const terminator = [_]u8{ 0xFF, 0xFF }; + final_state.update(&terminator); - const final_view = if (first_chunk_direct) |first_chunk| - MultiSliceView.init(first_chunk, msg_buf, &[_]u8{}) - else - MultiSliceView.init(msg_buf, &[_]u8{}, &[_]u8{}); - Variant.turboShakeToBuffer(&final_view, 0x06, output); + final_state.final(output); } /// Generic KangarooTwelve hash function builder.