From 2c639d657002ac66749d08c4977cbb201d113ce1 Mon Sep 17 00:00:00 2001 From: GethDW Date: Thu, 23 Mar 2023 10:00:10 +0000 Subject: [PATCH] std.MultiArrayList: add support for tagged unions. --- lib/std/multi_array_list.zig | 179 +++++++++++++++++++++++++++-------- lib/std/zig/Parse.zig | 4 +- src/translate_c/ast.zig | 2 +- 3 files changed, 144 insertions(+), 41 deletions(-) diff --git a/lib/std/multi_array_list.zig b/lib/std/multi_array_list.zig index b9624692fc..e7b9586eba 100644 --- a/lib/std/multi_array_list.zig +++ b/lib/std/multi_array_list.zig @@ -1,4 +1,4 @@ -const std = @import("std.zig"); +const std = @import("std"); const builtin = @import("builtin"); const assert = std.debug.assert; const meta = std.meta; @@ -6,24 +6,57 @@ const mem = std.mem; const Allocator = mem.Allocator; const testing = std.testing; -/// A MultiArrayList stores a list of a struct type. +/// A MultiArrayList stores a list of a struct or tagged union type. /// Instead of storing a single list of items, MultiArrayList -/// stores separate lists for each field of the struct. -/// This allows for memory savings if the struct has padding, -/// and also improves cache usage if only some fields are needed -/// for a computation. The primary API for accessing fields is +/// stores separate lists for each field of the struct or +/// lists of tags and bare unions. +/// This allows for memory savings if the struct or union has padding, +/// and also improves cache usage if only some fields or or just tags +/// are needed for a computation. The primary API for accessing fields is /// the `slice()` function, which computes the start pointers /// for the array of each field. From the slice you can call /// `.items(.)` to obtain a slice of field values. -pub fn MultiArrayList(comptime S: type) type { +/// For unions you can call `.items(.tags)` or `.items(.data)`. +pub fn MultiArrayList(comptime T: type) type { return struct { - bytes: [*]align(@alignOf(S)) u8 = undefined, + bytes: [*]align(@alignOf(T)) u8 = undefined, len: usize = 0, capacity: usize = 0, - pub const Elem = S; + const Elem = switch (@typeInfo(T)) { + .Struct => T, + .Union => |u| struct { + pub const Bare = + @Type(.{ .Union = .{ + .layout = u.layout, + .tag_type = null, + .fields = u.fields, + .decls = &.{}, + } }); + pub const Tag = + u.tag_type orelse @compileError("MultiArrayList does not support untagged unions"); + tags: Tag, + data: Bare, - pub const Field = meta.FieldEnum(S); + pub fn fromT(outer: T) @This() { + const tag = meta.activeTag(outer); + return .{ + .tags = tag, + .data = switch (tag) { + inline else => |t| @unionInit(Bare, @tagName(t), @field(outer, @tagName(t))), + }, + }; + } + pub fn toT(tag: Tag, bare: Bare) T { + return switch (tag) { + inline else => |t| @unionInit(T, @tagName(t), @field(bare, @tagName(t))), + }; + } + }, + else => @compileError("MultiArrayList only supports structs and tagged unions"), + }; + + pub const Field = meta.FieldEnum(Elem); /// A MultiArrayList.Slice contains cached start pointers for each field in the list. /// These pointers are not normally stored to reduce the size of the list in memory. @@ -49,18 +82,27 @@ pub fn MultiArrayList(comptime S: type) type { return casted_ptr[0..self.len]; } - pub fn set(self: Slice, index: usize, elem: S) void { - inline for (fields) |field_info| { - self.items(@field(Field, field_info.name))[index] = @field(elem, field_info.name); + pub fn set(self: *Slice, index: usize, elem: T) void { + const e = switch (@typeInfo(T)) { + .Struct => elem, + .Union => Elem.fromT(elem), + else => unreachable, + }; + inline for (fields, 0..) |field_info, i| { + self.items(@intToEnum(Field, i))[index] = @field(e, field_info.name); } } - pub fn get(self: Slice, index: usize) S { - var elem: S = undefined; - inline for (fields) |field_info| { - @field(elem, field_info.name) = self.items(@field(Field, field_info.name))[index]; + pub fn get(self: Slice, index: usize) T { + var result: Elem = undefined; + inline for (fields, 0..) |field_info, i| { + @field(result, field_info.name) = self.items(@intToEnum(Field, i))[index]; } - return elem; + return switch (@typeInfo(T)) { + .Struct => result, + .Union => Elem.toT(result.tags, result.data), + else => unreachable, + }; } pub fn toMultiArrayList(self: Slice) Self { @@ -68,8 +110,8 @@ pub fn MultiArrayList(comptime S: type) type { return .{}; } const unaligned_ptr = self.ptrs[sizes.fields[0]]; - const aligned_ptr = @alignCast(@alignOf(S), unaligned_ptr); - const casted_ptr = @ptrCast([*]align(@alignOf(S)) u8, aligned_ptr); + const aligned_ptr = @alignCast(@alignOf(Elem), unaligned_ptr); + const casted_ptr = @ptrCast([*]align(@alignOf(Elem)) u8, aligned_ptr); return .{ .bytes = casted_ptr, .len = self.len, @@ -85,7 +127,7 @@ pub fn MultiArrayList(comptime S: type) type { /// This function is used in the debugger pretty formatters in tools/ to fetch the /// child field order and entry type to facilitate fancy debug printing for this type. - fn dbHelper(self: *Slice, child: *S, field: *Field, entry: *Entry) void { + fn dbHelper(self: *Slice, child: *Elem, field: *Field, entry: *Entry) void { _ = self; _ = child; _ = field; @@ -95,8 +137,8 @@ pub fn MultiArrayList(comptime S: type) type { const Self = @This(); - const fields = meta.fields(S); - /// `sizes.bytes` is an array of @sizeOf each S field. Sorted by alignment, descending. + const fields = meta.fields(Elem); + /// `sizes.bytes` is an array of @sizeOf each T field. Sorted by alignment, descending. /// `sizes.fields` is an array mapping from `sizes.bytes` array index to field index. const sizes = blk: { const Data = struct { @@ -169,24 +211,25 @@ pub fn MultiArrayList(comptime S: type) type { } /// Overwrite one array element with new data. - pub fn set(self: *Self, index: usize, elem: S) void { - return self.slice().set(index, elem); + pub fn set(self: *Self, index: usize, elem: T) void { + var slices = self.slice(); + slices.set(index, elem); } /// Obtain all the data for one array element. - pub fn get(self: Self, index: usize) S { + pub fn get(self: Self, index: usize) T { return self.slice().get(index); } /// Extend the list by 1 element. Allocates more memory as necessary. - pub fn append(self: *Self, gpa: Allocator, elem: S) !void { + pub fn append(self: *Self, gpa: Allocator, elem: T) !void { try self.ensureUnusedCapacity(gpa, 1); self.appendAssumeCapacity(elem); } /// Extend the list by 1 element, but asserting `self.capacity` /// is sufficient to hold an additional item. - pub fn appendAssumeCapacity(self: *Self, elem: S) void { + pub fn appendAssumeCapacity(self: *Self, elem: T) void { assert(self.len < self.capacity); self.len += 1; self.set(self.len - 1, elem); @@ -213,7 +256,7 @@ pub fn MultiArrayList(comptime S: type) type { /// Remove and return the last element from the list. /// Asserts the list has at least one item. /// Invalidates pointers to fields of the removed element. - pub fn pop(self: *Self) S { + pub fn pop(self: *Self) T { const val = self.get(self.len - 1); self.len -= 1; return val; @@ -222,7 +265,7 @@ pub fn MultiArrayList(comptime S: type) type { /// Remove and return the last element from the list, or /// return `null` if list is empty. /// Invalidates pointers to fields of the removed element, if any. - pub fn popOrNull(self: *Self) ?S { + pub fn popOrNull(self: *Self) ?T { if (self.len == 0) return null; return self.pop(); } @@ -231,7 +274,7 @@ pub fn MultiArrayList(comptime S: type) type { /// after and including the specified index back by one and /// sets the given index to the specified element. May reallocate /// and invalidate iterators. - pub fn insert(self: *Self, gpa: Allocator, index: usize, elem: S) !void { + pub fn insert(self: *Self, gpa: Allocator, index: usize, elem: T) !void { try self.ensureUnusedCapacity(gpa, 1); self.insertAssumeCapacity(index, elem); } @@ -240,10 +283,15 @@ pub fn MultiArrayList(comptime S: type) type { /// Shifts all elements after and including the specified index /// back by one and sets the given index to the specified element. /// Will not reallocate the array, does not invalidate iterators. - pub fn insertAssumeCapacity(self: *Self, index: usize, elem: S) void { + pub fn insertAssumeCapacity(self: *Self, index: usize, elem: T) void { assert(self.len < self.capacity); assert(index <= self.len); self.len += 1; + const entry = switch (@typeInfo(T)) { + .Struct => elem, + .Union => Elem.fromT(elem), + else => unreachable, + }; const slices = self.slice(); inline for (fields, 0..) |field_info, field_index| { const field_slice = slices.items(@intToEnum(Field, field_index)); @@ -251,7 +299,7 @@ pub fn MultiArrayList(comptime S: type) type { while (i > index) : (i -= 1) { field_slice[i] = field_slice[i - 1]; } - field_slice[index] = @field(elem, field_info.name); + field_slice[index] = @field(entry, field_info.name); } } @@ -304,7 +352,7 @@ pub fn MultiArrayList(comptime S: type) type { const other_bytes = gpa.alignedAlloc( u8, - @alignOf(S), + @alignOf(Elem), capacityInBytes(new_len), ) catch { const self_slice = self.slice(); @@ -375,7 +423,7 @@ pub fn MultiArrayList(comptime S: type) type { assert(new_capacity >= self.len); const new_bytes = try gpa.alignedAlloc( u8, - @alignOf(S), + @alignOf(Elem), capacityInBytes(new_capacity), ); if (self.len == 0) { @@ -453,12 +501,12 @@ pub fn MultiArrayList(comptime S: type) type { return elem_bytes * capacity; } - fn allocatedBytes(self: Self) []align(@alignOf(S)) u8 { + fn allocatedBytes(self: Self) []align(@alignOf(Elem)) u8 { return self.bytes[0..capacityInBytes(self.capacity)]; } fn FieldType(comptime field: Field) type { - return meta.fieldInfo(S, field).type; + return meta.fieldInfo(Elem, field).type; } const Entry = entry: { @@ -479,7 +527,7 @@ pub fn MultiArrayList(comptime S: type) type { }; /// This function is used in the debugger pretty formatters in tools/ to fetch the /// child field order and entry type to facilitate fancy debug printing for this type. - fn dbHelper(self: *Self, child: *S, field: *Field, entry: *Entry) void { + fn dbHelper(self: *Self, child: *Elem, field: *Field, entry: *Entry) void { _ = self; _ = child; _ = field; @@ -719,3 +767,58 @@ test "insert elements" { try testing.expectEqualSlices(u8, &[_]u8{ 1, 2 }, list.items(.a)); try testing.expectEqualSlices(u32, &[_]u32{ 2, 3 }, list.items(.b)); } + +test "union" { + const ally = testing.allocator; + + const Foo = union(enum) { + a: u32, + b: []const u8, + }; + + var list = MultiArrayList(Foo){}; + defer list.deinit(ally); + + try testing.expectEqual(@as(usize, 0), list.items(.tags).len); + + try list.ensureTotalCapacity(ally, 2); + + list.appendAssumeCapacity(.{ .a = 1 }); + list.appendAssumeCapacity(.{ .b = "zigzag" }); + + try testing.expectEqualSlices(meta.Tag(Foo), list.items(.tags), &.{ .a, .b }); + try testing.expectEqual(@as(usize, 2), list.items(.tags).len); + + list.appendAssumeCapacity(.{ .b = "foobar" }); + try testing.expectEqualStrings("zigzag", list.items(.data)[1].b); + try testing.expectEqualStrings("foobar", list.items(.data)[2].b); + + // Add 6 more things to force a capacity increase. + for (0..6) |i| { + try list.append(ally, .{ .a = @intCast(u32, 4 + i) }); + } + + try testing.expectEqualSlices( + meta.Tag(Foo), + &.{ .a, .b, .b, .a, .a, .a, .a, .a, .a }, + list.items(.tags), + ); + try testing.expectEqual(list.get(0), .{ .a = 1 }); + try testing.expectEqual(list.get(1), .{ .b = "zigzag" }); + try testing.expectEqual(list.get(2), .{ .b = "foobar" }); + try testing.expectEqual(list.get(3), .{ .a = 4 }); + try testing.expectEqual(list.get(4), .{ .a = 5 }); + try testing.expectEqual(list.get(5), .{ .a = 6 }); + try testing.expectEqual(list.get(6), .{ .a = 7 }); + try testing.expectEqual(list.get(7), .{ .a = 8 }); + try testing.expectEqual(list.get(8), .{ .a = 9 }); + + list.shrinkAndFree(ally, 3); + + try testing.expectEqual(@as(usize, 3), list.items(.tags).len); + try testing.expectEqualSlices(meta.Tag(Foo), list.items(.tags), &.{ .a, .b, .b }); + + try testing.expectEqual(list.get(0), .{ .a = 1 }); + try testing.expectEqual(list.get(1), .{ .b = "zigzag" }); + try testing.expectEqual(list.get(2), .{ .b = "foobar" }); +} diff --git a/lib/std/zig/Parse.zig b/lib/std/zig/Parse.zig index 258e3b0368..2ef91ca3a6 100644 --- a/lib/std/zig/Parse.zig +++ b/lib/std/zig/Parse.zig @@ -41,13 +41,13 @@ fn listToSpan(p: *Parse, list: []const Node.Index) !Node.SubRange { }; } -fn addNode(p: *Parse, elem: Ast.NodeList.Elem) Allocator.Error!Node.Index { +fn addNode(p: *Parse, elem: Ast.Node) Allocator.Error!Node.Index { const result = @intCast(Node.Index, p.nodes.len); try p.nodes.append(p.gpa, elem); return result; } -fn setNode(p: *Parse, i: usize, elem: Ast.NodeList.Elem) Node.Index { +fn setNode(p: *Parse, i: usize, elem: Ast.Node) Node.Index { p.nodes.set(i, elem); return @intCast(Node.Index, i); } diff --git a/src/translate_c/ast.zig b/src/translate_c/ast.zig index 688235c2d3..328feb989a 100644 --- a/src/translate_c/ast.zig +++ b/src/translate_c/ast.zig @@ -845,7 +845,7 @@ const Context = struct { }; } - fn addNode(c: *Context, elem: std.zig.Ast.NodeList.Elem) Allocator.Error!NodeIndex { + fn addNode(c: *Context, elem: std.zig.Ast.Node) Allocator.Error!NodeIndex { const result = @intCast(NodeIndex, c.nodes.len); try c.nodes.append(c.gpa, elem); return result;