mirror of
https://codeberg.org/ziglang/zig.git
synced 2025-12-06 13:54:21 +00:00
parallel kangarootwelve: cap memory usage
This commit is contained in:
parent
b6fecdc622
commit
13f1fa7798
1 changed files with 61 additions and 83 deletions
|
|
@ -819,7 +819,7 @@ fn FinalLeafContext(comptime Variant: type) type {
|
|||
};
|
||||
}
|
||||
|
||||
/// Generic multi-threaded implementation
|
||||
/// Generic multi-threaded implementation with bounded heap allocation.
|
||||
fn ktMultiThreaded(
|
||||
comptime Variant: type,
|
||||
allocator: Allocator,
|
||||
|
|
@ -831,108 +831,86 @@ fn ktMultiThreaded(
|
|||
comptime std.debug.assert(bytes_per_batch % (optimal_vector_len * chunk_size) == 0);
|
||||
|
||||
const cv_size = Variant.cv_size;
|
||||
|
||||
// Calculate total leaves after the first chunk
|
||||
const StateType = Variant.StateType;
|
||||
const leaves_per_batch = bytes_per_batch / chunk_size;
|
||||
const remaining_bytes = total_len - chunk_size;
|
||||
const total_leaves = std.math.divCeil(usize, remaining_bytes, chunk_size) catch unreachable;
|
||||
|
||||
// Pre-compute suffix: right_encode(n) || terminator
|
||||
const n_enc = rightEncode(total_leaves);
|
||||
const terminator = [_]u8{ 0xFF, 0xFF };
|
||||
var final_state = StateType.init(.{});
|
||||
|
||||
var first_chunk_buffer: [chunk_size]u8 = undefined;
|
||||
if (view.tryGetSlice(0, chunk_size)) |first_chunk| {
|
||||
final_state.update(first_chunk);
|
||||
} else {
|
||||
view.copyRange(0, chunk_size, &first_chunk_buffer);
|
||||
final_state.update(&first_chunk_buffer);
|
||||
}
|
||||
|
||||
const padding = [_]u8{ 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 };
|
||||
final_state.update(&padding);
|
||||
|
||||
// Try to get first chunk as contiguous slice to avoid a copy and take advantage of MultiSliceView instead.
|
||||
const first_chunk_direct = view.tryGetSlice(0, chunk_size);
|
||||
|
||||
// Calculate buffer size - skip first chunk if we can reference it directly
|
||||
const cvs_len = total_leaves * cv_size;
|
||||
const suffix_len = n_enc.len + terminator.len;
|
||||
const msg_buf_len = if (first_chunk_direct != null)
|
||||
padding.len + cvs_len + suffix_len
|
||||
else
|
||||
chunk_size + padding.len + cvs_len + suffix_len;
|
||||
|
||||
const msg_buf = try allocator.alignedAlloc(u8, std.mem.Alignment.of(u64), msg_buf_len);
|
||||
defer allocator.free(msg_buf);
|
||||
|
||||
// Set up buffer layout based on whether we're using zero-copy for first chunk
|
||||
const cvs: []align(@alignOf(u64)) u8 = if (first_chunk_direct != null) blk: {
|
||||
// Zero-copy layout: padding || cvs || suffix
|
||||
@memcpy(msg_buf[0..padding.len], &padding);
|
||||
@memcpy(msg_buf[padding.len + cvs_len ..][0..n_enc.len], n_enc.slice());
|
||||
@memcpy(msg_buf[padding.len + cvs_len + n_enc.len ..][0..terminator.len], &terminator);
|
||||
break :blk @alignCast(msg_buf[padding.len..][0..cvs_len]);
|
||||
} else blk: {
|
||||
// Fallback layout: first_chunk || padding || cvs || suffix
|
||||
view.copyRange(0, chunk_size, msg_buf[0..chunk_size]);
|
||||
@memcpy(msg_buf[chunk_size..][0..padding.len], &padding);
|
||||
@memcpy(msg_buf[chunk_size + padding.len + cvs_len ..][0..n_enc.len], n_enc.slice());
|
||||
@memcpy(msg_buf[chunk_size + padding.len + cvs_len + n_enc.len ..][0..terminator.len], &terminator);
|
||||
break :blk @alignCast(msg_buf[chunk_size + padding.len ..][0..cvs_len]);
|
||||
};
|
||||
|
||||
// Calculate how many full (complete chunk_size) leaves we have
|
||||
const full_leaves = remaining_bytes / chunk_size;
|
||||
|
||||
// Check if there's a partial final leaf (less than chunk_size)
|
||||
const has_partial_leaf = (remaining_bytes % chunk_size) != 0;
|
||||
const partial_leaf_size = if (has_partial_leaf) remaining_bytes % chunk_size else 0;
|
||||
|
||||
// Number of leaves (chunks) per batch in multi-threaded mode
|
||||
const leaves_per_batch = bytes_per_batch / chunk_size;
|
||||
const max_concurrent_batches = 256;
|
||||
const cvs_per_super_batch = max_concurrent_batches * leaves_per_batch * cv_size;
|
||||
|
||||
// Calculate number of full thread tasks based on complete leaves only
|
||||
const num_full_tasks = full_leaves / leaves_per_batch;
|
||||
const remaining_full_leaves = full_leaves % leaves_per_batch;
|
||||
const cv_buf = try allocator.alignedAlloc(u8, std.mem.Alignment.of(u64), cvs_per_super_batch);
|
||||
defer allocator.free(cv_buf);
|
||||
|
||||
var leaves_processed: usize = 0;
|
||||
while (leaves_processed < full_leaves) {
|
||||
const leaves_in_super_batch = @min(max_concurrent_batches * leaves_per_batch, full_leaves - leaves_processed);
|
||||
const num_batches = std.math.divCeil(usize, leaves_in_super_batch, leaves_per_batch) catch unreachable;
|
||||
|
||||
var group: Io.Group = .init;
|
||||
|
||||
// Spawn tasks for full SIMD batches
|
||||
for (0..num_full_tasks) |task_id| {
|
||||
const start_offset = chunk_size + task_id * bytes_per_batch;
|
||||
const cv_start = task_id * leaves_per_batch * cv_size;
|
||||
for (0..num_batches) |batch_idx| {
|
||||
const batch_start_leaf = leaves_processed + batch_idx * leaves_per_batch;
|
||||
const batch_leaves = @min(leaves_per_batch, full_leaves - batch_start_leaf);
|
||||
|
||||
if (batch_leaves == 0) break;
|
||||
|
||||
const start_offset = chunk_size + batch_start_leaf * chunk_size;
|
||||
const cv_start = batch_idx * leaves_per_batch * cv_size;
|
||||
|
||||
group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){
|
||||
.view = view,
|
||||
.start_offset = start_offset,
|
||||
.num_leaves = leaves_per_batch,
|
||||
.output_cvs = @alignCast(cvs[cv_start..][0 .. leaves_per_batch * cv_size]),
|
||||
}});
|
||||
}
|
||||
|
||||
// Spawn task for remaining full leaves (if any)
|
||||
if (remaining_full_leaves > 0) {
|
||||
const start_offset = chunk_size + num_full_tasks * bytes_per_batch;
|
||||
const cv_start = num_full_tasks * leaves_per_batch * cv_size;
|
||||
|
||||
group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){
|
||||
.view = view,
|
||||
.start_offset = start_offset,
|
||||
.num_leaves = remaining_full_leaves,
|
||||
.output_cvs = @alignCast(cvs[cv_start..][0 .. remaining_full_leaves * cv_size]),
|
||||
}});
|
||||
}
|
||||
|
||||
// Spawn task for the partial final leaf (if required)
|
||||
if (has_partial_leaf) {
|
||||
const start_offset = chunk_size + full_leaves * chunk_size;
|
||||
const cv_start = full_leaves * cv_size;
|
||||
|
||||
group.async(io, FinalLeafContext(Variant).process, .{FinalLeafContext(Variant){
|
||||
.view = view,
|
||||
.start_offset = start_offset,
|
||||
.leaf_len = partial_leaf_size,
|
||||
.output_cv = @alignCast(cvs[cv_start..][0..cv_size]),
|
||||
.num_leaves = batch_leaves,
|
||||
.output_cvs = @alignCast(cv_buf[cv_start..][0 .. batch_leaves * cv_size]),
|
||||
}});
|
||||
}
|
||||
|
||||
group.wait(io);
|
||||
|
||||
const final_view = if (first_chunk_direct) |first_chunk|
|
||||
MultiSliceView.init(first_chunk, msg_buf, &[_]u8{})
|
||||
else
|
||||
MultiSliceView.init(msg_buf, &[_]u8{}, &[_]u8{});
|
||||
Variant.turboShakeToBuffer(&final_view, 0x06, output);
|
||||
final_state.update(cv_buf[0 .. leaves_in_super_batch * cv_size]);
|
||||
leaves_processed += leaves_in_super_batch;
|
||||
}
|
||||
|
||||
if (has_partial_leaf) {
|
||||
var cv_buffer: [64]u8 = undefined;
|
||||
var leaf_buffer: [chunk_size]u8 = undefined;
|
||||
|
||||
const start_offset = chunk_size + full_leaves * chunk_size;
|
||||
if (view.tryGetSlice(start_offset, start_offset + partial_leaf_size)) |leaf_data| {
|
||||
const cv_slice = MultiSliceView.init(leaf_data, &[_]u8{}, &[_]u8{});
|
||||
Variant.turboShakeToBuffer(&cv_slice, 0x0B, cv_buffer[0..cv_size]);
|
||||
} else {
|
||||
view.copyRange(start_offset, start_offset + partial_leaf_size, leaf_buffer[0..partial_leaf_size]);
|
||||
const cv_slice = MultiSliceView.init(leaf_buffer[0..partial_leaf_size], &[_]u8{}, &[_]u8{});
|
||||
Variant.turboShakeToBuffer(&cv_slice, 0x0B, cv_buffer[0..cv_size]);
|
||||
}
|
||||
final_state.update(cv_buffer[0..cv_size]);
|
||||
}
|
||||
|
||||
const n_enc = rightEncode(total_leaves);
|
||||
final_state.update(n_enc.slice());
|
||||
const terminator = [_]u8{ 0xFF, 0xFF };
|
||||
final_state.update(&terminator);
|
||||
|
||||
final_state.final(output);
|
||||
}
|
||||
|
||||
/// Generic KangarooTwelve hash function builder.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue