diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 7732f9eccb..286b45f973 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -1782,19 +1782,6 @@ const DeclGen = struct { wip.dg.gpa.free(wip.results); } - /// Return the scalar type of an input vector. This type is expected to be a vector - /// if `wip.is_vector`, and a scalar otherwise. - fn scalarType(wip: WipElementWise, ty: Type) Type { - const mod = wip.dg.module; - if (wip.is_vector) { - assert(ty.isVector(mod)); - return ty.childType(mod); - } else { - assert(!ty.isVector(mod)); - return ty; - } - } - /// Utility function to extract the element at a particular index in an /// input vector. This type is expected to be a vector if `wip.is_vector`, and /// a scalar otherwise. @@ -1844,7 +1831,7 @@ const DeclGen = struct { const results = try self.gpa.alloc(IdRef, num_results); for (results) |*result| result.* = undefined; - const scalar_ty = if (is_vector) result_ty.childType(mod) else result_ty; + const scalar_ty = result_ty.scalarType(mod); const scalar_ty_ref = try self.resolveType(scalar_ty, .direct); return .{ @@ -2198,6 +2185,7 @@ const DeclGen = struct { .add_with_overflow => try self.airAddSubOverflow(inst, .OpIAdd, .OpULessThan, .OpSLessThan), .sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan), + .shl_with_overflow => try self.airShlOverflow(inst), .shuffle => try self.airShuffle(inst), @@ -2343,23 +2331,30 @@ const DeclGen = struct { const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; const lhs_id = try self.resolve(bin_op.lhs); const rhs_id = try self.resolve(bin_op.rhs); - const result_ty = self.typeOfIndex(inst); - // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, - // so just manually upcast it if required. - // TODO(robin) + const result_ty = self.typeOfIndex(inst); + const shift_ty = self.typeOf(bin_op.rhs); + const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct); + + const info = try self.arithmeticTypeInfo(result_ty); + switch (info.class) { + .composite_integer => return self.todo("shift ops for composite integers", .{}), + .integer, .strange_integer => {}, + .float, .bool => unreachable, + } var wip = try self.elementWise(result_ty); defer wip.deinit(); - - const shift_ty = wip.scalarType(self.typeOf(bin_op.rhs)); - const shift_ty_ref = try self.resolveType(shift_ty, .direct); - for (0..wip.results.len) |i| { const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i); - const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i); + const rhs_elem_id = try wip.elementAt(shift_ty, rhs_id, i); - const shift_id = if (shift_ty_ref != wip.result_ty_ref) blk: { + // TODO: Can we omit normalizing lhs? + const lhs_norm_id = try self.normalizeInt(wip.scalar_ty_ref, lhs_elem_id, info); + + // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, + // so just manually upcast it if required. + const shift_id = if (scalar_shift_ty_ref != wip.scalar_ty_ref) blk: { const shift_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ .id_result_type = wip.scalar_ty_id, @@ -2368,12 +2363,13 @@ const DeclGen = struct { }); break :blk shift_id; } else rhs_elem_id; + const shift_norm_id = try self.normalizeInt(wip.scalar_ty_ref, shift_id, info); const args = .{ .id_result_type = wip.scalar_ty_id, .id_result = wip.allocId(i), - .base = lhs_elem_id, - .shift = shift_id, + .base = lhs_norm_id, + .shift = shift_norm_id, }; if (result_ty.isSignedInt(mod)) { @@ -2680,6 +2676,88 @@ const DeclGen = struct { ); } + fn airShlOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + const mod = self.module; + const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; + const extra = self.air.extraData(Air.Bin, ty_pl.payload).data; + const lhs = try self.resolve(extra.lhs); + const rhs = try self.resolve(extra.rhs); + + const result_ty = self.typeOfIndex(inst); + const operand_ty = self.typeOf(extra.lhs); + const shift_ty = self.typeOf(extra.rhs); + const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct); + + const ov_ty = result_ty.structFieldType(1, self.module); + + const bool_ty_ref = try self.resolveType(Type.bool, .direct); + + const info = try self.arithmeticTypeInfo(operand_ty); + switch (info.class) { + .composite_integer => return self.todo("overflow shift for composite integers", .{}), + .integer, .strange_integer => {}, + .float, .bool => unreachable, + } + + var wip_result = try self.elementWise(operand_ty); + defer wip_result.deinit(); + var wip_ov = try self.elementWise(ov_ty); + defer wip_ov.deinit(); + for (0..wip_result.results.len, wip_ov.results) |i, *ov_id| { + const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i); + const rhs_elem_id = try wip_result.elementAt(shift_ty, rhs, i); + + // Normalize both so that we can shift back and check if the result is the same. + const lhs_norm_id = try self.normalizeInt(wip_result.scalar_ty_ref, lhs_elem_id, info); + + // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, + // so just manually upcast it if required. + const shift_id = if (scalar_shift_ty_ref != wip_result.scalar_ty_ref) blk: { + const shift_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ + .id_result_type = wip_result.scalar_ty_id, + .id_result = shift_id, + .unsigned_value = rhs_elem_id, + }); + break :blk shift_id; + } else rhs_elem_id; + const shift_norm_id = try self.normalizeInt(wip_result.scalar_ty_ref, shift_id, info); + + try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{ + .id_result_type = wip_result.scalar_ty_id, + .id_result = wip_result.allocId(i), + .base = lhs_norm_id, + .shift = shift_norm_id, + }); + + // To check if overflow happened, just check if the right-shifted result is the same value. + const right_shift_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpShiftRightLogical, .{ + .id_result_type = wip_result.scalar_ty_id, + .id_result = right_shift_id, + .base = try self.normalizeInt(wip_result.scalar_ty_ref, wip_result.results[i], info), + .shift = shift_norm_id, + }); + + const overflowed_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = overflowed_id, + .operand_1 = lhs_norm_id, + .operand_2 = right_shift_id, + }); + + ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id); + } + + return try self.constructStruct( + result_ty, + &.{ operand_ty, ov_ty }, + &.{ try wip_result.finalize(), try wip_ov.finalize() }, + ); + } + fn airShuffle(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const mod = self.module; if (self.liveness.isUnused(inst)) return null; diff --git a/test/behavior/math.zig b/test/behavior/math.zig index 93c467eb5d..3aa65dddbb 100644 --- a/test/behavior/math.zig +++ b/test/behavior/math.zig @@ -1328,8 +1328,6 @@ fn testShlTrunc(x: u16) !void { } test "exact shift left" { - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; - try testShlExact(0b00110101); try comptime testShlExact(0b00110101); diff --git a/test/behavior/vector.zig b/test/behavior/vector.zig index 0c28b519b3..d02f0b6515 100644 --- a/test/behavior/vector.zig +++ b/test/behavior/vector.zig @@ -179,7 +179,6 @@ test "array vector coercion - odd sizes" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; @@ -219,7 +218,6 @@ test "array to vector with element type coercion" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest; @@ -659,7 +657,6 @@ test "vector shift operators" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn doTheTestShift(x: anytype, y: anytype) !void { @@ -1168,7 +1165,6 @@ test "@shlWithOverflow" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn doTheTest() !void { @@ -1453,7 +1449,6 @@ test "compare vectors with different element types" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO var a: @Vector(2, u8) = .{ 1, 2 }; var b: @Vector(2, u9) = .{ 3, 0 };