parallel kangarootwelve: cap memory usage

This commit is contained in:
Frank Denis 2025-11-25 23:14:20 +01:00
parent b6fecdc622
commit 13f1fa7798

View file

@ -819,7 +819,7 @@ fn FinalLeafContext(comptime Variant: type) type {
}; };
} }
/// Generic multi-threaded implementation /// Generic multi-threaded implementation with bounded heap allocation.
fn ktMultiThreaded( fn ktMultiThreaded(
comptime Variant: type, comptime Variant: type,
allocator: Allocator, allocator: Allocator,
@ -831,108 +831,86 @@ fn ktMultiThreaded(
comptime std.debug.assert(bytes_per_batch % (optimal_vector_len * chunk_size) == 0); comptime std.debug.assert(bytes_per_batch % (optimal_vector_len * chunk_size) == 0);
const cv_size = Variant.cv_size; const cv_size = Variant.cv_size;
const StateType = Variant.StateType;
// Calculate total leaves after the first chunk const leaves_per_batch = bytes_per_batch / chunk_size;
const remaining_bytes = total_len - chunk_size; const remaining_bytes = total_len - chunk_size;
const total_leaves = std.math.divCeil(usize, remaining_bytes, chunk_size) catch unreachable; const total_leaves = std.math.divCeil(usize, remaining_bytes, chunk_size) catch unreachable;
// Pre-compute suffix: right_encode(n) || terminator var final_state = StateType.init(.{});
const n_enc = rightEncode(total_leaves);
const terminator = [_]u8{ 0xFF, 0xFF }; var first_chunk_buffer: [chunk_size]u8 = undefined;
if (view.tryGetSlice(0, chunk_size)) |first_chunk| {
final_state.update(first_chunk);
} else {
view.copyRange(0, chunk_size, &first_chunk_buffer);
final_state.update(&first_chunk_buffer);
}
const padding = [_]u8{ 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; const padding = [_]u8{ 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 };
final_state.update(&padding);
// Try to get first chunk as contiguous slice to avoid a copy and take advantage of MultiSliceView instead.
const first_chunk_direct = view.tryGetSlice(0, chunk_size);
// Calculate buffer size - skip first chunk if we can reference it directly
const cvs_len = total_leaves * cv_size;
const suffix_len = n_enc.len + terminator.len;
const msg_buf_len = if (first_chunk_direct != null)
padding.len + cvs_len + suffix_len
else
chunk_size + padding.len + cvs_len + suffix_len;
const msg_buf = try allocator.alignedAlloc(u8, std.mem.Alignment.of(u64), msg_buf_len);
defer allocator.free(msg_buf);
// Set up buffer layout based on whether we're using zero-copy for first chunk
const cvs: []align(@alignOf(u64)) u8 = if (first_chunk_direct != null) blk: {
// Zero-copy layout: padding || cvs || suffix
@memcpy(msg_buf[0..padding.len], &padding);
@memcpy(msg_buf[padding.len + cvs_len ..][0..n_enc.len], n_enc.slice());
@memcpy(msg_buf[padding.len + cvs_len + n_enc.len ..][0..terminator.len], &terminator);
break :blk @alignCast(msg_buf[padding.len..][0..cvs_len]);
} else blk: {
// Fallback layout: first_chunk || padding || cvs || suffix
view.copyRange(0, chunk_size, msg_buf[0..chunk_size]);
@memcpy(msg_buf[chunk_size..][0..padding.len], &padding);
@memcpy(msg_buf[chunk_size + padding.len + cvs_len ..][0..n_enc.len], n_enc.slice());
@memcpy(msg_buf[chunk_size + padding.len + cvs_len + n_enc.len ..][0..terminator.len], &terminator);
break :blk @alignCast(msg_buf[chunk_size + padding.len ..][0..cvs_len]);
};
// Calculate how many full (complete chunk_size) leaves we have
const full_leaves = remaining_bytes / chunk_size; const full_leaves = remaining_bytes / chunk_size;
// Check if there's a partial final leaf (less than chunk_size)
const has_partial_leaf = (remaining_bytes % chunk_size) != 0; const has_partial_leaf = (remaining_bytes % chunk_size) != 0;
const partial_leaf_size = if (has_partial_leaf) remaining_bytes % chunk_size else 0; const partial_leaf_size = if (has_partial_leaf) remaining_bytes % chunk_size else 0;
// Number of leaves (chunks) per batch in multi-threaded mode const max_concurrent_batches = 256;
const leaves_per_batch = bytes_per_batch / chunk_size; const cvs_per_super_batch = max_concurrent_batches * leaves_per_batch * cv_size;
// Calculate number of full thread tasks based on complete leaves only const cv_buf = try allocator.alignedAlloc(u8, std.mem.Alignment.of(u64), cvs_per_super_batch);
const num_full_tasks = full_leaves / leaves_per_batch; defer allocator.free(cv_buf);
const remaining_full_leaves = full_leaves % leaves_per_batch;
var leaves_processed: usize = 0;
while (leaves_processed < full_leaves) {
const leaves_in_super_batch = @min(max_concurrent_batches * leaves_per_batch, full_leaves - leaves_processed);
const num_batches = std.math.divCeil(usize, leaves_in_super_batch, leaves_per_batch) catch unreachable;
var group: Io.Group = .init; var group: Io.Group = .init;
// Spawn tasks for full SIMD batches for (0..num_batches) |batch_idx| {
for (0..num_full_tasks) |task_id| { const batch_start_leaf = leaves_processed + batch_idx * leaves_per_batch;
const start_offset = chunk_size + task_id * bytes_per_batch; const batch_leaves = @min(leaves_per_batch, full_leaves - batch_start_leaf);
const cv_start = task_id * leaves_per_batch * cv_size;
if (batch_leaves == 0) break;
const start_offset = chunk_size + batch_start_leaf * chunk_size;
const cv_start = batch_idx * leaves_per_batch * cv_size;
group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){ group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){
.view = view, .view = view,
.start_offset = start_offset, .start_offset = start_offset,
.num_leaves = leaves_per_batch, .num_leaves = batch_leaves,
.output_cvs = @alignCast(cvs[cv_start..][0 .. leaves_per_batch * cv_size]), .output_cvs = @alignCast(cv_buf[cv_start..][0 .. batch_leaves * cv_size]),
}});
}
// Spawn task for remaining full leaves (if any)
if (remaining_full_leaves > 0) {
const start_offset = chunk_size + num_full_tasks * bytes_per_batch;
const cv_start = num_full_tasks * leaves_per_batch * cv_size;
group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){
.view = view,
.start_offset = start_offset,
.num_leaves = remaining_full_leaves,
.output_cvs = @alignCast(cvs[cv_start..][0 .. remaining_full_leaves * cv_size]),
}});
}
// Spawn task for the partial final leaf (if required)
if (has_partial_leaf) {
const start_offset = chunk_size + full_leaves * chunk_size;
const cv_start = full_leaves * cv_size;
group.async(io, FinalLeafContext(Variant).process, .{FinalLeafContext(Variant){
.view = view,
.start_offset = start_offset,
.leaf_len = partial_leaf_size,
.output_cv = @alignCast(cvs[cv_start..][0..cv_size]),
}}); }});
} }
group.wait(io); group.wait(io);
const final_view = if (first_chunk_direct) |first_chunk| final_state.update(cv_buf[0 .. leaves_in_super_batch * cv_size]);
MultiSliceView.init(first_chunk, msg_buf, &[_]u8{}) leaves_processed += leaves_in_super_batch;
else }
MultiSliceView.init(msg_buf, &[_]u8{}, &[_]u8{});
Variant.turboShakeToBuffer(&final_view, 0x06, output); if (has_partial_leaf) {
var cv_buffer: [64]u8 = undefined;
var leaf_buffer: [chunk_size]u8 = undefined;
const start_offset = chunk_size + full_leaves * chunk_size;
if (view.tryGetSlice(start_offset, start_offset + partial_leaf_size)) |leaf_data| {
const cv_slice = MultiSliceView.init(leaf_data, &[_]u8{}, &[_]u8{});
Variant.turboShakeToBuffer(&cv_slice, 0x0B, cv_buffer[0..cv_size]);
} else {
view.copyRange(start_offset, start_offset + partial_leaf_size, leaf_buffer[0..partial_leaf_size]);
const cv_slice = MultiSliceView.init(leaf_buffer[0..partial_leaf_size], &[_]u8{}, &[_]u8{});
Variant.turboShakeToBuffer(&cv_slice, 0x0B, cv_buffer[0..cv_size]);
}
final_state.update(cv_buffer[0..cv_size]);
}
const n_enc = rightEncode(total_leaves);
final_state.update(n_enc.slice());
const terminator = [_]u8{ 0xFF, 0xFF };
final_state.update(&terminator);
final_state.final(output);
} }
/// Generic KangarooTwelve hash function builder. /// Generic KangarooTwelve hash function builder.