diff --git a/src/all_types.hpp b/src/all_types.hpp index f2f52403d..dc701e293 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -726,6 +726,19 @@ struct FnTypeParamInfo { TypeTableEntry *type; }; +struct GenericParamValue { + TypeTableEntry *type; + ConstExprValue *value; +}; + +struct GenericFnTypeId { + FnTableEntry *fn_entry; + GenericParamValue *params; + size_t param_count; +}; + +uint32_t generic_fn_type_id_hash(GenericFnTypeId *id); +bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b); struct FnTypeId { TypeTableEntry *return_type; @@ -957,7 +970,6 @@ struct FnTableEntry { ScopeFnDef *fndef_scope; // parent should be the top level decls or container decls Scope *child_scope; // parent is scope for last parameter ScopeBlock *def_scope; // parent is child_scope - ImportTableEntry *import_entry; Buf symbol_name; TypeTableEntry *type_entry; // function type TypeTableEntry *implicit_return_type; @@ -969,6 +981,7 @@ struct FnTableEntry { IrExecutable ir_executable; IrExecutable analyzed_executable; size_t prealloc_bbc; + AstNode **param_source_nodes; AstNode *fn_no_inline_set_node; AstNode *fn_export_set_node; @@ -1050,6 +1063,7 @@ struct CodeGen { HashMap primitive_type_table; HashMap fn_type_table; HashMap error_table; + HashMap generic_table; ZigList import_queue; size_t import_queue_index; @@ -1201,7 +1215,6 @@ struct VariableTableEntry { Scope *parent_scope; Scope *child_scope; LLVMValueRef param_value_ref; - bool force_depends_on_compile_var; bool shadowable; size_t mem_slot_index; size_t ref_count; diff --git a/src/analyze.cpp b/src/analyze.cpp index 041fd9378..87b713b29 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -907,22 +907,25 @@ static TypeTableEntry *get_generic_fn_type(CodeGen *g, FnTypeId *fn_type_id) { return fn_type; } -static TypeTableEntry *analyze_fn_type(CodeGen *g, TldFn *tld_fn) { - AstNode *proto_node = tld_fn->base.source_node; +void init_fn_type_id(FnTypeId *fn_type_id, AstNode *proto_node) { + assert(proto_node->type == NodeTypeFnProto); + AstNodeFnProto *fn_proto = &proto_node->data.fn_proto; + + fn_type_id->is_extern = fn_proto->is_extern || (fn_proto->visib_mod == VisibModExport); + fn_type_id->is_naked = fn_proto->is_nakedcc; + fn_type_id->is_cold = fn_proto->is_coldcc; + fn_type_id->param_count = fn_proto->params.length; + fn_type_id->param_info = allocate_nonzero(fn_type_id->param_count); + fn_type_id->next_param_index = 0; + fn_type_id->is_var_args = fn_proto->is_var_args; +} + +static TypeTableEntry *analyze_fn_type(CodeGen *g, AstNode *proto_node, Scope *child_scope) { assert(proto_node->type == NodeTypeFnProto); AstNodeFnProto *fn_proto = &proto_node->data.fn_proto; FnTypeId fn_type_id = {0}; - fn_type_id.is_extern = fn_proto->is_extern || (fn_proto->visib_mod == VisibModExport); - fn_type_id.is_naked = fn_proto->is_nakedcc; - fn_type_id.is_cold = fn_proto->is_coldcc; - fn_type_id.param_count = fn_proto->params.length; - fn_type_id.param_info = allocate_nonzero(fn_type_id.param_count); - fn_type_id.next_param_index = 0; - fn_type_id.is_var_args = fn_proto->is_var_args; - - FnTableEntry *fn_entry = tld_fn->fn_entry; - Scope *child_scope = fn_entry->fndef_scope ? &fn_entry->fndef_scope->base : tld_fn->base.parent_scope; + init_fn_type_id(&fn_type_id, proto_node); for (; fn_type_id.next_param_index < fn_type_id.param_count; fn_type_id.next_param_index += 1) { AstNode *param_node = fn_proto->params.at(fn_type_id.next_param_index); @@ -939,10 +942,6 @@ static TypeTableEntry *analyze_fn_type(CodeGen *g, TldFn *tld_fn) { return get_generic_fn_type(g, &fn_type_id); } - if (fn_entry && buf_len(param_node->data.param_decl.name) == 0) { - add_node_error(g, param_node, buf_sprintf("missing parameter name")); - } - TypeTableEntry *type_entry = analyze_type_expr(g, child_scope, param_node->data.param_decl.type); switch (type_entry->id) { @@ -1366,6 +1365,23 @@ static void get_fully_qualified_decl_name(Buf *buf, Tld *tld, uint8_t sep) { } } +FnTableEntry *create_fn(CodeGen *g, AstNode *proto_node) { + assert(proto_node->type == NodeTypeFnProto); + AstNodeFnProto *fn_proto = &proto_node->data.fn_proto; + + FnTableEntry *fn_table_entry = allocate(1); + fn_table_entry->analyzed_executable.backward_branch_count = &fn_table_entry->prealloc_bbc; + fn_table_entry->analyzed_executable.backward_branch_quota = default_backward_branch_quota; + fn_table_entry->analyzed_executable.fn_entry = fn_table_entry; + fn_table_entry->ir_executable.fn_entry = fn_table_entry; + fn_table_entry->proto_node = proto_node; + fn_table_entry->fn_def_node = proto_node->data.fn_proto.fn_def_node; + fn_table_entry->fn_inline = fn_proto->is_inline ? FnInlineAlways : FnInlineAuto; + fn_table_entry->internal_linkage = (fn_proto->visib_mod != VisibModExport); + + return fn_table_entry; +} + static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) { ImportTableEntry *import = tld_fn->base.import; AstNode *proto_node = tld_fn->base.source_node; @@ -1381,17 +1397,7 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) { return; } - FnTableEntry *fn_table_entry = allocate(1); - fn_table_entry->analyzed_executable.backward_branch_count = &fn_table_entry->prealloc_bbc; - fn_table_entry->analyzed_executable.backward_branch_quota = default_backward_branch_quota; - fn_table_entry->analyzed_executable.fn_entry = fn_table_entry; - fn_table_entry->ir_executable.fn_entry = fn_table_entry; - fn_table_entry->import_entry = import; - fn_table_entry->proto_node = proto_node; - fn_table_entry->fn_def_node = fn_def_node; - fn_table_entry->fn_inline = fn_proto->is_inline ? FnInlineAlways : FnInlineAuto; - fn_table_entry->internal_linkage = (fn_proto->visib_mod != VisibModExport); - + FnTableEntry *fn_table_entry = create_fn(g, tld_fn->base.source_node); get_fully_qualified_decl_name(&fn_table_entry->symbol_name, &tld_fn->base, '_'); tld_fn->fn_entry = fn_table_entry; @@ -1399,9 +1405,18 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) { if (fn_table_entry->fn_def_node) { fn_table_entry->fndef_scope = create_fndef_scope( fn_table_entry->fn_def_node, tld_fn->base.parent_scope, fn_table_entry); + + for (size_t i = 0; i < fn_proto->params.length; i += 1) { + AstNode *param_node = fn_proto->params.at(i); + assert(param_node->type == NodeTypeParamDecl); + if (buf_len(param_node->data.param_decl.name) == 0) { + add_node_error(g, param_node, buf_sprintf("missing parameter name")); + } + } } - fn_table_entry->type_entry = analyze_fn_type(g, tld_fn); + Scope *child_scope = fn_table_entry->fndef_scope ? &fn_table_entry->fndef_scope->base : tld_fn->base.parent_scope; + fn_table_entry->type_entry = analyze_fn_type(g, proto_node, child_scope); if (fn_table_entry->type_entry->id == TypeTableEntryIdInvalid) { tld_fn->base.resolution = TldResolutionInvalid; @@ -2142,6 +2157,13 @@ bool type_is_codegen_pointer(TypeTableEntry *type) { return false; } +AstNode *get_param_decl_node(FnTableEntry *fn_entry, size_t index) { + if (fn_entry->param_source_nodes) + return fn_entry->param_source_nodes[index]; + else + return fn_entry->proto_node->data.fn_proto.params.at(index); +} + static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) { assert(fn_table_entry->anal_state != FnAnalStateProbing); if (fn_table_entry->anal_state != FnAnalStateReady) @@ -2151,17 +2173,19 @@ static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) { AstNodeFnProto *fn_proto = &fn_table_entry->proto_node->data.fn_proto; - Scope *child_scope = &fn_table_entry->fndef_scope->base; - assert(child_scope); + assert(fn_table_entry->fndef_scope); + if (!fn_table_entry->child_scope) + fn_table_entry->child_scope = &fn_table_entry->fndef_scope->base; // define local variables for parameters TypeTableEntry *fn_type = fn_table_entry->type_entry; assert(!fn_type->data.fn.is_generic); FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id; for (size_t i = 0; i < fn_type_id->param_count; i += 1) { - AstNode *param_decl_node = fn_proto->params.at(i); - AstNodeParamDecl *param_decl = ¶m_decl_node->data.param_decl; FnTypeParamInfo *param_info = &fn_type_id->param_info[i]; + AstNode *param_decl_node = get_param_decl_node(fn_table_entry, i); + AstNodeParamDecl *param_decl = ¶m_decl_node->data.param_decl; + TypeTableEntry *param_type = param_info->type; bool is_noalias = param_info->is_noalias; @@ -2175,9 +2199,9 @@ static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) { buf_sprintf("byvalue types not yet supported on extern function parameters")); } - VariableTableEntry *var = add_variable(g, param_decl_node, child_scope, param_decl->name, param_type, true, nullptr); + VariableTableEntry *var = add_variable(g, param_decl_node, fn_table_entry->child_scope, param_decl->name, param_type, true, nullptr); var->src_arg_index = i; - child_scope = var->child_scope; + fn_table_entry->child_scope = var->child_scope; fn_table_entry->variable_list.append(var); if (fn_type->data.fn.gen_param_info) { @@ -2185,8 +2209,6 @@ static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) { } } - fn_table_entry->child_scope = child_scope; - TypeTableEntry *expected_type = fn_type_id->return_type; if (fn_type_id->is_extern && handle_is_ptr(expected_type)) { @@ -2622,6 +2644,40 @@ static uint32_t hash_const_val(TypeTableEntry *type, ConstExprValue *const_val) zig_unreachable(); } +uint32_t generic_fn_type_id_hash(GenericFnTypeId *id) { + uint32_t result = 0; + result += hash_ptr(id->fn_entry); + for (size_t i = 0; i < id->param_count; i += 1) { + GenericParamValue *generic_param = &id->params[i]; + if (generic_param->value) { + result += hash_const_val(generic_param->type, generic_param->value); + result += hash_ptr(generic_param->type); + } + } + return result; +} + +bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b) { + assert(a->fn_entry); + if (a->fn_entry != b->fn_entry) return false; + assert(a->param_count == b->param_count); + for (size_t i = 0; i < a->param_count; i += 1) { + GenericParamValue *a_val = &a->params[i]; + GenericParamValue *b_val = &b->params[i]; + if (a_val->type != b_val->type) return false; + if (a_val->value && b_val->value) { + assert(a_val->value->special == ConstValSpecialStatic); + assert(b_val->value->special == ConstValSpecialStatic); + if (!const_values_equal(a_val->value, b_val->value, a_val->type)) { + return false; + } + } else { + assert(!a_val->value && !b_val->value); + } + } + return true; +} + bool type_has_bits(TypeTableEntry *type_entry) { assert(type_entry); assert(type_entry->id != TypeTableEntryIdInvalid); diff --git a/src/analyze.hpp b/src/analyze.hpp index 677e8c130..ec487ecc1 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -70,6 +70,9 @@ void init_tld(Tld *tld, TldId id, Buf *name, VisibMod visib_mod, AstNode *source VariableTableEntry *add_variable(CodeGen *g, AstNode *source_node, Scope *parent_scope, Buf *name, TypeTableEntry *type_entry, bool is_const, ConstExprValue *init_value); TypeTableEntry *analyze_type_expr(CodeGen *g, Scope *scope, AstNode *node); +FnTableEntry *create_fn(CodeGen *g, AstNode *proto_node); +void init_fn_type_id(FnTypeId *fn_type_id, AstNode *proto_node); +AstNode *get_param_decl_node(FnTableEntry *fn_entry, size_t index); Scope *create_block_scope(AstNode *node, Scope *parent); Scope *create_defer_scope(AstNode *node, Scope *parent); diff --git a/src/ast_render.cpp b/src/ast_render.cpp index ddcb842c8..236213bf9 100644 --- a/src/ast_render.cpp +++ b/src/ast_render.cpp @@ -397,7 +397,7 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) { assert(param_decl->type == NodeTypeParamDecl); if (buf_len(param_decl->data.param_decl.name) > 0) { const char *noalias_str = param_decl->data.param_decl.is_noalias ? "noalias " : ""; - const char *inline_str = param_decl->data.param_decl.is_inline ? "inline " : ""; + const char *inline_str = param_decl->data.param_decl.is_inline ? "inline " : ""; fprintf(ar->f, "%s%s", noalias_str, inline_str); print_symbol(ar, param_decl->data.param_decl.name); fprintf(ar->f, ": "); diff --git a/src/codegen.cpp b/src/codegen.cpp index af62f57eb..6b872e877 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -60,6 +60,7 @@ CodeGen *codegen_create(Buf *root_source_dir, const ZigTarget *target) { g->primitive_type_table.init(32); g->fn_type_table.init(32); g->error_table.init(16); + g->generic_table.init(16); g->is_release_build = false; g->is_test_build = false; g->want_h_file = true; @@ -2305,11 +2306,8 @@ static void do_code_gen(CodeGen *g) { if (should_skip_fn_codegen(g, fn_table_entry)) continue; - AstNode *proto_node = fn_table_entry->proto_node; - assert(proto_node->type == NodeTypeFnProto); - AstNodeFnProto *fn_proto = &proto_node->data.fn_proto; - TypeTableEntry *fn_type = fn_table_entry->type_entry; + FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id; LLVMValueRef fn_val = fn_llvm_value(g, fn_table_entry); @@ -2327,22 +2325,20 @@ static void do_code_gen(CodeGen *g) { // set parameter attributes - for (size_t param_decl_i = 0; param_decl_i < fn_proto->params.length; param_decl_i += 1) { - AstNode *param_node = fn_proto->params.at(param_decl_i); - assert(param_node->type == NodeTypeParamDecl); - - FnGenParamInfo *info = &fn_type->data.fn.gen_param_info[param_decl_i]; - size_t gen_index = info->gen_index; - bool is_byval = info->is_byval; + for (size_t param_i = 0; param_i < fn_type_id->param_count; param_i += 1) { + FnGenParamInfo *gen_info = &fn_type->data.fn.gen_param_info[param_i]; + size_t gen_index = gen_info->gen_index; + bool is_byval = gen_info->is_byval; if (gen_index == SIZE_MAX) { continue; } - TypeTableEntry *param_type = info->type; + FnTypeParamInfo *param_info = &fn_type_id->param_info[param_i]; + + TypeTableEntry *param_type = gen_info->type; LLVMValueRef argument_val = LLVMGetParam(fn_val, gen_index); - bool param_is_noalias = param_node->data.param_decl.is_noalias; - if (param_is_noalias) { + if (param_info->is_noalias) { LLVMAddAttribute(argument_val, LLVMNoAliasAttribute); } if ((param_type->id == TypeTableEntryIdPointer && param_type->data.pointer.is_const) || is_byval) { @@ -2402,7 +2398,6 @@ static void do_code_gen(CodeGen *g) { if (should_skip_fn_codegen(g, fn_table_entry)) continue; - ImportTableEntry *import = fn_table_entry->import_entry; LLVMValueRef fn = fn_llvm_value(g, fn_table_entry); g->cur_fn = fn_table_entry; g->cur_fn_val = fn; @@ -2412,10 +2407,6 @@ static void do_code_gen(CodeGen *g) { g->cur_ret_ptr = nullptr; } - AstNode *proto_node = fn_table_entry->proto_node; - assert(proto_node->type == NodeTypeFnProto); - AstNodeFnProto *fn_proto = &proto_node->data.fn_proto; - build_all_basic_blocks(g, fn_table_entry); clear_debug_source_node(g); @@ -2444,6 +2435,8 @@ static void do_code_gen(CodeGen *g) { *slot = LLVMBuildAlloca(g->builder, instruction->type_entry->type_ref, ""); } + ImportTableEntry *import = get_scope_import(&fn_table_entry->fndef_scope->base); + // create debug variable declarations for variables and allocate all local variables for (size_t var_i = 0; var_i < fn_table_entry->variable_list.length; var_i += 1) { VariableTableEntry *var = fn_table_entry->variable_list.at(var_i); @@ -2484,10 +2477,12 @@ static void do_code_gen(CodeGen *g) { } } + FnTypeId *fn_type_id = &fn_table_entry->type_entry->data.fn.fn_type_id; + // create debug variable declarations for parameters // rely on the first variables in the variable_list being parameters. size_t next_var_i = 0; - for (size_t param_i = 0; param_i < fn_proto->params.length; param_i += 1) { + for (size_t param_i = 0; param_i < fn_type_id->param_count; param_i += 1) { FnGenParamInfo *info = &fn_table_entry->type_entry->data.fn.gen_param_info[param_i]; if (info->gen_index == SIZE_MAX) continue; @@ -3392,14 +3387,12 @@ void codegen_generate_h_file(CodeGen *g) { buf_resize(&h_buf, 0); for (size_t fn_def_i = 0; fn_def_i < g->fn_defs.length; fn_def_i += 1) { FnTableEntry *fn_table_entry = g->fn_defs.at(fn_def_i); - AstNode *proto_node = fn_table_entry->proto_node; - assert(proto_node->type == NodeTypeFnProto); - AstNodeFnProto *fn_proto = &proto_node->data.fn_proto; - if (fn_proto->visib_mod != VisibModExport) + if (fn_table_entry->internal_linkage) continue; FnTypeId *fn_type_id = &fn_table_entry->type_entry->data.fn.fn_type_id; + Buf return_type_c = BUF_INIT; get_c_type(g, fn_type_id->return_type, &return_type_c); @@ -3412,7 +3405,7 @@ void codegen_generate_h_file(CodeGen *g) { if (fn_type_id->param_count > 0) { for (size_t param_i = 0; param_i < fn_type_id->param_count; param_i += 1) { FnTypeParamInfo *param_info = &fn_type_id->param_info[param_i]; - AstNode *param_decl_node = fn_proto->params.at(param_i); + AstNode *param_decl_node = get_param_decl_node(fn_table_entry, param_i); Buf *param_name = param_decl_node->data.param_decl.name; const char *comma_str = (param_i == 0) ? "" : ", "; diff --git a/src/ir.cpp b/src/ir.cpp index a0d0b1f86..f354b5a5e 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -2825,9 +2825,8 @@ IrInstruction *ir_gen_fn(CodeGen *codegen, FnTableEntry *fn_entry) { AstNode *body_node = fn_def_node->data.fn_def.body; assert(fn_entry->child_scope); - Scope *child_scope = fn_entry->child_scope; - return ir_gen(codegen, body_node, child_scope, ir_executable); + return ir_gen(codegen, body_node, fn_entry->child_scope, ir_executable); } static ErrorMsg *ir_add_error(IrAnalyze *ira, IrInstruction *source_instruction, Buf *msg) { @@ -4220,9 +4219,9 @@ static TypeTableEntry *ir_analyze_instruction_decl_var(IrAnalyze *ira, IrInstruc } static bool ir_analyze_fn_call_inline_arg(IrAnalyze *ira, AstNode *fn_proto_node, - IrInstruction *arg, Scope **exec_scope, size_t *next_arg_index) + IrInstruction *arg, Scope **exec_scope, size_t *next_proto_i) { - AstNode *param_decl_node = fn_proto_node->data.fn_proto.params.at(*next_arg_index); + AstNode *param_decl_node = fn_proto_node->data.fn_proto.params.at(*next_proto_i); assert(param_decl_node->type == NodeTypeParamDecl); AstNode *param_type_node = param_decl_node->data.param_decl.type; TypeTableEntry *param_type = analyze_type_expr(ira->codegen, *exec_scope, param_type_node); @@ -4241,14 +4240,66 @@ static bool ir_analyze_fn_call_inline_arg(IrAnalyze *ira, AstNode *fn_proto_node VariableTableEntry *var = add_variable(ira->codegen, param_decl_node, *exec_scope, param_name, param_type, true, first_arg_val); *exec_scope = var->child_scope; - *next_arg_index += 1; + *next_proto_i += 1; return true; } +static bool ir_analyze_fn_call_generic_arg(IrAnalyze *ira, AstNode *fn_proto_node, + IrInstruction *arg, Scope **child_scope, size_t *next_proto_i, + GenericFnTypeId *generic_id, FnTypeId *fn_type_id, IrInstruction **casted_args, + FnTableEntry *impl_fn) +{ + AstNode *param_decl_node = fn_proto_node->data.fn_proto.params.at(*next_proto_i); + assert(param_decl_node->type == NodeTypeParamDecl); + AstNode *param_type_node = param_decl_node->data.param_decl.type; + TypeTableEntry *param_type = analyze_type_expr(ira->codegen, *child_scope, param_type_node); + if (param_type->id == TypeTableEntryIdInvalid) + return false; + + bool is_var_type = (param_type->id == TypeTableEntryIdVar); + IrInstruction *casted_arg; + if (is_var_type) { + casted_arg = arg; + } else { + casted_arg = ir_get_casted_value(ira, arg, param_type); + if (casted_arg->type_entry->id == TypeTableEntryIdInvalid) + return false; + } + + bool inline_arg = param_decl_node->data.param_decl.is_inline; + if (inline_arg || is_var_type) { + ConstExprValue *arg_val = ir_resolve_const(ira, casted_arg); + if (!arg_val) + return false; + + Buf *param_name = param_decl_node->data.param_decl.name; + VariableTableEntry *var = add_variable(ira->codegen, param_decl_node, + *child_scope, param_name, param_type, true, arg_val); + *child_scope = var->child_scope; + // This generic function instance could be called with anything, so when this variable is read it + // needs to know that it depends on compile time variable data. + var->value->depends_on_compile_var = true; + + GenericParamValue *generic_param = &generic_id->params[generic_id->param_count]; + generic_param->type = casted_arg->type_entry; + generic_param->value = arg_val; + generic_id->param_count += 1; + } else { + casted_args[fn_type_id->param_count] = casted_arg; + FnTypeParamInfo *param_info = &fn_type_id->param_info[fn_type_id->param_count]; + param_info->type = param_type; + param_info->is_noalias = param_decl_node->data.param_decl.is_noalias; + impl_fn->param_source_nodes[fn_type_id->param_count] = param_decl_node; + fn_type_id->param_count += 1; + } + *next_proto_i += 1; + return true; +} + static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *call_instruction, FnTableEntry *fn_entry, TypeTableEntry *fn_type, IrInstruction *fn_ref, - IrInstruction *first_arg_ptr, bool is_inline) + IrInstruction *first_arg_ptr, bool inline_fn_call) { FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id; size_t first_arg_1_or_0 = first_arg_ptr ? 1 : 0; @@ -4278,7 +4329,8 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal return ira->codegen->builtin_types.entry_invalid; } - if (is_inline) { + if (inline_fn_call) { + // No special handling is needed for compile time evaluation of generic functions. if (!fn_entry) { ir_add_error(ira, fn_ref, buf_sprintf("unable to evaluate constant expression")); return ira->codegen->builtin_types.entry_invalid; @@ -4290,13 +4342,13 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal // Fork a scope of the function with known values for the parameters. Scope *exec_scope = &fn_entry->fndef_scope->base; - size_t next_arg_index = 0; + size_t next_proto_i = 0; if (first_arg_ptr) { IrInstruction *first_arg = ir_get_deref(ira, first_arg_ptr, first_arg_ptr); if (first_arg->type_entry->id == TypeTableEntryIdInvalid) return ira->codegen->builtin_types.entry_invalid; - if (!ir_analyze_fn_call_inline_arg(ira, fn_proto_node, first_arg, &exec_scope, &next_arg_index)) + if (!ir_analyze_fn_call_inline_arg(ira, fn_proto_node, first_arg, &exec_scope, &next_proto_i)) return ira->codegen->builtin_types.entry_invalid; } @@ -4305,7 +4357,7 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal if (old_arg->type_entry->id == TypeTableEntryIdInvalid) return ira->codegen->builtin_types.entry_invalid; - if (!ir_analyze_fn_call_inline_arg(ira, fn_proto_node, old_arg, &exec_scope, &next_arg_index)) + if (!ir_analyze_fn_call_inline_arg(ira, fn_proto_node, old_arg, &exec_scope, &next_proto_i)) return ira->codegen->builtin_types.entry_invalid; } @@ -4327,9 +4379,89 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal return ir_finish_anal(ira, return_type); } + if (fn_type->data.fn.is_generic) { + assert(fn_entry); + + IrInstruction **casted_args = allocate(call_param_count); + + // Fork a scope of the function with known values for the parameters. + Scope *parent_scope = fn_entry->fndef_scope->base.parent; + FnTableEntry *impl_fn = create_fn(ira->codegen, fn_proto_node); + impl_fn->param_source_nodes = allocate(call_param_count); + buf_init_from_buf(&impl_fn->symbol_name, &fn_entry->symbol_name); + impl_fn->fndef_scope = create_fndef_scope(impl_fn->fn_def_node, parent_scope, impl_fn); + impl_fn->child_scope = &impl_fn->fndef_scope->base; + FnTypeId fn_type_id = {0}; + init_fn_type_id(&fn_type_id, fn_proto_node); + fn_type_id.param_count = 0; + + // TODO maybe GenericFnTypeId can be replaced with using the child_scope directly + // as the key in generic_table + GenericFnTypeId *generic_id = allocate(1); + generic_id->fn_entry = fn_entry; + generic_id->param_count = 0; + generic_id->params = allocate(src_param_count); + size_t next_proto_i = 0; + + if (first_arg_ptr) { + IrInstruction *first_arg = ir_get_deref(ira, first_arg_ptr, first_arg_ptr); + if (first_arg->type_entry->id == TypeTableEntryIdInvalid) + return ira->codegen->builtin_types.entry_invalid; + + if (!ir_analyze_fn_call_generic_arg(ira, fn_proto_node, first_arg, &impl_fn->child_scope, + &next_proto_i, generic_id, &fn_type_id, casted_args, impl_fn)) + { + return ira->codegen->builtin_types.entry_invalid; + } + } + for (size_t call_i = 0; call_i < call_instruction->arg_count; call_i += 1) { + IrInstruction *arg = call_instruction->args[call_i]->other; + if (arg->type_entry->id == TypeTableEntryIdInvalid) + return ira->codegen->builtin_types.entry_invalid; + + if (!ir_analyze_fn_call_generic_arg(ira, fn_proto_node, arg, &impl_fn->child_scope, + &next_proto_i, generic_id, &fn_type_id, casted_args, impl_fn)) + { + return ira->codegen->builtin_types.entry_invalid; + } + } + + auto existing_entry = ira->codegen->generic_table.put_unique(generic_id, impl_fn); + if (existing_entry) { + // throw away all our work and use the existing function + impl_fn = existing_entry->value; + } else { + // finish instantiating the function + AstNode *return_type_node = fn_proto_node->data.fn_proto.return_type; + TypeTableEntry *return_type = analyze_type_expr(ira->codegen, impl_fn->child_scope, return_type_node); + if (return_type->id == TypeTableEntryIdInvalid) + return ira->codegen->builtin_types.entry_invalid; + fn_type_id.return_type = return_type; + + impl_fn->type_entry = get_fn_type(ira->codegen, &fn_type_id); + if (impl_fn->type_entry->id == TypeTableEntryIdInvalid) + return ira->codegen->builtin_types.entry_invalid; + + ira->codegen->fn_protos.append(impl_fn); + ira->codegen->fn_defs.append(impl_fn); + } + + size_t impl_param_count = impl_fn->type_entry->data.fn.fn_type_id.param_count; + IrInstruction *new_call_instruction = ir_build_call_from(&ira->new_irb, &call_instruction->base, + impl_fn, nullptr, impl_param_count, casted_args); + + TypeTableEntry *return_type = impl_fn->type_entry->data.fn.fn_type_id.return_type; + if (type_has_bits(return_type) && handle_is_ptr(return_type)) { + FnTableEntry *callsite_fn = exec_fn_entry(ira->new_irb.exec); + assert(callsite_fn); + callsite_fn->alloca_list.append(new_call_instruction); + } + + return ir_finish_anal(ira, return_type); + } + IrInstruction **casted_args = allocate(call_param_count); size_t next_arg_index = 0; - if (first_arg_ptr) { IrInstruction *first_arg = ir_get_deref(ira, first_arg_ptr, first_arg_ptr); if (first_arg->type_entry->id == TypeTableEntryIdInvalid) diff --git a/test/self_hosted2.zig b/test/self_hosted2.zig index 605720be1..687279f9a 100644 --- a/test/self_hosted2.zig +++ b/test/self_hosted2.zig @@ -111,6 +111,33 @@ fn testCompileTimeFib() { assert(fib_7 == 13); } +fn max(inline T: type, a: T, b: T) -> T { + if (a > b) a else b +} +const the_max = max(u32, 1234, 5678); + +fn testCompileTimeGenericEval() { + assert(the_max == 5678); +} + +fn gimmeTheBigOne(a: u32, b: u32) -> u32 { + max(u32, a, b) +} + +fn shouldCallSameInstance(a: u32, b: u32) -> u32 { + max(u32, a, b) +} + +fn sameButWithFloats(a: f64, b: f64) -> f64 { + max(f64, a, b) +} + +fn testFnWithInlineArgs() { + assert(gimmeTheBigOne(1234, 5678) == 5678); + assert(shouldCallSameInstance(34, 12) == 34); + assert(sameButWithFloats(0.43, 0.49) == 0.49); +} + fn assert(ok: bool) { if (!ok) @@ -129,6 +156,8 @@ fn runAllTests() { testStructStatic(); testStaticFnEval(); testCompileTimeFib(); + testCompileTimeGenericEval(); + testFnWithInlineArgs(); } export nakedcc fn _start() -> unreachable {