From a59d31bd28f82e002f68ce25581c0437463137ed Mon Sep 17 00:00:00 2001 From: LemonBoy Date: Tue, 3 Mar 2020 21:46:30 +0100 Subject: [PATCH] ir: Support tuple multiplication --- src/ir.cpp | 122 ++++++++++++++++++++++++++++++--- test/stage1/behavior/tuple.zig | 31 ++++++++- 2 files changed, 141 insertions(+), 12 deletions(-) diff --git a/src/ir.cpp b/src/ir.cpp index 6fed044c6..5a96bc2d5 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -17351,14 +17351,15 @@ static IrInstGen *ir_analyze_tuple_cat(IrAnalyze *ira, IrInst* source_instr, ContainerKindStruct, source_instr->source_node, buf_ptr(name), bare_name, ContainerLayoutAuto); new_type->data.structure.special = StructSpecialInferredTuple; new_type->data.structure.resolve_status = ResolveStatusBeingInferred; - - IrInstGen *new_struct_ptr = ir_resolve_result(ira, source_instr, no_result_loc(), - new_type, nullptr, false, true); uint32_t new_field_count = op1_field_count + op2_field_count; new_type->data.structure.src_field_count = new_field_count; new_type->data.structure.fields = realloc_type_struct_fields(new_type->data.structure.fields, 0, new_field_count); + + IrInstGen *new_struct_ptr = ir_resolve_result(ira, source_instr, no_result_loc(), + new_type, nullptr, false, true); + for (uint32_t i = 0; i < new_field_count; i += 1) { TypeStructField *src_field; if (i < op1_field_count) { @@ -17422,8 +17423,10 @@ static IrInstGen *ir_analyze_tuple_cat(IrAnalyze *ira, IrInst* source_instr, ir_analyze_store_ptr(ira, &elem_result_loc->base, elem_result_loc, deref, true); } } - IrInstGen *result = ir_get_deref(ira, source_instr, new_struct_ptr, nullptr); - return result; + + const_ptrs.deinit(); + + return ir_get_deref(ira, source_instr, new_struct_ptr, nullptr); } static IrInstGen *ir_analyze_array_cat(IrAnalyze *ira, IrInstSrcBinOp *instruction) { @@ -17480,8 +17483,9 @@ static IrInstGen *ir_analyze_array_cat(IrAnalyze *ira, IrInstSrcBinOp *instructi ZigValue *len_val = op1_val->data.x_struct.fields[slice_len_index]; op1_array_end = op1_array_index + bigint_as_usize(&len_val->data.x_bigint); sentinel1 = ptr_type->data.pointer.sentinel; - } else if (op1_type->id == ZigTypeIdPointer && op1_type->data.pointer.ptr_len == PtrLenSingle && - op1_type->data.pointer.child_type->id == ZigTypeIdArray) + } else if (op1_type->id == ZigTypeIdPointer && + op1_type->data.pointer.ptr_len == PtrLenSingle && + op1_type->data.pointer.child_type->id == ZigTypeIdArray) { ZigType *array_type = op1_type->data.pointer.child_type; child_type = array_type->data.array.child_type; @@ -17654,6 +17658,103 @@ static IrInstGen *ir_analyze_array_cat(IrAnalyze *ira, IrInstSrcBinOp *instructi return result; } +static IrInstGen *ir_analyze_tuple_mult(IrAnalyze *ira, IrInst* source_instr, + IrInstGen *op1, IrInstGen *op2) +{ + Error err; + ZigType *op1_type = op1->value->type; + uint64_t op1_field_count = op1_type->data.structure.src_field_count; + + uint64_t mult_amt; + if (!ir_resolve_usize(ira, op2, &mult_amt)) + return ira->codegen->invalid_inst_gen; + + uint64_t new_field_count; + if (mul_u64_overflow(op1_field_count, mult_amt, &new_field_count)) { + ir_add_error(ira, source_instr, buf_sprintf("operation results in overflow")); + return ira->codegen->invalid_inst_gen; + } + + Buf *bare_name = buf_alloc(); + Buf *name = get_anon_type_name(ira->codegen, nullptr, container_string(ContainerKindStruct), + source_instr->scope, source_instr->source_node, bare_name); + ZigType *new_type = get_partial_container_type(ira->codegen, source_instr->scope, + ContainerKindStruct, source_instr->source_node, buf_ptr(name), bare_name, ContainerLayoutAuto); + new_type->data.structure.special = StructSpecialInferredTuple; + new_type->data.structure.resolve_status = ResolveStatusBeingInferred; + new_type->data.structure.src_field_count = new_field_count; + new_type->data.structure.fields = realloc_type_struct_fields( + new_type->data.structure.fields, 0, new_field_count); + + IrInstGen *new_struct_ptr = ir_resolve_result(ira, source_instr, no_result_loc(), + new_type, nullptr, false, true); + + for (uint64_t i = 0; i < new_field_count; i += 1) { + TypeStructField *src_field = op1_type->data.structure.fields[i % op1_field_count]; + TypeStructField *new_field = new_type->data.structure.fields[i]; + + new_field->name = buf_sprintf("%lu", i); + new_field->type_entry = src_field->type_entry; + new_field->type_val = src_field->type_val; + new_field->src_index = i; + new_field->decl_node = src_field->decl_node; + new_field->init_val = src_field->init_val; + new_field->is_comptime = src_field->is_comptime; + } + + if ((err = type_resolve(ira->codegen, new_type, ResolveStatusZeroBitsKnown))) + return ira->codegen->invalid_inst_gen; + + ZigList const_ptrs = {}; + for (uint64_t i = 0; i < new_field_count; i += 1) { + TypeStructField *src_field = op1_type->data.structure.fields[i % op1_field_count]; + TypeStructField *dst_field = new_type->data.structure.fields[i]; + + IrInstGen *field_value = ir_analyze_struct_value_field_value( + ira, source_instr, op1, src_field); + if (type_is_invalid(field_value->value->type)) + return ira->codegen->invalid_inst_gen; + + IrInstGen *dest_ptr = ir_analyze_struct_field_ptr( + ira, source_instr, dst_field, new_struct_ptr, new_type, true); + if (type_is_invalid(dest_ptr->value->type)) + return ira->codegen->invalid_inst_gen; + + if (instr_is_comptime(field_value)) { + const_ptrs.append(dest_ptr); + } + + IrInstGen *store_ptr_inst = ir_analyze_store_ptr( + ira, source_instr, dest_ptr, field_value, true); + if (type_is_invalid(store_ptr_inst->value->type)) + return ira->codegen->invalid_inst_gen; + } + + if (const_ptrs.length != new_field_count) { + new_struct_ptr->value->special = ConstValSpecialRuntime; + for (size_t i = 0; i < const_ptrs.length; i += 1) { + IrInstGen *elem_result_loc = const_ptrs.at(i); + assert(elem_result_loc->value->special == ConstValSpecialStatic); + if (elem_result_loc->value->type->data.pointer.inferred_struct_field != nullptr) { + // This field will be generated comptime; no need to do this. + continue; + } + IrInstGen *deref = ir_get_deref(ira, &elem_result_loc->base, elem_result_loc, nullptr); + if (!type_requires_comptime(ira->codegen, elem_result_loc->value->type->data.pointer.child_type)) { + elem_result_loc->value->special = ConstValSpecialRuntime; + } + IrInstGen *store_ptr_inst = ir_analyze_store_ptr( + ira, &elem_result_loc->base, elem_result_loc, deref, true); + if (type_is_invalid(store_ptr_inst->value->type)) + return ira->codegen->invalid_inst_gen; + } + } + + const_ptrs.deinit(); + + return ir_get_deref(ira, source_instr, new_struct_ptr, nullptr); +} + static IrInstGen *ir_analyze_array_mult(IrAnalyze *ira, IrInstSrcBinOp *instruction) { IrInstGen *op1 = instruction->op1->child; if (type_is_invalid(op1->value->type)) @@ -17671,8 +17772,9 @@ static IrInstGen *ir_analyze_array_mult(IrAnalyze *ira, IrInstSrcBinOp *instruct array_val = ir_resolve_const(ira, op1, UndefOk); if (array_val == nullptr) return ira->codegen->invalid_inst_gen; - } else if (op1->value->type->id == ZigTypeIdPointer && op1->value->type->data.pointer.ptr_len == PtrLenSingle && - op1->value->type->data.pointer.child_type->id == ZigTypeIdArray) + } else if (op1->value->type->id == ZigTypeIdPointer && + op1->value->type->data.pointer.ptr_len == PtrLenSingle && + op1->value->type->data.pointer.child_type->id == ZigTypeIdArray) { array_type = op1->value->type->data.pointer.child_type; IrInstGen *array_inst = ir_get_deref(ira, &op1->base, op1, nullptr); @@ -17682,6 +17784,8 @@ static IrInstGen *ir_analyze_array_mult(IrAnalyze *ira, IrInstSrcBinOp *instruct if (array_val == nullptr) return ira->codegen->invalid_inst_gen; want_ptr_to_array = true; + } else if (is_tuple(op1->value->type)) { + return ir_analyze_tuple_mult(ira, &instruction->base.base, op1, op2); } else { ir_add_error(ira, &op1->base, buf_sprintf("expected array type, found '%s'", buf_ptr(&op1->value->type->name))); return ira->codegen->invalid_inst_gen; diff --git a/test/stage1/behavior/tuple.zig b/test/stage1/behavior/tuple.zig index 3c467baaa..0b2fdfe4e 100644 --- a/test/stage1/behavior/tuple.zig +++ b/test/stage1/behavior/tuple.zig @@ -1,5 +1,7 @@ const std = @import("std"); -const expect = std.testing.expect; +const testing = std.testing; +const expect = testing.expect; +const expectEqual = testing.expectEqual; test "tuple concatenation" { const S = struct { @@ -9,8 +11,31 @@ test "tuple concatenation" { var x = .{a}; var y = .{b}; var c = x ++ y; - expect(c[0] == 1); - expect(c[1] == 2); + expectEqual(@as(i32, 1), c[0]); + expectEqual(@as(i32, 2), c[1]); + } + }; + S.doTheTest(); + comptime S.doTheTest(); +} + +test "tuple multiplication" { + const S = struct { + fn doTheTest() void { + { + const t = .{} ** 4; + expectEqual(0, @typeInfo(@TypeOf(t)).Struct.fields.len); + } + { + const t = .{'a'} ** 4; + expectEqual(4, @typeInfo(@TypeOf(t)).Struct.fields.len); + inline for (t) |x| expectEqual('a', x); + } + { + const t = .{ 1, 2, 3 } ** 4; + expectEqual(12, @typeInfo(@TypeOf(t)).Struct.fields.len); + inline for (t) |x, i| expectEqual(1 + i % 3, x); + } } }; S.doTheTest();