std.MultiArrayList: add support for tagged unions.

This commit is contained in:
GethDW 2023-03-23 10:00:10 +00:00 committed by Andrew Kelley
parent 88dfb13818
commit 2c639d6570
3 changed files with 144 additions and 41 deletions

View file

@ -1,4 +1,4 @@
const std = @import("std.zig"); const std = @import("std");
const builtin = @import("builtin"); const builtin = @import("builtin");
const assert = std.debug.assert; const assert = std.debug.assert;
const meta = std.meta; const meta = std.meta;
@ -6,24 +6,57 @@ const mem = std.mem;
const Allocator = mem.Allocator; const Allocator = mem.Allocator;
const testing = std.testing; const testing = std.testing;
/// A MultiArrayList stores a list of a struct type. /// A MultiArrayList stores a list of a struct or tagged union type.
/// Instead of storing a single list of items, MultiArrayList /// Instead of storing a single list of items, MultiArrayList
/// stores separate lists for each field of the struct. /// stores separate lists for each field of the struct or
/// This allows for memory savings if the struct has padding, /// lists of tags and bare unions.
/// and also improves cache usage if only some fields are needed /// This allows for memory savings if the struct or union has padding,
/// for a computation. The primary API for accessing fields is /// and also improves cache usage if only some fields or or just tags
/// are needed for a computation. The primary API for accessing fields is
/// the `slice()` function, which computes the start pointers /// the `slice()` function, which computes the start pointers
/// for the array of each field. From the slice you can call /// for the array of each field. From the slice you can call
/// `.items(.<field_name>)` to obtain a slice of field values. /// `.items(.<field_name>)` to obtain a slice of field values.
pub fn MultiArrayList(comptime S: type) type { /// For unions you can call `.items(.tags)` or `.items(.data)`.
pub fn MultiArrayList(comptime T: type) type {
return struct { return struct {
bytes: [*]align(@alignOf(S)) u8 = undefined, bytes: [*]align(@alignOf(T)) u8 = undefined,
len: usize = 0, len: usize = 0,
capacity: usize = 0, capacity: usize = 0,
pub const Elem = S; const Elem = switch (@typeInfo(T)) {
.Struct => T,
.Union => |u| struct {
pub const Bare =
@Type(.{ .Union = .{
.layout = u.layout,
.tag_type = null,
.fields = u.fields,
.decls = &.{},
} });
pub const Tag =
u.tag_type orelse @compileError("MultiArrayList does not support untagged unions");
tags: Tag,
data: Bare,
pub const Field = meta.FieldEnum(S); pub fn fromT(outer: T) @This() {
const tag = meta.activeTag(outer);
return .{
.tags = tag,
.data = switch (tag) {
inline else => |t| @unionInit(Bare, @tagName(t), @field(outer, @tagName(t))),
},
};
}
pub fn toT(tag: Tag, bare: Bare) T {
return switch (tag) {
inline else => |t| @unionInit(T, @tagName(t), @field(bare, @tagName(t))),
};
}
},
else => @compileError("MultiArrayList only supports structs and tagged unions"),
};
pub const Field = meta.FieldEnum(Elem);
/// A MultiArrayList.Slice contains cached start pointers for each field in the list. /// A MultiArrayList.Slice contains cached start pointers for each field in the list.
/// These pointers are not normally stored to reduce the size of the list in memory. /// These pointers are not normally stored to reduce the size of the list in memory.
@ -49,18 +82,27 @@ pub fn MultiArrayList(comptime S: type) type {
return casted_ptr[0..self.len]; return casted_ptr[0..self.len];
} }
pub fn set(self: Slice, index: usize, elem: S) void { pub fn set(self: *Slice, index: usize, elem: T) void {
inline for (fields) |field_info| { const e = switch (@typeInfo(T)) {
self.items(@field(Field, field_info.name))[index] = @field(elem, field_info.name); .Struct => elem,
.Union => Elem.fromT(elem),
else => unreachable,
};
inline for (fields, 0..) |field_info, i| {
self.items(@intToEnum(Field, i))[index] = @field(e, field_info.name);
} }
} }
pub fn get(self: Slice, index: usize) S { pub fn get(self: Slice, index: usize) T {
var elem: S = undefined; var result: Elem = undefined;
inline for (fields) |field_info| { inline for (fields, 0..) |field_info, i| {
@field(elem, field_info.name) = self.items(@field(Field, field_info.name))[index]; @field(result, field_info.name) = self.items(@intToEnum(Field, i))[index];
} }
return elem; return switch (@typeInfo(T)) {
.Struct => result,
.Union => Elem.toT(result.tags, result.data),
else => unreachable,
};
} }
pub fn toMultiArrayList(self: Slice) Self { pub fn toMultiArrayList(self: Slice) Self {
@ -68,8 +110,8 @@ pub fn MultiArrayList(comptime S: type) type {
return .{}; return .{};
} }
const unaligned_ptr = self.ptrs[sizes.fields[0]]; const unaligned_ptr = self.ptrs[sizes.fields[0]];
const aligned_ptr = @alignCast(@alignOf(S), unaligned_ptr); const aligned_ptr = @alignCast(@alignOf(Elem), unaligned_ptr);
const casted_ptr = @ptrCast([*]align(@alignOf(S)) u8, aligned_ptr); const casted_ptr = @ptrCast([*]align(@alignOf(Elem)) u8, aligned_ptr);
return .{ return .{
.bytes = casted_ptr, .bytes = casted_ptr,
.len = self.len, .len = self.len,
@ -85,7 +127,7 @@ pub fn MultiArrayList(comptime S: type) type {
/// This function is used in the debugger pretty formatters in tools/ to fetch the /// This function is used in the debugger pretty formatters in tools/ to fetch the
/// child field order and entry type to facilitate fancy debug printing for this type. /// child field order and entry type to facilitate fancy debug printing for this type.
fn dbHelper(self: *Slice, child: *S, field: *Field, entry: *Entry) void { fn dbHelper(self: *Slice, child: *Elem, field: *Field, entry: *Entry) void {
_ = self; _ = self;
_ = child; _ = child;
_ = field; _ = field;
@ -95,8 +137,8 @@ pub fn MultiArrayList(comptime S: type) type {
const Self = @This(); const Self = @This();
const fields = meta.fields(S); const fields = meta.fields(Elem);
/// `sizes.bytes` is an array of @sizeOf each S field. Sorted by alignment, descending. /// `sizes.bytes` is an array of @sizeOf each T field. Sorted by alignment, descending.
/// `sizes.fields` is an array mapping from `sizes.bytes` array index to field index. /// `sizes.fields` is an array mapping from `sizes.bytes` array index to field index.
const sizes = blk: { const sizes = blk: {
const Data = struct { const Data = struct {
@ -169,24 +211,25 @@ pub fn MultiArrayList(comptime S: type) type {
} }
/// Overwrite one array element with new data. /// Overwrite one array element with new data.
pub fn set(self: *Self, index: usize, elem: S) void { pub fn set(self: *Self, index: usize, elem: T) void {
return self.slice().set(index, elem); var slices = self.slice();
slices.set(index, elem);
} }
/// Obtain all the data for one array element. /// Obtain all the data for one array element.
pub fn get(self: Self, index: usize) S { pub fn get(self: Self, index: usize) T {
return self.slice().get(index); return self.slice().get(index);
} }
/// Extend the list by 1 element. Allocates more memory as necessary. /// Extend the list by 1 element. Allocates more memory as necessary.
pub fn append(self: *Self, gpa: Allocator, elem: S) !void { pub fn append(self: *Self, gpa: Allocator, elem: T) !void {
try self.ensureUnusedCapacity(gpa, 1); try self.ensureUnusedCapacity(gpa, 1);
self.appendAssumeCapacity(elem); self.appendAssumeCapacity(elem);
} }
/// Extend the list by 1 element, but asserting `self.capacity` /// Extend the list by 1 element, but asserting `self.capacity`
/// is sufficient to hold an additional item. /// is sufficient to hold an additional item.
pub fn appendAssumeCapacity(self: *Self, elem: S) void { pub fn appendAssumeCapacity(self: *Self, elem: T) void {
assert(self.len < self.capacity); assert(self.len < self.capacity);
self.len += 1; self.len += 1;
self.set(self.len - 1, elem); self.set(self.len - 1, elem);
@ -213,7 +256,7 @@ pub fn MultiArrayList(comptime S: type) type {
/// Remove and return the last element from the list. /// Remove and return the last element from the list.
/// Asserts the list has at least one item. /// Asserts the list has at least one item.
/// Invalidates pointers to fields of the removed element. /// Invalidates pointers to fields of the removed element.
pub fn pop(self: *Self) S { pub fn pop(self: *Self) T {
const val = self.get(self.len - 1); const val = self.get(self.len - 1);
self.len -= 1; self.len -= 1;
return val; return val;
@ -222,7 +265,7 @@ pub fn MultiArrayList(comptime S: type) type {
/// Remove and return the last element from the list, or /// Remove and return the last element from the list, or
/// return `null` if list is empty. /// return `null` if list is empty.
/// Invalidates pointers to fields of the removed element, if any. /// Invalidates pointers to fields of the removed element, if any.
pub fn popOrNull(self: *Self) ?S { pub fn popOrNull(self: *Self) ?T {
if (self.len == 0) return null; if (self.len == 0) return null;
return self.pop(); return self.pop();
} }
@ -231,7 +274,7 @@ pub fn MultiArrayList(comptime S: type) type {
/// after and including the specified index back by one and /// after and including the specified index back by one and
/// sets the given index to the specified element. May reallocate /// sets the given index to the specified element. May reallocate
/// and invalidate iterators. /// and invalidate iterators.
pub fn insert(self: *Self, gpa: Allocator, index: usize, elem: S) !void { pub fn insert(self: *Self, gpa: Allocator, index: usize, elem: T) !void {
try self.ensureUnusedCapacity(gpa, 1); try self.ensureUnusedCapacity(gpa, 1);
self.insertAssumeCapacity(index, elem); self.insertAssumeCapacity(index, elem);
} }
@ -240,10 +283,15 @@ pub fn MultiArrayList(comptime S: type) type {
/// Shifts all elements after and including the specified index /// Shifts all elements after and including the specified index
/// back by one and sets the given index to the specified element. /// back by one and sets the given index to the specified element.
/// Will not reallocate the array, does not invalidate iterators. /// Will not reallocate the array, does not invalidate iterators.
pub fn insertAssumeCapacity(self: *Self, index: usize, elem: S) void { pub fn insertAssumeCapacity(self: *Self, index: usize, elem: T) void {
assert(self.len < self.capacity); assert(self.len < self.capacity);
assert(index <= self.len); assert(index <= self.len);
self.len += 1; self.len += 1;
const entry = switch (@typeInfo(T)) {
.Struct => elem,
.Union => Elem.fromT(elem),
else => unreachable,
};
const slices = self.slice(); const slices = self.slice();
inline for (fields, 0..) |field_info, field_index| { inline for (fields, 0..) |field_info, field_index| {
const field_slice = slices.items(@intToEnum(Field, field_index)); const field_slice = slices.items(@intToEnum(Field, field_index));
@ -251,7 +299,7 @@ pub fn MultiArrayList(comptime S: type) type {
while (i > index) : (i -= 1) { while (i > index) : (i -= 1) {
field_slice[i] = field_slice[i - 1]; field_slice[i] = field_slice[i - 1];
} }
field_slice[index] = @field(elem, field_info.name); field_slice[index] = @field(entry, field_info.name);
} }
} }
@ -304,7 +352,7 @@ pub fn MultiArrayList(comptime S: type) type {
const other_bytes = gpa.alignedAlloc( const other_bytes = gpa.alignedAlloc(
u8, u8,
@alignOf(S), @alignOf(Elem),
capacityInBytes(new_len), capacityInBytes(new_len),
) catch { ) catch {
const self_slice = self.slice(); const self_slice = self.slice();
@ -375,7 +423,7 @@ pub fn MultiArrayList(comptime S: type) type {
assert(new_capacity >= self.len); assert(new_capacity >= self.len);
const new_bytes = try gpa.alignedAlloc( const new_bytes = try gpa.alignedAlloc(
u8, u8,
@alignOf(S), @alignOf(Elem),
capacityInBytes(new_capacity), capacityInBytes(new_capacity),
); );
if (self.len == 0) { if (self.len == 0) {
@ -453,12 +501,12 @@ pub fn MultiArrayList(comptime S: type) type {
return elem_bytes * capacity; return elem_bytes * capacity;
} }
fn allocatedBytes(self: Self) []align(@alignOf(S)) u8 { fn allocatedBytes(self: Self) []align(@alignOf(Elem)) u8 {
return self.bytes[0..capacityInBytes(self.capacity)]; return self.bytes[0..capacityInBytes(self.capacity)];
} }
fn FieldType(comptime field: Field) type { fn FieldType(comptime field: Field) type {
return meta.fieldInfo(S, field).type; return meta.fieldInfo(Elem, field).type;
} }
const Entry = entry: { const Entry = entry: {
@ -479,7 +527,7 @@ pub fn MultiArrayList(comptime S: type) type {
}; };
/// This function is used in the debugger pretty formatters in tools/ to fetch the /// This function is used in the debugger pretty formatters in tools/ to fetch the
/// child field order and entry type to facilitate fancy debug printing for this type. /// child field order and entry type to facilitate fancy debug printing for this type.
fn dbHelper(self: *Self, child: *S, field: *Field, entry: *Entry) void { fn dbHelper(self: *Self, child: *Elem, field: *Field, entry: *Entry) void {
_ = self; _ = self;
_ = child; _ = child;
_ = field; _ = field;
@ -719,3 +767,58 @@ test "insert elements" {
try testing.expectEqualSlices(u8, &[_]u8{ 1, 2 }, list.items(.a)); try testing.expectEqualSlices(u8, &[_]u8{ 1, 2 }, list.items(.a));
try testing.expectEqualSlices(u32, &[_]u32{ 2, 3 }, list.items(.b)); try testing.expectEqualSlices(u32, &[_]u32{ 2, 3 }, list.items(.b));
} }
test "union" {
const ally = testing.allocator;
const Foo = union(enum) {
a: u32,
b: []const u8,
};
var list = MultiArrayList(Foo){};
defer list.deinit(ally);
try testing.expectEqual(@as(usize, 0), list.items(.tags).len);
try list.ensureTotalCapacity(ally, 2);
list.appendAssumeCapacity(.{ .a = 1 });
list.appendAssumeCapacity(.{ .b = "zigzag" });
try testing.expectEqualSlices(meta.Tag(Foo), list.items(.tags), &.{ .a, .b });
try testing.expectEqual(@as(usize, 2), list.items(.tags).len);
list.appendAssumeCapacity(.{ .b = "foobar" });
try testing.expectEqualStrings("zigzag", list.items(.data)[1].b);
try testing.expectEqualStrings("foobar", list.items(.data)[2].b);
// Add 6 more things to force a capacity increase.
for (0..6) |i| {
try list.append(ally, .{ .a = @intCast(u32, 4 + i) });
}
try testing.expectEqualSlices(
meta.Tag(Foo),
&.{ .a, .b, .b, .a, .a, .a, .a, .a, .a },
list.items(.tags),
);
try testing.expectEqual(list.get(0), .{ .a = 1 });
try testing.expectEqual(list.get(1), .{ .b = "zigzag" });
try testing.expectEqual(list.get(2), .{ .b = "foobar" });
try testing.expectEqual(list.get(3), .{ .a = 4 });
try testing.expectEqual(list.get(4), .{ .a = 5 });
try testing.expectEqual(list.get(5), .{ .a = 6 });
try testing.expectEqual(list.get(6), .{ .a = 7 });
try testing.expectEqual(list.get(7), .{ .a = 8 });
try testing.expectEqual(list.get(8), .{ .a = 9 });
list.shrinkAndFree(ally, 3);
try testing.expectEqual(@as(usize, 3), list.items(.tags).len);
try testing.expectEqualSlices(meta.Tag(Foo), list.items(.tags), &.{ .a, .b, .b });
try testing.expectEqual(list.get(0), .{ .a = 1 });
try testing.expectEqual(list.get(1), .{ .b = "zigzag" });
try testing.expectEqual(list.get(2), .{ .b = "foobar" });
}

View file

@ -41,13 +41,13 @@ fn listToSpan(p: *Parse, list: []const Node.Index) !Node.SubRange {
}; };
} }
fn addNode(p: *Parse, elem: Ast.NodeList.Elem) Allocator.Error!Node.Index { fn addNode(p: *Parse, elem: Ast.Node) Allocator.Error!Node.Index {
const result = @intCast(Node.Index, p.nodes.len); const result = @intCast(Node.Index, p.nodes.len);
try p.nodes.append(p.gpa, elem); try p.nodes.append(p.gpa, elem);
return result; return result;
} }
fn setNode(p: *Parse, i: usize, elem: Ast.NodeList.Elem) Node.Index { fn setNode(p: *Parse, i: usize, elem: Ast.Node) Node.Index {
p.nodes.set(i, elem); p.nodes.set(i, elem);
return @intCast(Node.Index, i); return @intCast(Node.Index, i);
} }

View file

@ -845,7 +845,7 @@ const Context = struct {
}; };
} }
fn addNode(c: *Context, elem: std.zig.Ast.NodeList.Elem) Allocator.Error!NodeIndex { fn addNode(c: *Context, elem: std.zig.Ast.Node) Allocator.Error!NodeIndex {
const result = @intCast(NodeIndex, c.nodes.len); const result = @intCast(NodeIndex, c.nodes.len);
try c.nodes.append(c.gpa, elem); try c.nodes.append(c.gpa, elem);
return result; return result;