Sema: fix generics with struct literal coerced to tagged union

The `Value.eql` function has to test for value equality *as-if* the lhs
value parameter is coerced into the type of the rhs. For tagged unions,
there was a problematic case when the lhs was an anonymous struct,
because in such case the value is empty_struct_value and the type
contains all the value information. But the only type available in the
function was the rhs type.

So the fix involved making `Value.eqlAdvanced` also accept the lhs type,
and then enhancing the logic to handle the case of the `.anon_struct` tag.

closes #12418

Tests run locally:
 * test-behavior
 * test-cases
This commit is contained in:
Andrew Kelley 2022-08-17 12:55:08 -07:00
parent a12abc6d6c
commit c764640e92
3 changed files with 110 additions and 42 deletions

View file

@ -1495,7 +1495,8 @@ pub fn resolveInst(sema: *Sema, zir_ref: Zir.Inst.Ref) !Air.Inst.Ref {
// Finally, the last section of indexes refers to the map of ZIR=>AIR. // Finally, the last section of indexes refers to the map of ZIR=>AIR.
const inst = sema.inst_map.get(@intCast(u32, i)).?; const inst = sema.inst_map.get(@intCast(u32, i)).?;
if (sema.typeOf(inst).tag() == .generic_poison) return error.GenericPoison; const ty = sema.typeOf(inst);
if (ty.tag() == .generic_poison) return error.GenericPoison;
return inst; return inst;
} }
@ -5570,11 +5571,15 @@ const GenericCallAdapter = struct {
generic_fn: *Module.Fn, generic_fn: *Module.Fn,
precomputed_hash: u64, precomputed_hash: u64,
func_ty_info: Type.Payload.Function.Data, func_ty_info: Type.Payload.Function.Data,
/// Unlike comptime_args, the Type here is not always present. args: []const Arg,
/// .generic_poison is used to communicate non-anytype parameters.
comptime_tvs: []const TypedValue,
module: *Module, module: *Module,
const Arg = struct {
ty: Type,
val: Value,
is_anytype: bool,
};
pub fn eql(ctx: @This(), adapted_key: void, other_key: *Module.Fn) bool { pub fn eql(ctx: @This(), adapted_key: void, other_key: *Module.Fn) bool {
_ = adapted_key; _ = adapted_key;
// The generic function Decl is guaranteed to be the first dependency // The generic function Decl is guaranteed to be the first dependency
@ -5585,10 +5590,10 @@ const GenericCallAdapter = struct {
const other_comptime_args = other_key.comptime_args.?; const other_comptime_args = other_key.comptime_args.?;
for (other_comptime_args[0..ctx.func_ty_info.param_types.len]) |other_arg, i| { for (other_comptime_args[0..ctx.func_ty_info.param_types.len]) |other_arg, i| {
const this_arg = ctx.comptime_tvs[i]; const this_arg = ctx.args[i];
const this_is_comptime = this_arg.val.tag() != .generic_poison; const this_is_comptime = this_arg.val.tag() != .generic_poison;
const other_is_comptime = other_arg.val.tag() != .generic_poison; const other_is_comptime = other_arg.val.tag() != .generic_poison;
const this_is_anytype = this_arg.ty.tag() != .generic_poison; const this_is_anytype = this_arg.is_anytype;
const other_is_anytype = other_key.isAnytypeParam(ctx.module, @intCast(u32, i)); const other_is_anytype = other_key.isAnytypeParam(ctx.module, @intCast(u32, i));
if (other_is_anytype != this_is_anytype) return false; if (other_is_anytype != this_is_anytype) return false;
@ -5607,7 +5612,17 @@ const GenericCallAdapter = struct {
} }
} else if (this_is_comptime) { } else if (this_is_comptime) {
// Both are comptime parameters but not anytype parameters. // Both are comptime parameters but not anytype parameters.
if (!this_arg.val.eql(other_arg.val, other_arg.ty, ctx.module)) { // We assert no error is possible here because any lazy values must be resolved
// before inserting into the generic function hash map.
const is_eql = Value.eqlAdvanced(
this_arg.val,
this_arg.ty,
other_arg.val,
other_arg.ty,
ctx.module,
null,
) catch unreachable;
if (!is_eql) {
return false; return false;
} }
} }
@ -6258,8 +6273,7 @@ fn instantiateGenericCall(
var hasher = std.hash.Wyhash.init(0); var hasher = std.hash.Wyhash.init(0);
std.hash.autoHash(&hasher, @ptrToInt(module_fn)); std.hash.autoHash(&hasher, @ptrToInt(module_fn));
const comptime_tvs = try sema.arena.alloc(TypedValue, func_ty_info.param_types.len); const generic_args = try sema.arena.alloc(GenericCallAdapter.Arg, func_ty_info.param_types.len);
{ {
var i: usize = 0; var i: usize = 0;
for (fn_info.param_body) |inst| { for (fn_info.param_body) |inst| {
@ -6283,8 +6297,9 @@ fn instantiateGenericCall(
else => continue, else => continue,
} }
const arg_ty = sema.typeOf(uncasted_args[i]);
if (is_comptime) { if (is_comptime) {
const arg_ty = sema.typeOf(uncasted_args[i]);
const arg_val = sema.analyzeGenericCallArgVal(block, .unneeded, uncasted_args[i]) catch |err| switch (err) { const arg_val = sema.analyzeGenericCallArgVal(block, .unneeded, uncasted_args[i]) catch |err| switch (err) {
error.NeededSourceLocation => { error.NeededSourceLocation => {
const decl = sema.mod.declPtr(block.src_decl); const decl = sema.mod.declPtr(block.src_decl);
@ -6297,27 +6312,30 @@ fn instantiateGenericCall(
arg_val.hash(arg_ty, &hasher, mod); arg_val.hash(arg_ty, &hasher, mod);
if (is_anytype) { if (is_anytype) {
arg_ty.hashWithHasher(&hasher, mod); arg_ty.hashWithHasher(&hasher, mod);
comptime_tvs[i] = .{ generic_args[i] = .{
.ty = arg_ty, .ty = arg_ty,
.val = arg_val, .val = arg_val,
.is_anytype = true,
}; };
} else { } else {
comptime_tvs[i] = .{ generic_args[i] = .{
.ty = Type.initTag(.generic_poison), .ty = arg_ty,
.val = arg_val, .val = arg_val,
.is_anytype = false,
}; };
} }
} else if (is_anytype) { } else if (is_anytype) {
const arg_ty = sema.typeOf(uncasted_args[i]);
arg_ty.hashWithHasher(&hasher, mod); arg_ty.hashWithHasher(&hasher, mod);
comptime_tvs[i] = .{ generic_args[i] = .{
.ty = arg_ty, .ty = arg_ty,
.val = Value.initTag(.generic_poison), .val = Value.initTag(.generic_poison),
.is_anytype = true,
}; };
} else { } else {
comptime_tvs[i] = .{ generic_args[i] = .{
.ty = Type.initTag(.generic_poison), .ty = arg_ty,
.val = Value.initTag(.generic_poison), .val = Value.initTag(.generic_poison),
.is_anytype = false,
}; };
} }
@ -6331,7 +6349,7 @@ fn instantiateGenericCall(
.generic_fn = module_fn, .generic_fn = module_fn,
.precomputed_hash = precomputed_hash, .precomputed_hash = precomputed_hash,
.func_ty_info = func_ty_info, .func_ty_info = func_ty_info,
.comptime_tvs = comptime_tvs, .args = generic_args,
.module = mod, .module = mod,
}; };
const gop = try mod.monomorphed_funcs.getOrPutAdapted(gpa, {}, adapter); const gop = try mod.monomorphed_funcs.getOrPutAdapted(gpa, {}, adapter);
@ -30124,7 +30142,7 @@ fn valuesEqual(
rhs: Value, rhs: Value,
ty: Type, ty: Type,
) CompileError!bool { ) CompileError!bool {
return Value.eqlAdvanced(lhs, rhs, ty, sema.mod, sema.kit(block, src)); return Value.eqlAdvanced(lhs, ty, rhs, ty, sema.mod, sema.kit(block, src));
} }
/// Asserts the values are comparable vectors of type `ty`. /// Asserts the values are comparable vectors of type `ty`.

View file

@ -2004,6 +2004,10 @@ pub const Value = extern union {
return (try orderAgainstZeroAdvanced(lhs, sema_kit)).compare(op); return (try orderAgainstZeroAdvanced(lhs, sema_kit)).compare(op);
} }
pub fn eql(a: Value, b: Value, ty: Type, mod: *Module) bool {
return eqlAdvanced(a, ty, b, ty, mod, null) catch unreachable;
}
/// This function is used by hash maps and so treats floating-point NaNs as equal /// This function is used by hash maps and so treats floating-point NaNs as equal
/// to each other, and not equal to other floating-point values. /// to each other, and not equal to other floating-point values.
/// Similarly, it treats `undef` as a distinct value from all other values. /// Similarly, it treats `undef` as a distinct value from all other values.
@ -2012,13 +2016,10 @@ pub const Value = extern union {
/// for `a`. This function must act *as if* `a` has been coerced to `ty`. This complication /// for `a`. This function must act *as if* `a` has been coerced to `ty`. This complication
/// is required in order to make generic function instantiation efficient - specifically /// is required in order to make generic function instantiation efficient - specifically
/// the insertion into the monomorphized function table. /// the insertion into the monomorphized function table.
pub fn eql(a: Value, b: Value, ty: Type, mod: *Module) bool {
return eqlAdvanced(a, b, ty, mod, null) catch unreachable;
}
/// If `null` is provided for `sema_kit` then it is guaranteed no error will be returned. /// If `null` is provided for `sema_kit` then it is guaranteed no error will be returned.
pub fn eqlAdvanced( pub fn eqlAdvanced(
a: Value, a: Value,
a_ty: Type,
b: Value, b: Value,
ty: Type, ty: Type,
mod: *Module, mod: *Module,
@ -2044,33 +2045,34 @@ pub const Value = extern union {
const a_payload = a.castTag(.opt_payload).?.data; const a_payload = a.castTag(.opt_payload).?.data;
const b_payload = b.castTag(.opt_payload).?.data; const b_payload = b.castTag(.opt_payload).?.data;
var buffer: Type.Payload.ElemType = undefined; var buffer: Type.Payload.ElemType = undefined;
return eqlAdvanced(a_payload, b_payload, ty.optionalChild(&buffer), mod, sema_kit); const payload_ty = ty.optionalChild(&buffer);
return eqlAdvanced(a_payload, payload_ty, b_payload, payload_ty, mod, sema_kit);
}, },
.slice => { .slice => {
const a_payload = a.castTag(.slice).?.data; const a_payload = a.castTag(.slice).?.data;
const b_payload = b.castTag(.slice).?.data; const b_payload = b.castTag(.slice).?.data;
if (!(try eqlAdvanced(a_payload.len, b_payload.len, Type.usize, mod, sema_kit))) { if (!(try eqlAdvanced(a_payload.len, Type.usize, b_payload.len, Type.usize, mod, sema_kit))) {
return false; return false;
} }
var ptr_buf: Type.SlicePtrFieldTypeBuffer = undefined; var ptr_buf: Type.SlicePtrFieldTypeBuffer = undefined;
const ptr_ty = ty.slicePtrFieldType(&ptr_buf); const ptr_ty = ty.slicePtrFieldType(&ptr_buf);
return eqlAdvanced(a_payload.ptr, b_payload.ptr, ptr_ty, mod, sema_kit); return eqlAdvanced(a_payload.ptr, ptr_ty, b_payload.ptr, ptr_ty, mod, sema_kit);
}, },
.elem_ptr => { .elem_ptr => {
const a_payload = a.castTag(.elem_ptr).?.data; const a_payload = a.castTag(.elem_ptr).?.data;
const b_payload = b.castTag(.elem_ptr).?.data; const b_payload = b.castTag(.elem_ptr).?.data;
if (a_payload.index != b_payload.index) return false; if (a_payload.index != b_payload.index) return false;
return eqlAdvanced(a_payload.array_ptr, b_payload.array_ptr, ty, mod, sema_kit); return eqlAdvanced(a_payload.array_ptr, ty, b_payload.array_ptr, ty, mod, sema_kit);
}, },
.field_ptr => { .field_ptr => {
const a_payload = a.castTag(.field_ptr).?.data; const a_payload = a.castTag(.field_ptr).?.data;
const b_payload = b.castTag(.field_ptr).?.data; const b_payload = b.castTag(.field_ptr).?.data;
if (a_payload.field_index != b_payload.field_index) return false; if (a_payload.field_index != b_payload.field_index) return false;
return eqlAdvanced(a_payload.container_ptr, b_payload.container_ptr, ty, mod, sema_kit); return eqlAdvanced(a_payload.container_ptr, ty, b_payload.container_ptr, ty, mod, sema_kit);
}, },
.@"error" => { .@"error" => {
const a_name = a.castTag(.@"error").?.data.name; const a_name = a.castTag(.@"error").?.data.name;
@ -2080,7 +2082,8 @@ pub const Value = extern union {
.eu_payload => { .eu_payload => {
const a_payload = a.castTag(.eu_payload).?.data; const a_payload = a.castTag(.eu_payload).?.data;
const b_payload = b.castTag(.eu_payload).?.data; const b_payload = b.castTag(.eu_payload).?.data;
return eqlAdvanced(a_payload, b_payload, ty.errorUnionPayload(), mod, sema_kit); const payload_ty = ty.errorUnionPayload();
return eqlAdvanced(a_payload, payload_ty, b_payload, payload_ty, mod, sema_kit);
}, },
.eu_payload_ptr => @panic("TODO: Implement more pointer eql cases"), .eu_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
.opt_payload_ptr => @panic("TODO: Implement more pointer eql cases"), .opt_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
@ -2098,7 +2101,7 @@ pub const Value = extern union {
const types = ty.tupleFields().types; const types = ty.tupleFields().types;
assert(types.len == a_field_vals.len); assert(types.len == a_field_vals.len);
for (types) |field_ty, i| { for (types) |field_ty, i| {
if (!(try eqlAdvanced(a_field_vals[i], b_field_vals[i], field_ty, mod, sema_kit))) { if (!(try eqlAdvanced(a_field_vals[i], field_ty, b_field_vals[i], field_ty, mod, sema_kit))) {
return false; return false;
} }
} }
@ -2109,7 +2112,7 @@ pub const Value = extern union {
const fields = ty.structFields().values(); const fields = ty.structFields().values();
assert(fields.len == a_field_vals.len); assert(fields.len == a_field_vals.len);
for (fields) |field, i| { for (fields) |field, i| {
if (!(try eqlAdvanced(a_field_vals[i], b_field_vals[i], field.ty, mod, sema_kit))) { if (!(try eqlAdvanced(a_field_vals[i], field.ty, b_field_vals[i], field.ty, mod, sema_kit))) {
return false; return false;
} }
} }
@ -2120,7 +2123,7 @@ pub const Value = extern union {
for (a_field_vals) |a_elem, i| { for (a_field_vals) |a_elem, i| {
const b_elem = b_field_vals[i]; const b_elem = b_field_vals[i];
if (!(try eqlAdvanced(a_elem, b_elem, elem_ty, mod, sema_kit))) { if (!(try eqlAdvanced(a_elem, elem_ty, b_elem, elem_ty, mod, sema_kit))) {
return false; return false;
} }
} }
@ -2132,7 +2135,7 @@ pub const Value = extern union {
switch (ty.containerLayout()) { switch (ty.containerLayout()) {
.Packed, .Extern => { .Packed, .Extern => {
const tag_ty = ty.unionTagTypeHypothetical(); const tag_ty = ty.unionTagTypeHypothetical();
if (!(try a_union.tag.eqlAdvanced(b_union.tag, tag_ty, mod, sema_kit))) { if (!(try eqlAdvanced(a_union.tag, tag_ty, b_union.tag, tag_ty, mod, sema_kit))) {
// In this case, we must disregard mismatching tags and compare // In this case, we must disregard mismatching tags and compare
// based on the in-memory bytes of the payloads. // based on the in-memory bytes of the payloads.
@panic("TODO comptime comparison of extern union values with mismatching tags"); @panic("TODO comptime comparison of extern union values with mismatching tags");
@ -2140,13 +2143,13 @@ pub const Value = extern union {
}, },
.Auto => { .Auto => {
const tag_ty = ty.unionTagTypeHypothetical(); const tag_ty = ty.unionTagTypeHypothetical();
if (!(try a_union.tag.eqlAdvanced(b_union.tag, tag_ty, mod, sema_kit))) { if (!(try eqlAdvanced(a_union.tag, tag_ty, b_union.tag, tag_ty, mod, sema_kit))) {
return false; return false;
} }
}, },
} }
const active_field_ty = ty.unionFieldType(a_union.tag, mod); const active_field_ty = ty.unionFieldType(a_union.tag, mod);
return a_union.val.eqlAdvanced(b_union.val, active_field_ty, mod, sema_kit); return eqlAdvanced(a_union.val, active_field_ty, b_union.val, active_field_ty, mod, sema_kit);
}, },
else => {}, else => {},
} else if (a_tag == .null_value or b_tag == .null_value) { } else if (a_tag == .null_value or b_tag == .null_value) {
@ -2180,7 +2183,7 @@ pub const Value = extern union {
const b_val = b.enumToInt(ty, &buf_b); const b_val = b.enumToInt(ty, &buf_b);
var buf_ty: Type.Payload.Bits = undefined; var buf_ty: Type.Payload.Bits = undefined;
const int_ty = ty.intTagType(&buf_ty); const int_ty = ty.intTagType(&buf_ty);
return eqlAdvanced(a_val, b_val, int_ty, mod, sema_kit); return eqlAdvanced(a_val, int_ty, b_val, int_ty, mod, sema_kit);
}, },
.Array, .Vector => { .Array, .Vector => {
const len = ty.arrayLen(); const len = ty.arrayLen();
@ -2191,17 +2194,44 @@ pub const Value = extern union {
while (i < len) : (i += 1) { while (i < len) : (i += 1) {
const a_elem = elemValueBuffer(a, mod, i, &a_buf); const a_elem = elemValueBuffer(a, mod, i, &a_buf);
const b_elem = elemValueBuffer(b, mod, i, &b_buf); const b_elem = elemValueBuffer(b, mod, i, &b_buf);
if (!(try eqlAdvanced(a_elem, b_elem, elem_ty, mod, sema_kit))) { if (!(try eqlAdvanced(a_elem, elem_ty, b_elem, elem_ty, mod, sema_kit))) {
return false; return false;
} }
} }
return true; return true;
}, },
.Struct => { .Struct => {
// A tuple can be represented with .empty_struct_value, // A struct can be represented with one of:
// the_one_possible_value, .aggregate in which case we could // .empty_struct_value,
// end up here and the values are equal if the type has zero fields. // .the_one_possible_value,
return ty.isTupleOrAnonStruct() and ty.structFieldCount() != 0; // .aggregate,
// Note that we already checked above for matching tags, e.g. both .aggregate.
return ty.onePossibleValue() != null;
},
.Union => {
// Here we have to check for value equality, as-if `a` has been coerced to `ty`.
if (ty.onePossibleValue() != null) {
return true;
}
if (a_ty.castTag(.anon_struct)) |payload| {
const tuple = payload.data;
if (tuple.values.len != 1) {
return false;
}
const field_name = tuple.names[0];
const union_obj = ty.cast(Type.Payload.Union).?.data;
const field_index = union_obj.fields.getIndex(field_name) orelse return false;
const tag_and_val = b.castTag(.@"union").?.data;
var field_tag_buf: Value.Payload.U32 = .{
.base = .{ .tag = .enum_field_index },
.data = @intCast(u32, field_index),
};
const field_tag = Value.initPayload(&field_tag_buf.base);
const tag_matches = tag_and_val.tag.eql(field_tag, union_obj.tag_ty, mod);
if (!tag_matches) return false;
return eqlAdvanced(tag_and_val.val, union_obj.tag_ty, tuple.values[0], tuple.types[0], mod, sema_kit);
}
return false;
}, },
.Float => { .Float => {
switch (ty.floatBits(target)) { switch (ty.floatBits(target)) {
@ -2230,7 +2260,8 @@ pub const Value = extern union {
.base = .{ .tag = .opt_payload }, .base = .{ .tag = .opt_payload },
.data = a, .data = a,
}; };
return eqlAdvanced(Value.initPayload(&buffer.base), b, ty, mod, sema_kit); const opt_val = Value.initPayload(&buffer.base);
return eqlAdvanced(opt_val, ty, b, ty, mod, sema_kit);
} }
}, },
else => {}, else => {},

View file

@ -323,3 +323,22 @@ test "generic function instantiation non-duplicates" {
S.copy(u8, &buffer, "hello"); S.copy(u8, &buffer, "hello");
S.copy(u8, &buffer, "hello2"); S.copy(u8, &buffer, "hello2");
} }
test "generic instantiation of tagged union with only one field" {
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.os.tag == .wasi) return error.SkipZigTest;
const S = struct {
const U = union(enum) {
s: []const u8,
};
fn foo(comptime u: U) usize {
return u.s.len;
}
};
try expect(S.foo(.{ .s = "a" }) == 1);
try expect(S.foo(.{ .s = "ab" }) == 2);
}