mirror of
https://codeberg.org/ziglang/zig.git
synced 2025-12-06 05:44:20 +00:00
Merge pull request #22529 from xtexx/x86-64/shl-sat-int
x86_64: Implement integer saturating left shifting codegen
This commit is contained in:
commit
a6525c1762
2 changed files with 184 additions and 6 deletions
|
|
@ -85078,10 +85078,132 @@ fn airShlShrBinOp(self: *CodeGen, inst: Air.Inst.Index) !void {
|
|||
}
|
||||
|
||||
fn airShlSat(self: *CodeGen, inst: Air.Inst.Index) !void {
|
||||
const zcu = self.pt.zcu;
|
||||
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
|
||||
_ = bin_op;
|
||||
return self.fail("TODO implement shl_sat for {}", .{self.target.cpu.arch});
|
||||
//return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
|
||||
const lhs_ty = self.typeOf(bin_op.lhs);
|
||||
const rhs_ty = self.typeOf(bin_op.rhs);
|
||||
|
||||
const result: MCValue = result: {
|
||||
switch (lhs_ty.zigTypeTag(zcu)) {
|
||||
.int => {
|
||||
const lhs_bits = lhs_ty.bitSize(zcu);
|
||||
const rhs_bits = rhs_ty.bitSize(zcu);
|
||||
if (!(lhs_bits <= 32 and rhs_bits <= 5) and !(lhs_bits > 32 and lhs_bits <= 64 and rhs_bits <= 6) and !(rhs_bits <= std.math.log2(lhs_bits))) {
|
||||
return self.fail("TODO implement shl_sat for {} with lhs bits {}, rhs bits {}", .{ self.target.cpu.arch, lhs_bits, rhs_bits });
|
||||
}
|
||||
|
||||
// clobberred by genShiftBinOp
|
||||
try self.spillRegisters(&.{.rcx});
|
||||
|
||||
const lhs_mcv = try self.resolveInst(bin_op.lhs);
|
||||
var lhs_temp1 = try self.tempInit(lhs_ty, lhs_mcv);
|
||||
const rhs_mcv = try self.resolveInst(bin_op.rhs);
|
||||
|
||||
const lhs_lock = switch (lhs_mcv) {
|
||||
.register => |reg| self.register_manager.lockRegAssumeUnused(reg),
|
||||
else => null,
|
||||
};
|
||||
defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
|
||||
|
||||
// shift left
|
||||
const dst_mcv = try self.genShiftBinOp(.shl, null, lhs_mcv, rhs_mcv, lhs_ty, rhs_ty);
|
||||
switch (dst_mcv) {
|
||||
.register => |dst_reg| try self.truncateRegister(lhs_ty, dst_reg),
|
||||
.register_pair => |dst_regs| try self.truncateRegister(lhs_ty, dst_regs[1]),
|
||||
.load_frame => |frame_addr| {
|
||||
const tmp_reg =
|
||||
try self.register_manager.allocReg(null, abi.RegisterClass.gp);
|
||||
const tmp_lock = self.register_manager.lockRegAssumeUnused(tmp_reg);
|
||||
defer self.register_manager.unlockReg(tmp_lock);
|
||||
|
||||
const lhs_bits_u31: u31 = @intCast(lhs_bits);
|
||||
const tmp_ty: Type = if (lhs_bits_u31 > 64) .usize else lhs_ty;
|
||||
const off = frame_addr.off + (lhs_bits_u31 - 1) / 64 * 8;
|
||||
try self.genSetReg(
|
||||
tmp_reg,
|
||||
tmp_ty,
|
||||
.{ .load_frame = .{ .index = frame_addr.index, .off = off } },
|
||||
.{},
|
||||
);
|
||||
try self.truncateRegister(lhs_ty, tmp_reg);
|
||||
try self.genSetMem(
|
||||
.{ .frame = frame_addr.index },
|
||||
off,
|
||||
tmp_ty,
|
||||
.{ .register = tmp_reg },
|
||||
.{},
|
||||
);
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
const dst_lock = switch (dst_mcv) {
|
||||
.register => |reg| self.register_manager.lockRegAssumeUnused(reg),
|
||||
else => null,
|
||||
};
|
||||
defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
|
||||
|
||||
// shift right
|
||||
const tmp_mcv = try self.genShiftBinOp(.shr, null, dst_mcv, rhs_mcv, lhs_ty, rhs_ty);
|
||||
var tmp_temp = try self.tempInit(lhs_ty, tmp_mcv);
|
||||
|
||||
// check if overflow happens
|
||||
const cc_temp = lhs_temp1.cmpInts(.neq, &tmp_temp, self) catch |err| switch (err) {
|
||||
error.SelectFailed => unreachable,
|
||||
else => |e| return e,
|
||||
};
|
||||
try lhs_temp1.die(self);
|
||||
try tmp_temp.die(self);
|
||||
const overflow_reloc = try self.genCondBrMir(lhs_ty, cc_temp.tracking(self).short);
|
||||
try cc_temp.die(self);
|
||||
|
||||
// if overflow,
|
||||
// for unsigned integers, the saturating result is just its max
|
||||
// for signed integers,
|
||||
// if lhs is positive, the result is its max
|
||||
// if lhs is negative, it is min
|
||||
switch (lhs_ty.intInfo(zcu).signedness) {
|
||||
.unsigned => {
|
||||
const bound_mcv = try self.genTypedValue(try lhs_ty.maxIntScalar(self.pt, lhs_ty));
|
||||
try self.genCopy(lhs_ty, dst_mcv, bound_mcv, .{});
|
||||
},
|
||||
.signed => {
|
||||
// check the sign of lhs
|
||||
// TODO: optimize this.
|
||||
// we only need the highest bit so shifting the highest part of lhs_mcv
|
||||
// is enough to check the signedness. other parts can be skipped here.
|
||||
var lhs_temp2 = try self.tempInit(lhs_ty, lhs_mcv);
|
||||
var zero_temp = try self.tempInit(lhs_ty, try self.genTypedValue(try self.pt.intValue(lhs_ty, 0)));
|
||||
const sign_cc_temp = lhs_temp2.cmpInts(.lt, &zero_temp, self) catch |err| switch (err) {
|
||||
error.SelectFailed => unreachable,
|
||||
else => |e| return e,
|
||||
};
|
||||
try lhs_temp2.die(self);
|
||||
try zero_temp.die(self);
|
||||
const sign_reloc_condbr = try self.genCondBrMir(lhs_ty, sign_cc_temp.tracking(self).short);
|
||||
try sign_cc_temp.die(self);
|
||||
|
||||
// if it is negative
|
||||
const min_mcv = try self.genTypedValue(try lhs_ty.minIntScalar(self.pt, lhs_ty));
|
||||
try self.genCopy(lhs_ty, dst_mcv, min_mcv, .{});
|
||||
const sign_reloc_br = try self.asmJmpReloc(undefined);
|
||||
self.performReloc(sign_reloc_condbr);
|
||||
|
||||
// if it is positive
|
||||
const max_mcv = try self.genTypedValue(try lhs_ty.maxIntScalar(self.pt, lhs_ty));
|
||||
try self.genCopy(lhs_ty, dst_mcv, max_mcv, .{});
|
||||
self.performReloc(sign_reloc_br);
|
||||
},
|
||||
}
|
||||
|
||||
self.performReloc(overflow_reloc);
|
||||
break :result dst_mcv;
|
||||
},
|
||||
else => {
|
||||
return self.fail("TODO implement shl_sat for {} op type {}", .{ self.target.cpu.arch, lhs_ty.zigTypeTag(zcu) });
|
||||
},
|
||||
}
|
||||
};
|
||||
return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
|
||||
}
|
||||
|
||||
fn airOptionalPayload(self: *CodeGen, inst: Air.Inst.Index) !void {
|
||||
|
|
@ -88466,7 +88588,7 @@ fn genShiftBinOpMir(
|
|||
) !void {
|
||||
const pt = self.pt;
|
||||
const zcu = pt.zcu;
|
||||
const abi_size: u32 = @intCast(lhs_ty.abiSize(zcu));
|
||||
const abi_size: u31 = @intCast(lhs_ty.abiSize(zcu));
|
||||
const shift_abi_size: u32 = @intCast(rhs_ty.abiSize(zcu));
|
||||
try self.spillEflagsIfOccupied();
|
||||
|
||||
|
|
@ -88650,7 +88772,17 @@ fn genShiftBinOpMir(
|
|||
.immediate => {},
|
||||
else => self.performReloc(skip),
|
||||
}
|
||||
}
|
||||
} else try self.asmRegisterMemory(.{ ._, .mov }, temp_regs[2].to64(), .{
|
||||
.base = .{ .frame = lhs_mcv.load_frame.index },
|
||||
.mod = .{ .rm = .{
|
||||
.size = .qword,
|
||||
.disp = switch (tag[0]) {
|
||||
._l => lhs_mcv.load_frame.off,
|
||||
._r => lhs_mcv.load_frame.off + abi_size - 8,
|
||||
else => unreachable,
|
||||
},
|
||||
} },
|
||||
});
|
||||
switch (rhs_mcv) {
|
||||
.immediate => |shift_imm| try self.asmRegisterImmediate(
|
||||
tag,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const std = @import("std");
|
||||
const expect = std.testing.expect;
|
||||
const expectEqual = std.testing.expectEqual;
|
||||
const builtin = @import("builtin");
|
||||
|
||||
fn ShardedTable(comptime Key: type, comptime mask_bit_count: comptime_int, comptime V: type) type {
|
||||
|
|
@ -111,7 +112,6 @@ test "comptime shift safety check" {
|
|||
}
|
||||
|
||||
test "Saturating Shift Left where lhs is of a computed type" {
|
||||
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
|
||||
|
|
@ -159,3 +159,49 @@ comptime {
|
|||
_ = ℑ
|
||||
_ = @shlExact(@as(u16, image[0]), 8);
|
||||
}
|
||||
|
||||
test "Saturating Shift Left" {
|
||||
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
|
||||
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
|
||||
|
||||
const S = struct {
|
||||
fn shlSat(x: anytype, y: std.math.Log2Int(@TypeOf(x))) @TypeOf(x) {
|
||||
// workaround https://github.com/ziglang/zig/issues/23033
|
||||
@setRuntimeSafety(false);
|
||||
return x <<| y;
|
||||
}
|
||||
|
||||
fn testType(comptime T: type) !void {
|
||||
comptime var rhs: std.math.Log2Int(T) = 0;
|
||||
inline while (true) : (rhs += 1) {
|
||||
comptime var lhs: T = std.math.minInt(T);
|
||||
inline while (true) : (lhs += 1) {
|
||||
try expectEqual(lhs <<| rhs, shlSat(lhs, rhs));
|
||||
if (lhs == std.math.maxInt(T)) break;
|
||||
}
|
||||
if (rhs == @bitSizeOf(T) - 1) break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
try S.testType(u2);
|
||||
try S.testType(i2);
|
||||
try S.testType(u3);
|
||||
try S.testType(i3);
|
||||
try S.testType(u4);
|
||||
try S.testType(i4);
|
||||
|
||||
try expectEqual(0xfffffffffffffff0fffffffffffffff0, S.shlSat(@as(u128, 0x0fffffffffffffff0fffffffffffffff), 4));
|
||||
try expectEqual(0xffffffffffffffffffffffffffffffff, S.shlSat(@as(u128, 0x0fffffffffffffff0fffffffffffffff), 5));
|
||||
try expectEqual(-0x80000000000000000000000000000000, S.shlSat(@as(i128, -0x0fffffffffffffff0fffffffffffffff), 5));
|
||||
|
||||
// TODO
|
||||
// try expectEqual(51146728248377216718956089012931236753385031969422887335676427626502090568823039920051095192592252455482604439493126109519019633529459266458258243583, S.shlSat(@as(i495, 0x2fe6bc5448c55ce18252e2c9d44777505dfe63ff249a8027a6626c7d8dd9893fd5731e51474727be556f757facb586a4e04bbc0148c6c7ad692302f46fbd), 0x31));
|
||||
try expectEqual(-57896044618658097711785492504343953926634992332820282019728792003956564819968, S.shlSat(@as(i256, -0x53d4148cee74ea43477a65b3daa7b8fdadcbf4508e793f4af113b8d8da5a7eb6), 0x91));
|
||||
try expectEqual(170141183460469231731687303715884105727, S.shlSat(@as(i128, 0x2fe6bc5448c55ce18252e2c9d4477750), 0x31));
|
||||
try expectEqual(0, S.shlSat(@as(i128, 0), 127));
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue