mirror of
https://codeberg.org/ziglang/zig.git
synced 2025-12-06 13:54:21 +00:00
parallel kangarootwelve: cap memory usage
This commit is contained in:
parent
b6fecdc622
commit
13f1fa7798
1 changed files with 61 additions and 83 deletions
|
|
@ -819,7 +819,7 @@ fn FinalLeafContext(comptime Variant: type) type {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic multi-threaded implementation
|
/// Generic multi-threaded implementation with bounded heap allocation.
|
||||||
fn ktMultiThreaded(
|
fn ktMultiThreaded(
|
||||||
comptime Variant: type,
|
comptime Variant: type,
|
||||||
allocator: Allocator,
|
allocator: Allocator,
|
||||||
|
|
@ -831,108 +831,86 @@ fn ktMultiThreaded(
|
||||||
comptime std.debug.assert(bytes_per_batch % (optimal_vector_len * chunk_size) == 0);
|
comptime std.debug.assert(bytes_per_batch % (optimal_vector_len * chunk_size) == 0);
|
||||||
|
|
||||||
const cv_size = Variant.cv_size;
|
const cv_size = Variant.cv_size;
|
||||||
|
const StateType = Variant.StateType;
|
||||||
// Calculate total leaves after the first chunk
|
const leaves_per_batch = bytes_per_batch / chunk_size;
|
||||||
const remaining_bytes = total_len - chunk_size;
|
const remaining_bytes = total_len - chunk_size;
|
||||||
const total_leaves = std.math.divCeil(usize, remaining_bytes, chunk_size) catch unreachable;
|
const total_leaves = std.math.divCeil(usize, remaining_bytes, chunk_size) catch unreachable;
|
||||||
|
|
||||||
// Pre-compute suffix: right_encode(n) || terminator
|
var final_state = StateType.init(.{});
|
||||||
const n_enc = rightEncode(total_leaves);
|
|
||||||
const terminator = [_]u8{ 0xFF, 0xFF };
|
var first_chunk_buffer: [chunk_size]u8 = undefined;
|
||||||
|
if (view.tryGetSlice(0, chunk_size)) |first_chunk| {
|
||||||
|
final_state.update(first_chunk);
|
||||||
|
} else {
|
||||||
|
view.copyRange(0, chunk_size, &first_chunk_buffer);
|
||||||
|
final_state.update(&first_chunk_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
const padding = [_]u8{ 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 };
|
const padding = [_]u8{ 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 };
|
||||||
|
final_state.update(&padding);
|
||||||
|
|
||||||
// Try to get first chunk as contiguous slice to avoid a copy and take advantage of MultiSliceView instead.
|
|
||||||
const first_chunk_direct = view.tryGetSlice(0, chunk_size);
|
|
||||||
|
|
||||||
// Calculate buffer size - skip first chunk if we can reference it directly
|
|
||||||
const cvs_len = total_leaves * cv_size;
|
|
||||||
const suffix_len = n_enc.len + terminator.len;
|
|
||||||
const msg_buf_len = if (first_chunk_direct != null)
|
|
||||||
padding.len + cvs_len + suffix_len
|
|
||||||
else
|
|
||||||
chunk_size + padding.len + cvs_len + suffix_len;
|
|
||||||
|
|
||||||
const msg_buf = try allocator.alignedAlloc(u8, std.mem.Alignment.of(u64), msg_buf_len);
|
|
||||||
defer allocator.free(msg_buf);
|
|
||||||
|
|
||||||
// Set up buffer layout based on whether we're using zero-copy for first chunk
|
|
||||||
const cvs: []align(@alignOf(u64)) u8 = if (first_chunk_direct != null) blk: {
|
|
||||||
// Zero-copy layout: padding || cvs || suffix
|
|
||||||
@memcpy(msg_buf[0..padding.len], &padding);
|
|
||||||
@memcpy(msg_buf[padding.len + cvs_len ..][0..n_enc.len], n_enc.slice());
|
|
||||||
@memcpy(msg_buf[padding.len + cvs_len + n_enc.len ..][0..terminator.len], &terminator);
|
|
||||||
break :blk @alignCast(msg_buf[padding.len..][0..cvs_len]);
|
|
||||||
} else blk: {
|
|
||||||
// Fallback layout: first_chunk || padding || cvs || suffix
|
|
||||||
view.copyRange(0, chunk_size, msg_buf[0..chunk_size]);
|
|
||||||
@memcpy(msg_buf[chunk_size..][0..padding.len], &padding);
|
|
||||||
@memcpy(msg_buf[chunk_size + padding.len + cvs_len ..][0..n_enc.len], n_enc.slice());
|
|
||||||
@memcpy(msg_buf[chunk_size + padding.len + cvs_len + n_enc.len ..][0..terminator.len], &terminator);
|
|
||||||
break :blk @alignCast(msg_buf[chunk_size + padding.len ..][0..cvs_len]);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Calculate how many full (complete chunk_size) leaves we have
|
|
||||||
const full_leaves = remaining_bytes / chunk_size;
|
const full_leaves = remaining_bytes / chunk_size;
|
||||||
|
|
||||||
// Check if there's a partial final leaf (less than chunk_size)
|
|
||||||
const has_partial_leaf = (remaining_bytes % chunk_size) != 0;
|
const has_partial_leaf = (remaining_bytes % chunk_size) != 0;
|
||||||
const partial_leaf_size = if (has_partial_leaf) remaining_bytes % chunk_size else 0;
|
const partial_leaf_size = if (has_partial_leaf) remaining_bytes % chunk_size else 0;
|
||||||
|
|
||||||
// Number of leaves (chunks) per batch in multi-threaded mode
|
const max_concurrent_batches = 256;
|
||||||
const leaves_per_batch = bytes_per_batch / chunk_size;
|
const cvs_per_super_batch = max_concurrent_batches * leaves_per_batch * cv_size;
|
||||||
|
|
||||||
// Calculate number of full thread tasks based on complete leaves only
|
const cv_buf = try allocator.alignedAlloc(u8, std.mem.Alignment.of(u64), cvs_per_super_batch);
|
||||||
const num_full_tasks = full_leaves / leaves_per_batch;
|
defer allocator.free(cv_buf);
|
||||||
const remaining_full_leaves = full_leaves % leaves_per_batch;
|
|
||||||
|
var leaves_processed: usize = 0;
|
||||||
|
while (leaves_processed < full_leaves) {
|
||||||
|
const leaves_in_super_batch = @min(max_concurrent_batches * leaves_per_batch, full_leaves - leaves_processed);
|
||||||
|
const num_batches = std.math.divCeil(usize, leaves_in_super_batch, leaves_per_batch) catch unreachable;
|
||||||
|
|
||||||
var group: Io.Group = .init;
|
var group: Io.Group = .init;
|
||||||
|
|
||||||
// Spawn tasks for full SIMD batches
|
for (0..num_batches) |batch_idx| {
|
||||||
for (0..num_full_tasks) |task_id| {
|
const batch_start_leaf = leaves_processed + batch_idx * leaves_per_batch;
|
||||||
const start_offset = chunk_size + task_id * bytes_per_batch;
|
const batch_leaves = @min(leaves_per_batch, full_leaves - batch_start_leaf);
|
||||||
const cv_start = task_id * leaves_per_batch * cv_size;
|
|
||||||
|
if (batch_leaves == 0) break;
|
||||||
|
|
||||||
|
const start_offset = chunk_size + batch_start_leaf * chunk_size;
|
||||||
|
const cv_start = batch_idx * leaves_per_batch * cv_size;
|
||||||
|
|
||||||
group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){
|
group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){
|
||||||
.view = view,
|
.view = view,
|
||||||
.start_offset = start_offset,
|
.start_offset = start_offset,
|
||||||
.num_leaves = leaves_per_batch,
|
.num_leaves = batch_leaves,
|
||||||
.output_cvs = @alignCast(cvs[cv_start..][0 .. leaves_per_batch * cv_size]),
|
.output_cvs = @alignCast(cv_buf[cv_start..][0 .. batch_leaves * cv_size]),
|
||||||
}});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Spawn task for remaining full leaves (if any)
|
|
||||||
if (remaining_full_leaves > 0) {
|
|
||||||
const start_offset = chunk_size + num_full_tasks * bytes_per_batch;
|
|
||||||
const cv_start = num_full_tasks * leaves_per_batch * cv_size;
|
|
||||||
|
|
||||||
group.async(io, LeafThreadContext(Variant).process, .{LeafThreadContext(Variant){
|
|
||||||
.view = view,
|
|
||||||
.start_offset = start_offset,
|
|
||||||
.num_leaves = remaining_full_leaves,
|
|
||||||
.output_cvs = @alignCast(cvs[cv_start..][0 .. remaining_full_leaves * cv_size]),
|
|
||||||
}});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Spawn task for the partial final leaf (if required)
|
|
||||||
if (has_partial_leaf) {
|
|
||||||
const start_offset = chunk_size + full_leaves * chunk_size;
|
|
||||||
const cv_start = full_leaves * cv_size;
|
|
||||||
|
|
||||||
group.async(io, FinalLeafContext(Variant).process, .{FinalLeafContext(Variant){
|
|
||||||
.view = view,
|
|
||||||
.start_offset = start_offset,
|
|
||||||
.leaf_len = partial_leaf_size,
|
|
||||||
.output_cv = @alignCast(cvs[cv_start..][0..cv_size]),
|
|
||||||
}});
|
}});
|
||||||
}
|
}
|
||||||
|
|
||||||
group.wait(io);
|
group.wait(io);
|
||||||
|
|
||||||
const final_view = if (first_chunk_direct) |first_chunk|
|
final_state.update(cv_buf[0 .. leaves_in_super_batch * cv_size]);
|
||||||
MultiSliceView.init(first_chunk, msg_buf, &[_]u8{})
|
leaves_processed += leaves_in_super_batch;
|
||||||
else
|
}
|
||||||
MultiSliceView.init(msg_buf, &[_]u8{}, &[_]u8{});
|
|
||||||
Variant.turboShakeToBuffer(&final_view, 0x06, output);
|
if (has_partial_leaf) {
|
||||||
|
var cv_buffer: [64]u8 = undefined;
|
||||||
|
var leaf_buffer: [chunk_size]u8 = undefined;
|
||||||
|
|
||||||
|
const start_offset = chunk_size + full_leaves * chunk_size;
|
||||||
|
if (view.tryGetSlice(start_offset, start_offset + partial_leaf_size)) |leaf_data| {
|
||||||
|
const cv_slice = MultiSliceView.init(leaf_data, &[_]u8{}, &[_]u8{});
|
||||||
|
Variant.turboShakeToBuffer(&cv_slice, 0x0B, cv_buffer[0..cv_size]);
|
||||||
|
} else {
|
||||||
|
view.copyRange(start_offset, start_offset + partial_leaf_size, leaf_buffer[0..partial_leaf_size]);
|
||||||
|
const cv_slice = MultiSliceView.init(leaf_buffer[0..partial_leaf_size], &[_]u8{}, &[_]u8{});
|
||||||
|
Variant.turboShakeToBuffer(&cv_slice, 0x0B, cv_buffer[0..cv_size]);
|
||||||
|
}
|
||||||
|
final_state.update(cv_buffer[0..cv_size]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const n_enc = rightEncode(total_leaves);
|
||||||
|
final_state.update(n_enc.slice());
|
||||||
|
const terminator = [_]u8{ 0xFF, 0xFF };
|
||||||
|
final_state.update(&terminator);
|
||||||
|
|
||||||
|
final_state.final(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic KangarooTwelve hash function builder.
|
/// Generic KangarooTwelve hash function builder.
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue