diff --git a/src/all_types.hpp b/src/all_types.hpp index 9c9631a05..3bd51002c 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -691,6 +691,7 @@ struct AstNodeSymbolExpr { // set this to instead of analyzing the node, pretend it's a type entry and it's this one. TypeTableEntry *override_type_entry; TypeEnumField *enum_field; + uint32_t err_value; }; struct AstNodeBoolLiteral { diff --git a/src/analyze.cpp b/src/analyze.cpp index 46f05fb64..737fff0a7 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -5116,8 +5116,11 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import, int *field_use_counts = nullptr; + HashMap err_use_nodes; if (expr_type->id == TypeTableEntryIdEnum) { field_use_counts = allocate(expr_type->data.enumeration.field_count); + } else if (expr_type->id == TypeTableEntryIdErrorUnion) { + err_use_nodes.init(10); } int *const_chosen_prong_index = &node->data.switch_expr.const_chosen_prong_index; @@ -5186,8 +5189,54 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import, add_node_error(g, item_node, buf_sprintf("expected enum tag name")); any_errors = true; } + } else if (expr_type->id == TypeTableEntryIdErrorUnion) { + if (item_node->type == NodeTypeSymbol) { + Buf *err_name = &item_node->data.symbol_expr.symbol; + bool is_ok_case = buf_eql_str(err_name, "Ok"); + auto err_table_entry = is_ok_case ? nullptr: g->error_table.maybe_get(err_name); + if (is_ok_case || err_table_entry) { + uint32_t err_value = is_ok_case ? 0 : err_table_entry->value->value; + item_node->data.symbol_expr.err_value = err_value; + TypeTableEntry *this_var_type; + if (is_ok_case) { + this_var_type = expr_type->data.error.child_type; + } else { + this_var_type = g->builtin_types.entry_pure_error; + } + if (!var_type) { + var_type = this_var_type; + } + if (this_var_type != var_type) { + all_agree_on_var_type = false; + } + + // detect duplicate switch values + auto existing_entry = err_use_nodes.maybe_get(err_value); + if (existing_entry) { + add_node_error(g, existing_entry->value, + buf_sprintf("duplicate switch value: '%s'", buf_ptr(err_name))); + any_errors = true; + } else { + err_use_nodes.put(err_value, item_node); + } + + if (!any_errors && expr_val->ok) { + if (expr_val->data.x_err.err->value == err_value) { + *const_chosen_prong_index = prong_i; + } + } + } else { + add_node_error(g, item_node, + buf_sprintf("use of undeclared error value '%s'", buf_ptr(err_name))); + any_errors = true; + } + } else { + add_node_error(g, item_node, buf_sprintf("expected error value name")); + any_errors = true; + } } else { if (!any_errors && expr_val->ok) { + // note: there is now a function in eval.cpp for doing const expr comparison zig_panic("TODO determine if const exprs are equal"); } TypeTableEntry *item_type = analyze_expression(g, import, context, expr_type, item_node); @@ -5252,17 +5301,25 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import, return g->builtin_types.entry_invalid; } + TypeTableEntry *result_type = resolve_peer_type_compatibility(g, import, context, node, + peer_nodes, peer_types, prong_count); + if (expr_val->ok) { assert(*const_chosen_prong_index != -1); *const_val = get_resolved_expr(peer_nodes[*const_chosen_prong_index])->const_val; - // the target expr depends on a compile var, - // so the entire if statement does too + // the target expr depends on a compile var because we have an error on unnecessary + // switch statement, so the entire switch statement does too const_val->depends_on_compile_var = true; + + if (!const_val->ok) { + return add_error_if_type_is_num_lit(g, result_type, node); + } + } else { + return add_error_if_type_is_num_lit(g, result_type, node); } - - return resolve_peer_type_compatibility(g, import, context, node, peer_nodes, peer_types, prong_count); + return result_type; } static TypeTableEntry *analyze_return_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, diff --git a/src/codegen.cpp b/src/codegen.cpp index e15553e2d..aa3b4d152 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2627,12 +2627,6 @@ static LLVMValueRef gen_symbol(CodeGen *g, AstNode *node) { } zig_unreachable(); - - /* TODO delete - FnTableEntry *fn_entry = node->data.symbol_expr.fn_entry; - assert(fn_entry); - return fn_entry->fn_value; - */ } static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) { @@ -2653,6 +2647,10 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) { add_debug_source_node(g, node); LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, target_value_handle, 0, ""); target_value = LLVMBuildLoad(g->builder, tag_field_ptr, ""); + } else if (target_type->id == TypeTableEntryIdErrorUnion) { + add_debug_source_node(g, node); + LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, target_value_handle, 0, ""); + target_value = LLVMBuildLoad(g->builder, tag_field_ptr, ""); } else { zig_unreachable(); } @@ -2696,12 +2694,23 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) { assert(item_node->type != NodeTypeSwitchRange); LLVMValueRef val; - if (target_type->id == TypeTableEntryIdEnum) { + if (target_type->id == TypeTableEntryIdEnum || + target_type->id == TypeTableEntryIdErrorUnion) + { assert(item_node->type == NodeTypeSymbol); - TypeEnumField *enum_field = item_node->data.symbol_expr.enum_field; - assert(enum_field); - val = LLVMConstInt(target_type->data.enumeration.tag_type->type_ref, - enum_field->value, false); + TypeEnumField *enum_field = nullptr; + uint32_t err_value = 0; + if (target_type->id == TypeTableEntryIdEnum) { + enum_field = item_node->data.symbol_expr.enum_field; + assert(enum_field); + val = LLVMConstInt(target_type->data.enumeration.tag_type->type_ref, + enum_field->value, false); + } else if (target_type->id == TypeTableEntryIdErrorUnion) { + err_value = item_node->data.symbol_expr.err_value; + val = LLVMConstInt(g->err_tag_type->type_ref, err_value, false); + } else { + zig_unreachable(); + } if (prong_var && type_has_bits(prong_var->type)) { LLVMBasicBlockRef item_block; @@ -2721,6 +2730,7 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) { gen_assign_raw(g, var_node, BinOpTypeAssign, prong_var->value_ref, target_value, prong_var->type, target_type); } else if (target_type->id == TypeTableEntryIdEnum) { + assert(enum_field); assert(type_has_bits(enum_field->type_entry)); LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, target_value_handle, 1, ""); @@ -2731,6 +2741,25 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) { gen_assign_raw(g, var_node, BinOpTypeAssign, prong_var->value_ref, handle_val, prong_var->type, enum_field->type_entry); + } else if (target_type->id == TypeTableEntryIdErrorUnion) { + if (err_value == 0) { + // variable is the payload + LLVMValueRef err_payload_ptr = LLVMBuildStructGEP(g->builder, + target_value_handle, 1, ""); + LLVMValueRef handle_val = get_handle_value(g, var_node, + err_payload_ptr, prong_var->type); + gen_assign_raw(g, var_node, BinOpTypeAssign, + prong_var->value_ref, handle_val, prong_var->type, prong_var->type); + } else { + // variable is the pure error value + LLVMValueRef err_tag_ptr = LLVMBuildStructGEP(g->builder, + target_value_handle, 0, ""); + LLVMValueRef handle_val = LLVMBuildLoad(g->builder, err_tag_ptr, ""); + gen_assign_raw(g, var_node, BinOpTypeAssign, + prong_var->value_ref, handle_val, prong_var->type, g->err_tag_type); + } + } else { + zig_unreachable(); } if (make_item_blocks) { LLVMBuildBr(g->builder, prong_block); diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 09b86dc7a..b0d3006c8 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1233,6 +1233,18 @@ fn bad_eql_2(a: EnumWithData, b: EnumWithData) -> bool { )SOURCE", 2, ".tmp_source.zig:3:7: error: operator not allowed for type '[]u8'", ".tmp_source.zig:10:7: error: operator not allowed for type 'EnumWithData'"); + + add_compile_fail_case("non-const switch number literal", R"SOURCE( +fn foo() { + const x = switch (bar()) { + 1, 2 => 1, + 3, 4 => 2, + else => 3, + }; +} +#static_eval_enable(false) +fn bar() -> i32 { 2 } + )SOURCE", 1, ".tmp_source.zig:3:15: error: unable to infer expression type"); } ////////////////////////////////////////////////////////////////////////////// diff --git a/test/self_hosted.zig b/test/self_hosted.zig index 610e9c755..413969479 100644 --- a/test/self_hosted.zig +++ b/test/self_hosted.zig @@ -1339,3 +1339,32 @@ fn character_literals() { assert('\'' == single_quote); } const single_quote = '\''; + + +#attribute("test") +fn switch_with_multiple_expressions() { + const x: i32 = switch (returns_five()) { + 1, 2, 3 => 1, + 4, 5, 6 => 2, + else => 3, + }; + assert(x == 2); +} +#static_eval_enable(false) +fn returns_five() -> i32 { 5 } + + +#attribute("test") +fn switch_on_error_union() { + const x = switch (returns_ten()) { + Ok => |val| val + 1, + ItBroke, NoMem => 1, + CrappedOut => 2, + }; + assert(x == 11); +} +error ItBroke; +error NoMem; +error CrappedOut; +#static_eval_enable(false) +fn returns_ten() -> %i32 { 10 }