AstGen: rework multi-object for loop

* Allow unbounded looping.
* Lower by incrementing raw pointers for each iterable rather than
  incrementing a single index variable. This elides safety checks
  without any analysis required thanks to the length assertion and
  lowers to decent machine code even in debug builds.
  - An "end" value is selected, prioritizing a counter if possible,
    falling back to a runtime calculation of ptr+len on a slice input.
* Specialize on the pattern `0..`, avoiding an unnecessary subtraction
  instruction being emitted.
* Add the `for_check_lens` ZIR instruction.
This commit is contained in:
Andrew Kelley 2023-02-17 11:51:22 -07:00
parent 6733e43d87
commit faa44e2e58
4 changed files with 156 additions and 55 deletions

View file

@ -2666,6 +2666,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
.validate_deref,
.save_err_ret_index,
.restore_err_ret_index,
.for_check_lens,
=> break :b true,
.@"defer" => unreachable,
@ -6294,37 +6295,35 @@ fn forExpr(
try astgen.checkLabelRedefinition(scope, label_token);
}
// Set up variables and constants.
const is_inline = parent_gz.force_comptime or for_full.inline_token != null;
const tree = astgen.tree;
const token_tags = tree.tokens.items(.tag);
const node_tags = tree.nodes.items(.tag);
const node_data = tree.nodes.items(.data);
const gpa = astgen.gpa;
// Check for unterminated ranges.
{
var unterminated: ?Ast.Node.Index = null;
for (for_full.ast.inputs) |input| {
if (node_tags[input] != .for_range) break;
if (node_data[input].rhs != 0) break;
unterminated = unterminated orelse input;
} else {
return astgen.failNode(unterminated.?, "unterminated for range", .{});
}
}
var lens = astgen.gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len);
defer astgen.gpa.free(lens);
var indexables = astgen.gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len);
defer astgen.gpa.free(indexables);
var counters = std.ArrayList(Zir.Inst.Ref).init(astgen.gpa);
defer counters.deinit();
const allocs = try gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len);
defer gpa.free(allocs);
// elements of this array can be `none`, indicating no length check.
const lens = try gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len);
defer gpa.free(lens);
const counter_alloc_tag: Zir.Inst.Tag = if (is_inline) .alloc_comptime_mut else .alloc;
// Tracks the index of allocs/lens that has a length to be checked and is
// used for the end value.
// If this is null, there are no len checks.
var end_input_index: ?u32 = null;
// This is a value to use to find out if the for loop has reached the end
// yet. It prefers to use a counter since the end value is provided directly,
// and otherwise falls back to adding ptr+len of a slice to compute end.
// Corresponds to end_input_index and will be .none in case that value is null.
var cond_end_val: Zir.Inst.Ref = .none;
{
var payload = for_full.payload_token;
for (for_full.ast.inputs) |input, i| {
for (for_full.ast.inputs) |input, i_usize| {
const i = @intCast(u32, i_usize);
const payload_is_ref = token_tags[payload] == .asterisk;
const ident_tok = payload + @boolToInt(payload_is_ref);
@ -6339,59 +6338,101 @@ fn forExpr(
return astgen.failTok(ident_tok, "cannot capture reference to range", .{});
}
const counter_ptr = try parent_gz.addUnNode(counter_alloc_tag, .usize_type, node);
const start_val = try expr(parent_gz, scope, node_data[input].lhs, input);
const start_node = node_data[input].lhs;
const start_val = try expr(parent_gz, scope, .{ .rl = .none }, start_node);
_ = try parent_gz.addBin(.store, counter_ptr, start_val);
indexables[i] = counter_ptr;
try counters.append(counter_ptr);
const end_node = node_data[input].rhs;
const end_val = if (end_node != 0) try expr(parent_gz, scope, node_data[input].rhs, input) else .none;
const range_len = try parent_gz.addPlNode(.for_range_len, input, Zir.Inst.Bin{
.lhs = start_val,
.rhs = end_val,
});
const end_val = if (end_node != 0)
try expr(parent_gz, scope, .{ .rl = .none }, node_data[input].rhs)
else
.none;
const range_len = if (end_val == .none or nodeIsTriviallyZero(tree, start_node))
end_val
else
try parent_gz.addPlNode(.sub, input, Zir.Inst.Bin{
.lhs = end_val,
.rhs = start_val,
});
if (range_len != .none and cond_end_val == .none) {
end_input_index = i;
cond_end_val = end_val;
}
allocs[i] = counter_ptr;
lens[i] = range_len;
} else {
const cond_ri: ResultInfo = .{ .rl = if (payload_is_ref) .ref else .none };
const indexable = try expr(parent_gz, scope, cond_ri, input);
indexables[i] = indexable;
const base_ptr = try parent_gz.addPlNode(.elem_ptr_imm, input, Zir.Inst.ElemPtrImm{
.ptr = indexable,
.index = 0,
});
const indexable_len = try parent_gz.addUnNode(.indexable_ptr_len, indexable, input);
lens[i] = indexable_len;
if (end_input_index == null) {
end_input_index = i;
assert(cond_end_val == .none);
}
allocs[i] = base_ptr;
lens[i] = try parent_gz.addUnNode(.indexable_ptr_len, indexable, input);
}
}
}
const len = "check_for_lens";
// In case there are no counters which already have an end computed, we
// compute an end from base pointer plus length.
if (end_input_index) |i| {
if (cond_end_val == .none) {
cond_end_val = try parent_gz.addPlNode(.add, for_full.ast.inputs[i], Zir.Inst.Bin{
.lhs = allocs[i],
.rhs = lens[i],
});
}
}
const index_ptr = blk: {
// Future optimization:
// for loops with only ranges don't need a separate index variable.
const index_ptr = try parent_gz.addUnNode(counter_alloc_tag, .usize_type, node);
// initialize to zero
_ = try parent_gz.addBin(.store, index_ptr, .zero_usize);
try counters.append(index_ptr);
break :blk index_ptr;
};
// We use a dedicated ZIR instruction to assert the lengths to assist with
// nicer error reporting as well as fewer ZIR bytes emitted.
if (end_input_index != null) {
const lens_len = @intCast(u32, lens.len);
try astgen.extra.ensureUnusedCapacity(gpa, @typeInfo(Zir.Inst.MultiOp).Struct.fields.len + lens_len);
_ = try parent_gz.addPlNode(.for_check_lens, node, Zir.Inst.MultiOp{
.operands_len = lens_len,
});
appendRefsAssumeCapacity(astgen, lens);
}
const loop_tag: Zir.Inst.Tag = if (is_inline) .block_inline else .loop;
const loop_block = try parent_gz.makeBlockInst(loop_tag, node);
try parent_gz.instructions.append(astgen.gpa, loop_block);
try parent_gz.instructions.append(gpa, loop_block);
var loop_scope = parent_gz.makeSubBlock(scope);
loop_scope.is_inline = is_inline;
loop_scope.setBreakResultInfo(ri);
defer loop_scope.unstack();
defer loop_scope.labeled_breaks.deinit(astgen.gpa);
defer loop_scope.labeled_breaks.deinit(gpa);
var cond_scope = parent_gz.makeSubBlock(&loop_scope.base);
defer cond_scope.unstack();
// check condition i < array_expr.len
const index = try cond_scope.addUnNode(.load, index_ptr, for_full.ast.cond_expr);
const cond = try cond_scope.addPlNode(.cmp_lt, for_full.ast.cond_expr, Zir.Inst.Bin{
.lhs = index,
.rhs = len,
// Load all the iterables.
const loaded_ptrs = try gpa.alloc(Zir.Inst.Ref, allocs.len);
defer gpa.free(loaded_ptrs);
for (allocs) |alloc, i| {
loaded_ptrs[i] = try cond_scope.addUnNode(.load, alloc, for_full.ast.inputs[i]);
}
// Check the condition.
const input_index = end_input_index orelse {
return astgen.failNode(node, "TODO: handle infinite for loop", .{});
};
assert(cond_end_val != .none);
const cond = try cond_scope.addPlNode(.cmp_neq, for_full.ast.inputs[input_index], Zir.Inst.Bin{
.lhs = loaded_ptrs[input_index],
.rhs = cond_end_val,
});
const condbr_tag: Zir.Inst.Tag = if (is_inline) .condbr_inline else .condbr;
@ -6400,16 +6441,15 @@ fn forExpr(
const cond_block = try loop_scope.makeBlockInst(block_tag, node);
try cond_scope.setBlockBody(cond_block);
// cond_block unstacked now, can add new instructions to loop_scope
try loop_scope.instructions.append(astgen.gpa, cond_block);
try loop_scope.instructions.append(gpa, cond_block);
// Increment the index variable and ranges.
for (counters) |counter_ptr| {
const counter = try loop_scope.addUnNode(.load, counter_ptr, for_full.ast.cond_expr);
const counter_plus_one = try loop_scope.addPlNode(.add, node, Zir.Inst.Bin{
.lhs = counter,
// Increment the loop variables.
for (allocs) |alloc, i| {
const incremented = try loop_scope.addPlNode(.add, node, Zir.Inst.Bin{
.lhs = loaded_ptrs[i],
.rhs = .one_usize,
});
_ = try loop_scope.addBin(.store, counter_ptr, counter_plus_one);
_ = try loop_scope.addBin(.store, alloc, incremented);
}
const repeat_tag: Zir.Inst.Tag = if (is_inline) .repeat_inline else .repeat;
_ = try loop_scope.addNode(repeat_tag, node);
@ -8960,6 +9000,25 @@ comptime {
}
}
fn nodeIsTriviallyZero(tree: *const Ast, node: Ast.Node.Index) bool {
const node_tags = tree.nodes.items(.tag);
const main_tokens = tree.nodes.items(.main_token);
switch (node_tags[node]) {
.number_literal => {
const ident = main_tokens[node];
return switch (std.zig.parseNumberLiteral(tree.tokenSlice(ident))) {
.int => |number| switch (number) {
0 => true,
else => false,
},
else => false,
};
},
else => return false,
}
}
fn nodeMayNeedMemoryLocation(tree: *const Ast, start_node: Ast.Node.Index, have_res_ty: bool) bool {
const node_tags = tree.nodes.items(.tag);
const node_datas = tree.nodes.items(.data);

View file

@ -1386,6 +1386,11 @@ fn analyzeBodyInner(
i += 1;
continue;
},
.for_check_lens => {
try sema.zirForCheckLens(block, inst);
i += 1;
continue;
},
// Special case instructions to handle comptime control flow.
.@"break" => {
@ -17096,6 +17101,16 @@ fn zirRestoreErrRetIndex(sema: *Sema, start_block: *Block, inst: Zir.Inst.Index)
return sema.popErrorReturnTrace(start_block, src, operand, saved_index);
}
fn zirForCheckLens(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!void {
const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
const extra = sema.code.extraData(Zir.Inst.MultiOp, inst_data.payload_index);
const args = sema.code.refSlice(extra.end, extra.data.operands_len);
const src = inst_data.src();
_ = args;
return sema.fail(block, src, "TODO implement zirForCheckLens", .{});
}
fn addToInferredErrorSet(sema: *Sema, uncasted_operand: Air.Inst.Ref) !void {
assert(sema.fn_ret_ty.zigTypeTag() == .ErrorUnion);

View file

@ -497,6 +497,15 @@ pub const Inst = struct {
/// Sends comptime control flow back to the beginning of the current block.
/// Uses the `node` field.
repeat_inline,
/// Asserts that all the lengths provided match. Used to build a for loop.
/// Return value is always void.
/// Uses the `pl_node` field with payload `MultiOp`.
/// There is exactly one item corresponding to each AST node inside the for
/// loop condition. Each item may be `none`, indicating an unbounded range.
/// Illegal behaviors:
/// * If all lengths are unbounded ranges (always a compile error).
/// * If any two lengths do not match each other.
for_check_lens,
/// Merge two error sets into one, `E1 || E2`.
/// Uses the `pl_node` field with payload `Bin`.
merge_error_sets,
@ -1242,6 +1251,7 @@ pub const Inst = struct {
.defer_err_code,
.save_err_ret_index,
.restore_err_ret_index,
.for_check_lens,
=> false,
.@"break",
@ -1309,6 +1319,7 @@ pub const Inst = struct {
.memcpy,
.memset,
.check_comptime_control_flow,
.for_check_lens,
.@"defer",
.defer_err_code,
.restore_err_ret_index,
@ -1588,6 +1599,7 @@ pub const Inst = struct {
.@"break" = .@"break",
.break_inline = .@"break",
.check_comptime_control_flow = .un_node,
.for_check_lens = .pl_node,
.call = .pl_node,
.cmp_lt = .pl_node,
.cmp_lte = .pl_node,

View file

@ -355,6 +355,8 @@ const Writer = struct {
.array_type,
=> try self.writePlNodeBin(stream, inst),
.for_check_lens => try self.writePlNodeMultiOp(stream, inst),
.elem_ptr_imm => try self.writeElemPtrImm(stream, inst),
.@"export" => try self.writePlNodeExport(stream, inst),
@ -868,6 +870,19 @@ const Writer = struct {
try self.writeSrc(stream, inst_data.src());
}
fn writePlNodeMultiOp(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
const inst_data = self.code.instructions.items(.data)[inst].pl_node;
const extra = self.code.extraData(Zir.Inst.MultiOp, inst_data.payload_index);
const args = self.code.refSlice(extra.end, extra.data.operands_len);
try stream.writeAll("{");
for (args) |arg, i| {
if (i != 0) try stream.writeAll(", ");
try self.writeInstRef(stream, arg);
}
try stream.writeAll("}) ");
try self.writeSrc(stream, inst_data.src());
}
fn writeElemPtrImm(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
const inst_data = self.code.instructions.items(.data)[inst].pl_node;
const extra = self.code.extraData(Zir.Inst.ElemPtrImm, inst_data.payload_index).data;