Allow more operators on bool vectors (#24131)

* Sema: allow binary operations and boolean not on vectors of bool

* langref: Clarify use of operators on vectors (`and` and `or` not allowed)

closes #24093
This commit is contained in:
Daniel Kongsgaard 2025-06-13 00:16:23 +02:00 committed by GitHub
parent 4a02e080d1
commit 5e3c0b7af7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 79 additions and 50 deletions

View file

@ -1926,8 +1926,10 @@ or
Vector types are created with the builtin function {#link|@Vector#}. Vector types are created with the builtin function {#link|@Vector#}.
</p> </p>
<p> <p>
Vectors support the same builtin operators as their underlying base types. Vectors generally support the same builtin operators as their underlying base types.
These operations are performed element-wise, and return a vector of the same length The only exception to this is the keywords `and` and `or` on vectors of bools, since
these operators affect control flow, which is not allowed for vectors.
All other operations are performed element-wise, and return a vector of the same length
as the input vectors. This includes: as the input vectors. This includes:
</p> </p>
<ul> <ul>
@ -1937,6 +1939,7 @@ or
<li>Bitwise operators ({#syntax#}>>{#endsyntax#}, {#syntax#}<<{#endsyntax#}, {#syntax#}&{#endsyntax#}, <li>Bitwise operators ({#syntax#}>>{#endsyntax#}, {#syntax#}<<{#endsyntax#}, {#syntax#}&{#endsyntax#},
{#syntax#}|{#endsyntax#}, {#syntax#}~{#endsyntax#}, etc.)</li> {#syntax#}|{#endsyntax#}, {#syntax#}~{#endsyntax#}, etc.)</li>
<li>Comparison operators ({#syntax#}<{#endsyntax#}, {#syntax#}>{#endsyntax#}, {#syntax#}=={#endsyntax#}, etc.)</li> <li>Comparison operators ({#syntax#}<{#endsyntax#}, {#syntax#}>{#endsyntax#}, {#syntax#}=={#endsyntax#}, etc.)</li>
<li>Boolean not ({#syntax#}!{#endsyntax#})</li>
</ul> </ul>
<p> <p>
It is prohibited to use a math operator on a mixture of scalars (individual numbers) It is prohibited to use a math operator on a mixture of scalars (individual numbers)

View file

@ -806,7 +806,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
.bool_and => return boolBinOp(gz, scope, ri, node, .bool_br_and), .bool_and => return boolBinOp(gz, scope, ri, node, .bool_br_and),
.bool_or => return boolBinOp(gz, scope, ri, node, .bool_br_or), .bool_or => return boolBinOp(gz, scope, ri, node, .bool_br_or),
.bool_not => return simpleUnOp(gz, scope, ri, node, coerced_bool_ri, tree.nodeData(node).node, .bool_not), .bool_not => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, tree.nodeData(node).node, .bool_not),
.bit_not => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, tree.nodeData(node).node, .bit_not), .bit_not => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, tree.nodeData(node).node, .bit_not),
.negation => return negation(gz, scope, ri, node), .negation => return negation(gz, scope, ri, node),

View file

@ -1171,11 +1171,11 @@ fn analyzeBodyInner(
.as_node => try sema.zirAsNode(block, inst), .as_node => try sema.zirAsNode(block, inst),
.as_shift_operand => try sema.zirAsShiftOperand(block, inst), .as_shift_operand => try sema.zirAsShiftOperand(block, inst),
.bit_and => try sema.zirBitwise(block, inst, .bit_and), .bit_and => try sema.zirBitwise(block, inst, .bit_and),
.bit_not => try sema.zirBitNot(block, inst), .bit_not => try sema.zirBitNot(block, inst, false),
.bit_or => try sema.zirBitwise(block, inst, .bit_or), .bit_or => try sema.zirBitwise(block, inst, .bit_or),
.bitcast => try sema.zirBitcast(block, inst), .bitcast => try sema.zirBitcast(block, inst),
.suspend_block => try sema.zirSuspendBlock(block, inst), .suspend_block => try sema.zirSuspendBlock(block, inst),
.bool_not => try sema.zirBoolNot(block, inst), .bool_not => try sema.zirBitNot(block, inst, true),
.bool_br_and => try sema.zirBoolBr(block, inst, false), .bool_br_and => try sema.zirBoolBr(block, inst, false),
.bool_br_or => try sema.zirBoolBr(block, inst, true), .bool_br_or => try sema.zirBoolBr(block, inst, true),
.c_import => try sema.zirCImport(block, inst), .c_import => try sema.zirCImport(block, inst),
@ -14412,9 +14412,9 @@ fn zirBitwise(
const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src); const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src); const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src);
const is_int = scalar_tag == .int or scalar_tag == .comptime_int; const is_int_or_bool = scalar_tag == .int or scalar_tag == .comptime_int or scalar_tag == .bool;
if (!is_int) { if (!is_int_or_bool) {
return sema.fail(block, src, "invalid operands to binary bitwise expression: '{s}' and '{s}'", .{ @tagName(lhs_ty.zigTypeTag(zcu)), @tagName(rhs_ty.zigTypeTag(zcu)) }); return sema.fail(block, src, "invalid operands to binary bitwise expression: '{s}' and '{s}'", .{ @tagName(lhs_ty.zigTypeTag(zcu)), @tagName(rhs_ty.zigTypeTag(zcu)) });
} }
@ -14442,7 +14442,12 @@ fn zirBitwise(
return block.addBinOp(air_tag, casted_lhs, casted_rhs); return block.addBinOp(air_tag, casted_lhs, casted_rhs);
} }
fn zirBitNot(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { fn zirBitNot(
sema: *Sema,
block: *Block,
inst: Zir.Inst.Index,
is_bool_not: bool,
) CompileError!Air.Inst.Ref {
const tracy = trace(@src()); const tracy = trace(@src());
defer tracy.end(); defer tracy.end();
@ -14455,10 +14460,14 @@ fn zirBitNot(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
const operand = try sema.resolveInst(inst_data.operand); const operand = try sema.resolveInst(inst_data.operand);
const operand_type = sema.typeOf(operand); const operand_type = sema.typeOf(operand);
const scalar_type = operand_type.scalarType(zcu); const scalar_type = operand_type.scalarType(zcu);
const scalar_tag = scalar_type.zigTypeTag(zcu);
if (scalar_type.zigTypeTag(zcu) != .int) { const is_finite_int_or_bool = scalar_tag == .int or scalar_tag == .bool;
return sema.fail(block, src, "unable to perform binary not operation on type '{}'", .{ const is_allowed_type = if (is_bool_not) scalar_tag == .bool else is_finite_int_or_bool;
operand_type.fmt(pt),
if (!is_allowed_type) {
return sema.fail(block, src, "unable to perform {s} not operation on type '{}'", .{
if (is_bool_not) "boolean" else "binary", operand_type.fmt(pt),
}); });
} }
@ -18336,25 +18345,6 @@ fn zirTypeofPeer(
return Air.internedToRef(result_type.toIntern()); return Air.internedToRef(result_type.toIntern());
} }
fn zirBoolNot(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
const tracy = trace(@src());
defer tracy.end();
const pt = sema.pt;
const zcu = pt.zcu;
const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
const src = block.nodeOffset(inst_data.src_node);
const operand_src = block.src(.{ .node_offset_un_op = inst_data.src_node });
const uncasted_operand = try sema.resolveInst(inst_data.operand);
const operand = try sema.coerce(block, .bool, uncasted_operand, operand_src);
if (try sema.resolveValue(operand)) |val| {
return if (val.isUndef(zcu)) .undef_bool else if (val.toBool()) .bool_false else .bool_true;
}
try sema.requireRuntimeBlock(block, src, null);
return block.addTyOp(.not, .bool, operand);
}
fn zirBoolBr( fn zirBoolBr(
sema: *Sema, sema: *Sema,
parent_block: *Block, parent_block: *Block,

View file

@ -1627,7 +1627,7 @@ pub fn numberMin(lhs: Value, rhs: Value, zcu: *Zcu) Value {
}; };
} }
/// operands must be (vectors of) integers; handles undefined scalars. /// operands must be (vectors of) integers or bools; handles undefined scalars.
pub fn bitwiseNot(val: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value { pub fn bitwiseNot(val: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
const zcu = pt.zcu; const zcu = pt.zcu;
if (ty.zigTypeTag(zcu) == .vector) { if (ty.zigTypeTag(zcu) == .vector) {
@ -1645,7 +1645,7 @@ pub fn bitwiseNot(val: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Va
return bitwiseNotScalar(val, ty, arena, pt); return bitwiseNotScalar(val, ty, arena, pt);
} }
/// operands must be integers; handles undefined. /// operands must be integers or bools; handles undefined.
pub fn bitwiseNotScalar(val: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value { pub fn bitwiseNotScalar(val: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
const zcu = pt.zcu; const zcu = pt.zcu;
if (val.isUndef(zcu)) return Value.fromInterned(try pt.intern(.{ .undef = ty.toIntern() })); if (val.isUndef(zcu)) return Value.fromInterned(try pt.intern(.{ .undef = ty.toIntern() }));
@ -1671,7 +1671,7 @@ pub fn bitwiseNotScalar(val: Value, ty: Type, arena: Allocator, pt: Zcu.PerThrea
return pt.intValue_big(ty, result_bigint.toConst()); return pt.intValue_big(ty, result_bigint.toConst());
} }
/// operands must be (vectors of) integers; handles undefined scalars. /// operands must be (vectors of) integers or bools; handles undefined scalars.
pub fn bitwiseAnd(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu.PerThread) !Value { pub fn bitwiseAnd(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu.PerThread) !Value {
const zcu = pt.zcu; const zcu = pt.zcu;
if (ty.zigTypeTag(zcu) == .vector) { if (ty.zigTypeTag(zcu) == .vector) {
@ -1690,7 +1690,7 @@ pub fn bitwiseAnd(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zc
return bitwiseAndScalar(lhs, rhs, ty, allocator, pt); return bitwiseAndScalar(lhs, rhs, ty, allocator, pt);
} }
/// operands must be integers; handles undefined. /// operands must be integers or bools; handles undefined.
pub fn bitwiseAndScalar(orig_lhs: Value, orig_rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value { pub fn bitwiseAndScalar(orig_lhs: Value, orig_rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
const zcu = pt.zcu; const zcu = pt.zcu;
// If one operand is defined, we turn the other into `0xAA` so the bitwise AND can // If one operand is defined, we turn the other into `0xAA` so the bitwise AND can
@ -1744,7 +1744,7 @@ fn intValueAa(ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
return pt.intValue_big(ty, result_bigint.toConst()); return pt.intValue_big(ty, result_bigint.toConst());
} }
/// operands must be (vectors of) integers; handles undefined scalars. /// operands must be (vectors of) integers or bools; handles undefined scalars.
pub fn bitwiseNand(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value { pub fn bitwiseNand(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
const zcu = pt.zcu; const zcu = pt.zcu;
if (ty.zigTypeTag(zcu) == .vector) { if (ty.zigTypeTag(zcu) == .vector) {
@ -1763,7 +1763,7 @@ pub fn bitwiseNand(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.P
return bitwiseNandScalar(lhs, rhs, ty, arena, pt); return bitwiseNandScalar(lhs, rhs, ty, arena, pt);
} }
/// operands must be integers; handles undefined. /// operands must be integers or bools; handles undefined.
pub fn bitwiseNandScalar(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value { pub fn bitwiseNandScalar(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
const zcu = pt.zcu; const zcu = pt.zcu;
if (lhs.isUndef(zcu) or rhs.isUndef(zcu)) return Value.fromInterned(try pt.intern(.{ .undef = ty.toIntern() })); if (lhs.isUndef(zcu) or rhs.isUndef(zcu)) return Value.fromInterned(try pt.intern(.{ .undef = ty.toIntern() }));
@ -1774,7 +1774,7 @@ pub fn bitwiseNandScalar(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt:
return bitwiseXor(anded, all_ones, ty, arena, pt); return bitwiseXor(anded, all_ones, ty, arena, pt);
} }
/// operands must be (vectors of) integers; handles undefined scalars. /// operands must be (vectors of) integers or bools; handles undefined scalars.
pub fn bitwiseOr(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu.PerThread) !Value { pub fn bitwiseOr(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu.PerThread) !Value {
const zcu = pt.zcu; const zcu = pt.zcu;
if (ty.zigTypeTag(zcu) == .vector) { if (ty.zigTypeTag(zcu) == .vector) {
@ -1793,7 +1793,7 @@ pub fn bitwiseOr(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu
return bitwiseOrScalar(lhs, rhs, ty, allocator, pt); return bitwiseOrScalar(lhs, rhs, ty, allocator, pt);
} }
/// operands must be integers; handles undefined. /// operands must be integers or bools; handles undefined.
pub fn bitwiseOrScalar(orig_lhs: Value, orig_rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value { pub fn bitwiseOrScalar(orig_lhs: Value, orig_rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
// If one operand is defined, we turn the other into `0xAA` so the bitwise AND can // If one operand is defined, we turn the other into `0xAA` so the bitwise AND can
// still zero out some bits. // still zero out some bits.
@ -1827,7 +1827,7 @@ pub fn bitwiseOrScalar(orig_lhs: Value, orig_rhs: Value, ty: Type, arena: Alloca
return pt.intValue_big(ty, result_bigint.toConst()); return pt.intValue_big(ty, result_bigint.toConst());
} }
/// operands must be (vectors of) integers; handles undefined scalars. /// operands must be (vectors of) integers or bools; handles undefined scalars.
pub fn bitwiseXor(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu.PerThread) !Value { pub fn bitwiseXor(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu.PerThread) !Value {
const zcu = pt.zcu; const zcu = pt.zcu;
if (ty.zigTypeTag(zcu) == .vector) { if (ty.zigTypeTag(zcu) == .vector) {
@ -1846,7 +1846,7 @@ pub fn bitwiseXor(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zc
return bitwiseXorScalar(lhs, rhs, ty, allocator, pt); return bitwiseXorScalar(lhs, rhs, ty, allocator, pt);
} }
/// operands must be integers; handles undefined. /// operands must be integers or bools; handles undefined.
pub fn bitwiseXorScalar(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value { pub fn bitwiseXorScalar(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
const zcu = pt.zcu; const zcu = pt.zcu;
if (lhs.isUndef(zcu) or rhs.isUndef(zcu)) return Value.fromInterned(try pt.intern(.{ .undef = ty.toIntern() })); if (lhs.isUndef(zcu) or rhs.isUndef(zcu)) return Value.fromInterned(try pt.intern(.{ .undef = ty.toIntern() }));

View file

@ -152,12 +152,22 @@ test "vector bit operators" {
const S = struct { const S = struct {
fn doTheTest() !void { fn doTheTest() !void {
var v: @Vector(4, u8) = [4]u8{ 0b10101010, 0b10101010, 0b10101010, 0b10101010 }; {
var x: @Vector(4, u8) = [4]u8{ 0b11110000, 0b00001111, 0b10101010, 0b01010101 }; var v: @Vector(4, bool) = [4]bool{ false, false, true, true };
_ = .{ &v, &x }; var x: @Vector(4, bool) = [4]bool{ true, false, true, false };
try expect(mem.eql(u8, &@as([4]u8, v ^ x), &[4]u8{ 0b01011010, 0b10100101, 0b00000000, 0b11111111 })); _ = .{ &v, &x };
try expect(mem.eql(u8, &@as([4]u8, v | x), &[4]u8{ 0b11111010, 0b10101111, 0b10101010, 0b11111111 })); try expect(mem.eql(bool, &@as([4]bool, v ^ x), &[4]bool{ true, false, false, true }));
try expect(mem.eql(u8, &@as([4]u8, v & x), &[4]u8{ 0b10100000, 0b00001010, 0b10101010, 0b00000000 })); try expect(mem.eql(bool, &@as([4]bool, v | x), &[4]bool{ true, false, true, true }));
try expect(mem.eql(bool, &@as([4]bool, v & x), &[4]bool{ false, false, true, false }));
}
{
var v: @Vector(4, u8) = [4]u8{ 0b10101010, 0b10101010, 0b10101010, 0b10101010 };
var x: @Vector(4, u8) = [4]u8{ 0b11110000, 0b00001111, 0b10101010, 0b01010101 };
_ = .{ &v, &x };
try expect(mem.eql(u8, &@as([4]u8, v ^ x), &[4]u8{ 0b01011010, 0b10100101, 0b00000000, 0b11111111 }));
try expect(mem.eql(u8, &@as([4]u8, v | x), &[4]u8{ 0b11111010, 0b10101111, 0b10101010, 0b11111111 }));
try expect(mem.eql(u8, &@as([4]u8, v & x), &[4]u8{ 0b10100000, 0b00001010, 0b10101010, 0b00000000 }));
}
} }
}; };
try S.doTheTest(); try S.doTheTest();
@ -659,15 +669,41 @@ test "vector bitwise not operator" {
} }
} }
fn doTheTest() !void { fn doTheTest() !void {
try doTheTestNot(u8, [_]u8{ 0, 2, 4, 255 }); try doTheTestNot(bool, [_]bool{ true, false, true, false });
try doTheTestNot(u16, [_]u16{ 0, 2, 4, 255 });
try doTheTestNot(u32, [_]u32{ 0, 2, 4, 255 });
try doTheTestNot(u64, [_]u64{ 0, 2, 4, 255 });
try doTheTestNot(u8, [_]u8{ 0, 2, 4, 255 }); try doTheTestNot(u8, [_]u8{ 0, 2, 4, 255 });
try doTheTestNot(u16, [_]u16{ 0, 2, 4, 255 }); try doTheTestNot(u16, [_]u16{ 0, 2, 4, 255 });
try doTheTestNot(u32, [_]u32{ 0, 2, 4, 255 }); try doTheTestNot(u32, [_]u32{ 0, 2, 4, 255 });
try doTheTestNot(u64, [_]u64{ 0, 2, 4, 255 }); try doTheTestNot(u64, [_]u64{ 0, 2, 4, 255 });
try doTheTestNot(i8, [_]i8{ 0, 2, 4, 127 });
try doTheTestNot(i16, [_]i16{ 0, 2, 4, 127 });
try doTheTestNot(i32, [_]i32{ 0, 2, 4, 127 });
try doTheTestNot(i64, [_]i64{ 0, 2, 4, 127 });
}
};
try S.doTheTest();
try comptime S.doTheTest();
}
test "vector boolean not operator" {
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
const S = struct {
fn doTheTestNot(comptime T: type, x: @Vector(4, T)) !void {
const y = !x;
for (@as([4]T, y), 0..) |v, i| {
try expect(!x[i] == v);
}
}
fn doTheTest() !void {
try doTheTestNot(bool, [_]bool{ true, false, true, false });
} }
}; };