debug safety for unions
parent
f12d36641f
commit
e26ccd5166
|
@ -1317,6 +1317,7 @@ enum PanicMsgId {
|
|||
PanicMsgIdUnwrapMaybeFail,
|
||||
PanicMsgIdInvalidErrorCode,
|
||||
PanicMsgIdIncorrectAlignment,
|
||||
PanicMsgIdBadUnionField,
|
||||
|
||||
PanicMsgIdCount,
|
||||
};
|
||||
|
|
|
@ -810,6 +810,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
|
|||
return buf_create_from_str("invalid error code");
|
||||
case PanicMsgIdIncorrectAlignment:
|
||||
return buf_create_from_str("incorrect alignment");
|
||||
case PanicMsgIdBadUnionField:
|
||||
return buf_create_from_str("access of inactive union field");
|
||||
}
|
||||
zig_unreachable();
|
||||
}
|
||||
|
@ -2415,6 +2417,23 @@ static LLVMValueRef ir_render_union_field_ptr(CodeGen *g, IrExecutable *executab
|
|||
return bitcasted_union_field_ptr;
|
||||
}
|
||||
|
||||
if (ir_want_debug_safety(g, &instruction->base)) {
|
||||
LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, union_type->data.unionation.gen_tag_index, "");
|
||||
LLVMValueRef tag_value = gen_load_untyped(g, tag_field_ptr, 0, false, "");
|
||||
LLVMValueRef expected_tag_value = LLVMConstInt(union_type->data.unionation.tag_type->type_ref,
|
||||
field->value, false);
|
||||
|
||||
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnionCheckOk");
|
||||
LLVMBasicBlockRef bad_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnionCheckFail");
|
||||
LLVMValueRef ok_val = LLVMBuildICmp(g->builder, LLVMIntEQ, tag_value, expected_tag_value, "");
|
||||
LLVMBuildCondBr(g->builder, ok_val, ok_block, bad_block);
|
||||
|
||||
LLVMPositionBuilderAtEnd(g->builder, bad_block);
|
||||
gen_debug_safety_crash(g, PanicMsgIdBadUnionField);
|
||||
|
||||
LLVMPositionBuilderAtEnd(g->builder, ok_block);
|
||||
}
|
||||
|
||||
LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, union_type->data.unionation.gen_union_index, "");
|
||||
LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, "");
|
||||
return bitcasted_union_field_ptr;
|
||||
|
@ -3977,21 +3996,17 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) {
|
|||
|
||||
LLVMValueRef union_value_ref;
|
||||
{
|
||||
unsigned field_count;
|
||||
LLVMValueRef fields[2];
|
||||
fields[0] = correctly_typed_value;
|
||||
if (pad_bytes == 0) {
|
||||
field_count = 1;
|
||||
union_value_ref = correctly_typed_value;
|
||||
} else {
|
||||
LLVMValueRef fields[2];
|
||||
fields[0] = correctly_typed_value;
|
||||
fields[1] = LLVMGetUndef(LLVMArrayType(LLVMInt8Type(), (unsigned)pad_bytes));
|
||||
field_count = 2;
|
||||
}
|
||||
|
||||
if (make_unnamed_struct || type_entry->data.unionation.gen_tag_index != SIZE_MAX) {
|
||||
union_value_ref = LLVMConstStruct(fields, field_count, false);
|
||||
} else {
|
||||
union_value_ref = LLVMConstNamedStruct(union_type_ref, fields, field_count);
|
||||
if (make_unnamed_struct || type_entry->data.unionation.gen_tag_index != SIZE_MAX) {
|
||||
union_value_ref = LLVMConstStruct(fields, 2, false);
|
||||
} else {
|
||||
union_value_ref = LLVMConstNamedStruct(union_type_ref, fields, 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ const Foo = union {
|
|||
test "basic unions" {
|
||||
var foo = Foo { .int = 1 };
|
||||
assert(foo.int == 1);
|
||||
foo.float = 12.34;
|
||||
foo = Foo {.float = 12.34};
|
||||
assert(foo.float == 12.34);
|
||||
}
|
||||
|
||||
|
|
|
@ -260,4 +260,24 @@ pub fn addCases(cases: &tests.CompareOutputContext) {
|
|||
\\ return int_slice[0];
|
||||
\\}
|
||||
);
|
||||
|
||||
cases.addDebugSafety("bad union field access",
|
||||
\\pub fn panic(message: []const u8) -> noreturn {
|
||||
\\ @import("std").os.exit(126);
|
||||
\\}
|
||||
\\
|
||||
\\const Foo = union {
|
||||
\\ float: f32,
|
||||
\\ int: u32,
|
||||
\\};
|
||||
\\
|
||||
\\pub fn main() -> %void {
|
||||
\\ var f = Foo { .int = 42 };
|
||||
\\ bar(&f);
|
||||
\\}
|
||||
\\
|
||||
\\fn bar(f: &Foo) {
|
||||
\\ f.float = 12.34;
|
||||
\\}
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue