diff --git a/src/all_types.hpp b/src/all_types.hpp index 64330acd3..88e31d65d 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -324,6 +324,7 @@ struct AstNodeFieldAccessExpr { TypeStructField *type_struct_field; TypeEnumField *type_enum_field; Expr resolved_expr; + StructValExprCodeGen resolved_struct_val_expr; // for enum values }; struct AstNodeExternBlock { @@ -718,8 +719,10 @@ struct TypeTableEntryMetaType { struct TypeTableEntryEnum { AstNode *decl_node; uint32_t field_count; + uint32_t gen_field_count; TypeEnumField *fields; bool is_invalid; // true if any fields are invalid + TypeTableEntry *tag_type; // reminder: hash tables must be initialized before use HashMap fn_table; @@ -916,6 +919,7 @@ struct CodeGen { ImportTableEntry *root_import; ImportTableEntry *bootstrap_import; LLVMValueRef memcpy_fn_val; + LLVMValueRef memset_fn_val; bool error_during_imports; }; diff --git a/src/analyze.cpp b/src/analyze.cpp index 1ba102f75..04090d39b 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -465,6 +465,8 @@ static TypeTableEntry *eval_const_expr(CodeGen *g, BlockContext *context, zig_panic("TODO eval_const_expr max_value"); } else if (buf_eql_str(name, "min_value")) { zig_panic("TODO eval_const_expr min_value"); + } else if (buf_eql_str(name, "value_count")) { + zig_panic("TODO eval_const_expr value_count"); } else { return g->builtin_types.entry_invalid; } @@ -767,10 +769,13 @@ static void resolve_enum_type(CodeGen *g, ImportTableEntry *import, TypeTableEnt enum_type->data.enumeration.embedded_in_current = false; if (!enum_type->data.enumeration.is_invalid) { - uint64_t tag_size_in_bits = get_number_literal_type_unsigned(g, field_count)->size_in_bits; + enum_type->data.enumeration.gen_field_count = gen_field_index; + + uint64_t tag_size_in_bits = num_lit_bit_count(get_number_literal_kind_unsigned(field_count)); enum_type->align_in_bits = tag_size_in_bits; enum_type->size_in_bits = tag_size_in_bits + biggest_union_member_size_in_bits; TypeTableEntry *tag_type_entry = get_int_type_unsigned(g, field_count); + enum_type->data.enumeration.tag_type = tag_type_entry; if (biggest_union_member) { // create llvm type for union @@ -1520,22 +1525,20 @@ static TypeStructField *get_struct_field(TypeTableEntry *struct_type, Buf *name) static TypeTableEntry *analyze_enum_value_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, AstNode *field_access_node, AstNode *value_node, TypeTableEntry *enum_type, Buf *field_name) { + assert(field_access_node->type == NodeTypeFieldAccessExpr); + TypeEnumField *type_enum_field = get_enum_field(enum_type, field_name); field_access_node->data.field_access_expr.type_enum_field = type_enum_field; + if (type_enum_field) { if (value_node) { - if (type_enum_field->type_entry->id == TypeTableEntryIdVoid) { - add_node_error(g, field_access_node, - buf_sprintf("enum value '%s.%s' has void parameter", - buf_ptr(&enum_type->name), - buf_ptr(field_name))); + analyze_expression(g, import, context, type_enum_field->type_entry, value_node); - } else { - analyze_expression(g, import, context, type_enum_field->type_entry, value_node); - } - } else if (type_enum_field->type_entry->id == TypeTableEntryIdVoid) { - // OK - } else { + StructValExprCodeGen *codegen = &field_access_node->data.field_access_expr.resolved_struct_val_expr; + codegen->type_entry = enum_type; + codegen->source_node = field_access_node; + context->struct_val_expr_alloca_list.append(codegen); + } else if (type_enum_field->type_entry->id != TypeTableEntryIdVoid) { add_node_error(g, field_access_node, buf_sprintf("enum value '%s.%s' requires parameter of type '%s'", buf_ptr(&enum_type->name), @@ -2295,7 +2298,8 @@ static TypeTableEntry *analyze_min_max_value(CodeGen *g, AstNode *node, TypeTabl { if (type_entry->id == TypeTableEntryIdInt || type_entry->id == TypeTableEntryIdFloat || - type_entry->id == TypeTableEntryIdBool) + type_entry->id == TypeTableEntryIdBool || + type_entry->id == TypeTableEntryIdInvalid) { return type_entry; } else { @@ -2314,15 +2318,38 @@ static TypeTableEntry *analyze_compiler_fn_type(CodeGen *g, ImportTableEntry *im TypeTableEntry *type_entry = resolve_type(g, node->data.compiler_fn_type.type, import, context, false); if (buf_eql_str(name, "sizeof")) { - uint64_t size_in_bytes = type_entry->size_in_bits / 8; + if (type_entry->id == TypeTableEntryIdInvalid) { + return type_entry; + } else if (type_entry->id == TypeTableEntryIdUnreachable) { + add_node_error(g, node, + buf_sprintf("no size available for type '%s'", buf_ptr(&type_entry->name))); + return g->builtin_types.entry_invalid; + } else { + uint64_t size_in_bytes = type_entry->size_in_bits / 8; - TypeTableEntry *num_lit_type = get_number_literal_type_unsigned(g, size_in_bytes); - TypeTableEntry *resolved_type = resolve_rhs_number_literal(g, nullptr, expected_type, node, num_lit_type); - return resolved_type ? resolved_type : num_lit_type; + TypeTableEntry *num_lit_type = get_number_literal_type_unsigned(g, size_in_bytes); + TypeTableEntry *resolved_type = resolve_rhs_number_literal(g, nullptr, expected_type, node, num_lit_type); + return resolved_type ? resolved_type : num_lit_type; + } } else if (buf_eql_str(name, "min_value")) { return analyze_min_max_value(g, node, type_entry, "no min value available for type '%s'"); } else if (buf_eql_str(name, "max_value")) { return analyze_min_max_value(g, node, type_entry, "no max value available for type '%s'"); + } else if (buf_eql_str(name, "value_count")) { + if (type_entry->id == TypeTableEntryIdInvalid) { + return type_entry; + } else if (type_entry->id == TypeTableEntryIdEnum) { + uint64_t value_count = type_entry->data.enumeration.field_count; + + TypeTableEntry *num_lit_type = get_number_literal_type_unsigned(g, value_count); + TypeTableEntry *resolved_type = resolve_rhs_number_literal(g, nullptr, expected_type, node, num_lit_type); + return resolved_type ? resolved_type : num_lit_type; + + } else { + add_node_error(g, node, + buf_sprintf("no value count available for type '%s'", buf_ptr(&type_entry->name))); + return g->builtin_types.entry_invalid; + } } else { add_node_error(g, node, buf_sprintf("invalid compiler function: '%s'", buf_ptr(name))); @@ -2451,7 +2478,24 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import } else if (struct_type->id == TypeTableEntryIdMetaType && struct_type->data.meta_type.child_type->id == TypeTableEntryIdEnum) { - zig_panic("TODO enum initialization"); + TypeTableEntry *enum_type = struct_type->data.meta_type.child_type; + Buf *field_name = &fn_ref_expr->data.field_access_expr.field_name; + int param_count = node->data.fn_call_expr.params.length; + if (param_count > 1) { + add_node_error(g, first_executing_node(node->data.fn_call_expr.params.at(1)), + buf_sprintf("enum values accept only one parameter")); + return enum_type; + } else { + AstNode *value_node; + if (param_count == 1) { + value_node = node->data.fn_call_expr.params.at(0); + } else { + value_node = nullptr; + } + + return analyze_enum_value_expr(g, import, context, fn_ref_expr, value_node, + enum_type, field_name); + } } else { add_node_error(g, fn_ref_expr->data.field_access_expr.struct_expr, buf_sprintf("member reference base type not struct or enum")); diff --git a/src/codegen.cpp b/src/codegen.cpp index 1a12f61c0..3f5daa21d 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -137,8 +137,13 @@ static LLVMValueRef find_or_create_string(CodeGen *g, Buf *str, bool c) { static TypeTableEntry *get_expr_type(AstNode *node) { Expr *expr = get_resolved_expr(node); - TypeTableEntry *cast_type = expr->implicit_cast.after_type; - return cast_type ? cast_type : expr->type_entry; + if (expr->implicit_maybe_cast.after_type) { + return expr->implicit_maybe_cast.after_type; + } + if (expr->implicit_cast.after_type) { + return expr->implicit_cast.after_type; + } + return expr->type_entry; } static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) { @@ -237,6 +242,51 @@ static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) { zig_unreachable(); } +static LLVMValueRef gen_enum_value_expr(CodeGen *g, AstNode *node, TypeTableEntry *enum_type, + AstNode *arg_node) +{ + assert(node->type == NodeTypeFieldAccessExpr); + + uint64_t value = node->data.field_access_expr.type_enum_field->value; + LLVMTypeRef tag_type_ref = enum_type->data.enumeration.tag_type->type_ref; + LLVMValueRef tag_value = LLVMConstInt(tag_type_ref, value, false); + + if (enum_type->data.enumeration.gen_field_count == 0) { + return tag_value; + } else { + TypeTableEntry *arg_node_type = nullptr; + LLVMValueRef new_union_val = gen_expr(g, arg_node); + if (arg_node) { + arg_node_type = get_expr_type(arg_node); + new_union_val = gen_expr(g, arg_node); + } else { + arg_node_type = g->builtin_types.entry_void; + } + + LLVMValueRef tmp_struct_ptr = node->data.field_access_expr.resolved_struct_val_expr.ptr; + + // populate the new tag value + add_debug_source_node(g, node); + LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 0, ""); + LLVMBuildStore(g->builder, tag_value, tag_field_ptr); + + if (arg_node_type->id != TypeTableEntryIdVoid) { + // populate the union value + TypeTableEntry *union_val_type = get_expr_type(arg_node); + LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 1, ""); + LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, + LLVMPointerType(union_val_type->type_ref, 0), ""); + + gen_assign_raw(g, arg_node, BinOpTypeAssign, bitcasted_union_field_ptr, new_union_val, + union_val_type, union_val_type); + + } + + return tmp_struct_ptr; + } +} + + static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeFnCallExpr); @@ -253,6 +303,19 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { } else if (struct_type->id == TypeTableEntryIdPointer) { assert(struct_type->data.pointer.child_type->id == TypeTableEntryIdStruct); fn_table_entry = struct_type->data.pointer.child_type->data.structure.fn_table.get(name); + } else if (struct_type->id == TypeTableEntryIdMetaType && + struct_type->data.meta_type.child_type->id == TypeTableEntryIdEnum) + { + TypeTableEntry *enum_type = struct_type->data.meta_type.child_type; + int param_count = node->data.fn_call_expr.params.length; + AstNode *arg1_node; + if (param_count == 1) { + arg1_node = node->data.fn_call_expr.params.at(0); + } else { + assert(param_count == 0); + arg1_node = nullptr; + } + return gen_enum_value_expr(g, fn_ref_expr, enum_type, arg1_node); } else { zig_unreachable(); } @@ -500,15 +563,6 @@ static LLVMValueRef gen_array_access_expr(CodeGen *g, AstNode *node, bool is_lva } } -static LLVMValueRef gen_enum_value_expr(CodeGen *g, AstNode *node, TypeTableEntry *enum_type) { - assert(node->type == NodeTypeFieldAccessExpr); - - uint64_t value = node->data.field_access_expr.type_enum_field->value; - LLVMTypeRef tag_type_ref = enum_type->type_ref; - - return LLVMConstInt(tag_type_ref, value, false); -} - static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node, bool is_lvalue) { assert(node->type == NodeTypeFieldAccessExpr); @@ -546,7 +600,7 @@ static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node, bool is_lva { assert(!is_lvalue); TypeTableEntry *enum_type = struct_type->data.meta_type.child_type; - return gen_enum_value_expr(g, node, enum_type); + return gen_enum_value_expr(g, node, enum_type, nullptr); } else { zig_panic("gen_field_access_expr bad struct type"); } @@ -968,7 +1022,9 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) { static LLVMValueRef gen_struct_memcpy(CodeGen *g, AstNode *source_node, LLVMValueRef src, LLVMValueRef dest, TypeTableEntry *type_entry) { - assert(type_entry->id == TypeTableEntryIdStruct || type_entry->id == TypeTableEntryIdMaybe); + assert(type_entry->id == TypeTableEntryIdStruct || + type_entry->id == TypeTableEntryIdMaybe || + (type_entry->id == TypeTableEntryIdEnum && type_entry->data.enumeration.gen_field_count != 0)); LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0); @@ -991,8 +1047,13 @@ static LLVMValueRef gen_assign_raw(CodeGen *g, AstNode *source_node, BinOpType b LLVMValueRef target_ref, LLVMValueRef value, TypeTableEntry *op1_type, TypeTableEntry *op2_type) { - if (op1_type->id == TypeTableEntryIdStruct) { - assert(op2_type->id == TypeTableEntryIdStruct); + if (op1_type->id == TypeTableEntryIdStruct || + (op1_type->id == TypeTableEntryIdEnum && op1_type->data.enumeration.gen_field_count != 0) || + op1_type->id == TypeTableEntryIdMaybe) + { + assert(op2_type->id == TypeTableEntryIdStruct || + (op2_type->id == TypeTableEntryIdEnum && op2_type->data.enumeration.gen_field_count != 0) || + op2_type->id == TypeTableEntryIdMaybe); assert(op1_type == op2_type); assert(bin_op == BinOpTypeAssign); @@ -1546,32 +1607,48 @@ static LLVMValueRef gen_var_decl_raw(CodeGen *g, AstNode *source_node, AstNodeVa if (var_decl->expr) { *init_value = gen_expr(g, var_decl->expr); - } else { - *init_value = LLVMConstNull(variable->type->type_ref); } if (variable->type->id == TypeTableEntryIdVoid) { return nullptr; } else { - LLVMValueRef store_instr; - LLVMValueRef value; - if (unwrap_maybe) { - assert(var_decl->expr); - value = gen_unwrap_maybe(g, source_node, *init_value); - } else { - value = *init_value; - } - if ((variable->type->id == TypeTableEntryIdStruct || variable->type->id == TypeTableEntryIdMaybe) && - var_decl->expr) - { - store_instr = gen_struct_memcpy(g, source_node, value, variable->value_ref, variable->type); - } else { + if (var_decl->expr) { + TypeTableEntry *expr_type = get_expr_type(var_decl->expr); + LLVMValueRef value; + if (unwrap_maybe) { + assert(var_decl->expr); + assert(expr_type->id == TypeTableEntryIdMaybe); + value = gen_unwrap_maybe(g, source_node, *init_value); + expr_type = expr_type->data.maybe.child_type; + } else { + value = *init_value; + } + gen_assign_raw(g, var_decl->expr, BinOpTypeAssign, variable->value_ref, + value, variable->type, expr_type); + } else if (g->build_type != CodeGenBuildTypeRelease) { + // memset uninitialized memory to 0xa add_debug_source_node(g, source_node); - store_instr = LLVMBuildStore(g->builder, value, variable->value_ref); + LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0); + LLVMValueRef fill_char = LLVMConstInt(LLVMInt8Type(), 0xaa, false); + LLVMValueRef dest_ptr = LLVMBuildBitCast(g->builder, variable->value_ref, ptr_u8, ""); + LLVMValueRef byte_count = LLVMConstInt(LLVMIntType(g->pointer_size_bytes * 8), + variable->type->size_in_bits / 8, false); + LLVMValueRef align_in_bytes = LLVMConstInt(LLVMInt32Type(), + variable->type->align_in_bits / 8, false); + LLVMValueRef params[] = { + dest_ptr, + fill_char, + byte_count, + align_in_bytes, + LLVMConstNull(LLVMInt1Type()), // is volatile + }; + + LLVMBuildCall(g->builder, g->memset_fn_val, params, 5, ""); } LLVMZigDILocation *debug_loc = LLVMZigGetDebugLoc(source_node->line + 1, source_node->column + 1, g->cur_block_context->di_scope); - LLVMZigInsertDeclare(g->dbuilder, variable->value_ref, variable->di_loc_var, debug_loc, store_instr); + LLVMZigInsertDeclareAtEnd(g->dbuilder, variable->value_ref, variable->di_loc_var, debug_loc, + LLVMGetInsertBlock(g->builder)); return nullptr; } } @@ -1644,6 +1721,17 @@ static LLVMValueRef gen_compiler_fn_type(CodeGen *g, AstNode *node) { } else { zig_unreachable(); } + } else if (buf_eql_str(name, "value_count")) { + if (type_entry->id == TypeTableEntryIdEnum) { + NumLitCodeGen *codegen_num_lit = get_resolved_num_lit(node); + AstNodeNumberLiteral num_lit_node; + num_lit_node.kind = type_entry->data.num_lit.kind; + num_lit_node.overflow = false; + num_lit_node.data.x_uint = type_entry->data.enumeration.field_count; + return gen_number_literal_raw(g, node, codegen_num_lit, &num_lit_node); + } else { + zig_unreachable(); + } } else { zig_unreachable(); } @@ -2112,6 +2200,7 @@ static void define_builtin_types(CodeGen *g) { buf_resize(&entry->name, 0); buf_appendf(&entry->name, "(%s literal)", num_lit_str(num_lit_kind)); entry->data.num_lit.kind = num_lit_kind; + entry->size_in_bits = num_lit_bit_count(num_lit_kind); g->num_lit_types[i] = entry; } @@ -2377,10 +2466,10 @@ static void define_builtin_fns(CodeGen *g) { }; LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false); Buf *name = buf_sprintf("llvm.memcpy.p0i8.p0i8.i%d", g->pointer_size_bytes * 8); - g->memcpy_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type); - builtin_fn->fn_val = g->memcpy_fn_val; - assert(LLVMGetIntrinsicID(g->memcpy_fn_val)); + builtin_fn->fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type); + assert(LLVMGetIntrinsicID(builtin_fn->fn_val)); + g->memcpy_fn_val = builtin_fn->fn_val; g->builtin_fn_table.put(&builtin_fn->name, builtin_fn); } { @@ -2406,6 +2495,7 @@ static void define_builtin_fns(CodeGen *g) { builtin_fn->fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type); assert(LLVMGetIntrinsicID(builtin_fn->fn_val)); + g->memset_fn_val = builtin_fn->fn_val; g->builtin_fn_table.put(&builtin_fn->name, builtin_fn); } } @@ -2658,7 +2748,8 @@ void codegen_add_root_code(CodeGen *g, Buf *src_dir, Buf *src_basename, Buf *sou g->bootstrap_import = add_special_code(g, "bootstrap.zig"); } - add_special_code(g, "builtin.zig"); + // TODO re-enable this + //add_special_code(g, "builtin.zig"); } if (g->verbose) { diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 159124609..6955b6c0e 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -355,7 +355,7 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { use "std.zig"; pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { - var zero : i32; + var zero : i32 = 0; if (zero == 0) { print_str("zero\n"); } var i = 0 as i32; @@ -619,6 +619,7 @@ use "std.zig"; pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 { var foo : Foo; + @memset(&foo, 0, #sizeof(Foo)); foo.a += 1; foo.b = foo.a == 1; test_foo(foo); @@ -689,7 +690,7 @@ fn test_initializer() { use "std.zig"; const g1 : i32 = 1233 + 1; -var g2 : i32; +var g2 : i32 = 0; pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 { if (g2 != 0) { print_str("BAD\n"); } @@ -1044,6 +1045,53 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { return 0; } )SOURCE", "OK\n"); + + add_simple_case("enum type", R"SOURCE( +use "std.zig"; + +struct Point { + x: u64, + y: u64, +} + +enum Foo { + One: i32, + Two: Point, + Three: void, +} + +enum Bar { + A, + B, + C, + D, +} + +pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { + const foo1 = Foo.One(13); + const foo2 = Foo.Two(Point { .x = 1234, .y = 5678, }); + const bar = Bar.A; + + if (#value_count(Foo) != 3) { + print_str("BAD\n"); + } + + if (#value_count(Bar) != 4) { + print_str("BAD\n"); + } + + if (#sizeof(Foo) != 17) { + print_str("BAD\n"); + } + if (#sizeof(Bar) != 1) { + print_str("BAD\n"); + } + + print_str("OK\n"); + + return 0; +} + )SOURCE", "OK\n"); }