mirror of
https://codeberg.org/ziglang/zig.git
synced 2025-12-06 13:54:21 +00:00
ir: Implement more safety checks for shl/shr
The checks are now valid on types whose size is not a power of two. Closes #2096
This commit is contained in:
parent
9c4dc7b1bb
commit
300fceac6e
5 changed files with 120 additions and 27 deletions
|
|
@ -1834,6 +1834,7 @@ enum PanicMsgId {
|
||||||
PanicMsgIdBadNoAsyncCall,
|
PanicMsgIdBadNoAsyncCall,
|
||||||
PanicMsgIdResumeNotSuspendedFn,
|
PanicMsgIdResumeNotSuspendedFn,
|
||||||
PanicMsgIdBadSentinel,
|
PanicMsgIdBadSentinel,
|
||||||
|
PanicMsgIdShxTooBigRhs,
|
||||||
|
|
||||||
PanicMsgIdCount,
|
PanicMsgIdCount,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -974,6 +974,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
|
||||||
return buf_create_from_str("resumed a non-suspended function");
|
return buf_create_from_str("resumed a non-suspended function");
|
||||||
case PanicMsgIdBadSentinel:
|
case PanicMsgIdBadSentinel:
|
||||||
return buf_create_from_str("sentinel mismatch");
|
return buf_create_from_str("sentinel mismatch");
|
||||||
|
case PanicMsgIdShxTooBigRhs:
|
||||||
|
return buf_create_from_str("shift amount is greater than the type size");
|
||||||
}
|
}
|
||||||
zig_unreachable();
|
zig_unreachable();
|
||||||
}
|
}
|
||||||
|
|
@ -2841,6 +2843,26 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type, LLVMValueRef value) {
|
||||||
|
// We only check if the rhs value of the shift expression is greater or
|
||||||
|
// equal to the number of bits of the lhs if it's not a power of two,
|
||||||
|
// otherwise the check is useful as the allowed values are limited by the
|
||||||
|
// operand type itself
|
||||||
|
if (!is_power_of_2(lhs_type->data.integral.bit_count)) {
|
||||||
|
LLVMValueRef bit_count_value = LLVMConstInt(get_llvm_type(g, rhs_type),
|
||||||
|
lhs_type->data.integral.bit_count, false);
|
||||||
|
LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
|
||||||
|
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckFail");
|
||||||
|
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
|
||||||
|
LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);
|
||||||
|
|
||||||
|
LLVMPositionBuilderAtEnd(g->builder, fail_block);
|
||||||
|
gen_safety_crash(g, PanicMsgIdShxTooBigRhs);
|
||||||
|
|
||||||
|
LLVMPositionBuilderAtEnd(g->builder, ok_block);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
|
static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
|
||||||
IrInstGenBinOp *bin_op_instruction)
|
IrInstGenBinOp *bin_op_instruction)
|
||||||
{
|
{
|
||||||
|
|
@ -2949,6 +2971,11 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
|
||||||
{
|
{
|
||||||
assert(scalar_type->id == ZigTypeIdInt);
|
assert(scalar_type->id == ZigTypeIdInt);
|
||||||
LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
|
LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
|
||||||
|
|
||||||
|
if (want_runtime_safety) {
|
||||||
|
gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value);
|
||||||
|
}
|
||||||
|
|
||||||
bool is_sloppy = (op_id == IrBinOpBitShiftLeftLossy);
|
bool is_sloppy = (op_id == IrBinOpBitShiftLeftLossy);
|
||||||
if (is_sloppy) {
|
if (is_sloppy) {
|
||||||
return LLVMBuildShl(g->builder, op1_value, op2_casted, "");
|
return LLVMBuildShl(g->builder, op1_value, op2_casted, "");
|
||||||
|
|
@ -2965,6 +2992,11 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
|
||||||
{
|
{
|
||||||
assert(scalar_type->id == ZigTypeIdInt);
|
assert(scalar_type->id == ZigTypeIdInt);
|
||||||
LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
|
LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
|
||||||
|
|
||||||
|
if (want_runtime_safety) {
|
||||||
|
gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value);
|
||||||
|
}
|
||||||
|
|
||||||
bool is_sloppy = (op_id == IrBinOpBitShiftRightLossy);
|
bool is_sloppy = (op_id == IrBinOpBitShiftRightLossy);
|
||||||
if (is_sloppy) {
|
if (is_sloppy) {
|
||||||
if (scalar_type->data.integral.is_signed) {
|
if (scalar_type->data.integral.is_signed) {
|
||||||
|
|
|
||||||
46
src/ir.cpp
46
src/ir.cpp
|
|
@ -16648,36 +16648,34 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
|
||||||
return ira->codegen->invalid_inst_gen;
|
return ira->codegen->invalid_inst_gen;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
assert(op1->value->type->data.integral.bit_count > 0);
|
||||||
ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
|
ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
|
||||||
op1->value->type->data.integral.bit_count - 1);
|
op1->value->type->data.integral.bit_count - 1);
|
||||||
if (bin_op_instruction->op_id == IrBinOpBitShiftLeftLossy &&
|
|
||||||
op2->value->type->id == ZigTypeIdComptimeInt) {
|
|
||||||
|
|
||||||
ZigValue *op2_val = ir_resolve_const(ira, op2, UndefBad);
|
|
||||||
if (op2_val == nullptr)
|
|
||||||
return ira->codegen->invalid_inst_gen;
|
|
||||||
if (!bigint_fits_in_bits(&op2_val->data.x_bigint,
|
|
||||||
shift_amt_type->data.integral.bit_count,
|
|
||||||
op2_val->data.x_bigint.is_negative)) {
|
|
||||||
Buf *val_buf = buf_alloc();
|
|
||||||
bigint_append_buf(val_buf, &op2_val->data.x_bigint, 10);
|
|
||||||
ErrorMsg* msg = ir_add_error(ira,
|
|
||||||
&bin_op_instruction->base.base,
|
|
||||||
buf_sprintf("RHS of shift is too large for LHS type"));
|
|
||||||
add_error_note(
|
|
||||||
ira->codegen,
|
|
||||||
msg,
|
|
||||||
op2->base.source_node,
|
|
||||||
buf_sprintf("value %s cannot fit into type %s",
|
|
||||||
buf_ptr(val_buf),
|
|
||||||
buf_ptr(&shift_amt_type->name)));
|
|
||||||
return ira->codegen->invalid_inst_gen;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
|
casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
|
||||||
if (type_is_invalid(casted_op2->value->type))
|
if (type_is_invalid(casted_op2->value->type))
|
||||||
return ira->codegen->invalid_inst_gen;
|
return ira->codegen->invalid_inst_gen;
|
||||||
|
|
||||||
|
if (instr_is_comptime(casted_op2)) {
|
||||||
|
ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
|
||||||
|
if (op2_val == nullptr)
|
||||||
|
return ira->codegen->invalid_inst_gen;
|
||||||
|
|
||||||
|
BigInt bit_count_value = {0};
|
||||||
|
bigint_init_unsigned(&bit_count_value, op1->value->type->data.integral.bit_count);
|
||||||
|
|
||||||
|
if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) {
|
||||||
|
ErrorMsg* msg = ir_add_error(ira,
|
||||||
|
&bin_op_instruction->base.base,
|
||||||
|
buf_sprintf("RHS of shift is too large for LHS type"));
|
||||||
|
add_error_note(ira->codegen, msg, op1->base.source_node,
|
||||||
|
buf_sprintf("type %s has only %u bits",
|
||||||
|
buf_ptr(&op1->value->type->name),
|
||||||
|
op1->value->type->data.integral.bit_count));
|
||||||
|
|
||||||
|
return ira->codegen->invalid_inst_gen;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (instr_is_comptime(op1) && instr_is_comptime(casted_op2)) {
|
if (instr_is_comptime(op1) && instr_is_comptime(casted_op2)) {
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,38 @@ const tests = @import("tests.zig");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
pub fn addCases(cases: *tests.CompileErrorContext) void {
|
pub fn addCases(cases: *tests.CompileErrorContext) void {
|
||||||
|
cases.addTest("shift on type with non-power-of-two size",
|
||||||
|
\\export fn entry() void {
|
||||||
|
\\ const S = struct {
|
||||||
|
\\ fn a() void {
|
||||||
|
\\ var x: u24 = 42;
|
||||||
|
\\ _ = x >> 24;
|
||||||
|
\\ }
|
||||||
|
\\ fn b() void {
|
||||||
|
\\ var x: u24 = 42;
|
||||||
|
\\ _ = x << 24;
|
||||||
|
\\ }
|
||||||
|
\\ fn c() void {
|
||||||
|
\\ var x: u24 = 42;
|
||||||
|
\\ _ = @shlExact(x, 24);
|
||||||
|
\\ }
|
||||||
|
\\ fn d() void {
|
||||||
|
\\ var x: u24 = 42;
|
||||||
|
\\ _ = @shrExact(x, 24);
|
||||||
|
\\ }
|
||||||
|
\\ };
|
||||||
|
\\ S.a();
|
||||||
|
\\ S.b();
|
||||||
|
\\ S.c();
|
||||||
|
\\ S.d();
|
||||||
|
\\}
|
||||||
|
, &[_][]const u8{
|
||||||
|
"tmp.zig:5:19: error: RHS of shift is too large for LHS type",
|
||||||
|
"tmp.zig:9:19: error: RHS of shift is too large for LHS type",
|
||||||
|
"tmp.zig:13:17: error: RHS of shift is too large for LHS type",
|
||||||
|
"tmp.zig:17:17: error: RHS of shift is too large for LHS type",
|
||||||
|
});
|
||||||
|
|
||||||
cases.addTest("combination of noasync and async",
|
cases.addTest("combination of noasync and async",
|
||||||
\\export fn entry() void {
|
\\export fn entry() void {
|
||||||
\\ noasync {
|
\\ noasync {
|
||||||
|
|
@ -4029,8 +4061,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
|
||||||
\\}
|
\\}
|
||||||
\\export fn entry() u16 { return f(); }
|
\\export fn entry() u16 { return f(); }
|
||||||
, &[_][]const u8{
|
, &[_][]const u8{
|
||||||
"tmp.zig:3:14: error: RHS of shift is too large for LHS type",
|
"tmp.zig:3:17: error: integer value 8 cannot be coerced to type 'u3'",
|
||||||
"tmp.zig:3:17: note: value 8 cannot fit into type u3",
|
|
||||||
});
|
});
|
||||||
|
|
||||||
cases.add("missing function call param",
|
cases.add("missing function call param",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,37 @@
|
||||||
const tests = @import("tests.zig");
|
const tests = @import("tests.zig");
|
||||||
|
|
||||||
pub fn addCases(cases: *tests.CompareOutputContext) void {
|
pub fn addCases(cases: *tests.CompareOutputContext) void {
|
||||||
|
cases.addRuntimeSafety("shift left by huge amount",
|
||||||
|
\\const std = @import("std");
|
||||||
|
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
|
||||||
|
\\ std.debug.warn("{}\n", .{message});
|
||||||
|
\\ if (std.mem.eql(u8, message, "shift amount is greater than the type size")) {
|
||||||
|
\\ std.process.exit(126); // good
|
||||||
|
\\ }
|
||||||
|
\\ std.process.exit(0); // test failed
|
||||||
|
\\}
|
||||||
|
\\pub fn main() void {
|
||||||
|
\\ var x: u24 = 42;
|
||||||
|
\\ var y: u5 = 24;
|
||||||
|
\\ var z = x >> y;
|
||||||
|
\\}
|
||||||
|
);
|
||||||
|
|
||||||
|
cases.addRuntimeSafety("shift right by huge amount",
|
||||||
|
\\const std = @import("std");
|
||||||
|
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
|
||||||
|
\\ if (std.mem.eql(u8, message, "shift amount is greater than the type size")) {
|
||||||
|
\\ std.process.exit(126); // good
|
||||||
|
\\ }
|
||||||
|
\\ std.process.exit(0); // test failed
|
||||||
|
\\}
|
||||||
|
\\pub fn main() void {
|
||||||
|
\\ var x: u24 = 42;
|
||||||
|
\\ var y: u5 = 24;
|
||||||
|
\\ var z = x << y;
|
||||||
|
\\}
|
||||||
|
);
|
||||||
|
|
||||||
cases.addRuntimeSafety("slice sentinel mismatch - optional pointers",
|
cases.addRuntimeSafety("slice sentinel mismatch - optional pointers",
|
||||||
\\const std = @import("std");
|
\\const std = @import("std");
|
||||||
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
|
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue