From 6db9be8900bf43632c8a98d91c6a92f30b33500a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 9 Mar 2018 14:20:44 -0500 Subject: [PATCH] don't memoize comptime functions if they can mutate state via parameters closes #639 --- src/all_types.hpp | 9 ++++++++- src/analyze.cpp | 28 ++++++++++++++++++++++++++++ src/analyze.hpp | 1 + src/ir.cpp | 17 +++++++++++------ std/sort.zig | 3 +-- test/cases/eval.zig | 28 ++++++++++++++++++++++++++++ 6 files changed, 77 insertions(+), 9 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 3d732d4ac..c85a778b6 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1168,10 +1168,17 @@ struct TypeTableEntry { LLVMTypeRef type_ref; ZigLLVMDIType *di_type; - bool zero_bits; + bool zero_bits; // this is denormalized data bool is_copyable; bool gen_h_loop_flag; + // This is denormalized data. The simplest type that has this + // flag set to true is a mutable pointer. A const pointer has + // the same value for this flag as the child type. + // If a struct has any fields that have this flag true, then + // the flag is true for the struct. + bool can_mutate_state_through_it; + union { TypeTableEntryPointer pointer; TypeTableEntryInt integral; diff --git a/src/analyze.cpp b/src/analyze.cpp index 37b2798f3..e40412e86 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -398,6 +398,7 @@ TypeTableEntry *get_pointer_to_type_extra(CodeGen *g, TypeTableEntry *child_type TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdPointer); entry->is_copyable = true; + entry->can_mutate_state_through_it = is_const ? child_type->can_mutate_state_through_it : true; const char *const_str = is_const ? "const " : ""; const char *volatile_str = is_volatile ? "volatile " : ""; @@ -482,6 +483,7 @@ TypeTableEntry *get_maybe_type(CodeGen *g, TypeTableEntry *child_type) { assert(child_type->type_ref || child_type->zero_bits); assert(child_type->di_type); entry->is_copyable = type_is_copyable(g, child_type); + entry->can_mutate_state_through_it = child_type->can_mutate_state_through_it; buf_resize(&entry->name, 0); buf_appendf(&entry->name, "?%s", buf_ptr(&child_type->name)); @@ -572,6 +574,7 @@ TypeTableEntry *get_error_union_type(CodeGen *g, TypeTableEntry *err_set_type, T entry->is_copyable = true; assert(payload_type->di_type); ensure_complete_type(g, payload_type); + entry->can_mutate_state_through_it = payload_type->can_mutate_state_through_it; buf_resize(&entry->name, 0); buf_appendf(&entry->name, "%s!%s", buf_ptr(&err_set_type->name), buf_ptr(&payload_type->name)); @@ -730,6 +733,7 @@ TypeTableEntry *get_slice_type(CodeGen *g, TypeTableEntry *ptr_type) { TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdStruct); entry->is_copyable = true; + entry->can_mutate_state_through_it = ptr_type->can_mutate_state_through_it; // replace the & with [] to go from a ptr type name to a slice type name buf_resize(&entry->name, 0); @@ -1735,6 +1739,8 @@ TypeTableEntry *get_struct_type(CodeGen *g, const char *type_name, const char *f struct_type->data.structure.gen_field_count += 1; } else { field->gen_index = SIZE_MAX; + struct_type->can_mutate_state_through_it = struct_type->can_mutate_state_through_it || + field->type_entry->can_mutate_state_through_it; } auto prev_entry = struct_type->data.structure.fields_by_name.put_unique(field->name, field); @@ -2475,6 +2481,9 @@ static void resolve_struct_zero_bits(CodeGen *g, TypeTableEntry *struct_type) { if (!type_has_bits(field_type)) continue; + struct_type->can_mutate_state_through_it = struct_type->can_mutate_state_through_it || + field_type->can_mutate_state_through_it; + if (gen_field_index == 0) { if (struct_type->data.structure.layout == ContainerLayoutPacked) { struct_type->data.structure.abi_alignment = 1; @@ -2662,6 +2671,8 @@ static void resolve_union_zero_bits(CodeGen *g, TypeTableEntry *union_type) { } } union_field->type_entry = field_type; + union_type->can_mutate_state_through_it = union_type->can_mutate_state_through_it || + field_type->can_mutate_state_through_it; if (field_node->data.struct_field.value != nullptr && !decl_node->data.container_decl.auto_enum) { ErrorMsg *msg = add_node_error(g, field_node->data.struct_field.value, @@ -4565,6 +4576,23 @@ bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b) { return true; } +bool fn_eval_cacheable(Scope *scope) { + while (scope) { + if (scope->id == ScopeIdVarDecl) { + ScopeVarDecl *var_scope = (ScopeVarDecl *)scope; + if (var_scope->var->value->type->can_mutate_state_through_it) + return false; + } else if (scope->id == ScopeIdFnDef) { + return true; + } else { + zig_unreachable(); + } + + scope = scope->parent; + } + zig_unreachable(); +} + uint32_t fn_eval_hash(Scope* scope) { uint32_t result = 0; while (scope) { diff --git a/src/analyze.hpp b/src/analyze.hpp index 9ed24f4e4..bdc60b02b 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -195,5 +195,6 @@ TypeTableEntry *get_auto_err_set_type(CodeGen *g, FnTableEntry *fn_entry); uint32_t get_coro_frame_align_bytes(CodeGen *g); bool fn_type_can_fail(FnTypeId *fn_type_id); +bool fn_eval_cacheable(Scope *scope); #endif diff --git a/src/ir.cpp b/src/ir.cpp index 67caa84a9..36a6defd9 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -11830,12 +11830,15 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal return_type = specified_return_type; } - IrInstruction *result; + bool cacheable = fn_eval_cacheable(exec_scope); + IrInstruction *result = nullptr; + if (cacheable) { + auto entry = ira->codegen->memoized_fn_eval_table.maybe_get(exec_scope); + if (entry) + result = entry->value; + } - auto entry = ira->codegen->memoized_fn_eval_table.maybe_get(exec_scope); - if (entry) { - result = entry->value; - } else { + if (result == nullptr) { // Analyze the fn body block like any other constant expression. AstNode *body_node = fn_entry->body_node; result = ir_eval_const_value(ira->codegen, exec_scope, body_node, return_type, @@ -11859,7 +11862,9 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal } } - ira->codegen->memoized_fn_eval_table.put(exec_scope, result); + if (cacheable) { + ira->codegen->memoized_fn_eval_table.put(exec_scope, result); + } if (type_is_invalid(result->value.type)) return ira->codegen->builtin_types.entry_invalid; diff --git a/std/sort.zig b/std/sort.zig index a771cbdb4..c13e99fed 100644 --- a/std/sort.zig +++ b/std/sort.zig @@ -964,8 +964,7 @@ fn u8desc(lhs: &const u8, rhs: &const u8) bool { test "stable sort" { testStableSort(); - // TODO: uncomment this after https://github.com/zig-lang/zig/issues/639 - //comptime testStableSort(); + comptime testStableSort(); } fn testStableSort() void { var expected = []IdAndValue { diff --git a/test/cases/eval.zig b/test/cases/eval.zig index e5e826eff..58877f7e2 100644 --- a/test/cases/eval.zig +++ b/test/cases/eval.zig @@ -420,3 +420,31 @@ test "binary math operator in partially inlined function" { assert(s[2] == 0x90a0b0c); assert(s[3] == 0xd0e0f10); } + + +test "comptime function with the same args is memoized" { + comptime { + assert(MakeType(i32) == MakeType(i32)); + assert(MakeType(i32) != MakeType(f64)); + } +} + +fn MakeType(comptime T: type) type { + return struct { + field: T, + }; +} + +test "comptime function with mutable pointer is not memoized" { + comptime { + var x: i32 = 1; + const ptr = &x; + increment(ptr); + increment(ptr); + assert(x == 3); + } +} + +fn increment(value: &i32) void { + *value += 1; +}