diff --git a/tools/gen_spirv_spec.zig b/tools/gen_spirv_spec.zig index 465404e5d1..721203ff4e 100644 --- a/tools/gen_spirv_spec.zig +++ b/tools/gen_spirv_spec.zig @@ -1,5 +1,8 @@ const std = @import("std"); const g = @import("spirv/grammar.zig"); +const Allocator = std.mem.Allocator; + +const ExtendedStructSet = std.StringHashMap(void); pub fn main() !void { var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); @@ -20,101 +23,308 @@ pub fn main() !void { var tokens = std.json.TokenStream.init(spec); var registry = try std.json.parse(g.Registry, &tokens, .{ .allocator = allocator }); + const core_reg = switch (registry) { + .core => |core_reg| core_reg, + .extension => return error.TODOSpirVExtensionSpec, + }; + var bw = std.io.bufferedWriter(std.io.getStdOut().writer()); - try render(bw.writer(), registry); + try render(bw.writer(), allocator, core_reg); try bw.flush(); } -fn render(writer: anytype, registry: g.Registry) !void { +/// Returns a set with types that require an extra struct for the `Instruction` interface +/// to the spir-v spec, or whether the original type can be used. +fn extendedStructs( + arena: Allocator, + kinds: []const g.OperandKind, +) !ExtendedStructSet { + var map = ExtendedStructSet.init(arena); + try map.ensureTotalCapacity(@intCast(u32, kinds.len)); + + for (kinds) |kind| { + const enumerants = kind.enumerants orelse continue; + + for (enumerants) |enumerant| { + if (enumerant.parameters.len > 0) { + break; + } + } else continue; + + map.putAssumeCapacity(kind.kind, {}); + } + + return map; +} + +// Return a score for a particular priority. Duplicate instruction/operand enum values are +// removed by picking the tag with the lowest score to keep, and by making an alias for the +// other. Note that the tag does not need to be just a tag at this point, in which case it +// gets the lowest score automatically anyway. +fn tagPriorityScore(tag: []const u8) usize { + if (tag.len == 0) { + return 1; + } else if (std.mem.eql(u8, tag, "EXT")) { + return 2; + } else if (std.mem.eql(u8, tag, "KHR")) { + return 3; + } else { + return 4; + } +} + +fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void { try writer.writeAll( \\//! This file is auto-generated by tools/gen_spirv_spec.zig. \\ \\const Version = @import("std").builtin.Version; \\ + \\pub const Word = u32; + \\pub const IdResultType = struct{ + \\ id: Word, + \\ pub fn toRef(self: IdResultType) IdRef { + \\ return .{.id = self.id}; + \\ } + \\}; + \\pub const IdResult = struct{ + \\ id: Word, + \\ pub fn toRef(self: IdResult) IdRef { + \\ return .{.id = self.id}; + \\ } + \\ pub fn toResultType(self: IdResult) IdResultType { + \\ return .{.id = self.id}; + \\ } + \\}; + \\pub const IdRef = struct{ id: Word }; + \\ + \\pub const IdMemorySemantics = IdRef; + \\pub const IdScope = IdRef; + \\ + \\pub const LiteralInteger = Word; + \\pub const LiteralString = []const u8; + \\pub const LiteralContextDependentNumber = union(enum) { + \\ int32: i32, + \\ uint32: u32, + \\ int64: i64, + \\ uint64: u64, + \\ float32: f32, + \\ float64: f64, + \\}; + \\pub const LiteralExtInstInteger = struct{ inst: Word }; + \\pub const LiteralSpecConstantOpInteger = struct { opcode: Opcode }; + \\pub const PairLiteralIntegerIdRef = struct { value: LiteralInteger, label: IdRef }; + \\pub const PairIdRefLiteralInteger = struct { target: IdRef, member: LiteralInteger }; + \\pub const PairIdRefIdRef = [2]IdRef; + \\ + \\ ); - switch (registry) { - .core => |core_reg| { - try writer.print( - \\pub const version = Version{{ .major = {}, .minor = {}, .patch = {} }}; - \\pub const magic_number: u32 = {s}; - \\ - , - .{ core_reg.major_version, core_reg.minor_version, core_reg.revision, core_reg.magic_number }, - ); - try renderOpcodes(writer, core_reg.instructions); - try renderOperandKinds(writer, core_reg.operand_kinds); - }, - .extension => |ext_reg| { - try writer.print( - \\pub const version = Version{{ .major = {}, .minor = 0, .patch = {} }}; - \\ - , - .{ ext_reg.version, ext_reg.revision }, - ); - try renderOpcodes(writer, ext_reg.instructions); - try renderOperandKinds(writer, ext_reg.operand_kinds); - }, - } + try writer.print( + \\pub const version = Version{{ .major = {}, .minor = {}, .patch = {} }}; + \\pub const magic_number: Word = {s}; + \\ + , + .{ registry.major_version, registry.minor_version, registry.revision, registry.magic_number }, + ); + const extended_structs = try extendedStructs(allocator, registry.operand_kinds); + try renderOpcodes(writer, allocator, registry.instructions, extended_structs); + try renderOperandKinds(writer, allocator, registry.operand_kinds, extended_structs); } -fn renderOpcodes(writer: anytype, instructions: []const g.Instruction) !void { - try writer.writeAll("pub const Opcode = extern enum(u16) {\n"); - for (instructions) |instr| { - try writer.print(" {} = {},\n", .{ std.zig.fmtId(instr.opname), instr.opcode }); +fn renderOpcodes( + writer: anytype, + allocator: Allocator, + instructions: []const g.Instruction, + extended_structs: ExtendedStructSet, +) !void { + var inst_map = std.AutoArrayHashMap(u32, usize).init(allocator); + try inst_map.ensureTotalCapacity(instructions.len); + + var aliases = std.ArrayList(struct { inst: usize, alias: usize }).init(allocator); + try aliases.ensureTotalCapacity(instructions.len); + + for (instructions) |inst, i| { + const result = inst_map.getOrPutAssumeCapacity(inst.opcode); + if (!result.found_existing) { + result.value_ptr.* = i; + continue; + } + + const existing = instructions[result.value_ptr.*]; + + const tag_index = std.mem.indexOfDiff(u8, inst.opname, existing.opname).?; + const inst_priority = tagPriorityScore(inst.opname[tag_index..]); + const existing_priority = tagPriorityScore(existing.opname[tag_index..]); + + if (inst_priority < existing_priority) { + aliases.appendAssumeCapacity(.{ .inst = result.value_ptr.*, .alias = i }); + result.value_ptr.* = i; + } else { + aliases.appendAssumeCapacity(.{ .inst = i, .alias = result.value_ptr.* }); + } } - try writer.writeAll(" _,\n};\n"); + + const instructions_indices = inst_map.values(); + + try writer.writeAll("pub const Opcode = enum(u16) {\n"); + for (instructions_indices) |i| { + const inst = instructions[i]; + try writer.print("{} = {},\n", .{ std.zig.fmtId(inst.opname), inst.opcode }); + } + + try writer.writeByte('\n'); + + for (aliases.items) |alias| { + try writer.print("pub const {} = Opcode.{};\n", .{ + std.zig.fmtId(instructions[alias.inst].opname), + std.zig.fmtId(instructions[alias.alias].opname), + }); + } + + try writer.writeAll( + \\ + \\pub fn Operands(comptime self: Opcode) type { + \\return switch (self) { + \\ + ); + + for (instructions_indices) |i| { + const inst = instructions[i]; + try renderOperand(writer, .instruction, inst.opname, inst.operands, extended_structs); + } + try writer.writeAll("};\n}\n};\n"); + _ = extended_structs; } -fn renderOperandKinds(writer: anytype, kinds: []const g.OperandKind) !void { +fn renderOperandKinds( + writer: anytype, + allocator: Allocator, + kinds: []const g.OperandKind, + extended_structs: ExtendedStructSet, +) !void { for (kinds) |kind| { switch (kind.category) { - .ValueEnum => try renderValueEnum(writer, kind), - .BitEnum => try renderBitEnum(writer, kind), + .ValueEnum => try renderValueEnum(writer, allocator, kind, extended_structs), + .BitEnum => try renderBitEnum(writer, allocator, kind, extended_structs), else => {}, } } } -fn renderValueEnum(writer: anytype, enumeration: g.OperandKind) !void { - try writer.print("pub const {s} = extern enum(u32) {{\n", .{enumeration.kind}); - +fn renderValueEnum( + writer: anytype, + allocator: Allocator, + enumeration: g.OperandKind, + extended_structs: ExtendedStructSet, +) !void { const enumerants = enumeration.enumerants orelse return error.InvalidRegistry; - for (enumerants) |enumerant| { - if (enumerant.value != .int) return error.InvalidRegistry; - try writer.print(" {} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), enumerant.value.int }); + var enum_map = std.AutoArrayHashMap(u32, usize).init(allocator); + try enum_map.ensureTotalCapacity(enumerants.len); + + var aliases = std.ArrayList(struct { enumerant: usize, alias: usize }).init(allocator); + try aliases.ensureTotalCapacity(enumerants.len); + + for (enumerants) |enumerant, i| { + const result = enum_map.getOrPutAssumeCapacity(enumerant.value.int); + if (!result.found_existing) { + result.value_ptr.* = i; + continue; + } + + const existing = enumerants[result.value_ptr.*]; + + const tag_index = std.mem.indexOfDiff(u8, enumerant.enumerant, existing.enumerant).?; + const enum_priority = tagPriorityScore(enumerant.enumerant[tag_index..]); + const existing_priority = tagPriorityScore(existing.enumerant[tag_index..]); + + if (enum_priority < existing_priority) { + aliases.appendAssumeCapacity(.{ .enumerant = result.value_ptr.*, .alias = i }); + result.value_ptr.* = i; + } else { + aliases.appendAssumeCapacity(.{ .enumerant = i, .alias = result.value_ptr.* }); + } } - try writer.writeAll(" _,\n};\n"); + const enum_indices = enum_map.values(); + + try writer.print("pub const {s} = enum(u32) {{\n", .{std.zig.fmtId(enumeration.kind)}); + + for (enum_indices) |i| { + const enumerant = enumerants[i]; + if (enumerant.value != .int) return error.InvalidRegistry; + + try writer.print("{} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), enumerant.value.int }); + } + + try writer.writeByte('\n'); + + for (aliases.items) |alias| { + try writer.print("pub const {} = {}.{};\n", .{ + std.zig.fmtId(enumerants[alias.enumerant].enumerant), + std.zig.fmtId(enumeration.kind), + std.zig.fmtId(enumerants[alias.alias].enumerant), + }); + } + + if (!extended_structs.contains(enumeration.kind)) { + try writer.writeAll("};\n"); + return; + } + + try writer.print("\npub const Extended = union({}) {{\n", .{std.zig.fmtId(enumeration.kind)}); + + for (enum_indices) |i| { + const enumerant = enumerants[i]; + try renderOperand(writer, .@"union", enumerant.enumerant, enumerant.parameters, extended_structs); + } + + try writer.writeAll("};\n};\n"); } -fn renderBitEnum(writer: anytype, enumeration: g.OperandKind) !void { - try writer.print("pub const {s} = packed struct {{\n", .{enumeration.kind}); +fn renderBitEnum( + writer: anytype, + allocator: Allocator, + enumeration: g.OperandKind, + extended_structs: ExtendedStructSet, +) !void { + try writer.print("pub const {s} = packed struct {{\n", .{std.zig.fmtId(enumeration.kind)}); - var flags_by_bitpos = [_]?[]const u8{null} ** 32; + var flags_by_bitpos = [_]?usize{null} ** 32; const enumerants = enumeration.enumerants orelse return error.InvalidRegistry; - for (enumerants) |enumerant| { + + var aliases = std.ArrayList(struct { flag: usize, alias: u5 }).init(allocator); + try aliases.ensureTotalCapacity(enumerants.len); + + for (enumerants) |enumerant, i| { if (enumerant.value != .bitflag) return error.InvalidRegistry; const value = try parseHexInt(enumerant.value.bitflag); - if (@popCount(u32, value) != 1) { - continue; // Skip combinations and 'none' items + if (@popCount(u32, value) == 0) { + continue; // Skip 'none' items } + std.debug.assert(@popCount(u32, value) == 1); + var bitpos = std.math.log2_int(u32, value); if (flags_by_bitpos[bitpos]) |*existing| { - // Keep the shortest - if (enumerant.enumerant.len < existing.len) - existing.* = enumerant.enumerant; + const tag_index = std.mem.indexOfDiff(u8, enumerant.enumerant, enumerants[existing.*].enumerant).?; + const enum_priority = tagPriorityScore(enumerant.enumerant[tag_index..]); + const existing_priority = tagPriorityScore(enumerants[existing.*].enumerant[tag_index..]); + + if (enum_priority < existing_priority) { + aliases.appendAssumeCapacity(.{ .flag = existing.*, .alias = bitpos }); + existing.* = i; + } else { + aliases.appendAssumeCapacity(.{ .flag = i, .alias = bitpos }); + } } else { - flags_by_bitpos[bitpos] = enumerant.enumerant; + flags_by_bitpos[bitpos] = i; } } - for (flags_by_bitpos) |maybe_flag_name, bitpos| { - try writer.writeAll(" "); - if (maybe_flag_name) |flag_name| { - try writer.writeAll(flag_name); + for (flags_by_bitpos) |maybe_flag_index, bitpos| { + if (maybe_flag_index) |flag_index| { + try writer.print("{}", .{std.zig.fmtId(enumerants[flag_index].enumerant)}); } else { try writer.print("_reserved_bit_{}", .{bitpos}); } @@ -126,7 +336,169 @@ fn renderBitEnum(writer: anytype, enumeration: g.OperandKind) !void { try writer.writeAll("= false,\n"); } - try writer.writeAll("};\n"); + try writer.writeByte('\n'); + + for (aliases.items) |alias| { + try writer.print("pub const {}: {} = .{{.{} = true}};\n", .{ + std.zig.fmtId(enumerants[alias.flag].enumerant), + std.zig.fmtId(enumeration.kind), + std.zig.fmtId(enumerants[flags_by_bitpos[alias.alias].?].enumerant), + }); + } + + if (!extended_structs.contains(enumeration.kind)) { + try writer.writeAll("};\n"); + return; + } + + try writer.print("\npub const Extended = struct {{\n", .{}); + + for (flags_by_bitpos) |maybe_flag_index, bitpos| { + const flag_index = maybe_flag_index orelse { + try writer.print("_reserved_bit_{}: bool = false,\n", .{bitpos}); + continue; + }; + const enumerant = enumerants[flag_index]; + + try renderOperand(writer, .mask, enumerant.enumerant, enumerant.parameters, extended_structs); + } + + try writer.writeAll("};\n};\n"); +} + +fn renderOperand( + writer: anytype, + kind: enum { + @"union", + instruction, + mask, + }, + field_name: []const u8, + parameters: []const g.Operand, + extended_structs: ExtendedStructSet, +) !void { + if (kind == .instruction) { + try writer.writeByte('.'); + } + try writer.print("{}", .{std.zig.fmtId(field_name)}); + if (parameters.len == 0) { + switch (kind) { + .@"union" => try writer.writeAll(",\n"), + .instruction => try writer.writeAll(" => void,\n"), + .mask => try writer.writeAll(": bool = false,\n"), + } + return; + } + + if (kind == .instruction) { + try writer.writeAll(" => "); + } else { + try writer.writeAll(": "); + } + + if (kind == .mask) { + try writer.writeByte('?'); + } + + try writer.writeAll("struct{"); + + for (parameters) |param, j| { + if (j != 0) { + try writer.writeAll(", "); + } + + try renderFieldName(writer, parameters, j); + try writer.writeAll(": "); + + if (param.quantifier) |q| { + switch (q) { + .@"?" => try writer.writeByte('?'), + .@"*" => try writer.writeAll("[]const "), + } + } + + try writer.print("{}", .{std.zig.fmtId(param.kind)}); + + if (extended_structs.contains(param.kind)) { + try writer.writeAll(".Extended"); + } + + if (param.quantifier) |q| { + switch (q) { + .@"?" => try writer.writeAll(" = null"), + .@"*" => try writer.writeAll(" = &.{}"), + } + } + } + + try writer.writeAll("}"); + + if (kind == .mask) { + try writer.writeAll(" = null"); + } + + try writer.writeAll(",\n"); +} + +fn renderFieldName(writer: anytype, operands: []const g.Operand, field_index: usize) !void { + const operand = operands[field_index]; + + // Should be enough for all names - adjust as needed. + var name_buffer = std.BoundedArray(u8, 64){ + .buffer = undefined, + }; + + derive_from_kind: { + // Operand names are often in the json encoded as "'Name'" (with two sets of quotes). + // Additionally, some operands have ~ in them at the end (D~ref~). + const name = std.mem.trim(u8, operand.name, "'~"); + if (name.len == 0) { + break :derive_from_kind; + } + + // Some names have weird characters in them (like newlines) - skip any such ones. + // Use the same loop to transform to snake-case. + for (name) |c| { + switch (c) { + 'a'...'z', '0'...'9' => try name_buffer.append(c), + 'A'...'Z' => try name_buffer.append(std.ascii.toLower(c)), + ' ', '~' => try name_buffer.append('_'), + else => break :derive_from_kind, + } + } + + // Assume there are no duplicate 'name' fields. + try writer.print("{}", .{std.zig.fmtId(name_buffer.slice())}); + return; + } + + // Translate to snake case. + name_buffer.len = 0; + for (operand.kind) |c, i| { + switch (c) { + 'a'...'z', '0'...'9' => try name_buffer.append(c), + 'A'...'Z' => if (i > 0 and std.ascii.isLower(operand.kind[i - 1])) { + try name_buffer.appendSlice(&[_]u8{ '_', std.ascii.toLower(c) }); + } else { + try name_buffer.append(std.ascii.toLower(c)); + }, + else => unreachable, // Assume that the name is valid C-syntax (and contains no underscores). + } + } + + try writer.print("{}", .{std.zig.fmtId(name_buffer.slice())}); + + // For fields derived from type name, there could be any amount. + // Simply check against all other fields, and if another similar one exists, add a number. + const need_extra_index = for (operands) |other_operand, i| { + if (i != field_index and std.mem.eql(u8, operand.kind, other_operand.kind)) { + break true; + } + } else false; + + if (need_extra_index) { + try writer.print("_{}", .{field_index}); + } } fn parseHexInt(text: []const u8) !u31 { @@ -142,7 +514,7 @@ fn usageAndExit(file: std.fs.File, arg0: []const u8, code: u8) noreturn { \\ \\Generates Zig bindings for a SPIR-V specification .json (either core or \\extinst versions). The result, printed to stdout, should be used to update - \\files in src/codegen/spirv. + \\files in src/codegen/spirv. Don't forget to format the output. \\ \\The relevant specifications can be obtained from the SPIR-V registry: \\https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/unified1/