diff --git a/src/stage1/analyze.cpp b/src/stage1/analyze.cpp index c5a4a7aa2..e8958052b 100644 --- a/src/stage1/analyze.cpp +++ b/src/stage1/analyze.cpp @@ -6113,7 +6113,7 @@ ZigValue *get_the_one_possible_value(CodeGen *g, ZigType *type_entry) { TypeUnionField *only_field = &union_type->data.unionation.fields[0]; ZigType *field_type = resolve_union_field_type(g, only_field); assert(field_type); - bigint_init_unsigned(&result->data.x_union.tag, 0); + bigint_init_bigint(&result->data.x_union.tag, &only_field->enum_field->value); result->data.x_union.payload = g->pass1_arena->create(); copy_const_val(g, result->data.x_union.payload, get_the_one_possible_value(g, field_type)); @@ -6122,6 +6122,11 @@ ZigValue *get_the_one_possible_value(CodeGen *g, ZigType *type_entry) { result->data.x_ptr.mut = ConstPtrMutComptimeConst; result->data.x_ptr.special = ConstPtrSpecialRef; result->data.x_ptr.data.ref.pointee = get_the_one_possible_value(g, result->type->data.pointer.child_type); + } else if (result->type->id == ZigTypeIdEnum) { + ZigType *enum_type = result->type; + assert(enum_type->data.enumeration.src_field_count == 1); + TypeEnumField *only_field = &result->type->data.enumeration.fields[0]; + bigint_init_bigint(&result->data.x_enum_tag, &only_field->value); } g->one_possible_values.put(type_entry, result); return result; diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index f4cb8a9ae..8f42b8fbb 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -14186,6 +14186,18 @@ static ZigType *ir_resolve_union_tag_type(IrAnalyze *ira, AstNode *source_node, } } +static bool can_fold_enum_type(ZigType *ty) { + assert(ty->id == ZigTypeIdEnum); + // We can fold the enum type (and avoid any check, be it at runtime or at + // compile time) iff it has only a single element and its tag type is + // zero-sized. + ZigType *tag_int_type = ty->data.enumeration.tag_int_type; + return ty->data.enumeration.layout == ContainerLayoutAuto && + ty->data.enumeration.src_field_count == 1 && + !ty->data.enumeration.non_exhaustive && + (tag_int_type->id == ZigTypeIdInt && tag_int_type->data.integral.bit_count == 0); +} + static IrInstGen *ir_analyze_enum_to_int(IrAnalyze *ira, IrInst *source_instr, IrInstGen *target) { Error err; @@ -14214,10 +14226,7 @@ static IrInstGen *ir_analyze_enum_to_int(IrAnalyze *ira, IrInst *source_instr, I assert(tag_type->id == ZigTypeIdInt || tag_type->id == ZigTypeIdComptimeInt); // If there is only one possible tag, then we know at comptime what it is. - if (enum_type->data.enumeration.layout == ContainerLayoutAuto && - enum_type->data.enumeration.src_field_count == 1 && - !enum_type->data.enumeration.non_exhaustive) - { + if (can_fold_enum_type(enum_type)) { IrInstGen *result = ir_const(ira, source_instr, tag_type); init_const_bigint(result->value, tag_type, &enum_type->data.enumeration.fields[0].value); @@ -14255,10 +14264,7 @@ static IrInstGen *ir_analyze_union_to_tag(IrAnalyze *ira, IrInst* source_instr, } // If there is only 1 possible tag, then we know at comptime what it is. - if (wanted_type->data.enumeration.layout == ContainerLayoutAuto && - wanted_type->data.enumeration.src_field_count == 1 && - !wanted_type->data.enumeration.non_exhaustive) - { + if (can_fold_enum_type(wanted_type)) { IrInstGen *result = ir_const(ira, source_instr, wanted_type); result->value->special = ConstValSpecialStatic; result->value->type = wanted_type; @@ -24039,7 +24045,8 @@ static IrInstGen *ir_analyze_instruction_switch_target(IrAnalyze *ira, bigint_init_bigint(&result->value->data.x_enum_tag, &pointee_val->data.x_union.tag); return result; } - if (tag_type->data.enumeration.src_field_count == 1 && !tag_type->data.enumeration.non_exhaustive) { + + if (can_fold_enum_type(tag_type)) { IrInstGen *result = ir_const(ira, &switch_target_instruction->base.base, tag_type); TypeEnumField *only_field = &tag_type->data.enumeration.fields[0]; bigint_init_bigint(&result->value->data.x_enum_tag, &only_field->value); @@ -24054,7 +24061,8 @@ static IrInstGen *ir_analyze_instruction_switch_target(IrAnalyze *ira, case ZigTypeIdEnum: { if ((err = type_resolve(ira->codegen, target_type, ResolveStatusZeroBitsKnown))) return ira->codegen->invalid_inst_gen; - if (target_type->data.enumeration.src_field_count == 1 && !target_type->data.enumeration.non_exhaustive) { + + if (can_fold_enum_type(target_type)) { TypeEnumField *only_field = &target_type->data.enumeration.fields[0]; IrInstGen *result = ir_const(ira, &switch_target_instruction->base.base, target_type); bigint_init_bigint(&result->value->data.x_enum_tag, &only_field->value); @@ -24789,7 +24797,9 @@ static IrInstGen *ir_analyze_instruction_enum_tag_name(IrAnalyze *ira, IrInstSrc if (type_is_invalid(target->value->type)) return ira->codegen->invalid_inst_gen; - if (target->value->type->id == ZigTypeIdEnumLiteral) { + ZigType *target_type = target->value->type; + + if (target_type->id == ZigTypeIdEnumLiteral) { IrInstGen *result = ir_const(ira, &instruction->base.base, nullptr); Buf *field_name = target->value->data.x_enum_literal; ZigValue *array_val = create_const_str_lit(ira->codegen, field_name)->data.x_ptr.data.ref.pointee; @@ -24797,21 +24807,21 @@ static IrInstGen *ir_analyze_instruction_enum_tag_name(IrAnalyze *ira, IrInstSrc return result; } - if (target->value->type->id == ZigTypeIdUnion) { + if (target_type->id == ZigTypeIdUnion) { target = ir_analyze_union_tag(ira, &instruction->base.base, target, instruction->base.is_gen); if (type_is_invalid(target->value->type)) return ira->codegen->invalid_inst_gen; + target_type = target->value->type; } - if (target->value->type->id != ZigTypeIdEnum) { + if (target_type->id != ZigTypeIdEnum) { ir_add_error(ira, &target->base, - buf_sprintf("expected enum tag, found '%s'", buf_ptr(&target->value->type->name))); + buf_sprintf("expected enum tag, found '%s'", buf_ptr(&target_type->name))); return ira->codegen->invalid_inst_gen; } - if (target->value->type->data.enumeration.src_field_count == 1 && - !target->value->type->data.enumeration.non_exhaustive) { - TypeEnumField *only_field = &target->value->type->data.enumeration.fields[0]; + if (can_fold_enum_type(target_type)) { + TypeEnumField *only_field = &target_type->data.enumeration.fields[0]; ZigValue *array_val = create_const_str_lit(ira->codegen, only_field->name)->data.x_ptr.data.ref.pointee; IrInstGen *result = ir_const(ira, &instruction->base.base, nullptr); init_const_slice(ira->codegen, result->value, array_val, 0, buf_len(only_field->name), true); @@ -24819,9 +24829,9 @@ static IrInstGen *ir_analyze_instruction_enum_tag_name(IrAnalyze *ira, IrInstSrc } if (instr_is_comptime(target)) { - if ((err = type_resolve(ira->codegen, target->value->type, ResolveStatusZeroBitsKnown))) + if ((err = type_resolve(ira->codegen, target_type, ResolveStatusZeroBitsKnown))) return ira->codegen->invalid_inst_gen; - TypeEnumField *field = find_enum_field_by_tag(target->value->type, &target->value->data.x_bigint); + TypeEnumField *field = find_enum_field_by_tag(target_type, &target->value->data.x_bigint); if (field == nullptr) { Buf *int_buf = buf_alloc(); bigint_append_buf(int_buf, &target->value->data.x_bigint, 10); diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig index f4bcfa6f9..2ab728b58 100644 --- a/test/runtime_safety.zig +++ b/test/runtime_safety.zig @@ -1,6 +1,85 @@ const tests = @import("tests.zig"); pub fn addCases(cases: *tests.CompareOutputContext) void { + { + const check_panic_msg = + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ if (std.mem.eql(u8, message, "reached unreachable code")) { + \\ std.process.exit(126); // good + \\ } + \\ std.process.exit(0); // test failed + \\} + ; + + cases.addRuntimeSafety("switch on corrupted enum value", + \\const std = @import("std"); + ++ check_panic_msg ++ + \\const E = enum(u32) { + \\ X = 1, + \\}; + \\pub fn main() void { + \\ var e: E = undefined; + \\ @memset(@ptrCast([*]u8, &e), 0x55, @sizeOf(E)); + \\ switch (e) { + \\ .X => @breakpoint(), + \\ } + \\} + ); + + cases.addRuntimeSafety("switch on corrupted union value", + \\const std = @import("std"); + ++ check_panic_msg ++ + \\const U = union(enum(u32)) { + \\ X: u8, + \\}; + \\pub fn main() void { + \\ var u: U = undefined; + \\ @memset(@ptrCast([*]u8, &u), 0x55, @sizeOf(U)); + \\ switch (u) { + \\ .X => @breakpoint(), + \\ } + \\} + ); + } + + { + const check_panic_msg = + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ if (std.mem.eql(u8, message, "invalid enum value")) { + \\ std.process.exit(126); // good + \\ } + \\ std.process.exit(0); // test failed + \\} + ; + + cases.addRuntimeSafety("@tagName on corrupted enum value", + \\const std = @import("std"); + ++ check_panic_msg ++ + \\const E = enum(u32) { + \\ X = 1, + \\}; + \\pub fn main() void { + \\ var e: E = undefined; + \\ @memset(@ptrCast([*]u8, &e), 0x55, @sizeOf(E)); + \\ var n = @tagName(e); + \\} + ); + + cases.addRuntimeSafety("@tagName on corrupted union value", + \\const std = @import("std"); + ++ check_panic_msg ++ + \\const U = union(enum(u32)) { + \\ X: u8, + \\}; + \\pub fn main() void { + \\ var u: U = undefined; + \\ @memset(@ptrCast([*]u8, &u), 0x55, @sizeOf(U)); + \\ var t: @TagType(U) = u; + \\ var n = @tagName(t); + \\} + ); + } + { const check_panic_msg = \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {