diff --git a/.gitignore b/.gitignore index b0b6e7cb61..852d1613c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ build/ +build2/ build-release/ /.cproject /.project diff --git a/src/all_types.hpp b/src/all_types.hpp index 0fc74d2212..3deeff5980 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1257,6 +1257,7 @@ enum ZigLLVMFnId { ZigLLVMFnIdClz, ZigLLVMFnIdOverflowArithmetic, ZigLLVMFnIdOverflowArithmeticPanic, + ZigLLVMFnIdBoundsCheck, }; enum AddSubMul { @@ -1280,6 +1281,10 @@ struct ZigLLVMFnKey { uint32_t bit_count; bool is_signed; } overflow_arithmetic; + struct { + LLVMIntPredicate pred; + uint32_t bit_count; + } bounds_check; } data; }; diff --git a/src/analyze.cpp b/src/analyze.cpp index a190ee88b7..facc700d0c 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -4145,6 +4145,9 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) { return ((uint32_t)(x.data.overflow_arithmetic.bit_count) * 3329604261) + ((uint32_t)(x.data.overflow_arithmetic.add_sub_mul) * 966805797) + ((uint32_t)(x.data.overflow_arithmetic.is_signed) ? 3679835291 : 1187552903); + case ZigLLVMFnIdBoundsCheck: + return (uint32_t)(x.data.bounds_check.pred) * (uint32_t)3146725107 + + x.data.bounds_check.bit_count * (uint32_t)2904561957; } zig_unreachable(); } @@ -4162,6 +4165,9 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) { return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) && (a.data.overflow_arithmetic.add_sub_mul == b.data.overflow_arithmetic.add_sub_mul) && (a.data.overflow_arithmetic.is_signed == b.data.overflow_arithmetic.is_signed); + case ZigLLVMFnIdBoundsCheck: + return a.data.bounds_check.pred == b.data.bounds_check.pred && + a.data.bounds_check.bit_count == b.data.bounds_check.bit_count; } zig_unreachable(); } diff --git a/src/codegen.cpp b/src/codegen.cpp index a5411e8df1..cca0037402 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -812,37 +812,102 @@ static void gen_debug_safety_crash_for_err(CodeGen *g, LLVMValueRef err_val) { LLVMBuildUnreachable(g->builder); } -static void add_bounds_check(CodeGen *g, LLVMValueRef target_val, - LLVMIntPredicate lower_pred, LLVMValueRef lower_value, - LLVMIntPredicate upper_pred, LLVMValueRef upper_value) -{ - if (!lower_value && !upper_value) { - return; - } - if (upper_value && !lower_value) { - lower_value = upper_value; - lower_pred = upper_pred; - upper_value = nullptr; +static const char *pred_name(LLVMIntPredicate pred) { + switch (pred) { + case LLVMIntEQ: return "eq"; + case LLVMIntNE: return "ne"; + case LLVMIntULT: return "lt"; + case LLVMIntULE: return "le"; + default: + zig_unreachable(); } +} - LLVMBasicBlockRef bounds_check_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "BoundsCheckFail"); - LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "BoundsCheckOk"); - LLVMBasicBlockRef lower_ok_block = upper_value ? - LLVMAppendBasicBlock(g->cur_fn_val, "FirstBoundsCheckOk") : ok_block; +static LLVMValueRef get_bounds_check_fn_val(CodeGen *g, LLVMIntPredicate pred, uint32_t bit_count) { + ZigLLVMFnKey key = {}; + key.id = ZigLLVMFnIdBoundsCheck; + key.data.bounds_check.pred = pred; + key.data.bounds_check.bit_count = bit_count; - LLVMValueRef lower_ok_val = LLVMBuildICmp(g->builder, lower_pred, target_val, lower_value, ""); - LLVMBuildCondBr(g->builder, lower_ok_val, lower_ok_block, bounds_check_fail_block); + auto existing_entry = g->llvm_fn_table.maybe_get(key); + if (existing_entry) + return existing_entry->value; + + Buf *desired_name = buf_sprintf("__zig_bounds_check_%s_%" PRIu32, pred_name(pred), bit_count); + Buf *fn_name = get_mangled_name(g, desired_name, false); + LLVMTypeRef type_ref = LLVMIntType(bit_count); + LLVMTypeRef arg_types[] = { type_ref, type_ref }; + LLVMTypeRef fn_type_ref = LLVMFunctionType(LLVMVoidType(), arg_types, 2, false); + LLVMValueRef fn_val = LLVMAddFunction(g->module, buf_ptr(fn_name), fn_type_ref); + LLVMSetLinkage(fn_val, LLVMInternalLinkage); + LLVMSetFunctionCallConv(fn_val, LLVMFastCallConv); + + auto prev_state = save_and_clear_builder_state(g); + + LLVMBasicBlockRef entry_block = LLVMAppendBasicBlock(fn_val, "Entry"); + LLVMPositionBuilderAtEnd(g->builder, entry_block); + + LLVMValueRef target_val = LLVMGetParam(fn_val, 0); + LLVMValueRef bound_val = LLVMGetParam(fn_val, 1); + + LLVMBasicBlockRef bounds_check_fail_block = LLVMAppendBasicBlock(fn_val, "BoundsCheckFail"); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(fn_val, "BoundsCheckOk"); + + LLVMValueRef ok_val = LLVMBuildICmp(g->builder, pred, target_val, bound_val, ""); + LLVMBuildCondBr(g->builder, ok_val, ok_block, bounds_check_fail_block); LLVMPositionBuilderAtEnd(g->builder, bounds_check_fail_block); gen_debug_safety_crash(g, PanicMsgIdBoundsCheckFailure); - if (upper_value) { - LLVMPositionBuilderAtEnd(g->builder, lower_ok_block); - LLVMValueRef upper_ok_val = LLVMBuildICmp(g->builder, upper_pred, target_val, upper_value, ""); - LLVMBuildCondBr(g->builder, upper_ok_val, ok_block, bounds_check_fail_block); + LLVMPositionBuilderAtEnd(g->builder, ok_block); + LLVMBuildRetVoid(g->builder); + + restore_builder_state(g, prev_state); + g->llvm_fn_table.put(key, fn_val); + return fn_val; +} + +static void add_one_bounds_check(CodeGen *g, LLVMValueRef target_val, LLVMIntPredicate pred, LLVMValueRef bound_val) { + LLVMValueRef arg1; + LLVMValueRef arg2; + switch (pred) { + case LLVMIntEQ: + case LLVMIntNE: + case LLVMIntULT: + case LLVMIntULE: + arg1 = target_val; + arg2 = bound_val; + break; + case LLVMIntUGT: + arg1 = bound_val; + arg2 = target_val; + pred = LLVMIntULE; + break; + case LLVMIntUGE: + arg1 = bound_val; + arg2 = target_val; + pred = LLVMIntULT; + break; + default: + zig_unreachable(); + } + uint32_t bit_count = LLVMGetIntTypeWidth(LLVMTypeOf(target_val)); + LLVMValueRef fn_val = get_bounds_check_fn_val(g, pred, bit_count); + LLVMValueRef params[] = { arg1, arg2, }; + LLVMBuildCall(g->builder, fn_val, params, 2, ""); +} + +static void add_bounds_check(CodeGen *g, LLVMValueRef target_val, + LLVMIntPredicate lower_pred, LLVMValueRef lower_value, + LLVMIntPredicate upper_pred, LLVMValueRef upper_value) +{ + if (lower_value) { + add_one_bounds_check(g, target_val, lower_pred, lower_value); } - LLVMPositionBuilderAtEnd(g->builder, ok_block); + if (upper_value) { + add_one_bounds_check(g, target_val, upper_pred, upper_value); + } } static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_debug_safety, TypeTableEntry *actual_type,