diff --git a/doc/langref.md b/doc/langref.md index 37ac9d8d4..805de170b 100644 --- a/doc/langref.md +++ b/doc/langref.md @@ -25,7 +25,7 @@ UseDecl = "use" Expression ";" ExternDecl = "extern" (FnProto | VariableDeclaration) ";" -FnProto = "fn" option("Symbol") option(ParamDeclList) ParamDeclList option("->" TypeExpr) +FnProto = "fn" option("Symbol") ParamDeclList option("->" TypeExpr) Directive = "#" "Symbol" "(" Expression ")" @@ -35,7 +35,7 @@ FnDef = option("inline" | "extern") FnProto Block ParamDeclList = "(" list(ParamDecl, ",") ")" -ParamDecl = option("noalias") option("Symbol" ":") TypeExpr | "..." +ParamDecl = option("noalias" | "inline") option("Symbol" ":") TypeExpr | "..." Block = "{" list(option(Statement), ";") "}" diff --git a/example/guess_number/main.zig b/example/guess_number/main.zig index b702209ce..c9f5c6f93 100644 --- a/example/guess_number/main.zig +++ b/example/guess_number/main.zig @@ -23,7 +23,7 @@ pub fn main(args: [][]u8) -> %void { return err; }; - const guess = io.parse_unsigned(u8)(line_buf[0...line_len - 1], 10) %% { + const guess = io.parse_unsigned(u8, line_buf[0...line_len - 1], 10) %% { %%io.stdout.printf("Invalid number.\n"); continue; }; diff --git a/src/all_types.hpp b/src/all_types.hpp index 49336f892..b94188bdb 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -195,10 +195,8 @@ struct AstNodeRoot { struct AstNodeFnProto { TopLevelDecl top_level_decl; Buf name; - ZigList generic_params; ZigList params; AstNode *return_type; - bool generic_params_is_var_args; bool is_var_args; bool is_extern; bool is_inline; @@ -210,7 +208,10 @@ struct AstNodeFnProto { FnTableEntry *fn_table_entry; bool skip; Expr resolved_expr; - TypeTableEntry *generic_fn_type; + // computed from params field + int inline_arg_count; + // if this is a generic function implementation, this points to the generic node + AstNode *generic_proto_node; }; struct AstNodeFnDef { @@ -219,6 +220,7 @@ struct AstNodeFnDef { // populated by semantic analyzer TypeTableEntry *implicit_return_type; + // the first child block context BlockContext *block_context; }; @@ -230,6 +232,7 @@ struct AstNodeParamDecl { Buf name; AstNode *type; bool is_noalias; + bool is_inline; // populated by semantic analyzer VariableTableEntry *variable; @@ -841,6 +844,7 @@ struct FnTypeId { bool is_naked; bool is_cold; bool is_extern; + bool is_inline; FnTypeParamInfo prealloc_param_info[fn_type_id_prealloc_param_info_count]; }; @@ -1063,7 +1067,6 @@ struct FnTableEntry { ZigList all_labels; Buf symbol_name; TypeTableEntry *type_entry; // function type - bool is_inline; bool internal_linkage; bool is_extern; bool is_test; @@ -1172,8 +1175,8 @@ struct CodeGen { ZigList import_queue; int import_queue_index; - ZigList export_queue; - int export_queue_index; + ZigList resolve_queue; + int resolve_queue_index; ZigList use_queue; int use_queue_index; diff --git a/src/analyze.cpp b/src/analyze.cpp index 2f7a1f78b..d5a93a398 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -32,6 +32,8 @@ static TypeTableEntry *analyze_block_expr(CodeGen *g, ImportTableEntry *import, static TypeTableEntry *resolve_expr_const_val_as_void(CodeGen *g, AstNode *node); static TypeTableEntry *resolve_expr_const_val_as_fn(CodeGen *g, AstNode *node, FnTableEntry *fn, bool depends_on_compile_var); +static TypeTableEntry *resolve_expr_const_val_as_generic_fn(CodeGen *g, AstNode *node, + TypeTableEntry *type_entry, bool depends_on_compile_var); static TypeTableEntry *resolve_expr_const_val_as_type(CodeGen *g, AstNode *node, TypeTableEntry *type, bool depends_on_compile_var); static TypeTableEntry *resolve_expr_const_val_as_unsigned_num_lit(CodeGen *g, AstNode *node, @@ -874,7 +876,8 @@ static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *impor fn_type_id.is_extern = fn_proto->is_extern || (fn_proto->top_level_decl.visib_mod == VisibModExport); fn_type_id.is_naked = is_naked; fn_type_id.is_cold = is_cold; - fn_type_id.param_count = node->data.fn_proto.params.length; + fn_type_id.is_inline = fn_proto->is_inline; + fn_type_id.param_count = fn_proto->params.length; if (fn_type_id.param_count > fn_type_id_prealloc_param_info_count) { fn_type_id.param_info = allocate_nonzero(fn_type_id.param_count); @@ -883,15 +886,52 @@ static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *impor } fn_type_id.is_var_args = fn_proto->is_var_args; - fn_type_id.return_type = analyze_type_expr(g, import, context, node->data.fn_proto.return_type); + fn_type_id.return_type = analyze_type_expr(g, import, context, fn_proto->return_type); - if (fn_type_id.return_type->id == TypeTableEntryIdInvalid) { - fn_proto->skip = true; + switch (fn_type_id.return_type->id) { + case TypeTableEntryIdInvalid: + fn_proto->skip = true; + break; + case TypeTableEntryIdNumLitFloat: + case TypeTableEntryIdNumLitInt: + case TypeTableEntryIdUndefLit: + case TypeTableEntryIdNamespace: + case TypeTableEntryIdGenericFn: + fn_proto->skip = true; + add_node_error(g, fn_proto->return_type, + buf_sprintf("return type '%s' not allowed", buf_ptr(&fn_type_id.return_type->name))); + break; + case TypeTableEntryIdMetaType: + if (!fn_proto->is_inline) { + fn_proto->skip = true; + add_node_error(g, fn_proto->return_type, + buf_sprintf("function with return type '%s' must be declared inline", + buf_ptr(&fn_type_id.return_type->name))); + return g->builtin_types.entry_invalid; + } + break; + case TypeTableEntryIdUnreachable: + case TypeTableEntryIdVoid: + case TypeTableEntryIdBool: + case TypeTableEntryIdInt: + case TypeTableEntryIdFloat: + case TypeTableEntryIdPointer: + case TypeTableEntryIdArray: + case TypeTableEntryIdStruct: + case TypeTableEntryIdMaybe: + case TypeTableEntryIdErrorUnion: + case TypeTableEntryIdPureError: + case TypeTableEntryIdEnum: + case TypeTableEntryIdUnion: + case TypeTableEntryIdFn: + case TypeTableEntryIdTypeDecl: + break; } for (int i = 0; i < fn_type_id.param_count; i += 1) { - AstNode *child = node->data.fn_proto.params.at(i); + AstNode *child = fn_proto->params.at(i); assert(child->type == NodeTypeParamDecl); + TypeTableEntry *type_entry = analyze_type_expr(g, import, context, child->data.param_decl.type); switch (type_entry->id) { @@ -901,13 +941,20 @@ static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *impor case TypeTableEntryIdNumLitFloat: case TypeTableEntryIdNumLitInt: case TypeTableEntryIdUndefLit: - case TypeTableEntryIdMetaType: case TypeTableEntryIdUnreachable: case TypeTableEntryIdNamespace: case TypeTableEntryIdGenericFn: fn_proto->skip = true; add_node_error(g, child->data.param_decl.type, - buf_sprintf("parameter of type '%s' not allowed'", buf_ptr(&type_entry->name))); + buf_sprintf("parameter of type '%s' not allowed", buf_ptr(&type_entry->name))); + break; + case TypeTableEntryIdMetaType: + if (!child->data.param_decl.is_inline) { + fn_proto->skip = true; + add_node_error(g, child->data.param_decl.type, + buf_sprintf("parameter of type '%s' must be declared inline", + buf_ptr(&type_entry->name))); + } break; case TypeTableEntryIdVoid: case TypeTableEntryIdBool: @@ -998,8 +1045,6 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t return; } - fn_table_entry->is_inline = fn_proto->is_inline; - bool is_cold = false; bool is_naked = false; bool is_test = false; @@ -1095,7 +1140,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t return; } - if (fn_table_entry->is_inline && fn_table_entry->is_noinline) { + if (fn_proto->is_inline && fn_table_entry->is_noinline) { add_node_error(g, node, buf_sprintf("function is both inline and noinline")); fn_proto->skip = true; return; @@ -1109,10 +1154,14 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t symbol_name = buf_sprintf("_%s", buf_ptr(&fn_table_entry->symbol_name)); } - fn_table_entry->fn_value = LLVMAddFunction(g->module, buf_ptr(symbol_name), - fn_type->data.fn.raw_type_ref); + if (fn_table_entry->fn_def_node) { + BlockContext *context = new_block_context(fn_table_entry->fn_def_node, containing_context); + fn_table_entry->fn_def_node->data.fn_def.block_context = context; + } - if (fn_table_entry->is_inline) { + fn_table_entry->fn_value = LLVMAddFunction(g->module, buf_ptr(symbol_name), fn_type->data.fn.raw_type_ref); + + if (fn_proto->is_inline) { LLVMAddFunctionAttr(fn_table_entry->fn_value, LLVMAlwaysInlineAttribute); } if (fn_table_entry->is_noinline) { @@ -1150,9 +1199,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t fn_type->di_type, fn_table_entry->internal_linkage, is_definition, scope_line, flags, is_optimized, nullptr); - BlockContext *context = new_block_context(fn_table_entry->fn_def_node, containing_context); - fn_table_entry->fn_def_node->data.fn_def.block_context = context; - context->di_scope = LLVMZigSubprogramToScope(subprogram); + fn_table_entry->fn_def_node->data.fn_def.block_context->di_scope = LLVMZigSubprogramToScope(subprogram); ZigLLVMFnSetSubprogram(fn_table_entry->fn_value, subprogram); } } @@ -1176,6 +1223,7 @@ static void resolve_enum_type(CodeGen *g, ImportTableEntry *import, TypeTableEnt return; } + assert(decl_node->type == NodeTypeContainerDecl); assert(enum_type->di_type); enum_type->deep_const = true; @@ -1370,7 +1418,7 @@ static void resolve_struct_type(CodeGen *g, ImportTableEntry *import, TypeTableE return; } - + assert(decl_node->type == NodeTypeContainerDecl); assert(struct_type->di_type); struct_type->deep_const = true; @@ -1496,38 +1544,30 @@ static void get_fully_qualified_decl_name(Buf *buf, AstNode *decl_node, uint8_t } static void preview_generic_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *node) { - if (node->type == NodeTypeFnProto) { - if (node->data.fn_proto.generic_params_is_var_args) { - add_node_error(g, node, buf_sprintf("generic parameters cannot be var args")); - node->data.fn_proto.skip = true; - node->data.fn_proto.generic_fn_type = g->builtin_types.entry_invalid; - return; - } + assert(node->type == NodeTypeContainerDecl); - node->data.fn_proto.generic_fn_type = get_generic_fn_type(g, node); - } else if (node->type == NodeTypeContainerDecl) { - if (node->data.struct_decl.generic_params_is_var_args) { - add_node_error(g, node, buf_sprintf("generic parameters cannot be var args")); - node->data.struct_decl.skip = true; - node->data.struct_decl.generic_fn_type = g->builtin_types.entry_invalid; - return; - } - - node->data.struct_decl.generic_fn_type = get_generic_fn_type(g, node); - } else { - zig_unreachable(); + if (node->data.struct_decl.generic_params_is_var_args) { + add_node_error(g, node, buf_sprintf("generic parameters cannot be var args")); + node->data.struct_decl.skip = true; + node->data.struct_decl.generic_fn_type = g->builtin_types.entry_invalid; + return; } + node->data.struct_decl.generic_fn_type = get_generic_fn_type(g, node); } static void preview_fn_proto_instance(CodeGen *g, ImportTableEntry *import, AstNode *proto_node, BlockContext *containing_context) { + assert(proto_node->type == NodeTypeFnProto); + if (proto_node->data.fn_proto.skip) { return; } - bool is_generic_instance = (proto_node->data.fn_proto.generic_params.length > 0); + bool is_generic_instance = proto_node->data.fn_proto.generic_proto_node; + bool is_generic_fn = proto_node->data.fn_proto.inline_arg_count > 0; + assert(!is_generic_instance || !is_generic_fn); AstNode *parent_decl = proto_node->data.fn_proto.top_level_decl.parent_decl; Buf *proto_name = &proto_node->data.fn_proto.name; @@ -1551,43 +1591,52 @@ static void preview_fn_proto_instance(CodeGen *g, ImportTableEntry *import, AstN get_fully_qualified_decl_name(&fn_table_entry->symbol_name, proto_node, '_'); - g->fn_protos.append(fn_table_entry); - - if (fn_def_node) { - g->fn_defs.append(fn_table_entry); - } - - bool is_main_fn = !is_generic_instance && - !parent_decl && (import == g->root_import) && - buf_eql_str(proto_name, "main"); - if (is_main_fn) { - g->main_fn = fn_table_entry; - } - proto_node->data.fn_proto.fn_table_entry = fn_table_entry; - resolve_function_proto(g, proto_node, fn_table_entry, import, containing_context); - if (is_main_fn && !g->link_libc) { - TypeTableEntry *err_void = get_error_type(g, g->builtin_types.entry_void); - TypeTableEntry *actual_return_type = fn_table_entry->type_entry->data.fn.fn_type_id.return_type; - if (actual_return_type != err_void) { - AstNode *return_type_node = fn_table_entry->proto_node->data.fn_proto.return_type; - add_node_error(g, return_type_node, - buf_sprintf("expected return type of main to be '%%void', instead is '%s'", - buf_ptr(&actual_return_type->name))); + if (is_generic_fn) { + fn_table_entry->type_entry = get_generic_fn_type(g, proto_node); + + if (is_extern || proto_node->data.fn_proto.top_level_decl.visib_mod == VisibModExport) { + for (int i = 0; i < proto_node->data.fn_proto.params.length; i += 1) { + AstNode *param_decl_node = proto_node->data.fn_proto.params.at(i); + if (param_decl_node->data.param_decl.is_inline) { + proto_node->data.fn_proto.skip = true; + add_node_error(g, param_decl_node, + buf_sprintf("inline parameter not allowed in extern function")); + } + } + } + + + } else { + g->fn_protos.append(fn_table_entry); + + if (fn_def_node) { + g->fn_defs.append(fn_table_entry); + } + + bool is_main_fn = !is_generic_instance && + !parent_decl && (import == g->root_import) && + buf_eql_str(proto_name, "main"); + if (is_main_fn) { + g->main_fn = fn_table_entry; + } + + resolve_function_proto(g, proto_node, fn_table_entry, import, containing_context); + + if (is_main_fn && !g->link_libc) { + TypeTableEntry *err_void = get_error_type(g, g->builtin_types.entry_void); + TypeTableEntry *actual_return_type = fn_table_entry->type_entry->data.fn.fn_type_id.return_type; + if (actual_return_type != err_void) { + AstNode *return_type_node = fn_table_entry->proto_node->data.fn_proto.return_type; + add_node_error(g, return_type_node, + buf_sprintf("expected return type of main to be '%%void', instead is '%s'", + buf_ptr(&actual_return_type->name))); + } } } } -static void preview_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *proto_node) { - if (proto_node->data.fn_proto.generic_params.length > 0) { - return preview_generic_fn_proto(g, import, proto_node); - } else { - return preview_fn_proto_instance(g, import, proto_node, proto_node->block_context); - } - -} - static void scan_struct_decl(CodeGen *g, ImportTableEntry *import, BlockContext *context, AstNode *node) { assert(node->type == NodeTypeContainerDecl); @@ -1683,7 +1732,7 @@ static void resolve_top_level_decl(CodeGen *g, AstNode *node, bool pointer_only) switch (node->type) { case NodeTypeFnProto: - preview_fn_proto(g, import, node); + preview_fn_proto_instance(g, import, node, node->block_context); break; case NodeTypeContainerDecl: resolve_struct_decl(g, import, node); @@ -2600,7 +2649,11 @@ static TypeTableEntry *analyze_field_access_expr(CodeGen *g, ImportTableEntry *i node->data.field_access_expr.is_member_fn = true; FnTableEntry *fn_entry = fn_decl_node->data.fn_proto.fn_table_entry; - return resolve_expr_const_val_as_fn(g, node, fn_entry, false); + if (fn_entry->type_entry->id == TypeTableEntryIdGenericFn) { + return resolve_expr_const_val_as_generic_fn(g, node, fn_entry->type_entry, false); + } else { + return resolve_expr_const_val_as_fn(g, node, fn_entry, false); + } } else { add_node_error(g, node, buf_sprintf("no function named '%s' in '%s'", buf_ptr(field_name), buf_ptr(&bare_struct_type->name))); @@ -3004,13 +3057,11 @@ static TypeTableEntry *analyze_decl_ref(CodeGen *g, AstNode *source_node, AstNod VariableTableEntry *var = decl_node->data.variable_declaration.variable; return analyze_var_ref(g, source_node, var, block_context, depends_on_compile_var); } else if (decl_node->type == NodeTypeFnProto) { - if (decl_node->data.fn_proto.generic_params.length > 0) { - TypeTableEntry *type_entry = decl_node->data.fn_proto.generic_fn_type; - assert(type_entry); - return resolve_expr_const_val_as_generic_fn(g, source_node, type_entry, depends_on_compile_var); + FnTableEntry *fn_entry = decl_node->data.fn_proto.fn_table_entry; + assert(fn_entry->type_entry); + if (fn_entry->type_entry->id == TypeTableEntryIdGenericFn) { + return resolve_expr_const_val_as_generic_fn(g, source_node, fn_entry->type_entry, depends_on_compile_var); } else { - FnTableEntry *fn_entry = decl_node->data.fn_proto.fn_table_entry; - assert(fn_entry->type_entry); return resolve_expr_const_val_as_fn(g, source_node, fn_entry, depends_on_compile_var); } } else if (decl_node->type == NodeTypeContainerDecl) { @@ -5238,6 +5289,8 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry zig_unreachable(); } +// Before calling this function, set node->data.fn_call_expr.fn_table_entry if the function is known +// at compile time. Otherwise this is a function pointer call. static TypeTableEntry *analyze_fn_call_ptr(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node, TypeTableEntry *fn_type, AstNode *struct_node) @@ -5248,26 +5301,30 @@ static TypeTableEntry *analyze_fn_call_ptr(CodeGen *g, ImportTableEntry *import, return fn_type; } - // count parameters - int src_param_count = fn_type->data.fn.fn_type_id.param_count; - int actual_param_count = node->data.fn_call_expr.params.length; + // The function call might include inline parameters which we need to ignore according to the + // fn_type. + FnTableEntry *fn_table_entry = node->data.fn_call_expr.fn_entry; + AstNode *generic_proto_node = fn_table_entry ? + fn_table_entry->proto_node->data.fn_proto.generic_proto_node : nullptr; - if (struct_node) { - actual_param_count += 1; - } + // count parameters + int struct_node_1_or_0 = struct_node ? 1 : 0; + int src_param_count = fn_type->data.fn.fn_type_id.param_count + + (generic_proto_node ? generic_proto_node->data.fn_proto.inline_arg_count : 0); + int call_param_count = node->data.fn_call_expr.params.length; bool ok_invocation = true; if (fn_type->data.fn.fn_type_id.is_var_args) { - if (actual_param_count < src_param_count) { + if (call_param_count < src_param_count - struct_node_1_or_0) { ok_invocation = false; add_node_error(g, node, - buf_sprintf("expected at least %d arguments, got %d", src_param_count, actual_param_count)); + buf_sprintf("expected at least %d arguments, got %d", src_param_count, call_param_count)); } - } else if (src_param_count != actual_param_count) { + } else if (src_param_count - struct_node_1_or_0 != call_param_count) { ok_invocation = false; add_node_error(g, node, - buf_sprintf("expected %d arguments, got %d", src_param_count, actual_param_count)); + buf_sprintf("expected %d arguments, got %d", src_param_count, call_param_count)); } bool all_args_const_expr = true; @@ -5281,17 +5338,30 @@ static TypeTableEntry *analyze_fn_call_ptr(CodeGen *g, ImportTableEntry *import, // analyze each parameter. in the case of a method, we already analyzed the // first parameter in order to figure out which struct we were calling a method on. - for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) { - AstNode **child = &node->data.fn_call_expr.params.at(i); + int next_type_i = struct_node_1_or_0; + for (int call_i = 0; call_i < call_param_count; call_i += 1) { + int proto_i = call_i + struct_node_1_or_0; + AstNode **param_node = &node->data.fn_call_expr.params.at(call_i); // determine the expected type for each parameter TypeTableEntry *expected_param_type = nullptr; - int fn_proto_i = i + (struct_node ? 1 : 0); - if (fn_proto_i < src_param_count) { - expected_param_type = fn_type->data.fn.fn_type_id.param_info[fn_proto_i].type; - } - analyze_expression(g, import, context, expected_param_type, *child); + if (proto_i < src_param_count) { + if (generic_proto_node && + generic_proto_node->data.fn_proto.params.at(proto_i)->data.param_decl.is_inline) + { + continue; + } - ConstExprValue *const_arg_val = &get_resolved_expr(*child)->const_val; + FnTypeParamInfo *param_info = &fn_type->data.fn.fn_type_id.param_info[next_type_i]; + next_type_i += 1; + + expected_param_type = param_info->type; + } + TypeTableEntry *param_type = analyze_expression(g, import, context, expected_param_type, *param_node); + if (param_type->id == TypeTableEntryIdInvalid) { + return param_type; + } + + ConstExprValue *const_arg_val = &get_resolved_expr(*param_node)->const_val; if (!const_arg_val->ok) { all_args_const_expr = false; } @@ -5303,7 +5373,6 @@ static TypeTableEntry *analyze_fn_call_ptr(CodeGen *g, ImportTableEntry *import, return return_type; } - FnTableEntry *fn_table_entry = node->data.fn_call_expr.fn_entry; ConstExprValue *result_val = &get_resolved_expr(node)->const_val; if (ok_invocation && fn_table_entry && fn_table_entry->is_pure && fn_table_entry->want_pure != WantPureFalse) { if (fn_table_entry->anal_state == FnAnalStateReady) { @@ -5335,14 +5404,103 @@ static TypeTableEntry *analyze_fn_call_ptr(CodeGen *g, ImportTableEntry *import, return return_type; } -static TypeTableEntry *analyze_fn_call_raw(CodeGen *g, ImportTableEntry *import, BlockContext *context, - TypeTableEntry *expected_type, AstNode *node, FnTableEntry *fn_table_entry, AstNode *struct_node) +static TypeTableEntry *analyze_fn_call_with_inline_args(CodeGen *g, ImportTableEntry *import, + BlockContext *parent_context, TypeTableEntry *expected_type, AstNode *call_node, + FnTableEntry *fn_table_entry, AstNode *struct_node) { - assert(node->type == NodeTypeFnCallExpr); + assert(call_node->type == NodeTypeFnCallExpr); + assert(fn_table_entry); - node->data.fn_call_expr.fn_entry = fn_table_entry; + AstNode *decl_node = fn_table_entry->proto_node; - return analyze_fn_call_ptr(g, import, context, expected_type, node, fn_table_entry->type_entry, struct_node); + // count parameters + int struct_node_1_or_0 = (struct_node ? 1 : 0); + int src_param_count = decl_node->data.fn_proto.params.length; + int call_param_count = call_node->data.fn_call_expr.params.length; + + if (src_param_count != call_param_count + struct_node_1_or_0) { + add_node_error(g, call_node, + buf_sprintf("expected %d arguments, got %d", src_param_count, call_param_count)); + return g->builtin_types.entry_invalid; + } + + int inline_arg_count = decl_node->data.fn_proto.inline_arg_count; + assert(inline_arg_count > 0); + + BlockContext *child_context = decl_node->owner->block_context; + int next_generic_param_index = 0; + + GenericFnTypeId *generic_fn_type_id = allocate(1); + generic_fn_type_id->decl_node = decl_node; + generic_fn_type_id->generic_param_count = inline_arg_count; + generic_fn_type_id->generic_params = allocate(inline_arg_count); + + for (int call_i = 0; call_i < call_param_count; call_i += 1) { + int proto_i = call_i + struct_node_1_or_0; + AstNode *generic_param_decl_node = decl_node->data.fn_proto.params.at(proto_i); + assert(generic_param_decl_node->type == NodeTypeParamDecl); + bool is_inline = generic_param_decl_node->data.param_decl.is_inline; + if (!is_inline) continue; + + AstNode **generic_param_type_node = &generic_param_decl_node->data.param_decl.type; + TypeTableEntry *expected_param_type = analyze_type_expr(g, decl_node->owner, child_context, + *generic_param_type_node); + if (expected_param_type->id == TypeTableEntryIdInvalid) { + return expected_param_type; + } + + AstNode **param_node = &call_node->data.fn_call_expr.params.at(call_i); + TypeTableEntry *param_type = analyze_expression(g, import, parent_context, + expected_param_type, *param_node); + if (param_type->id == TypeTableEntryIdInvalid) { + return param_type; + } + + // set child_context so that the previous param is in scope + child_context = new_block_context(generic_param_decl_node, child_context); + + ConstExprValue *const_val = &get_resolved_expr(*param_node)->const_val; + if (const_val->ok) { + add_local_var(g, generic_param_decl_node, decl_node->owner, child_context, + &generic_param_decl_node->data.param_decl.name, param_type, true, *param_node); + } else { + add_node_error(g, *param_node, + buf_sprintf("unable to evaluate constant expression for inline parameter")); + + return g->builtin_types.entry_invalid; + } + + GenericParamValue *generic_param_value = + &generic_fn_type_id->generic_params[next_generic_param_index]; + generic_param_value->type = param_type; + generic_param_value->node = *param_node; + next_generic_param_index += 1; + } + + assert(next_generic_param_index == inline_arg_count); + + auto entry = g->generic_table.maybe_get(generic_fn_type_id); + FnTableEntry *impl_fn; + if (entry) { + AstNode *impl_decl_node = entry->value; + assert(impl_decl_node->type == NodeTypeFnProto); + impl_fn = impl_decl_node->data.fn_proto.fn_table_entry; + } else { + AstNode *decl_node = generic_fn_type_id->decl_node; + AstNode *impl_fn_def_node = ast_clone_subtree_special(decl_node->data.fn_proto.fn_def_node, + &g->next_node_index, AstCloneSpecialOmitInlineParams); + AstNode *impl_decl_node = impl_fn_def_node->data.fn_def.fn_proto; + impl_decl_node->data.fn_proto.inline_arg_count = 0; + impl_decl_node->data.fn_proto.generic_proto_node = decl_node; + + preview_fn_proto_instance(g, import, impl_decl_node, child_context); + g->generic_table.put(generic_fn_type_id, impl_decl_node); + impl_fn = impl_decl_node->data.fn_proto.fn_table_entry; + } + + call_node->data.fn_call_expr.fn_entry = impl_fn; + return analyze_fn_call_ptr(g, import, parent_context, expected_type, call_node, + impl_fn->type_entry, struct_node); } static TypeTableEntry *analyze_generic_fn_call(CodeGen *g, ImportTableEntry *import, BlockContext *parent_context, @@ -5352,14 +5510,8 @@ static TypeTableEntry *analyze_generic_fn_call(CodeGen *g, ImportTableEntry *imp assert(generic_fn_type->id == TypeTableEntryIdGenericFn); AstNode *decl_node = generic_fn_type->data.generic_fn.decl_node; - ZigList *generic_params; - if (decl_node->type == NodeTypeFnProto) { - generic_params = &decl_node->data.fn_proto.generic_params; - } else if (decl_node->type == NodeTypeContainerDecl) { - generic_params = &decl_node->data.struct_decl.generic_params; - } else { - zig_unreachable(); - } + assert(decl_node->type == NodeTypeContainerDecl); + ZigList *generic_params = &decl_node->data.struct_decl.generic_params; int expected_param_count = generic_params->length; int actual_param_count = node->data.fn_call_expr.params.length; @@ -5405,10 +5557,6 @@ static TypeTableEntry *analyze_generic_fn_call(CodeGen *g, ImportTableEntry *imp } else { add_node_error(g, *param_node, buf_sprintf("unable to evaluate constant expression")); - add_local_var(g, generic_param_decl_node, decl_node->owner, child_context, - &generic_param_decl_node->data.param_decl.name, g->builtin_types.entry_invalid, - true, nullptr); - return g->builtin_types.entry_invalid; } @@ -5420,36 +5568,19 @@ static TypeTableEntry *analyze_generic_fn_call(CodeGen *g, ImportTableEntry *imp auto entry = g->generic_table.maybe_get(generic_fn_type_id); if (entry) { AstNode *impl_decl_node = entry->value; - if (impl_decl_node->type == NodeTypeFnProto) { - FnTableEntry *fn_table_entry = impl_decl_node->data.fn_proto.fn_table_entry; - return resolve_expr_const_val_as_fn(g, node, fn_table_entry, false); - } else if (impl_decl_node->type == NodeTypeContainerDecl) { - TypeTableEntry *type_entry = impl_decl_node->data.struct_decl.type_entry; - return resolve_expr_const_val_as_type(g, node, type_entry, false); - } else { - zig_unreachable(); - } + assert(impl_decl_node->type == NodeTypeContainerDecl); + TypeTableEntry *type_entry = impl_decl_node->data.struct_decl.type_entry; + return resolve_expr_const_val_as_type(g, node, type_entry, false); } // make a type from the generic parameters supplied - if (decl_node->type == NodeTypeFnProto) { - AstNode *impl_fn_def_node = ast_clone_subtree(decl_node->data.fn_proto.fn_def_node, &g->next_node_index); - AstNode *impl_decl_node = impl_fn_def_node->data.fn_def.fn_proto; - - preview_fn_proto_instance(g, import, impl_decl_node, child_context); - g->generic_table.put(generic_fn_type_id, impl_decl_node); - FnTableEntry *fn_table_entry = impl_decl_node->data.fn_proto.fn_table_entry; - return resolve_expr_const_val_as_fn(g, node, fn_table_entry, false); - } else if (decl_node->type == NodeTypeContainerDecl) { - AstNode *impl_decl_node = ast_clone_subtree(decl_node, &g->next_node_index); - g->generic_table.put(generic_fn_type_id, impl_decl_node); - scan_struct_decl(g, import, child_context, impl_decl_node); - TypeTableEntry *type_entry = impl_decl_node->data.struct_decl.type_entry; - resolve_struct_type(g, import, type_entry); - return resolve_expr_const_val_as_type(g, node, type_entry, false); - } else { - zig_unreachable(); - } + assert(decl_node->type == NodeTypeContainerDecl); + AstNode *impl_decl_node = ast_clone_subtree(decl_node, &g->next_node_index); + g->generic_table.put(generic_fn_type_id, impl_decl_node); + scan_struct_decl(g, import, child_context, impl_decl_node); + TypeTableEntry *type_entry = impl_decl_node->data.struct_decl.type_entry; + resolve_struct_type(g, import, type_entry); + return resolve_expr_const_val_as_type(g, node, type_entry, false); } static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, @@ -5487,10 +5618,32 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import struct_node = nullptr; } - return analyze_fn_call_raw(g, import, context, expected_type, node, - const_val->data.x_fn, struct_node); + FnTableEntry *fn_table_entry = const_val->data.x_fn; + node->data.fn_call_expr.fn_entry = fn_table_entry; + return analyze_fn_call_ptr(g, import, context, expected_type, node, + fn_table_entry->type_entry, struct_node); } else if (invoke_type_entry->id == TypeTableEntryIdGenericFn) { - return analyze_generic_fn_call(g, import, context, expected_type, node, const_val->data.x_type); + TypeTableEntry *generic_fn_type = const_val->data.x_type; + AstNode *decl_node = generic_fn_type->data.generic_fn.decl_node; + if (decl_node->type == NodeTypeFnProto) { + AstNode *struct_node; + if (fn_ref_expr->type == NodeTypeFieldAccessExpr && + fn_ref_expr->data.field_access_expr.is_member_fn) + { + struct_node = fn_ref_expr->data.field_access_expr.struct_expr; + } else { + struct_node = nullptr; + } + + FnTableEntry *fn_table_entry = decl_node->data.fn_proto.fn_table_entry; + if (fn_table_entry->proto_node->data.fn_proto.skip) { + return g->builtin_types.entry_invalid; + } + return analyze_fn_call_with_inline_args(g, import, context, expected_type, node, + fn_table_entry, struct_node); + } else { + return analyze_generic_fn_call(g, import, context, expected_type, node, const_val->data.x_type); + } } else { add_node_error(g, fn_ref_expr, buf_sprintf("type '%s' not a function", buf_ptr(&invoke_type_entry->name))); @@ -6367,7 +6520,9 @@ static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) { var->src_arg_index = i; param_decl_node->data.param_decl.variable = var; - var->gen_arg_index = fn_type->data.fn.gen_param_info[i].gen_index; + if (fn_type->data.fn.gen_param_info) { + var->gen_arg_index = fn_type->data.fn.gen_param_info[i].gen_index; + } if (!type->deep_const) { fn_table_entry->is_pure = false; @@ -6406,11 +6561,11 @@ static void add_top_level_decl(CodeGen *g, ImportTableEntry *import, BlockContex tld->import = import; tld->name = name; - bool want_as_export = (g->check_unused || g->is_test_build || tld->visib_mod == VisibModExport); - bool is_generic = (node->type == NodeTypeFnProto && node->data.fn_proto.generic_params.length > 0) || - (node->type == NodeTypeContainerDecl && node->data.struct_decl.generic_params.length > 0); - if (!is_generic && want_as_export) { - g->export_queue.append(node); + bool want_to_resolve = (g->check_unused || g->is_test_build || tld->visib_mod == VisibModExport); + bool is_generic_container = (node->type == NodeTypeContainerDecl && + node->data.struct_decl.generic_params.length > 0); + if (want_to_resolve && !is_generic_container) { + g->resolve_queue.append(node); } node->block_context = block_context; @@ -6425,6 +6580,18 @@ static void add_top_level_decl(CodeGen *g, ImportTableEntry *import, BlockContex } } +static int fn_proto_inline_arg_count(AstNode *proto_node) { + assert(proto_node->type == NodeTypeFnProto); + int result = 0; + for (int i = 0; i < proto_node->data.fn_proto.params.length; i += 1) { + AstNode *param_node = proto_node->data.fn_proto.params.at(i); + assert(param_node->type == NodeTypeParamDecl); + result += param_node->data.param_decl.is_inline ? 1 : 0; + } + return result; +} + + static void scan_decls(CodeGen *g, ImportTableEntry *import, BlockContext *context, AstNode *node) { switch (node->type) { case NodeTypeRoot: @@ -6467,6 +6634,7 @@ static void scan_decls(CodeGen *g, ImportTableEntry *import, BlockContext *conte add_node_error(g, node, buf_sprintf("missing function name")); break; } + node->data.fn_proto.inline_arg_count = fn_proto_inline_arg_count(node); add_top_level_decl(g, import, context, node, fn_name); break; @@ -6692,8 +6860,8 @@ void semantic_analyze(CodeGen *g) { resolve_use_decl(g, use_decl_node); } - for (; g->export_queue_index < g->export_queue.length; g->export_queue_index += 1) { - AstNode *decl_node = g->export_queue.at(g->export_queue_index); + for (; g->resolve_queue_index < g->resolve_queue.length; g->resolve_queue_index += 1) { + AstNode *decl_node = g->resolve_queue.at(g->resolve_queue_index); bool pointer_only = false; resolve_top_level_decl(g, decl_node, pointer_only); } @@ -6983,11 +7151,9 @@ bool fn_type_id_eql(FnTypeId *a, FnTypeId *b) { FnTypeParamInfo *a_param_info = &a->param_info[i]; FnTypeParamInfo *b_param_info = &b->param_info[i]; - if (a_param_info->type != b_param_info->type) { - return false; - } - - if (a_param_info->is_noalias != b_param_info->is_noalias) { + if (a_param_info->type != b_param_info->type || + a_param_info->is_noalias != b_param_info->is_noalias) + { return false; } } diff --git a/src/ast_render.cpp b/src/ast_render.cpp index 86b70d3c3..9d9f56cb3 100644 --- a/src/ast_render.cpp +++ b/src/ast_render.cpp @@ -353,7 +353,8 @@ static void render_node(AstRender *ar, AstNode *node) { assert(param_decl->type == NodeTypeParamDecl); if (buf_len(¶m_decl->data.param_decl.name) > 0) { const char *noalias_str = param_decl->data.param_decl.is_noalias ? "noalias " : ""; - fprintf(ar->f, "%s", noalias_str); + const char *inline_str = param_decl->data.param_decl.is_inline ? "inline " : ""; + fprintf(ar->f, "%s%s", noalias_str, inline_str); print_symbol(ar, ¶m_decl->data.param_decl.name); fprintf(ar->f, ": "); } diff --git a/src/codegen.cpp b/src/codegen.cpp index da2aa2541..55e101fff 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1062,12 +1062,15 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { TypeTableEntry *fn_type; LLVMValueRef fn_val; + AstNode *generic_proto_node; if (fn_table_entry) { fn_val = fn_table_entry->fn_value; fn_type = fn_table_entry->type_entry; + generic_proto_node = fn_table_entry->proto_node->data.fn_proto.generic_proto_node; } else { fn_val = gen_expr(g, fn_ref_expr); fn_type = get_expr_type(fn_ref_expr); + generic_proto_node = nullptr; } TypeTableEntry *src_return_type = fn_type->data.fn.fn_type_id.return_type; @@ -1093,8 +1096,14 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { gen_param_index += 1; } - for (int i = 0; i < fn_call_param_count; i += 1) { - AstNode *expr_node = node->data.fn_call_expr.params.at(i); + for (int call_i = 0; call_i < fn_call_param_count; call_i += 1) { + int proto_i = call_i + (struct_type ? 1 : 0); + if (generic_proto_node && + generic_proto_node->data.fn_proto.params.at(proto_i)->data.param_decl.is_inline) + { + continue; + } + AstNode *expr_node = node->data.fn_call_expr.params.at(call_i); LLVMValueRef param_value = gen_expr(g, expr_node); assert(param_value); TypeTableEntry *param_type = get_expr_type(expr_node); @@ -3734,7 +3743,7 @@ static void delete_unused_builtin_fns(CodeGen *g) { } } -static bool skip_fn_codegen(CodeGen *g, FnTableEntry *fn_entry) { +static bool should_skip_fn_codegen(CodeGen *g, FnTableEntry *fn_entry) { if (g->is_test_build) { if (fn_entry->is_test) { return false; @@ -3889,7 +3898,7 @@ static void do_code_gen(CodeGen *g) { // Generate function prototypes for (int fn_proto_i = 0; fn_proto_i < g->fn_protos.length; fn_proto_i += 1) { FnTableEntry *fn_table_entry = g->fn_protos.at(fn_proto_i); - if (skip_fn_codegen(g, fn_table_entry)) { + if (should_skip_fn_codegen(g, fn_table_entry)) { // huge time saver LLVMDeleteFunction(fn_table_entry->fn_value); fn_table_entry->fn_value = nullptr; @@ -3995,7 +4004,7 @@ static void do_code_gen(CodeGen *g) { // Generate function definitions. for (int fn_i = 0; fn_i < g->fn_defs.length; fn_i += 1) { FnTableEntry *fn_table_entry = g->fn_defs.at(fn_i); - if (skip_fn_codegen(g, fn_table_entry)) { + if (should_skip_fn_codegen(g, fn_table_entry)) { // huge time saver continue; } diff --git a/src/eval.cpp b/src/eval.cpp index 334a15fb4..321a5d667 100644 --- a/src/eval.cpp +++ b/src/eval.cpp @@ -884,9 +884,9 @@ static bool eval_fn_call_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val int param_count = node->data.fn_call_expr.params.length; ConstExprValue *args = allocate(param_count); - for (int i = 0; i < param_count; i += 1) { - AstNode *param_expr_node = node->data.fn_call_expr.params.at(i); - ConstExprValue *param_val = &args[i]; + for (int call_i = 0; call_i < param_count; call_i += 1) { + AstNode *param_expr_node = node->data.fn_call_expr.params.at(call_i); + ConstExprValue *param_val = &args[call_i]; if (eval_expr(ef, param_expr_node, param_val)) return true; } @@ -1291,6 +1291,13 @@ static bool eval_expr(EvalFn *ef, AstNode *node, ConstExprValue *out) { } static bool eval_fn_args(EvalFnRoot *efr, FnTableEntry *fn, ConstExprValue *args, ConstExprValue *out_val) { + AstNode *acting_proto_node; + if (fn->proto_node->data.fn_proto.generic_proto_node) { + acting_proto_node = fn->proto_node->data.fn_proto.generic_proto_node; + } else { + acting_proto_node = fn->proto_node; + } + EvalFn ef = {0}; ef.root = efr; ef.fn = fn; @@ -1300,12 +1307,12 @@ static bool eval_fn_args(EvalFnRoot *efr, FnTableEntry *fn, ConstExprValue *args root_scope->block_context = fn->fn_def_node->data.fn_def.body->block_context; ef.scope_stack.append(root_scope); - int param_count = fn->type_entry->data.fn.fn_type_id.param_count; - for (int i = 0; i < param_count; i += 1) { - AstNode *decl_param_node = fn->proto_node->data.fn_proto.params.at(i); + int param_count = acting_proto_node->data.fn_proto.params.length; + for (int proto_i = 0; proto_i < param_count; proto_i += 1) { + AstNode *decl_param_node = acting_proto_node->data.fn_proto.params.at(proto_i); assert(decl_param_node->type == NodeTypeParamDecl); - ConstExprValue *src_const_val = &args[i]; + ConstExprValue *src_const_val = &args[proto_i]; assert(src_const_val->ok); root_scope->vars.add_one(); @@ -1315,7 +1322,6 @@ static bool eval_fn_args(EvalFnRoot *efr, FnTableEntry *fn, ConstExprValue *args } return eval_expr(&ef, fn->fn_def_node->data.fn_def.body, out_val); - } bool eval_fn(CodeGen *g, AstNode *node, FnTableEntry *fn, ConstExprValue *out_val, @@ -1329,9 +1335,16 @@ bool eval_fn(CodeGen *g, AstNode *node, FnTableEntry *fn, ConstExprValue *out_va efr.call_node = node; efr.branch_quota = branch_quota; + AstNode *acting_proto_node; + if (fn->proto_node->data.fn_proto.generic_proto_node) { + acting_proto_node = fn->proto_node->data.fn_proto.generic_proto_node; + } else { + acting_proto_node = fn->proto_node; + } + int call_param_count = node->data.fn_call_expr.params.length; - int type_param_count = fn->type_entry->data.fn.fn_type_id.param_count; - ConstExprValue *args = allocate(type_param_count); + int proto_param_count = acting_proto_node->data.fn_proto.params.length; + ConstExprValue *args = allocate(proto_param_count); int next_arg_index = 0; if (struct_node) { ConstExprValue *struct_val = &get_resolved_expr(struct_node)->const_val; diff --git a/src/parser.cpp b/src/parser.cpp index ff9a6268f..418a9d210 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -747,7 +747,7 @@ static void ast_parse_directives(ParseContext *pc, int *token_index, } /* -ParamDecl = option("noalias") option("Symbol" ":") PrefixOpExpression | "..." +ParamDecl = option("noalias" | "inline") option("Symbol" ":") TypeExpr | "..." */ static AstNode *ast_parse_param_decl(ParseContext *pc, int *token_index) { Token *token = &pc->tokens->at(*token_index); @@ -763,6 +763,10 @@ static AstNode *ast_parse_param_decl(ParseContext *pc, int *token_index) { node->data.param_decl.is_noalias = true; *token_index += 1; token = &pc->tokens->at(*token_index); + } else if (token->id == TokenIdKeywordInline) { + node->data.param_decl.is_inline = true; + *token_index += 1; + token = &pc->tokens->at(*token_index); } buf_resize(&node->data.param_decl.name, 0); @@ -2472,7 +2476,7 @@ static AstNode *ast_parse_block(ParseContext *pc, int *token_index, bool mandato } /* -FnProto = "fn" option("Symbol") option(ParamDeclList) ParamDeclList option("->" TypeExpr) +FnProto = "fn" option("Symbol") ParamDeclList option("->" TypeExpr) */ static AstNode *ast_parse_fn_proto(ParseContext *pc, int *token_index, bool mandatory, ZigList *directives, VisibMod visib_mod) @@ -2502,17 +2506,6 @@ static AstNode *ast_parse_fn_proto(ParseContext *pc, int *token_index, bool mand ast_parse_param_decl_list(pc, token_index, &node->data.fn_proto.params, &node->data.fn_proto.is_var_args); - Token *maybe_lparen = &pc->tokens->at(*token_index); - if (maybe_lparen->id == TokenIdLParen) { - for (int i = 0; i < node->data.fn_proto.params.length; i += 1) { - node->data.fn_proto.generic_params.append(node->data.fn_proto.params.at(i)); - } - node->data.fn_proto.generic_params_is_var_args = node->data.fn_proto.is_var_args; - - node->data.fn_proto.params.resize(0); - ast_parse_param_decl_list(pc, token_index, &node->data.fn_proto.params, &node->data.fn_proto.is_var_args); - } - Token *next_token = &pc->tokens->at(*token_index); if (next_token->id == TokenIdArrow) { *token_index += 1; @@ -2931,7 +2924,6 @@ void ast_visit_node_children(AstNode *node, void (*visit)(AstNode **, void *cont case NodeTypeFnProto: visit_field(&node->data.fn_proto.return_type, visit, context); visit_node_list(node->data.fn_proto.top_level_decl.directives, visit, context); - visit_node_list(&node->data.fn_proto.generic_params, visit, context); visit_node_list(&node->data.fn_proto.params, visit, context); break; case NodeTypeFnDef: @@ -3123,6 +3115,22 @@ static void clone_subtree_list(ZigList *dest, ZigList *src } } +static void clone_subtree_list_omit_inline_params(ZigList *dest, ZigList *src, + uint32_t *next_node_index) +{ + memset(dest, 0, sizeof(ZigList)); + dest->ensure_capacity(src->length); + for (int i = 0; i < src->length; i += 1) { + AstNode *src_node = src->at(i); + assert(src_node->type == NodeTypeParamDecl); + if (src_node->data.param_decl.is_inline) { + continue; + } + dest->append(ast_clone_subtree(src_node, next_node_index)); + dest->last()->parent_field = &dest->last(); + } +} + static void clone_subtree_list_ptr(ZigList **dest_ptr, ZigList *src, uint32_t *next_node_index) { @@ -3133,20 +3141,26 @@ static void clone_subtree_list_ptr(ZigList **dest_ptr, ZigListparent_field = dest; } else { *dest = nullptr; } } +static void clone_subtree_field(AstNode **dest, AstNode *src, uint32_t *next_node_index) { + return clone_subtree_field_special(dest, src, next_node_index, AstCloneSpecialNone); +} + static void clone_subtree_tld(TopLevelDecl *dest, TopLevelDecl *src, uint32_t *next_node_index) { clone_subtree_list_ptr(&dest->directives, src->directives, next_node_index); } -AstNode *ast_clone_subtree(AstNode *old_node, uint32_t *next_node_index) { +AstNode *ast_clone_subtree_special(AstNode *old_node, uint32_t *next_node_index, enum AstCloneSpecial special) { AstNode *new_node = allocate_nonzero(1); memcpy(new_node, old_node, sizeof(AstNode)); new_node->create_index = *next_node_index; @@ -3163,14 +3177,19 @@ AstNode *ast_clone_subtree(AstNode *old_node, uint32_t *next_node_index) { next_node_index); clone_subtree_field(&new_node->data.fn_proto.return_type, old_node->data.fn_proto.return_type, next_node_index); - clone_subtree_list(&new_node->data.fn_proto.generic_params, - &old_node->data.fn_proto.generic_params, next_node_index); - clone_subtree_list(&new_node->data.fn_proto.params, &old_node->data.fn_proto.params, - next_node_index); + + if (special == AstCloneSpecialOmitInlineParams) { + clone_subtree_list_omit_inline_params(&new_node->data.fn_proto.params, &old_node->data.fn_proto.params, + next_node_index); + } else { + clone_subtree_list(&new_node->data.fn_proto.params, &old_node->data.fn_proto.params, + next_node_index); + } break; case NodeTypeFnDef: - clone_subtree_field(&new_node->data.fn_def.fn_proto, old_node->data.fn_def.fn_proto, next_node_index); + clone_subtree_field_special(&new_node->data.fn_def.fn_proto, old_node->data.fn_def.fn_proto, + next_node_index, special); new_node->data.fn_def.fn_proto->data.fn_proto.fn_def_node = new_node; clone_subtree_field(&new_node->data.fn_def.body, old_node->data.fn_def.body, next_node_index); break; @@ -3354,3 +3373,7 @@ AstNode *ast_clone_subtree(AstNode *old_node, uint32_t *next_node_index) { return new_node; } + +AstNode *ast_clone_subtree(AstNode *old_node, uint32_t *next_node_index) { + return ast_clone_subtree_special(old_node, next_node_index, AstCloneSpecialNone); +} diff --git a/src/parser.hpp b/src/parser.hpp index 58da0234e..ce8a42a4d 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -25,6 +25,13 @@ void ast_print(AstNode *node, int indent); void normalize_parent_ptrs(AstNode *node); AstNode *ast_clone_subtree(AstNode *node, uint32_t *next_node_index); + +enum AstCloneSpecial { + AstCloneSpecialNone, + AstCloneSpecialOmitInlineParams, +}; +AstNode *ast_clone_subtree_special(AstNode *node, uint32_t *next_node_index, enum AstCloneSpecial special); + void ast_visit_node_children(AstNode *node, void (*visit)(AstNode **, void *context), void *context); #endif diff --git a/std/hash_map.zig b/std/hash_map.zig index f5147fdf2..6c4e44cac 100644 --- a/std/hash_map.zig +++ b/std/hash_map.zig @@ -7,7 +7,7 @@ const want_modification_safety = !@compile_var("is_release"); const debug_u32 = if (want_modification_safety) u32 else void; /* -pub fn HashMap(K: type, V: type, hash: fn(key: K)->u32, eql: fn(a: K, b: K)->bool) { +pub inline fn HashMap(inline K: type, inline V: type, inline hash: fn(key: K)->u32, inline eql: fn(a: K, b: K)->bool) { SmallHashMap(K, V, hash, eql, 8); } */ @@ -70,7 +70,7 @@ pub struct SmallHashMap(K: type, V: type, hash: fn(key: K)->u32, eql: fn(a: K, b pub fn deinit(hm: &Self) { if (hm.entries.ptr != &hm.prealloc_entries[0]) { - hm.allocator.free(hm.allocator, ([]u8)(hm.entries)); + hm.allocator.free(Entry, hm.entries); } } @@ -103,7 +103,7 @@ pub struct SmallHashMap(K: type, V: type, hash: fn(key: K)->u32, eql: fn(a: K, b } } if (old_entries.ptr != &hm.prealloc_entries[0]) { - hm.allocator.free(hm.allocator, ([]u8)(old_entries)); + hm.allocator.free(Entry, old_entries); } } @@ -152,7 +152,7 @@ pub struct SmallHashMap(K: type, V: type, hash: fn(key: K)->u32, eql: fn(a: K, b } fn init_capacity(hm: &Self, capacity: isize) -> %void { - hm.entries = ([]Entry)(%return hm.allocator.alloc(hm.allocator, capacity * @sizeof(Entry))); + hm.entries = %return hm.allocator.alloc(Entry, capacity); hm.size = 0; hm.max_distance_from_start_index = 0; for (hm.entries) |*entry| { @@ -180,7 +180,7 @@ pub struct SmallHashMap(K: type, V: type, hash: fn(key: K)->u32, eql: fn(a: K, b if (entry.distance_from_start_index < distance_from_start_index) { // robin hood to the rescue const tmp = *entry; - hm.max_distance_from_start_index = math.max(isize)( + hm.max_distance_from_start_index = math.max(isize, hm.max_distance_from_start_index, distance_from_start_index); *entry = Entry { .used = true, @@ -201,7 +201,8 @@ pub struct SmallHashMap(K: type, V: type, hash: fn(key: K)->u32, eql: fn(a: K, b hm.size += 1; } - hm.max_distance_from_start_index = math.max(isize)(distance_from_start_index, hm.max_distance_from_start_index); + hm.max_distance_from_start_index = math.max(isize, distance_from_start_index, + hm.max_distance_from_start_index); *entry = Entry { .used = true, .distance_from_start_index = distance_from_start_index, @@ -231,9 +232,9 @@ pub struct SmallHashMap(K: type, V: type, hash: fn(key: K)->u32, eql: fn(a: K, b } var global_allocator = Allocator { - .alloc = global_alloc, - .realloc = global_realloc, - .free = global_free, + .alloc_fn = global_alloc, + .realloc_fn = global_realloc, + .free_fn = global_free, .context = null, }; diff --git a/std/io.zig b/std/io.zig index 432600ee0..d64514ac3 100644 --- a/std/io.zig +++ b/std/io.zig @@ -69,7 +69,7 @@ pub struct OutStream { const dest_space_left = os.buffer.len - os.index; while (src_bytes_left > 0) { - const copy_amt = math.min(isize)(dest_space_left, src_bytes_left); + const copy_amt = math.min(isize, dest_space_left, src_bytes_left); @memcpy(&os.buffer[os.index], &bytes[src_index], copy_amt); os.index += copy_amt; if (os.index == os.buffer.len) { @@ -208,59 +208,47 @@ pub struct InStream { } } -pub error InvalidChar; -pub error Overflow; - -pub fn parse_unsigned(T: type)(buf: []u8, radix: u8) -> %T { +pub fn parse_unsigned(inline T: type, buf: []u8, radix: u8) -> %T { var x: T = 0; for (buf) |c| { - const digit = char_to_digit(c); - - if (digit >= radix) { - return error.InvalidChar; - } - - // x *= radix - if (@mul_with_overflow(T, x, radix, &x)) { - return error.Overflow; - } - - // x += digit - if (@add_with_overflow(T, x, digit, &x)) { - return error.Overflow; - } + const digit = %return char_to_digit(c, radix); + x = %return math.mul_overflow(T, x, radix); + x = %return math.add_overflow(T, x, digit); } return x; } -fn char_to_digit(c: u8) -> u8 { - // TODO use switch with range - if ('0' <= c && c <= '9') { +pub error InvalidChar; +fn char_to_digit(c: u8, radix: u8) -> %u8 { + const value = if ('0' <= c && c <= '9') { c - '0' } else if ('A' <= c && c <= 'Z') { c - 'A' + 10 } else if ('a' <= c && c <= 'z') { c - 'a' + 10 } else { - @max_value(u8) - } + return error.InvalidChar; + }; + return if (value >= radix) error.InvalidChar else value; } -pub fn buf_print_signed(T: type)(out_buf: []u8, x: T) -> isize { +pub fn buf_print_signed(inline T: type, out_buf: []u8, x: T) -> isize { const uint = @int_type(false, T.bit_count, false); if (x < 0) { out_buf[0] = '-'; - return 1 + buf_print_unsigned(uint)(out_buf[1...], uint(-(x + 1)) + 1); + return 1 + buf_print_unsigned(uint, out_buf[1...], uint(-(x + 1)) + 1); } else { - return buf_print_unsigned(uint)(out_buf, uint(x)); + return buf_print_unsigned(uint, out_buf, uint(x)); } } -pub const buf_print_i64 = buf_print_signed(i64); +pub fn buf_print_i64(out_buf: []u8, x: i64) -> isize { + buf_print_signed(i64, out_buf, x) +} -pub fn buf_print_unsigned(T: type)(out_buf: []u8, x: T) -> isize { +pub fn buf_print_unsigned(inline T: type, out_buf: []u8, x: T) -> isize { var buf: [max_u64_base10_digits]u8 = undefined; var a = x; var index: isize = buf.len; @@ -281,7 +269,9 @@ pub fn buf_print_unsigned(T: type)(out_buf: []u8, x: T) -> isize { return len; } -pub const buf_print_u64 = buf_print_unsigned(u64); +pub fn buf_print_u64(out_buf: []u8, x: u64) -> isize { + buf_print_unsigned(u64, out_buf, x) +} pub fn buf_print_f64(out_buf: []u8, x: f64, decimals: isize) -> isize { const numExpBits = 11; @@ -409,7 +399,7 @@ pub fn buf_print_f64(out_buf: []u8, x: f64, decimals: isize) -> isize { #attribute("test") fn parse_u64_digit_too_big() { - parse_unsigned(u64)("123a", 10) %% |err| { + parse_unsigned(u64, "123a", 10) %% |err| { if (err == error.InvalidChar) return; unreachable{}; }; diff --git a/std/list.zig b/std/list.zig index 31300a071..14b72e6a5 100644 --- a/std/list.zig +++ b/std/list.zig @@ -2,59 +2,58 @@ const assert = @import("debug.zig").assert; const mem = @import("mem.zig"); const Allocator = mem.Allocator; -/* -pub fn List(T: type) -> type { +pub inline fn List(inline T: type) -> type { SmallList(T, 8) } -*/ pub struct SmallList(T: type, STATIC_SIZE: isize) { + const Self = SmallList(T, STATIC_SIZE); + items: []T, length: isize, prealloc_items: [STATIC_SIZE]T, allocator: &Allocator, - pub fn init(l: &SmallList(T, STATIC_SIZE), allocator: &Allocator) { + pub fn init(l: &Self, allocator: &Allocator) { l.items = l.prealloc_items[0...]; l.length = 0; l.allocator = allocator; } - pub fn deinit(l: &SmallList(T, STATIC_SIZE)) { + pub fn deinit(l: &Self) { if (l.items.ptr != &l.prealloc_items[0]) { - l.allocator.free(l.allocator, ([]u8)(l.items)); + l.allocator.free(T, l.items); } } - pub fn append(l: &SmallList(T, STATIC_SIZE), item: T) -> %void { + pub fn append(l: &Self, item: T) -> %void { const new_length = l.length + 1; %return l.ensure_capacity(new_length); l.items[l.length] = item; l.length = new_length; } - pub fn ensure_capacity(l: &SmallList(T, STATIC_SIZE), new_capacity: isize) -> %void { + pub fn ensure_capacity(l: &Self, new_capacity: isize) -> %void { const old_capacity = l.items.len; var better_capacity = old_capacity; while (better_capacity < new_capacity) { better_capacity *= 2; } if (better_capacity != old_capacity) { - const alloc_bytes = better_capacity * @sizeof(T); if (l.items.ptr == &l.prealloc_items[0]) { - l.items = ([]T)(%return l.allocator.alloc(l.allocator, alloc_bytes)); - @memcpy(l.items.ptr, &l.prealloc_items[0], old_capacity * @sizeof(T)); + l.items = %return l.allocator.alloc(T, better_capacity); + mem.copy(T, l.items, l.prealloc_items[0...old_capacity]); } else { - l.items = ([]T)(%return l.allocator.realloc(l.allocator, ([]u8)(l.items), alloc_bytes)); + l.items = %return l.allocator.realloc(T, l.items, better_capacity); } } } } var global_allocator = Allocator { - .alloc = global_alloc, - .realloc = global_realloc, - .free = global_free, + .alloc_fn = global_alloc, + .realloc_fn = global_realloc, + .free_fn = global_free, .context = null, }; diff --git a/std/math.zig b/std/math.zig index 39f982b36..ee208e23d 100644 --- a/std/math.zig +++ b/std/math.zig @@ -26,10 +26,24 @@ pub fn f64_is_inf(f: f64) -> bool { f == f64_get_neg_inf() || f == f64_get_pos_inf() } -pub fn min(T: type)(x: T, y: T) -> T { +pub fn min(inline T: type, x: T, y: T) -> T { if (x < y) x else y } -pub fn max(T: type)(x: T, y: T) -> T { +pub fn max(inline T: type, x: T, y: T) -> T { if (x > y) x else y } + +pub error Overflow; +pub fn mul_overflow(inline T: type, a: T, b: T) -> %T { + var answer: T = undefined; + if (@mul_with_overflow(T, a, b, &answer)) error.Overflow else answer +} +pub fn add_overflow(inline T: type, a: T, b: T) -> %T { + var answer: T = undefined; + if (@add_with_overflow(T, a, b, &answer)) error.Overflow else answer +} +pub fn sub_overflow(inline T: type, a: T, b: T) -> %T { + var answer: T = undefined; + if (@sub_with_overflow(T, a, b, &answer)) error.Overflow else answer +} diff --git a/std/mem.zig b/std/mem.zig index d2a1ab049..4f837df25 100644 --- a/std/mem.zig +++ b/std/mem.zig @@ -1,18 +1,46 @@ const assert = @import("debug.zig").assert; +const math = @import("math.zig"); +const os = @import("os.zig"); +const io = @import("io.zig"); pub error NoMem; pub type Context = u8; pub struct Allocator { - alloc: fn (self: &Allocator, n: isize) -> %[]u8, - realloc: fn (self: &Allocator, old_mem: []u8, new_size: isize) -> %[]u8, - free: fn (self: &Allocator, mem: []u8), + alloc_fn: fn (self: &Allocator, n: isize) -> %[]u8, + realloc_fn: fn (self: &Allocator, old_mem: []u8, new_size: isize) -> %[]u8, + free_fn: fn (self: &Allocator, mem: []u8), context: ?&Context, + + /// Aborts the program if an allocation fails. + fn checked_alloc(self: &Allocator, inline T: type, n: isize) -> []T { + alloc(self, T, n) %% |err| { + // TODO var args printf + %%io.stderr.write("allocation failure: "); + %%io.stderr.write(@err_name(err)); + %%io.stderr.printf("\n"); + os.abort() + } + } + + fn alloc(self: &Allocator, inline T: type, n: isize) -> %[]T { + const byte_count = %return math.mul_overflow(isize, @sizeof(T), n); + ([]T)(%return self.alloc_fn(self, byte_count)) + } + + fn realloc(self: &Allocator, inline T: type, old_mem: []T, n: isize) -> %[]T { + const byte_count = %return math.mul_overflow(isize, @sizeof(T), n); + ([]T)(%return self.realloc_fn(self, ([]u8)(old_mem), byte_count)) + } + + fn free(self: &Allocator, inline T: type, mem: []T) { + self.free_fn(self, ([]u8)(mem)); + } } /// Copy all of source into dest at position 0. /// dest.len must be >= source.len. -pub fn copy(T)(dest: []T, source: []T) { +pub fn copy(inline T: type, dest: []T, source: []T) { assert(dest.len >= source.len); @memcpy(dest.ptr, source.ptr, @sizeof(T) * source.len); } diff --git a/std/net.zig b/std/net.zig index f98b6fbad..e126f0b8f 100644 --- a/std/net.zig +++ b/std/net.zig @@ -99,14 +99,14 @@ pub fn connect_addr(addr: &Address, port: u16) -> %Connection { const connect_ret = if (addr.family == linux.AF_INET) { var os_addr: linux.sockaddr_in = undefined; os_addr.family = addr.family; - os_addr.port = host_to_be(u16)(port); + os_addr.port = swap_if_little_endian(u16, port); @memcpy((&u8)(&os_addr.addr), &addr.addr[0], 4); @memset(&os_addr.zero, 0, @sizeof(@typeof(os_addr.zero))); linux.connect(socket_fd, (&linux.sockaddr)(&os_addr), @sizeof(linux.sockaddr_in)) } else if (addr.family == linux.AF_INET6) { var os_addr: linux.sockaddr_in6 = undefined; os_addr.family = addr.family; - os_addr.port = host_to_be(u16)(port); + os_addr.port = swap_if_little_endian(u16, port); os_addr.flowinfo = 0; os_addr.scope_id = addr.scope_id; @memcpy(&os_addr.addr[0], &addr.addr[0], 16); @@ -319,7 +319,7 @@ fn parse_ip4(buf: []const u8) -> %u32 { #attribute("test") fn test_parse_ip4() { - assert(%%parse_ip4("127.0.0.1") == be_to_host(u32)(0x7f000001)); + assert(%%parse_ip4("127.0.0.1") == swap_if_little_endian(u32, 0x7f000001)); switch (parse_ip4("256.0.0.1")) { Overflow => {}, else => unreachable {}, } switch (parse_ip4("x.0.0.1")) { InvalidChar => {}, else => unreachable {}, } switch (parse_ip4("127.0.0.1.1")) { JunkAtEnd => {}, else => unreachable {}, } @@ -352,12 +352,11 @@ fn test_lookup_simple_ip() { } } -const be_to_host = host_to_be; -fn host_to_be(T: type)(x: T) -> T { - if (@compile_var("is_big_endian")) x else endian_swap(T)(x) +fn swap_if_little_endian(inline T: type, x: T) -> T { + if (@compile_var("is_big_endian")) x else endian_swap(T, x) } -fn endian_swap(T: type)(x: T) -> T { +fn endian_swap(inline T: type, x: T) -> T { const x_slice = ([]u8)((&const x)[0...1]); var result: T = undefined; const result_slice = ([]u8)((&result)[0...1]); diff --git a/std/str.zig b/std/str.zig index 06879e9be..c5b0afbb6 100644 --- a/std/str.zig +++ b/std/str.zig @@ -1,8 +1,10 @@ const assert = @import("debug.zig").assert; -pub const eql = slice_eql(u8); +pub fn eql(a: []const u8, b: []const u8) -> bool { + slice_eql(u8, a, b) +} -pub fn slice_eql(T: type)(a: []const T, b: []const T) -> bool { +pub fn slice_eql(inline T: type, a: []const T, b: []const T) -> bool { if (a.len != b.len) return false; for (a) |item, index| { if (b[index] != item) return false; diff --git a/std/test_runner.zig b/std/test_runner.zig index 9533aeadb..5b5599db0 100644 --- a/std/test_runner.zig +++ b/std/test_runner.zig @@ -9,6 +9,7 @@ extern var zig_test_fn_list: []TestFn; pub fn run_tests() -> %void { for (zig_test_fn_list) |test_fn, i| { + // TODO: print var args %%io.stderr.write("Test "); %%io.stderr.print_i64(i + 1); %%io.stderr.write("/"); diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 16b555f4a..096c76809 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1181,11 +1181,11 @@ const invalid = foo > foo; )SOURCE", 1, ".tmp_source.zig:3:21: error: operator not allowed for type 'fn()'"); add_compile_fail_case("generic function instance with non-constant expression", R"SOURCE( -fn foo(x: i32)(y: i32) -> i32 { return x + y; } +fn foo(inline x: i32, y: i32) -> i32 { return x + y; } fn test1(a: i32, b: i32) -> i32 { - return foo(a)(b); + return foo(a, b); } - )SOURCE", 1, ".tmp_source.zig:4:16: error: unable to evaluate constant expression"); + )SOURCE", 1, ".tmp_source.zig:4:16: error: unable to evaluate constant expression for inline parameter"); add_compile_fail_case("goto jumping into block", R"SOURCE( fn f() { @@ -1406,6 +1406,27 @@ fn f() { } )SOURCE", 1, ".tmp_source.zig:3:13: error: unable to evaluate constant expression"); + add_compile_fail_case("export function with inline parameter", R"SOURCE( +export fn foo(inline x: i32, y: i32) -> i32{ + x + y +} + )SOURCE", 1, ".tmp_source.zig:2:15: error: inline parameter not allowed in extern function"); + + add_compile_fail_case("extern function with inline parameter", R"SOURCE( +extern fn foo(inline x: i32, y: i32) -> i32; +fn f() -> i32 { + foo(1, 2) +} + )SOURCE", 1, ".tmp_source.zig:2:15: error: inline parameter not allowed in extern function"); + + /* TODO + add_compile_fail_case("inline export function", R"SOURCE( +export inline fn foo(x: i32, y: i32) -> i32{ + x + y +} + )SOURCE", 1, ".tmp_source.zig:2:1: error: extern functions cannot be inline"); + */ + } ////////////////////////////////////////////////////////////////////////////// diff --git a/test/self_hosted.zig b/test/self_hosted.zig index 8ce198a9a..85eeb8bab 100644 --- a/test/self_hosted.zig +++ b/test/self_hosted.zig @@ -712,17 +712,17 @@ three)"; #attribute("test") fn simple_generic_fn() { - assert(max(i32)(3, -1) == 3); - assert(max(f32)(0.123, 0.456) == 0.456); - assert(add(2)(3) == 5); + assert(max(i32, 3, -1) == 3); + assert(max(f32, 0.123, 0.456) == 0.456); + assert(add(2, 3) == 5); } -fn max(T: type)(a: T, b: T) -> T { +fn max(inline T: type, a: T, b: T) -> T { return if (a > b) a else b; } -fn add(a: i32)(b: i32) -> i32 { - return a + b; +fn add(inline a: i32, b: i32) -> i32 { + return @const_eval(a) + b; } @@ -734,23 +734,18 @@ fn constant_equal_function_pointers() { fn empty_fn() {} -#attribute("test") -fn generic_function_equality() { - assert(max(i32) == max(i32)); -} - #attribute("test") fn generic_malloc_free() { - const a = %%mem_alloc(u8)(10); - mem_free(u8)(a); + const a = %%mem_alloc(u8, 10); + mem_free(u8, a); } const some_mem : [100]u8 = undefined; #static_eval_enable(false) -fn mem_alloc(T: type)(n: isize) -> %[]T { +fn mem_alloc(inline T: type, n: isize) -> %[]T { return (&T)(&some_mem[0])[0...n]; } -fn mem_free(T: type)(mem: []T) { } +fn mem_free(inline T: type, mem: []T) { } #attribute("test") @@ -982,11 +977,11 @@ pub fn vec3(x: f32, y: f32, z: f32) -> Vec3 { #attribute("test") fn generic_fn_with_implicit_cast() { - assert(get_first_byte(u8)([]u8 {13}) == 13); - assert(get_first_byte(u16)([]u16 {0, 13}) == 0); + assert(get_first_byte(u8, []u8 {13}) == 13); + assert(get_first_byte(u16, []u16 {0, 13}) == 0); } fn get_byte(ptr: ?&u8) -> u8 {*??ptr} -fn get_first_byte(T: type)(mem: []T) -> u8 { +fn get_first_byte(inline T: type, mem: []T) -> u8 { get_byte((&u8)(&mem[0])) } @@ -1651,9 +1646,9 @@ struct GenericDataThing(count: isize) { #attribute("test") fn use_generic_param_in_generic_param() { - assert(a_generic_fn(i32, 3)(4) == 7); + assert(a_generic_fn(i32, 3, 4) == 7); } -fn a_generic_fn(T: type, a: T)(b: T) -> T { +fn a_generic_fn(inline T: type, inline a: T, b: T) -> T { return a + b; }