diff --git a/src/codegen.cpp b/src/codegen.cpp index cbaf974cd..850c40699 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -677,9 +677,14 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) { assert(instruction->value.special != ConstValSpecialRuntime); assert(instruction->value.type); render_const_val(g, &instruction->value); + // we might have to do some pointer casting here due to the way union + // values are rendered with a type other than the one we expect if (handle_is_ptr(instruction->value.type)) { render_const_val_global(g, &instruction->value); - instruction->llvm_value = instruction->value.llvm_global; + TypeTableEntry *ptr_type = get_pointer_to_type(g, instruction->value.type, true); + instruction->llvm_value = LLVMBuildBitCast(g->builder, instruction->value.llvm_global, ptr_type->type_ref, ""); + } else if (instruction->value.type->id == TypeTableEntryIdPointer) { + instruction->llvm_value = LLVMBuildBitCast(g->builder, instruction->value.llvm_value, instruction->value.type->type_ref, ""); } else { instruction->llvm_value = instruction->value.llvm_value; } @@ -2540,7 +2545,7 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { tag_value, union_value, }; - return LLVMConstNamedStruct(canon_type->type_ref, fields, 2); + return LLVMConstStruct(fields, 2, false); } } case TypeTableEntryIdFn: @@ -2553,11 +2558,7 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { render_const_val(g, const_val->data.x_ptr.base_ptr); render_const_val_global(g, const_val->data.x_ptr.base_ptr); ConstExprValue *other_val = const_val->data.x_ptr.base_ptr; - if (other_val->type == const_val->type->data.pointer.child_type) { - const_val->llvm_value = other_val->llvm_global; - } else { - const_val->llvm_value = LLVMConstBitCast(other_val->llvm_global, const_val->type->type_ref); - } + const_val->llvm_value = LLVMConstBitCast(other_val->llvm_global, const_val->type->type_ref); render_const_val_global(g, const_val); return const_val->llvm_value; } else { @@ -2636,7 +2637,8 @@ static void render_const_val(CodeGen *g, ConstExprValue *const_val) { static void render_const_val_global(CodeGen *g, ConstExprValue *const_val) { if (!const_val->llvm_global) { - LLVMValueRef global_value = LLVMAddGlobal(g->module, const_val->type->type_ref, ""); + LLVMTypeRef type_ref = const_val->llvm_value ? LLVMTypeOf(const_val->llvm_value) : const_val->type->type_ref; + LLVMValueRef global_value = LLVMAddGlobal(g->module, type_ref, ""); LLVMSetLinkage(global_value, LLVMInternalLinkage); LLVMSetGlobalConstant(global_value, true); LLVMSetUnnamedAddr(global_value, true); diff --git a/src/ir.cpp b/src/ir.cpp index 50eae17eb..cd72d029d 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -9565,7 +9565,20 @@ static TypeTableEntry *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstr } if (instr_is_comptime(target_value_ptr)) { - zig_panic("TODO comptime switch var"); + ConstExprValue *target_val_ptr = ir_resolve_const(ira, target_value_ptr, UndefBad); + if (!target_value_ptr) + return ira->codegen->builtin_types.entry_invalid; + + ConstExprValue *pointee_val = const_ptr_pointee(target_val_ptr); + if (pointee_val->type->id == TypeTableEntryIdEnum) { + bool depends_on_compile_var = target_value_ptr->value.depends_on_compile_var; + ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base, depends_on_compile_var); + out_val->data.x_ptr.base_ptr = pointee_val->data.x_enum.payload; + out_val->data.x_ptr.index = SIZE_MAX; + return get_pointer_to_type(ira->codegen, pointee_val->type, target_value_ptr->value.type->data.pointer.is_const); + } else { + zig_panic("TODO comptime switch var"); + } } ir_build_enum_field_ptr_from(&ira->new_irb, &instruction->base, target_value_ptr, field); diff --git a/test/cases/switch.zig b/test/cases/switch.zig index 801ce52b1..edb42c871 100644 --- a/test/cases/switch.zig +++ b/test/cases/switch.zig @@ -129,3 +129,25 @@ fn switchWithMultipleExpressions() { fn returnsFive() -> i32 { 5 } + + +const Number = enum { + One: u64, + Two: u8, + Three: f32, +}; + +const number = Number.Three { 1.23 }; + +fn returnsFalse() -> bool { + switch (number) { + Number.One => |x| return x > 1234, + Number.Two => |x| return x == 'a', + Number.Three => |x| return x > 12.34, + } +} +fn switchOnConstEnumWithVar() { + @setFnTest(this); + + assert(!returnsFalse()); +}