From 7aa42f47b79f289829a1b43a68c8c08e374aa6a2 Mon Sep 17 00:00:00 2001 From: HydroH Date: Thu, 28 Mar 2024 18:23:32 +0800 Subject: [PATCH] allow `@errorcast` to cast error sets to error unions --- src/Sema.zig | 24 +++++++++---------- test/behavior/error.zig | 5 ++++ .../@errorCast_with_bad_type.zig | 23 ++++++++++++++++++ 3 files changed, 39 insertions(+), 13 deletions(-) create mode 100644 test/cases/compile_errors/@errorCast_with_bad_type.zig diff --git a/src/Sema.zig b/src/Sema.zig index d21fed6910..001d841959 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -22626,20 +22626,18 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData const base_operand_ty = sema.typeOf(operand); const dest_tag = base_dest_ty.zigTypeTag(mod); const operand_tag = base_operand_ty.zigTypeTag(mod); - if (dest_tag != operand_tag) { - return sema.fail(block, src, "expected source and destination types to match, found '{s}' and '{s}'", .{ - @tagName(operand_tag), @tagName(dest_tag), - }); - } else if (dest_tag != .ErrorSet and dest_tag != .ErrorUnion) { + + if (dest_tag != .ErrorSet and dest_tag != .ErrorUnion) { return sema.fail(block, src, "expected error set or error union type, found '{s}'", .{@tagName(dest_tag)}); } - const dest_ty, const operand_ty = if (dest_tag == .ErrorUnion) .{ - base_dest_ty.errorUnionSet(mod), - base_operand_ty.errorUnionSet(mod), - } else .{ - base_dest_ty, - base_operand_ty, - }; + if (operand_tag != .ErrorSet and operand_tag != .ErrorUnion) { + return sema.fail(block, src, "expected error set or error union type, found '{s}'", .{@tagName(operand_tag)}); + } + if (dest_tag == .ErrorSet and operand_tag == .ErrorUnion) { + return sema.fail(block, src, "cannot cast an error union type to error set", .{}); + } + const dest_ty = if (dest_tag == .ErrorUnion) base_dest_ty.errorUnionSet(mod) else base_dest_ty; + const operand_ty = if (operand_tag == .ErrorUnion) base_operand_ty.errorUnionSet(mod) else base_operand_ty; // operand must be defined since it can be an invalid error value const maybe_operand_val = try sema.resolveDefinedValue(block, operand_src, operand); @@ -22681,7 +22679,7 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData if (!dest_ty.isAnyError(mod)) check: { const operand_val = mod.intern_pool.indexToKey(val.toIntern()); var error_name: InternPool.NullTerminatedString = undefined; - if (dest_tag == .ErrorUnion) { + if (operand_tag == .ErrorUnion) { if (operand_val.error_union.val != .err_name) break :check; error_name = operand_val.error_union.val.err_name; } else { diff --git a/test/behavior/error.zig b/test/behavior/error.zig index 8e4dd2c091..952d010a14 100644 --- a/test/behavior/error.zig +++ b/test/behavior/error.zig @@ -1039,3 +1039,8 @@ test "errorCast to adhoc inferred error set" { }; try std.testing.expect((try S.baz()) == 1234); } + +test "errorCast from error sets to error unions" { + const err_union: Set1!void = @errorCast(error.A); + try expectError(error.A, err_union); +} diff --git a/test/cases/compile_errors/@errorCast_with_bad_type.zig b/test/cases/compile_errors/@errorCast_with_bad_type.zig new file mode 100644 index 0000000000..b698203737 --- /dev/null +++ b/test/cases/compile_errors/@errorCast_with_bad_type.zig @@ -0,0 +1,23 @@ +const err = error.Foo; + +export fn entry1() void { + const a: anyerror = @errorCast(1); + _ = a; +} +export fn entry2() void { + const a: i32 = @errorCast(err); + _ = a; +} +export fn entry3() void { + const e: anyerror!void = err; + const a: anyerror = @errorCast(e); + _ = a; +} + +// error +// backend=stage2 +// target=x86_64-linux +// +// :4:25: error: expected error set or error union type, found 'ComptimeInt' +// :8:20: error: expected error set or error union type, found 'Int' +// :13:25: error: cannot cast an error union type to error set