writeSplat

This commit is contained in:
Andrew Kelley 2025-02-15 21:02:42 -08:00
parent b26aceba7d
commit 35824e4822
6 changed files with 121 additions and 171 deletions

View file

@ -1591,8 +1591,7 @@ pub fn writer(file: File) std.io.Writer {
return .{
.context = interface.handleToOpaque(file.handle),
.vtable = &.{
.writev = interface.writev,
.splat = interface.splat,
.writeSplat = interface.writeSplat,
.writeFile = interface.writeFile,
},
};
@ -1610,45 +1609,20 @@ const interface = struct {
/// vectors through the underlying write calls as possible.
const max_buffers_len = 16;
fn writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize {
const file = opaqueToHandle(context);
if (is_windows) {
// TODO improve this to use WriteFileScatter
if (data.len == 0) return 0;
const first = data[0];
return windows.WriteFile(file, first.base[0..first.len], null);
if (data.len == 1 and splat == 0) return 0;
return windows.WriteFile(file, data[0], null);
}
var iovecs_buffer: [max_buffers_len]std.posix.iovec_const = undefined;
const iovecs = iovecs_buffer[0..@min(iovecs_buffer.len, data.len)];
for (iovecs, data[0..iovecs.len]) |*v, d| v.* = .{
.base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length.
.len = d.len,
};
return std.posix.writev(file, iovecs);
}
fn splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize {
const file = opaqueToHandle(context);
if (is_windows) {
// TODO improve this to use WriteFileScatter
if (headers.len > 0) {
const first = headers[0];
return windows.WriteFile(file, first, null);
}
if (n > 0) return windows.WriteFile(file, pattern, null);
return 0;
}
var iovecs_buffer: [max_buffers_len]std.posix.iovec_const = undefined;
const iovecs = iovecs_buffer[0..@min(iovecs_buffer.len, headers.len)];
for (iovecs, headers[0..iovecs.len]) |*v, d| v.* = .{
.base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length.
.len = d.len,
};
return std.posix.writev(file, iovecs);
const send_iovecs = if (splat == 0) iovecs[0 .. iovecs.len - 1] else iovecs;
return std.posix.writev(file, send_iovecs);
}
fn writeFile(
@ -1662,7 +1636,7 @@ const interface = struct {
const out_fd = opaqueToHandle(context);
const in_fd = in_file.handle;
const len_int = switch (in_len) {
.zero => return interface.writev(context, headers_and_trailers),
.zero => return interface.writeSplat(context, headers_and_trailers, 1),
.entire_file => 0,
else => in_len.int(),
};

View file

@ -344,22 +344,16 @@ pub const tty = @import("io/tty.zig");
pub const null_writer: Writer = .{
.context = undefined,
.vtable = &.{
.writev = null_writev,
.splat = null_splat,
.writeSplat = null_writeSplat,
.writeFile = null_writeFile,
},
};
fn null_writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
fn null_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize {
_ = context;
var written: usize = 0;
for (data) |bytes| written += bytes.len;
return written;
}
fn null_splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize {
_ = context;
var written: usize = pattern.len * n;
const headers = data[0 .. data.len - 1];
const pattern = data[headers.len..];
var written: usize = pattern.len * splat;
for (headers) |bytes| written += bytes.len;
return written;
}

View file

@ -21,7 +21,7 @@ allocator: std.mem.Allocator,
buffered_writer: std.io.BufferedWriter,
const vtable: std.io.Writer.VTable = .{
.writev = writev,
.writeSplat = writeSplat,
.writeFile = writeFile,
};
@ -98,35 +98,37 @@ pub fn clearRetainingCapacity(aw: *AllocatingWriter) void {
aw.written.len = 0;
}
fn writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
return splat(context, data, &.{}, 0);
}
fn splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize {
fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize {
const aw: *AllocatingWriter = @alignCast(@ptrCast(context));
const start_len = aw.written.len;
const bw = &aw.buffered_writer;
assert(headers[0].ptr == aw.written.ptr + start_len);
const skip_first = data[0].ptr == aw.written.ptr + start_len;
const items_len = if (skip_first) start_len + data[0].len else start_len;
var list: std.ArrayListUnmanaged(u8) = .{
.items = aw.written.ptr[0 .. start_len + headers[0].len],
.items = aw.written.ptr[0..items_len],
.capacity = start_len + bw.buffer.len,
};
defer setArrayList(aw, list);
const rest = headers[1..];
var new_capacity: usize = list.capacity + pattern.len * n;
const rest = data[1 .. data.len - 1];
const pattern = data[data.len - 1];
var new_capacity: usize = list.capacity + pattern.len * splat;
for (rest) |bytes| new_capacity += bytes.len;
try list.ensureTotalCapacity(aw.allocator, new_capacity + 1);
for (rest) |bytes| list.appendSliceAssumeCapacity(bytes);
if (pattern.len == 1) {
list.appendNTimesAssumeCapacity(pattern[0], n);
} else {
for (0..n) |_| list.appendSliceAssumeCapacity(pattern);
}
appendPatternAssumeCapacity(&list, pattern, splat);
aw.written = list.items;
bw.buffer = list.unusedCapacitySlice();
return list.items.len - start_len;
}
fn appendPatternAssumeCapacity(list: *std.ArrayListUnmanaged(u8), pattern: []const u8, splat: usize) void {
if (pattern.len == 1) {
list.appendNTimesAssumeCapacity(pattern[0], splat);
} else {
for (0..splat) |_| list.appendSliceAssumeCapacity(pattern);
}
}
fn writeFile(
context: *anyopaque,
file: std.fs.File,

View file

@ -33,16 +33,14 @@ pub fn writer(bw: *BufferedWriter) Writer {
return .{
.context = bw,
.vtable = &.{
.writev = passthru_writev,
.splat = passthru_splat,
.write = passthru_writeSplat,
.writeFile = passthru_writeFile,
},
};
}
const fixed_vtable: Writer.VTable = .{
.writev = fixed_writev,
.splat = fixed_splat,
.writeSplat = fixed_writeSplat,
.writeFile = fixed_writeFile,
};
@ -81,7 +79,7 @@ pub fn flush(bw: *BufferedWriter) anyerror!void {
pub fn writevAll(bw: *BufferedWriter, data: []const []const u8) anyerror!void {
var i: usize = 0;
while (true) {
var n = try writev(bw, data[i..]);
var n = try passthru_writeSplat(bw, data[i..], 1);
while (n >= data[i].len) {
n -= data[i].len;
i += 1;
@ -91,14 +89,16 @@ pub fn writevAll(bw: *BufferedWriter, data: []const []const u8) anyerror!void {
}
}
pub fn writev(bw: *BufferedWriter, data: []const []const u8) anyerror!usize {
return passthru_writev(bw, data);
pub fn writeSplat(bw: *BufferedWriter, data: []const []const u8, splat: usize) anyerror!usize {
return passthru_writeSplat(bw, data, splat);
}
fn passthru_writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
fn passthru_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize {
const bw: *BufferedWriter = @alignCast(@ptrCast(context));
const buffer = bw.buffer;
const start_end = bw.end;
var buffers: [max_buffers_len][]const u8 = undefined;
var end = bw.end;
for (data, 0..) |bytes, i| {
const new_end = end + bytes.len;
@ -108,14 +108,17 @@ fn passthru_writev(context: *anyopaque, data: []const []const u8) anyerror!usize
end = new_end;
continue;
}
if (end == 0) return bw.unbuffered_writer.writev(data);
var buffers: [max_buffers_len][]const u8 = undefined;
if (end == 0) return bw.unbuffered_writer.writeSplat(data, splat);
buffers[0] = buffer[0..end];
const remaining_data = data[i..];
const remaining_buffers = buffers[1..];
const len: usize = @min(remaining_data.len, remaining_buffers.len);
@memcpy(remaining_buffers[0..len], remaining_data[0..len]);
const n = try bw.unbuffered_writer.writev(buffers[0 .. len + 1]);
const send_buffers = buffers[0 .. len + 1];
if (len >= remaining_data.len) {
@branchHint(.likely);
// Made it past the headers, so we can enable splatting.
const n = try bw.unbuffered_writer.writeSplat(send_buffers, splat);
if (n < end) {
@branchHint(.unlikely);
const remainder = buffer[n..end];
@ -126,57 +129,29 @@ fn passthru_writev(context: *anyopaque, data: []const []const u8) anyerror!usize
bw.end = 0;
return n - start_end;
}
const n = try bw.unbuffered_writer.writeSplat(send_buffers, 1);
if (n < end) {
@branchHint(.unlikely);
const remainder = buffer[n..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return end - start_end;
}
bw.end = 0;
return n - start_end;
}
const pattern = data[data.len - 1];
if (splat == 0) {
@branchHint(.unlikely);
// It was added in the loop above; undo it here.
end -= pattern.len;
bw.end = end;
return end - start_end;
}
fn passthru_splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize {
const bw: *BufferedWriter = @alignCast(@ptrCast(context));
const buffer = bw.buffer;
const start_end = bw.end;
var end = bw.end;
for (headers, 0..) |bytes, i| {
const new_end = end + bytes.len;
if (new_end <= buffer.len) {
@branchHint(.likely);
@memcpy(buffer[end..new_end], bytes);
end = new_end;
continue;
}
if (end == 0) return bw.unbuffered_writer.splat(headers, pattern, n);
var buffers: [max_buffers_len][]const u8 = undefined;
buffers[0] = buffer[0..end];
const remaining_headers = headers[i..];
const remaining_buffers = buffers[1..];
const len: usize = @min(remaining_headers.len, remaining_buffers.len);
@memcpy(remaining_buffers[0..len], remaining_headers[0..len]);
const send_buffers = buffers[0 .. len + 1];
if (len >= remaining_headers.len) {
@branchHint(.likely);
// Made it past the headers, so we can call `splat`.
const written = try bw.unbuffered_writer.splat(send_buffers, pattern, n);
if (written < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return end - start_end;
}
bw.end = 0;
return written - start_end;
}
const written = try bw.unbuffered_writer.writev(send_buffers);
if (written < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return end - start_end;
}
bw.end = 0;
return written - start_end;
}
const remaining_splat = splat - 1;
switch (pattern.len) {
0 => {
@ -184,26 +159,28 @@ fn passthru_splat(context: *anyopaque, headers: []const []const u8, pattern: []c
return end - start_end;
},
1 => {
const new_end = end + n;
const new_end = end + remaining_splat;
if (new_end <= buffer.len) {
@branchHint(.likely);
@memset(buffer[end..new_end], pattern[0]);
bw.end = new_end;
return end - start_end;
}
const written = try bw.unbuffered_writer.splat(buffer[0..end], pattern, n);
if (written < end) {
buffers[0] = buffer[0..end];
buffers[1] = pattern;
const n = try bw.unbuffered_writer.writeSplat(buffers[0..2], remaining_splat);
if (n < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
const remainder = buffer[n..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return end - start_end;
}
bw.end = 0;
return written - start_end;
return n - start_end;
},
else => {
const new_end = end + pattern.len * n;
const new_end = end + pattern.len * remaining_splat;
if (new_end <= buffer.len) {
@branchHint(.likely);
while (end < new_end) : (end += pattern.len) {
@ -212,16 +189,18 @@ fn passthru_splat(context: *anyopaque, headers: []const []const u8, pattern: []c
bw.end = end;
return end - start_end;
}
const written = try bw.unbuffered_writer.splat(buffer[0..end], pattern, n);
if (written < end) {
buffers[0] = buffer[0..end];
buffers[1] = pattern;
const n = try bw.unbuffered_writer.writeSplat(buffers[0..2], remaining_splat);
if (n < end) {
@branchHint(.unlikely);
const remainder = buffer[written..end];
const remainder = buffer[n..end];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
bw.end = remainder.len;
return end - start_end;
}
bw.end = 0;
return written - start_end;
return n - start_end;
},
}
}
@ -237,15 +216,24 @@ fn fixed_writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
return error.NoSpaceLeft;
}
fn fixed_splat(context: *anyopaque, headers: []const []const u8, pattern: []const u8, n: usize) anyerror!usize {
/// When this function is called it means the buffer got full, so it's time
/// to return an error. However, we still need to make sure all of the
/// available buffer has been filled.
fn fixed_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize {
const bw: *BufferedWriter = @alignCast(@ptrCast(context));
for (data) |bytes| {
const dest = bw.buffer[bw.end..];
if (headers.len > 0) {
@memcpy(dest, headers[0][0..dest.len]);
} else switch (pattern.len) {
if (dest.len == 0) return error.NoSpaceLeft;
const len = @min(bytes.len, dest.len);
@memcpy(dest[0..len], bytes[0..len]);
bw.end += len;
}
const pattern = data[data.len - 1];
const dest = bw.buffer[bw.end..];
switch (pattern.len) {
0 => unreachable,
1 => @memset(dest, pattern[0]),
else => for (0..n) |i| @memcpy(dest[i * pattern.len ..][0..pattern.len], pattern),
else => for (0..splat - 1) |i| @memcpy(dest[i * pattern.len ..][0..pattern.len], pattern),
}
return error.NoSpaceLeft;
}
@ -329,21 +317,26 @@ pub fn splatByteAll(bw: *BufferedWriter, byte: u8, n: usize) anyerror!void {
///
/// Does maximum of one underlying `Writer.VTable.writev`.
pub fn splatByte(bw: *BufferedWriter, byte: u8, n: usize) anyerror!usize {
return passthru_splat(bw, &.{}, &.{byte}, n);
return passthru_writeSplat(bw, &.{&.{byte}}, n);
}
/// Writes the same slice many times, performing the underlying write call as
/// many times as necessary.
pub fn splatBytesAll(bw: *BufferedWriter, bytes: []const u8, n: usize) anyerror!void {
var remaining: usize = n * bytes.len;
while (remaining > 0) remaining -= try splatBytes(bw, bytes, remaining);
pub fn splatBytesAll(bw: *BufferedWriter, bytes: []const u8, splat: usize) anyerror!void {
var remaining_bytes: usize = bytes.len * splat;
remaining_bytes -= try splatBytes(bw, bytes, splat);
while (remaining_bytes > 0) {
const leftover = remaining_bytes % bytes.len;
const buffers: [2][]const u8 = .{ bytes[bytes.len - leftover ..], bytes };
remaining_bytes -= try splatBytes(bw, &buffers, splat);
}
}
/// Writes the same slice many times, allowing short writes.
///
/// Does maximum of one underlying `Writer.VTable.writev`.
pub fn splatBytes(bw: *BufferedWriter, bytes: []const u8, n: usize) anyerror!usize {
return passthru_splat(bw, &.{}, bytes, n);
return passthru_writeSplat(bw, &.{bytes}, n);
}
/// Asserts the `buffer` was initialized with a capacity of at least `@sizeOf(T)` bytes.

View file

@ -13,8 +13,7 @@ pub fn writer(cw: *CountingWriter) Writer {
return .{
.context = cw,
.vtable = &.{
.writev = passthru_writev,
.splat = passthru_splat,
.writeSplat = passthru_writeSplat,
.writeFile = passthru_writeFile,
},
};
@ -27,18 +26,11 @@ pub fn unbufferedWriter(cw: *CountingWriter) std.io.BufferedWriter {
};
}
fn passthru_writev(context: *anyopaque, data: []const []const u8) anyerror!usize {
fn passthru_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize {
const cw: *CountingWriter = @alignCast(@ptrCast(context));
const written = try cw.child_writer.writev(data);
cw.bytes_written += written;
return written;
}
fn passthru_splat(context: *anyopaque, header: []const u8, pattern: []const u8, n: usize) anyerror!usize {
const cw: *CountingWriter = @alignCast(@ptrCast(context));
const written = try cw.child_writer.splat(header, pattern, n);
cw.bytes_written += written;
return written;
const n = try cw.child_writer.writeSplat(data, splat);
cw.bytes_written += n;
return n;
}
fn passthru_writeFile(
@ -50,9 +42,9 @@ fn passthru_writeFile(
headers_len: usize,
) anyerror!usize {
const cw: *CountingWriter = @alignCast(@ptrCast(context));
const written = try cw.child_writer.writeFile(file, offset, len, headers_and_trailers, headers_len);
cw.bytes_written += written;
return written;
const n = try cw.child_writer.writeFile(file, offset, len, headers_and_trailers, headers_len);
cw.bytes_written += n;
return n;
}
test CountingWriter {

View file

@ -8,25 +8,16 @@ vtable: *const VTable,
pub const VTable = struct {
/// Each slice in `data` is written in order.
///
/// Number of bytes actually written is returned.
///
/// Number of bytes returned may be zero, which does not mean
/// end-of-stream. A subsequent call may return nonzero, or may signal end
/// of stream via an error.
writev: *const fn (context: *anyopaque, data: []const []const u8) anyerror!usize,
/// `headers_and_pattern` must have length of at least one. The last slice
/// is `pattern` which is the byte sequence to repeat `n` times. The rest
/// of the slices are headers to write before the pattern.
///
/// When `n == 1`, this is equivalent to `writev`.
/// `data.len` must be greater than zero, and the last element of `data` is
/// special. It is repeated as necessary so that it is written `splat`
/// number of times.
///
/// Number of bytes actually written is returned.
///
/// Number of bytes returned may be zero, which does not mean
/// end-of-stream. A subsequent call may return nonzero, or may signal end
/// of stream via an error.
splat: *const fn (context: *anyopaque, headers_and_pattern: []const []const u8, n: usize) anyerror!usize,
writeSplat: *const fn (context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize,
/// Writes contents from an open file. `headers` are written first, then `len`
/// bytes of `file` starting from `offset`, then `trailers`.
@ -67,7 +58,11 @@ pub const VTable = struct {
};
pub fn writev(w: Writer, data: []const []const u8) anyerror!usize {
return w.vtable.writev(w.context, data);
return w.vtable.writeSplat(w.context, data, 1);
}
pub fn writeSplat(w: Writer, data: []const []const u8, splat: usize) anyerror!usize {
return w.vtable.writeSplat(w.context, data, splat);
}
pub fn writeFile(
@ -83,12 +78,12 @@ pub fn writeFile(
pub fn write(w: Writer, bytes: []const u8) anyerror!usize {
const single: [1][]const u8 = .{bytes};
return w.vtable.writev(w.context, &single);
return w.vtable.writeSplat(w.context, &single, 1);
}
pub fn writeAll(w: Writer, bytes: []const u8) anyerror!void {
var index: usize = 0;
while (index < bytes.len) index += try write(w, bytes[index..]);
while (index < bytes.len) index += try w.vtable.writeSplat(w.context, &.{bytes[index..]}, 1);
}
///// Directly calls `writeAll` many times to render the formatted text. To
@ -102,7 +97,7 @@ pub fn writeAll(w: Writer, bytes: []const u8) anyerror!void {
pub fn writevAll(w: Writer, data: [][]const u8) anyerror!void {
var i: usize = 0;
while (true) {
var n = try w.vtable.writev(w.context, data[i..]);
var n = try w.vtable.writeSplat(w.context, data[i..], 1);
while (n >= data[i].len) {
n -= data[i].len;
i += 1;