diff --git a/src/all_types.hpp b/src/all_types.hpp index d7642e45ff..64330acd30 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -87,6 +87,12 @@ struct TopLevelDecl { bool in_current_deps; }; +struct TypeEnumField { + Buf *name; + TypeTableEntry *type_entry; + uint32_t value; +}; + enum NodeType { NodeTypeRoot, NodeTypeRootExportDecl, @@ -316,6 +322,7 @@ struct AstNodeFieldAccessExpr { // populated by semantic analyzer TypeStructField *type_struct_field; + TypeEnumField *type_enum_field; Expr resolved_expr; }; @@ -680,11 +687,10 @@ struct TypeStructField { int src_index; int gen_index; }; - struct TypeTableEntryStruct { AstNode *decl_node; bool is_packed; - int field_count; + uint32_t field_count; TypeStructField *fields; uint64_t size_bytes; bool is_invalid; // true if any fields are invalid @@ -709,14 +715,9 @@ struct TypeTableEntryMetaType { TypeTableEntry *child_type; }; -struct TypeEnumField { - Buf *name; - TypeTableEntry *type_entry; -}; - struct TypeTableEntryEnum { AstNode *decl_node; - int field_count; + uint32_t field_count; TypeEnumField *fields; bool is_invalid; // true if any fields are invalid diff --git a/src/analyze.cpp b/src/analyze.cpp index 55a352d221..1ba102f754 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -176,7 +176,7 @@ static TypeTableEntry *get_int_type_unsigned(CodeGen *g, uint64_t x) { static TypeTableEntry *get_meta_type(CodeGen *g, TypeTableEntry *child_type) { if (child_type->meta_parent) { - return child_type->maybe_parent; + return child_type->meta_parent; } else { TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdMetaType); buf_resize(&entry->name, 0); @@ -705,7 +705,7 @@ static void resolve_enum_type(CodeGen *g, ImportTableEntry *import, TypeTableEnt assert(enum_type->di_type); - int field_count = decl_node->data.struct_decl.fields.length; + uint32_t field_count = decl_node->data.struct_decl.fields.length; enum_type->data.enumeration.field_count = field_count; enum_type->data.enumeration.fields = allocate(field_count); @@ -723,12 +723,13 @@ static void resolve_enum_type(CodeGen *g, ImportTableEntry *import, TypeTableEnt enum_type->data.enumeration.embedded_in_current = true; int gen_field_index = 0; - for (int i = 0; i < field_count; i += 1) { + for (uint32_t i = 0; i < field_count; i += 1) { AstNode *field_node = decl_node->data.struct_decl.fields.at(i); TypeEnumField *type_enum_field = &enum_type->data.enumeration.fields[i]; type_enum_field->name = &field_node->data.struct_field.name; type_enum_field->type_entry = resolve_type(g, field_node->data.struct_field.type, import, import->block_context, false); + type_enum_field->value = i; di_enumerators[i] = LLVMZigCreateDebugEnumerator(g->dbuilder, buf_ptr(type_enum_field->name), i); @@ -1496,6 +1497,16 @@ TypeTableEntry *find_container(BlockContext *context, Buf *name) { return nullptr; } +static TypeEnumField *get_enum_field(TypeTableEntry *enum_type, Buf *name) { + for (int i = 0; i < enum_type->data.enumeration.field_count; i += 1) { + TypeEnumField *type_enum_field = &enum_type->data.enumeration.fields[i]; + if (buf_eql_buf(type_enum_field->name, name)) { + return type_enum_field; + } + } + return nullptr; +} + static TypeStructField *get_struct_field(TypeTableEntry *struct_type, Buf *name) { for (int i = 0; i < struct_type->data.structure.field_count; i += 1) { TypeStructField *type_struct_field = &struct_type->data.structure.fields[i]; @@ -1506,13 +1517,46 @@ static TypeStructField *get_struct_field(TypeTableEntry *struct_type, Buf *name) return nullptr; } +static TypeTableEntry *analyze_enum_value_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, + AstNode *field_access_node, AstNode *value_node, TypeTableEntry *enum_type, Buf *field_name) +{ + TypeEnumField *type_enum_field = get_enum_field(enum_type, field_name); + field_access_node->data.field_access_expr.type_enum_field = type_enum_field; + if (type_enum_field) { + if (value_node) { + if (type_enum_field->type_entry->id == TypeTableEntryIdVoid) { + add_node_error(g, field_access_node, + buf_sprintf("enum value '%s.%s' has void parameter", + buf_ptr(&enum_type->name), + buf_ptr(field_name))); + + } else { + analyze_expression(g, import, context, type_enum_field->type_entry, value_node); + } + } else if (type_enum_field->type_entry->id == TypeTableEntryIdVoid) { + // OK + } else { + add_node_error(g, field_access_node, + buf_sprintf("enum value '%s.%s' requires parameter of type '%s'", + buf_ptr(&enum_type->name), + buf_ptr(field_name), + buf_ptr(&type_enum_field->type_entry->name))); + } + } else { + add_node_error(g, field_access_node, + buf_sprintf("no member named '%s' in '%s'", buf_ptr(field_name), + buf_ptr(&enum_type->name))); + } + return enum_type; +} + static TypeTableEntry *analyze_field_access_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, AstNode *node) { assert(node->type == NodeTypeFieldAccessExpr); - TypeTableEntry *struct_type = analyze_expression(g, import, context, nullptr, - node->data.field_access_expr.struct_expr); + AstNode *struct_expr_node = node->data.field_access_expr.struct_expr; + TypeTableEntry *struct_type = analyze_expression(g, import, context, nullptr, struct_expr_node); TypeTableEntry *return_type; @@ -1548,9 +1592,9 @@ static TypeTableEntry *analyze_field_access_expr(CodeGen *g, ImportTableEntry *i } else if (struct_type->id == TypeTableEntryIdMetaType && struct_type->data.meta_type.child_type->id == TypeTableEntryIdEnum) { - //TypeTableEntry *enum_type = struct_type->data.meta_type.child_type; - - zig_panic("TODO enum field access"); + TypeTableEntry *enum_type = struct_type->data.meta_type.child_type; + Buf *field_name = &node->data.field_access_expr.field_name; + return_type = analyze_enum_value_expr(g, import, context, node, nullptr, enum_type, field_name); } else { if (struct_type->id != TypeTableEntryIdInvalid) { add_node_error(g, node, diff --git a/src/codegen.cpp b/src/codegen.cpp index 79cefd311e..1a12f61c07 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -500,6 +500,15 @@ static LLVMValueRef gen_array_access_expr(CodeGen *g, AstNode *node, bool is_lva } } +static LLVMValueRef gen_enum_value_expr(CodeGen *g, AstNode *node, TypeTableEntry *enum_type) { + assert(node->type == NodeTypeFieldAccessExpr); + + uint64_t value = node->data.field_access_expr.type_enum_field->value; + LLVMTypeRef tag_type_ref = enum_type->type_ref; + + return LLVMConstInt(tag_type_ref, value, false); +} + static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node, bool is_lvalue) { assert(node->type == NodeTypeFieldAccessExpr); @@ -532,6 +541,12 @@ static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node, bool is_lva add_debug_source_node(g, node); return LLVMBuildLoad(g->builder, ptr, ""); } + } else if (struct_type->id == TypeTableEntryIdMetaType && + struct_type->data.meta_type.child_type->id == TypeTableEntryIdEnum) + { + assert(!is_lvalue); + TypeTableEntry *enum_type = struct_type->data.meta_type.child_type; + return gen_enum_value_expr(g, node, enum_type); } else { zig_panic("gen_field_access_expr bad struct type"); } @@ -875,11 +890,15 @@ static LLVMValueRef gen_cmp_expr(CodeGen *g, AstNode *node) { if (op1_type->id == TypeTableEntryIdFloat) { LLVMRealPredicate pred = cmp_op_to_real_predicate(node->data.bin_op_expr.bin_op); return LLVMBuildFCmp(g->builder, pred, val1, val2, ""); - } else { - assert(op1_type->id == TypeTableEntryIdInt); + } else if (op1_type->id == TypeTableEntryIdInt) { LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.bin_op_expr.bin_op, op1_type->data.integral.is_signed); return LLVMBuildICmp(g->builder, pred, val1, val2, ""); + } else if (op1_type->id == TypeTableEntryIdEnum) { + LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.bin_op_expr.bin_op, false); + return LLVMBuildICmp(g->builder, pred, val1, val2, ""); + } else { + zig_unreachable(); } } diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 6858cf1bbc..159124609e 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1031,6 +1031,19 @@ fn print_ok(val: #typeof(x)) -> #typeof(foo) { } const foo : i32 = 0; )SOURCE", "OK\n"); + + add_simple_case("enum with void types", R"SOURCE( +use "std.zig"; +enum Foo { A, B, C, D, } +pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { + const foo : Foo = Foo.B; + if (foo != Foo.B) { + print_str("BAD\n"); + } + print_str("OK\n"); + return 0; +} + )SOURCE", "OK\n"); }