From 90f780253e668a3bccc9c559cdf46b9ad0fd92c7 Mon Sep 17 00:00:00 2001 From: pentuppup Date: Wed, 15 Oct 2025 00:33:12 -0400 Subject: [PATCH] detect `comptime var` references in asm input/output and improve errors --- lib/std/zig/AstGen.zig | 2 +- src/Sema.zig | 34 +++++++++++---- src/Zcu.zig | 31 ++++++++++++++ src/print_zir.zig | 41 +++++-------------- .../compile_errors/asm_output_to_const.zig | 16 +++++--- .../comptime_var_referenced_by_asm.zig | 21 ++++++++++ 6 files changed, 100 insertions(+), 45 deletions(-) create mode 100644 test/cases/compile_errors/comptime_var_referenced_by_asm.zig diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig index 6306dde4f8..77f03364af 100644 --- a/lib/std/zig/AstGen.zig +++ b/lib/std/zig/AstGen.zig @@ -13010,9 +13010,9 @@ const GenZir = struct { } const small: Zir.Inst.Asm.Small = .{ + .is_volatile = args.is_volatile, .outputs_len = @intCast(args.outputs.len), .inputs_len = @intCast(args.inputs.len), - .is_volatile = args.is_volatile, }; const new_index: Zir.Inst.Index = @enumFromInt(astgen.instructions.len); diff --git a/src/Sema.zig b/src/Sema.zig index 324f1b6867..8e4cf4f6fe 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -16386,6 +16386,7 @@ fn zirAsm( const pt = sema.pt; const zcu = pt.zcu; + const ip = &zcu.intern_pool; const extra = sema.code.extraData(Zir.Inst.Asm, extended.operand); const src = block.nodeOffset(extra.data.src_node); const ret_ty_src = block.src(.{ .node_offset_asm_ret_ty = extra.data.src_node }); @@ -16394,7 +16395,6 @@ fn zirAsm( const inputs_len = small.inputs_len; const is_volatile = small.is_volatile; const is_global_assembly = sema.func_index == .none; - const zir_tags = sema.code.instructions.items(.tag); const asm_source: []const u8 = if (tmpl_is_expr) s: { const tmpl: Zir.Inst.Ref = @enumFromInt(@intFromEnum(extra.data.asm_source)); @@ -16424,29 +16424,37 @@ fn zirAsm( for (out_args, 0..) |*arg, out_i| { const output = sema.code.extraData(Zir.Inst.Asm.Output, extra_i); + const output_src = block.src(.{ .asm_output = .{ + .offset = src.offset.node_offset.x, + .output_index = @intCast(out_i), + } }); extra_i = output.end; const is_type = @as(u1, @truncate(output_type_bits)) != 0; output_type_bits >>= 1; + const name = sema.code.nullTerminatedString(output.data.name); + if (is_type) { // Indicate the output is the asm instruction return value. arg.* = .none; const out_ty = try sema.resolveType(block, ret_ty_src, output.data.operand); expr_ty = Air.internedToRef(out_ty.toIntern()); } else { - arg.* = try sema.resolveInst(output.data.operand); + const inst = try sema.resolveInst(output.data.operand); + if (!sema.checkRuntimeValue(inst)) { + const output_name = try ip.getOrPutString(sema.gpa, pt.tid, name, .no_embedded_nulls); + return sema.failWithContainsReferenceToComptimeVar(block, output_src, output_name, "assembly output", .fromInterned(inst.toInterned().?)); + } + arg.* = inst; } const constraint = sema.code.nullTerminatedString(output.data.constraint); - const name = sema.code.nullTerminatedString(output.data.name); needed_capacity += (constraint.len + name.len + (2 + 3)) / 4; - if (output.data.operand.toIndex()) |index| { - if (zir_tags[@intFromEnum(index)] == .ref) { - // TODO: better error location; it would be even nicer if there were notes that pointed at the output and the variable definition - return sema.fail(block, src, "asm cannot output to const local '{s}'", .{name}); - } + // AstGen gives us a reference to a variable + if (arg.* != .none and sema.typeOf(arg.*).isConstPtr(zcu)) { + return sema.fail(block, output_src, "asm cannot output to const '{s}'", .{name}); } outputs[out_i] = .{ .c = constraint, .n = name }; @@ -16457,9 +16465,18 @@ fn zirAsm( for (args, 0..) |*arg, arg_i| { const input = sema.code.extraData(Zir.Inst.Asm.Input, extra_i); + const input_src = block.src(.{ .asm_input = .{ + .offset = src.offset.node_offset.x, + .input_index = @intCast(arg_i), + } }); extra_i = input.end; const uncasted_arg = try sema.resolveInst(input.data.operand); + const name = sema.code.nullTerminatedString(input.data.name); + if (!sema.checkRuntimeValue(uncasted_arg)) { + const input_name = try ip.getOrPutString(sema.gpa, pt.tid, name, .no_embedded_nulls); + return sema.failWithContainsReferenceToComptimeVar(block, input_src, input_name, "assembly input", .fromInterned(uncasted_arg.toInterned().?)); + } const uncasted_arg_ty = sema.typeOf(uncasted_arg); switch (uncasted_arg_ty.zigTypeTag(zcu)) { .comptime_int => arg.* = try sema.coerce(block, .usize, uncasted_arg, src), @@ -16470,7 +16487,6 @@ fn zirAsm( } const constraint = sema.code.nullTerminatedString(input.data.constraint); - const name = sema.code.nullTerminatedString(input.data.name); needed_capacity += (constraint.len + name.len + (2 + 3)) / 4; inputs[arg_i] = .{ .c = constraint, .n = name }; } diff --git a/src/Zcu.zig b/src/Zcu.zig index 6027ab07fa..59f6c8ee91 100644 --- a/src/Zcu.zig +++ b/src/Zcu.zig @@ -1542,6 +1542,25 @@ pub const SrcLoc = struct { }; return tree.nodeToSpan(src_node); }, + .asm_input => |input| { + const tree = try src_loc.file_scope.getTree(zcu); + const node = input.offset.toAbsolute(src_loc.base_node); + const full = tree.fullAsm(node).?; + const asm_input = full.inputs[input.input_index]; + return tree.nodeToSpan(tree.nodeData(asm_input).node_and_token[0]); + }, + .asm_output => |output| { + const tree = try src_loc.file_scope.getTree(zcu); + const node = output.offset.toAbsolute(src_loc.base_node); + const full = tree.fullAsm(node).?; + const asm_output = full.outputs[output.output_index]; + const data = tree.nodeData(asm_output).opt_node_and_token; + return if (data[0].unwrap()) |output_node| + tree.nodeToSpan(output_node) + else + // token points to the ')' + tree.tokenToSpan(data[1] - 1); + }, .for_input => |for_input| { const tree = try src_loc.file_scope.getTree(zcu); const node = for_input.for_node_offset.toAbsolute(src_loc.base_node); @@ -2507,6 +2526,18 @@ pub const LazySrcLoc = struct { /// The source location points to the operand of a `return` statement, or /// the `return` itself if there is no explicit operand. node_offset_return_operand: Ast.Node.Offset, + /// The source location points to an assembly input + asm_input: struct { + /// Points to the assembly node + offset: Ast.Node.Offset, + input_index: u32, + }, + /// The source location points to an assembly output + asm_output: struct { + /// Points to the assembly node + offset: Ast.Node.Offset, + output_index: u32, + }, /// The source location points to a for loop input. for_input: struct { /// Points to the for loop AST node. diff --git a/src/print_zir.zig b/src/print_zir.zig index 316632f2d3..627374711a 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -1267,18 +1267,14 @@ const Writer = struct { tmpl_is_expr: bool, ) !void { const extra = self.code.extraData(Zir.Inst.Asm, extended.operand); - const outputs_len = @as(u5, @truncate(extended.small)); - const inputs_len = @as(u5, @truncate(extended.small >> 5)); - const clobbers_len = @as(u5, @truncate(extended.small >> 10)); - const is_volatile = @as(u1, @truncate(extended.small >> 15)) != 0; + const small: Zir.Inst.Asm.Small = @bitCast(extended.small); - try self.writeFlag(stream, "volatile, ", is_volatile); + try self.writeFlag(stream, "volatile, ", small.is_volatile); if (tmpl_is_expr) { try self.writeInstRef(stream, @enumFromInt(@intFromEnum(extra.data.asm_source))); - try stream.writeAll(", "); } else { const asm_source = self.code.nullTerminatedString(extra.data.asm_source); - try stream.print("\"{f}\", ", .{std.zig.fmtString(asm_source)}); + try stream.print("\"{f}\"", .{std.zig.fmtString(asm_source)}); } try stream.writeAll(", "); @@ -1286,7 +1282,7 @@ const Writer = struct { var output_type_bits = extra.data.output_type_bits; { var i: usize = 0; - while (i < outputs_len) : (i += 1) { + while (i < small.outputs_len) : (i += 1) { const output = self.code.extraData(Zir.Inst.Asm.Output, extra_i); extra_i = output.end; @@ -1298,17 +1294,14 @@ const Writer = struct { try stream.print("output({f}, \"{f}\", ", .{ std.zig.fmtIdP(name), std.zig.fmtString(constraint), }); - try self.writeFlag(stream, "->", is_type); + try self.writeFlag(stream, "-> ", is_type); try self.writeInstRef(stream, output.data.operand); - try stream.writeAll(")"); - if (i + 1 < outputs_len) { - try stream.writeAll("), "); - } + try stream.writeAll("), "); } } { var i: usize = 0; - while (i < inputs_len) : (i += 1) { + while (i < small.inputs_len) : (i += 1) { const input = self.code.extraData(Zir.Inst.Asm.Input, extra_i); extra_i = input.end; @@ -1318,24 +1311,12 @@ const Writer = struct { std.zig.fmtIdP(name), std.zig.fmtString(constraint), }); try self.writeInstRef(stream, input.data.operand); - try stream.writeAll(")"); - if (i + 1 < inputs_len) { - try stream.writeAll(", "); - } - } - } - { - var i: usize = 0; - while (i < clobbers_len) : (i += 1) { - const str_index = self.code.extra[extra_i]; - extra_i += 1; - const clobber = self.code.nullTerminatedString(@enumFromInt(str_index)); - try stream.print("{f}", .{std.zig.fmtIdP(clobber)}); - if (i + 1 < clobbers_len) { - try stream.writeAll(", "); - } + try stream.writeAll("), "); } } + + try self.writeInstRef(stream, extra.data.clobbers); + try stream.writeAll(")) "); try self.writeSrcNode(stream, extra.data.src_node); } diff --git a/test/cases/compile_errors/asm_output_to_const.zig b/test/cases/compile_errors/asm_output_to_const.zig index a25552c3a1..f50080d63a 100644 --- a/test/cases/compile_errors/asm_output_to_const.zig +++ b/test/cases/compile_errors/asm_output_to_const.zig @@ -1,12 +1,18 @@ export fn foo() void { - const f: i64 = 1000; + const local: usize = 0; + asm volatile ("" + : [_] "=r" (local), + ); +} - asm volatile ( - \\ movq $10, %[f] - : [f] "=r" (f), +const global: usize = 0; +export fn bar() void { + asm volatile ("" + : [_] "=r" (global), ); } // error // -// :4:5: error: asm cannot output to const local 'f' +// :4:21: error: asm cannot output to const '_' +// :11:21: error: asm cannot output to const '_' diff --git a/test/cases/compile_errors/comptime_var_referenced_by_asm.zig b/test/cases/compile_errors/comptime_var_referenced_by_asm.zig new file mode 100644 index 0000000000..e6a0828f14 --- /dev/null +++ b/test/cases/compile_errors/comptime_var_referenced_by_asm.zig @@ -0,0 +1,21 @@ +export fn foo() void { + comptime var a: u32 = 0; + asm volatile ("" + : + : [in] "r" (&a), + ); +} + +export fn bar() void { + comptime var a: u32 = 0; + asm volatile ("" + : [out] "=r" (a), + ); +} + +// error +// +// :5:21: error: assembly input contains reference to comptime var +// :2:14: note: 'in' points to comptime var declared here +// :12:23: error: assembly output contains reference to comptime var +// :10:14: note: 'out' points to comptime var declared here