ir: Support div/mod/rem on vector types

Closes #4050
This commit is contained in:
LemonBoy 2020-03-14 13:23:41 +01:00 committed by Andrew Kelley
parent e2dc63644a
commit 54ffcf95a8
No known key found for this signature in database
GPG Key ID: 7C5F548F728501A9
3 changed files with 327 additions and 174 deletions

View File

@ -2591,12 +2591,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry,
}
static LLVMValueRef gen_float_op(CodeGen *g, LLVMValueRef val, ZigType *type_entry, BuiltinFnId op) {
if ((op == BuiltinFnIdCeil ||
op == BuiltinFnIdFloor) &&
type_entry->id == ZigTypeIdInt)
return val;
assert(type_entry->id == ZigTypeIdFloat);
assert(type_entry->id == ZigTypeIdFloat || type_entry->id == ZigTypeIdVector);
LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloatOp, op);
return LLVMBuildCall(g->builder, floor_fn, &val, 1, "");
}
@ -2612,6 +2607,21 @@ static LLVMValueRef bigint_to_llvm_const(LLVMTypeRef type_ref, BigInt *bigint) {
if (bigint->digit_count == 0) {
return LLVMConstNull(type_ref);
}
if (LLVMGetTypeKind(type_ref) == LLVMVectorTypeKind) {
const unsigned vector_len = LLVMGetVectorSize(type_ref);
LLVMTypeRef elem_type = LLVMGetElementType(type_ref);
LLVMValueRef *values = heap::c_allocator.allocate_nonzero<LLVMValueRef>(vector_len);
// Create a vector with all the elements having the same value
for (unsigned i = 0; i < vector_len; i++) {
values[i] = bigint_to_llvm_const(elem_type, bigint);
}
LLVMValueRef result = LLVMConstVector(values, vector_len);
heap::c_allocator.deallocate(values, vector_len);
return result;
}
LLVMValueRef unsigned_val;
if (bigint->digit_count == 1) {
unsigned_val = LLVMConstInt(type_ref, bigint_ptr(bigint)[0], false);
@ -2625,22 +2635,40 @@ static LLVMValueRef bigint_to_llvm_const(LLVMTypeRef type_ref, BigInt *bigint) {
}
}
// Collapses a <N x i1> vector into a single i1 whose value is 1 iff all the
// vector elements are 1
static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val) {
assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");
return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
}
static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast_math,
LLVMValueRef val1, LLVMValueRef val2,
ZigType *type_entry, DivKind div_kind)
LLVMValueRef val1, LLVMValueRef val2, ZigType *operand_type, DivKind div_kind)
{
ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ?
operand_type->data.vector.elem_type : operand_type;
ZigLLVMSetFastMath(g->builder, want_fast_math);
LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, type_entry));
if (want_runtime_safety && (want_fast_math || type_entry->id != ZigTypeIdFloat)) {
LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, operand_type));
if (want_runtime_safety && (want_fast_math || scalar_type->id != ZigTypeIdFloat)) {
// Safety check: divisor != 0
LLVMValueRef is_zero_bit;
if (type_entry->id == ZigTypeIdInt) {
if (scalar_type->id == ZigTypeIdInt) {
is_zero_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, "");
} else if (type_entry->id == ZigTypeIdFloat) {
} else if (scalar_type->id == ZigTypeIdFloat) {
is_zero_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, val2, zero, "");
} else {
zig_unreachable();
}
if (operand_type->id == ZigTypeIdVector) {
is_zero_bit = scalarize_cmp_result(g, is_zero_bit);
}
LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
LLVMBasicBlockRef div_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroOk");
LLVMBuildCondBr(g->builder, is_zero_bit, div_zero_fail_block, div_zero_ok_block);
@ -2650,16 +2678,21 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMPositionBuilderAtEnd(g->builder, div_zero_ok_block);
if (type_entry->id == ZigTypeIdInt && type_entry->data.integral.is_signed) {
LLVMValueRef neg_1_value = LLVMConstInt(get_llvm_type(g, type_entry), -1, true);
// Safety check: check for overflow (dividend = minInt and divisor = -1)
if (scalar_type->id == ZigTypeIdInt && scalar_type->data.integral.is_signed) {
LLVMValueRef neg_1_value = LLVMConstAllOnes(get_llvm_type(g, operand_type));
BigInt int_min_bi = {0};
eval_min_max_value_int(g, type_entry, &int_min_bi, false);
LLVMValueRef int_min_value = bigint_to_llvm_const(get_llvm_type(g, type_entry), &int_min_bi);
eval_min_max_value_int(g, scalar_type, &int_min_bi, false);
LLVMValueRef int_min_value = bigint_to_llvm_const(get_llvm_type(g, operand_type), &int_min_bi);
LLVMBasicBlockRef overflow_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowFail");
LLVMBasicBlockRef overflow_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowOk");
LLVMValueRef num_is_int_min = LLVMBuildICmp(g->builder, LLVMIntEQ, val1, int_min_value, "");
LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
if (operand_type->id == ZigTypeIdVector) {
overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit);
}
LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block);
LLVMPositionBuilderAtEnd(g->builder, overflow_fail_block);
@ -2669,18 +2702,22 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
}
}
if (type_entry->id == ZigTypeIdFloat) {
if (scalar_type->id == ZigTypeIdFloat) {
LLVMValueRef result = LLVMBuildFDiv(g->builder, val1, val2, "");
switch (div_kind) {
case DivKindFloat:
return result;
case DivKindExact:
if (want_runtime_safety) {
LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor);
// Safety check: a / b == floor(a / b)
LLVMValueRef floored = gen_float_op(g, result, operand_type, BuiltinFnIdFloor);
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
LLVMPositionBuilderAtEnd(g->builder, fail_block);
@ -2695,54 +2732,61 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef gez_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncGEZero");
LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd");
LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, "");
if (operand_type->id == ZigTypeIdVector) {
ltz = scalarize_cmp_result(g, ltz);
}
LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);
LLVMPositionBuilderAtEnd(g->builder, ltz_block);
LLVMValueRef ceiled = gen_float_op(g, result, type_entry, BuiltinFnIdCeil);
LLVMValueRef ceiled = gen_float_op(g, result, operand_type, BuiltinFnIdCeil);
LLVMBasicBlockRef ceiled_end_block = LLVMGetInsertBlock(g->builder);
LLVMBuildBr(g->builder, end_block);
LLVMPositionBuilderAtEnd(g->builder, gez_block);
LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor);
LLVMValueRef floored = gen_float_op(g, result, operand_type, BuiltinFnIdFloor);
LLVMBasicBlockRef floored_end_block = LLVMGetInsertBlock(g->builder);
LLVMBuildBr(g->builder, end_block);
LLVMPositionBuilderAtEnd(g->builder, end_block);
LLVMValueRef phi = LLVMBuildPhi(g->builder, get_llvm_type(g, type_entry), "");
LLVMValueRef phi = LLVMBuildPhi(g->builder, get_llvm_type(g, operand_type), "");
LLVMValueRef incoming_values[] = { ceiled, floored };
LLVMBasicBlockRef incoming_blocks[] = { ceiled_end_block, floored_end_block };
LLVMAddIncoming(phi, incoming_values, incoming_blocks, 2);
return phi;
}
case DivKindFloor:
return gen_float_op(g, result, type_entry, BuiltinFnIdFloor);
return gen_float_op(g, result, operand_type, BuiltinFnIdFloor);
}
zig_unreachable();
}
assert(type_entry->id == ZigTypeIdInt);
assert(scalar_type->id == ZigTypeIdInt);
switch (div_kind) {
case DivKindFloat:
zig_unreachable();
case DivKindTrunc:
if (type_entry->data.integral.is_signed) {
if (scalar_type->data.integral.is_signed) {
return LLVMBuildSDiv(g->builder, val1, val2, "");
} else {
return LLVMBuildUDiv(g->builder, val1, val2, "");
}
case DivKindExact:
if (want_runtime_safety) {
// Safety check: a % b == 0
LLVMValueRef remainder_val;
if (type_entry->data.integral.is_signed) {
if (scalar_type->data.integral.is_signed) {
remainder_val = LLVMBuildSRem(g->builder, val1, val2, "");
} else {
remainder_val = LLVMBuildURem(g->builder, val1, val2, "");
}
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
LLVMPositionBuilderAtEnd(g->builder, fail_block);
@ -2750,14 +2794,14 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMPositionBuilderAtEnd(g->builder, ok_block);
}
if (type_entry->data.integral.is_signed) {
if (scalar_type->data.integral.is_signed) {
return LLVMBuildExactSDiv(g->builder, val1, val2, "");
} else {
return LLVMBuildExactUDiv(g->builder, val1, val2, "");
}
case DivKindFloor:
{
if (!type_entry->data.integral.is_signed) {
if (!scalar_type->data.integral.is_signed) {
return LLVMBuildUDiv(g->builder, val1, val2, "");
}
// const d = @divTrunc(a, b);
@ -2784,22 +2828,30 @@ enum RemKind {
};
static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast_math,
LLVMValueRef val1, LLVMValueRef val2,
ZigType *type_entry, RemKind rem_kind)
LLVMValueRef val1, LLVMValueRef val2, ZigType *operand_type, RemKind rem_kind)
{
ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ?
operand_type->data.vector.elem_type : operand_type;
ZigLLVMSetFastMath(g->builder, want_fast_math);
LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, type_entry));
LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, operand_type));
if (want_runtime_safety) {
// Safety check: divisor != 0
LLVMValueRef is_zero_bit;
if (type_entry->id == ZigTypeIdInt) {
LLVMIntPredicate pred = type_entry->data.integral.is_signed ? LLVMIntSLE : LLVMIntEQ;
if (scalar_type->id == ZigTypeIdInt) {
LLVMIntPredicate pred = scalar_type->data.integral.is_signed ? LLVMIntSLE : LLVMIntEQ;
is_zero_bit = LLVMBuildICmp(g->builder, pred, val2, zero, "");
} else if (type_entry->id == ZigTypeIdFloat) {
} else if (scalar_type->id == ZigTypeIdFloat) {
is_zero_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, val2, zero, "");
} else {
zig_unreachable();
}
if (operand_type->id == ZigTypeIdVector) {
is_zero_bit = scalarize_cmp_result(g, is_zero_bit);
}
LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk");
LLVMBasicBlockRef rem_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroFail");
LLVMBuildCondBr(g->builder, is_zero_bit, rem_zero_fail_block, rem_zero_ok_block);
@ -2810,7 +2862,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMPositionBuilderAtEnd(g->builder, rem_zero_ok_block);
}
if (type_entry->id == ZigTypeIdFloat) {
if (scalar_type->id == ZigTypeIdFloat) {
if (rem_kind == RemKindRem) {
return LLVMBuildFRem(g->builder, val1, val2, "");
} else {
@ -2821,8 +2873,8 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
return LLVMBuildSelect(g->builder, ltz, c, a, "");
}
} else {
assert(type_entry->id == ZigTypeIdInt);
if (type_entry->data.integral.is_signed) {
assert(scalar_type->id == ZigTypeIdInt);
if (scalar_type->data.integral.is_signed) {
if (rem_kind == RemKindRem) {
return LLVMBuildSRem(g->builder, val1, val2, "");
} else {
@ -3010,22 +3062,22 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
}
case IrBinOpDivUnspecified:
return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
op1_value, op2_value, scalar_type, DivKindFloat);
op1_value, op2_value, operand_type, DivKindFloat);
case IrBinOpDivExact:
return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
op1_value, op2_value, scalar_type, DivKindExact);
op1_value, op2_value, operand_type, DivKindExact);
case IrBinOpDivTrunc:
return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
op1_value, op2_value, scalar_type, DivKindTrunc);
op1_value, op2_value, operand_type, DivKindTrunc);
case IrBinOpDivFloor:
return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
op1_value, op2_value, scalar_type, DivKindFloor);
op1_value, op2_value, operand_type, DivKindFloor);
case IrBinOpRemRem:
return gen_rem(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
op1_value, op2_value, scalar_type, RemKindRem);
op1_value, op2_value, operand_type, RemKindRem);
case IrBinOpRemMod:
return gen_rem(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
op1_value, op2_value, scalar_type, RemKindMod);
op1_value, op2_value, operand_type, RemKindMod);
}
zig_unreachable();
}

View File

@ -16943,6 +16943,7 @@ static bool ok_float_op(IrBinOp op) {
case IrBinOpDivExact:
case IrBinOpRemRem:
case IrBinOpRemMod:
case IrBinOpRemUnspecified:
return true;
case IrBinOpBoolOr:
@ -16963,7 +16964,6 @@ static bool ok_float_op(IrBinOp op) {
case IrBinOpAddWrap:
case IrBinOpSubWrap:
case IrBinOpMultWrap:
case IrBinOpRemUnspecified:
case IrBinOpArrayCat:
case IrBinOpArrayMult:
return false;
@ -16991,6 +16991,31 @@ static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) {
zig_unreachable();
}
static bool value_cmp_zero_any(ZigValue *value, Cmp predicate) {
assert(value->special == ConstValSpecialStatic);
switch (value->type->id) {
case ZigTypeIdComptimeInt:
case ZigTypeIdInt:
return bigint_cmp_zero(&value->data.x_bigint) == predicate;
case ZigTypeIdComptimeFloat:
case ZigTypeIdFloat:
if (float_is_nan(value))
return false;
return float_cmp_zero(value) == predicate;
case ZigTypeIdVector: {
for (size_t i = 0; i < value->type->data.vector.len; i++) {
ZigValue *scalar_val = &value->data.x_array.data.s_none.elements[i];
if (!value_cmp_zero_any(scalar_val, predicate))
return true;
}
return false;
}
default:
zig_unreachable();
}
}
static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruction) {
Error err;
@ -17096,127 +17121,13 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc
if (type_is_invalid(resolved_type))
return ira->codegen->invalid_inst_gen;
bool is_int = resolved_type->id == ZigTypeIdInt || resolved_type->id == ZigTypeIdComptimeInt;
bool is_float = resolved_type->id == ZigTypeIdFloat || resolved_type->id == ZigTypeIdComptimeFloat;
bool is_signed_div = (
(resolved_type->id == ZigTypeIdInt && resolved_type->data.integral.is_signed) ||
resolved_type->id == ZigTypeIdFloat ||
(resolved_type->id == ZigTypeIdComptimeFloat &&
((bigfloat_cmp_zero(&op1->value->data.x_bigfloat) != CmpGT) !=
(bigfloat_cmp_zero(&op2->value->data.x_bigfloat) != CmpGT))) ||
(resolved_type->id == ZigTypeIdComptimeInt &&
((bigint_cmp_zero(&op1->value->data.x_bigint) != CmpGT) !=
(bigint_cmp_zero(&op2->value->data.x_bigint) != CmpGT)))
);
if (op_id == IrBinOpDivUnspecified && is_int) {
if (is_signed_div) {
bool ok = false;
if (instr_is_comptime(op1) && instr_is_comptime(op2)) {
ZigValue *op1_val = ir_resolve_const(ira, op1, UndefBad);
if (op1_val == nullptr)
return ira->codegen->invalid_inst_gen;
ZigType *scalar_type = (resolved_type->id == ZigTypeIdVector) ?
resolved_type->data.vector.elem_type : resolved_type;
ZigValue *op2_val = ir_resolve_const(ira, op2, UndefBad);
if (op2_val == nullptr)
return ira->codegen->invalid_inst_gen;
bool is_int = scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdComptimeInt;
bool is_float = scalar_type->id == ZigTypeIdFloat || scalar_type->id == ZigTypeIdComptimeFloat;
if (bigint_cmp_zero(&op2_val->data.x_bigint) == CmpEQ) {
// the division by zero error will be caught later, but we don't have a
// division function ambiguity problem.
op_id = IrBinOpDivTrunc;
ok = true;
} else {
BigInt trunc_result;
BigInt floor_result;
bigint_div_trunc(&trunc_result, &op1_val->data.x_bigint, &op2_val->data.x_bigint);
bigint_div_floor(&floor_result, &op1_val->data.x_bigint, &op2_val->data.x_bigint);
if (bigint_cmp(&trunc_result, &floor_result) == CmpEQ) {
ok = true;
op_id = IrBinOpDivTrunc;
}
}
}
if (!ok) {
ir_add_error(ira, &instruction->base.base,
buf_sprintf("division with '%s' and '%s': signed integers must use @divTrunc, @divFloor, or @divExact",
buf_ptr(&op1->value->type->name),
buf_ptr(&op2->value->type->name)));
return ira->codegen->invalid_inst_gen;
}
} else {
op_id = IrBinOpDivTrunc;
}
} else if (op_id == IrBinOpRemUnspecified) {
if (is_signed_div && (is_int || is_float)) {
bool ok = false;
if (instr_is_comptime(op1) && instr_is_comptime(op2)) {
ZigValue *op1_val = ir_resolve_const(ira, op1, UndefBad);
if (op1_val == nullptr)
return ira->codegen->invalid_inst_gen;
if (is_int) {
ZigValue *op2_val = ir_resolve_const(ira, op2, UndefBad);
if (op2_val == nullptr)
return ira->codegen->invalid_inst_gen;
if (bigint_cmp_zero(&op2->value->data.x_bigint) == CmpEQ) {
// the division by zero error will be caught later, but we don't
// have a remainder function ambiguity problem
ok = true;
} else {
BigInt rem_result;
BigInt mod_result;
bigint_rem(&rem_result, &op1_val->data.x_bigint, &op2_val->data.x_bigint);
bigint_mod(&mod_result, &op1_val->data.x_bigint, &op2_val->data.x_bigint);
ok = bigint_cmp(&rem_result, &mod_result) == CmpEQ;
}
} else {
IrInstGen *casted_op2 = ir_implicit_cast(ira, op2, resolved_type);
if (type_is_invalid(casted_op2->value->type))
return ira->codegen->invalid_inst_gen;
ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
if (op2_val == nullptr)
return ira->codegen->invalid_inst_gen;
if (float_cmp_zero(casted_op2->value) == CmpEQ) {
// the division by zero error will be caught later, but we don't
// have a remainder function ambiguity problem
ok = true;
} else {
ZigValue rem_result = {};
ZigValue mod_result = {};
float_rem(&rem_result, op1_val, op2_val);
float_mod(&mod_result, op1_val, op2_val);
ok = float_cmp(&rem_result, &mod_result) == CmpEQ;
}
}
}
if (!ok) {
ir_add_error(ira, &instruction->base.base,
buf_sprintf("remainder division with '%s' and '%s': signed integers and floats must use @rem or @mod",
buf_ptr(&op1->value->type->name),
buf_ptr(&op2->value->type->name)));
return ira->codegen->invalid_inst_gen;
}
}
op_id = IrBinOpRemRem;
}
bool ok = false;
if (is_int) {
ok = true;
} else if (is_float && ok_float_op(op_id)) {
ok = true;
} else if (resolved_type->id == ZigTypeIdVector) {
ZigType *elem_type = resolved_type->data.vector.elem_type;
if (elem_type->id == ZigTypeIdInt || elem_type->id == ZigTypeIdComptimeInt) {
ok = true;
} else if ((elem_type->id == ZigTypeIdFloat || elem_type->id == ZigTypeIdComptimeFloat) && ok_float_op(op_id)) {
ok = true;
}
}
if (!ok) {
if (!is_int && !(is_float && ok_float_op(op_id))) {
AstNode *source_node = instruction->base.base.source_node;
ir_add_error_node(ira, source_node,
buf_sprintf("invalid operands to binary expression: '%s' and '%s'",
@ -17225,16 +17136,6 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc
return ira->codegen->invalid_inst_gen;
}
if (resolved_type->id == ZigTypeIdComptimeInt) {
if (op_id == IrBinOpAddWrap) {
op_id = IrBinOpAdd;
} else if (op_id == IrBinOpSubWrap) {
op_id = IrBinOpSub;
} else if (op_id == IrBinOpMultWrap) {
op_id = IrBinOpMult;
}
}
IrInstGen *casted_op1 = ir_implicit_cast(ira, op1, resolved_type);
if (type_is_invalid(casted_op1->value->type))
return ira->codegen->invalid_inst_gen;
@ -17243,17 +17144,142 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc
if (type_is_invalid(casted_op2->value->type))
return ira->codegen->invalid_inst_gen;
// Comptime integers have no fixed size
if (scalar_type->id == ZigTypeIdComptimeInt) {
if (op_id == IrBinOpAddWrap) {
op_id = IrBinOpAdd;
} else if (op_id == IrBinOpSubWrap) {
op_id = IrBinOpSub;
} else if (op_id == IrBinOpMultWrap) {
op_id = IrBinOpMult;
}
}
if (instr_is_comptime(casted_op1) && instr_is_comptime(casted_op2)) {
ZigValue *op1_val = ir_resolve_const(ira, casted_op1, UndefBad);
if (op1_val == nullptr)
return ira->codegen->invalid_inst_gen;
ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
if (op2_val == nullptr)
return ira->codegen->invalid_inst_gen;
// Promote division with negative numbers to signed
bool is_signed_div = value_cmp_zero_any(op1_val, CmpLT) ||
value_cmp_zero_any(op2_val, CmpLT);
if (op_id == IrBinOpDivUnspecified && is_int) {
// Default to truncating division and check if it's valid for the
// given operands if signed
op_id = IrBinOpDivTrunc;
if (is_signed_div) {
bool ok = false;
if (value_cmp_zero_any(op2_val, CmpEQ)) {
// the division by zero error will be caught later, but we don't have a
// division function ambiguity problem.
ok = true;
} else {
IrInstGen *trunc_val = ir_analyze_math_op(ira, &instruction->base.base, resolved_type,
op1_val, IrBinOpDivTrunc, op2_val);
if (type_is_invalid(trunc_val->value->type))
return ira->codegen->invalid_inst_gen;
IrInstGen *floor_val = ir_analyze_math_op(ira, &instruction->base.base, resolved_type,
op1_val, IrBinOpDivFloor, op2_val);
if (type_is_invalid(floor_val->value->type))
return ira->codegen->invalid_inst_gen;
IrInstGen *cmp_val = ir_analyze_bin_op_cmp_numeric(ira, &instruction->base.base,
trunc_val, floor_val, IrBinOpCmpEq);
if (type_is_invalid(cmp_val->value->type))
return ira->codegen->invalid_inst_gen;
// We can "upgrade" the operator only if trunc(a/b) == floor(a/b)
if (!ir_resolve_bool(ira, cmp_val, &ok))
return ira->codegen->invalid_inst_gen;
}
if (!ok) {
ir_add_error(ira, &instruction->base.base,
buf_sprintf("division with '%s' and '%s': signed integers must use @divTrunc, @divFloor, or @divExact",
buf_ptr(&op1->value->type->name),
buf_ptr(&op2->value->type->name)));
return ira->codegen->invalid_inst_gen;
}
}
} else if (op_id == IrBinOpRemUnspecified) {
op_id = IrBinOpRemRem;
if (is_signed_div) {
bool ok = false;
if (value_cmp_zero_any(op2_val, CmpEQ)) {
// the division by zero error will be caught later, but we don't have a
// division function ambiguity problem.
ok = true;
} else {
IrInstGen *rem_val = ir_analyze_math_op(ira, &instruction->base.base, resolved_type,
op1_val, IrBinOpRemRem, op2_val);
if (type_is_invalid(rem_val->value->type))
return ira->codegen->invalid_inst_gen;
IrInstGen *mod_val = ir_analyze_math_op(ira, &instruction->base.base, resolved_type,
op1_val, IrBinOpRemMod, op2_val);
if (type_is_invalid(mod_val->value->type))
return ira->codegen->invalid_inst_gen;
IrInstGen *cmp_val = ir_analyze_bin_op_cmp_numeric(ira, &instruction->base.base,
rem_val, mod_val, IrBinOpCmpEq);
if (type_is_invalid(cmp_val->value->type))
return ira->codegen->invalid_inst_gen;
// We can "upgrade" the operator only if mod(a,b) == rem(a,b)
if (!ir_resolve_bool(ira, cmp_val, &ok))
return ira->codegen->invalid_inst_gen;
}
if (!ok) {
ir_add_error(ira, &instruction->base.base,
buf_sprintf("remainder division with '%s' and '%s': signed integers and floats must use @rem or @mod",
buf_ptr(&op1->value->type->name),
buf_ptr(&op2->value->type->name)));
return ira->codegen->invalid_inst_gen;
}
}
}
return ir_analyze_math_op(ira, &instruction->base.base, resolved_type, op1_val, op_id, op2_val);
}
const bool is_signed_div =
(scalar_type->id == ZigTypeIdInt && scalar_type->data.integral.is_signed) ||
scalar_type->id == ZigTypeIdFloat;
// Warn the user to use the proper operators here
if (op_id == IrBinOpDivUnspecified && is_int) {
op_id = IrBinOpDivTrunc;
if (is_signed_div) {
ir_add_error(ira, &instruction->base.base,
buf_sprintf("division with '%s' and '%s': signed integers must use @divTrunc, @divFloor, or @divExact",
buf_ptr(&op1->value->type->name),
buf_ptr(&op2->value->type->name)));
return ira->codegen->invalid_inst_gen;
}
} else if (op_id == IrBinOpRemUnspecified) {
op_id = IrBinOpRemRem;
if (is_signed_div) {
ir_add_error(ira, &instruction->base.base,
buf_sprintf("remainder division with '%s' and '%s': signed integers and floats must use @rem or @mod",
buf_ptr(&op1->value->type->name),
buf_ptr(&op2->value->type->name)));
return ira->codegen->invalid_inst_gen;
}
}
return ir_build_bin_op_gen(ira, &instruction->base.base, resolved_type,
op_id, casted_op1, casted_op2, instruction->safety_check_on);
}

View File

@ -276,3 +276,78 @@ test "vector comparison operators" {
S.doTheTest();
comptime S.doTheTest();
}
test "vector division operators" {
const S = struct {
fn doTheTestDiv(comptime T: type, x: @Vector(4, T), y: @Vector(4, T)) void {
if (!comptime std.meta.trait.isSignedInt(T)) {
const d0 = x / y;
for (@as([4]T, d0)) |v, i| {
expectEqual(x[i] / y[i], v);
}
}
const d1 = @divExact(x, y);
for (@as([4]T, d1)) |v, i| {
expectEqual(@divExact(x[i], y[i]), v);
}
const d2 = @divFloor(x, y);
for (@as([4]T, d2)) |v, i| {
expectEqual(@divFloor(x[i], y[i]), v);
}
const d3 = @divTrunc(x, y);
for (@as([4]T, d3)) |v, i| {
expectEqual(@divTrunc(x[i], y[i]), v);
}
}
fn doTheTestMod(comptime T: type, x: @Vector(4, T), y: @Vector(4, T)) void {
if ((!comptime std.meta.trait.isSignedInt(T)) and @typeInfo(T) != .Float) {
const r0 = x % y;
for (@as([4]T, r0)) |v, i| {
expectEqual(x[i] % y[i], v);
}
}
const r1 = @mod(x, y);
for (@as([4]T, r1)) |v, i| {
expectEqual(@mod(x[i], y[i]), v);
}
const r2 = @rem(x, y);
for (@as([4]T, r2)) |v, i| {
expectEqual(@rem(x[i], y[i]), v);
}
}
fn doTheTest() void {
doTheTestDiv(f16, [4]f16{ 4.0, -4.0, 4.0, -4.0 }, [4]f16{ 1.0, 2.0, -1.0, -2.0 });
doTheTestDiv(f32, [4]f32{ 4.0, -4.0, 4.0, -4.0 }, [4]f32{ 1.0, 2.0, -1.0, -2.0 });
doTheTestDiv(f64, [4]f64{ 4.0, -4.0, 4.0, -4.0 }, [4]f64{ 1.0, 2.0, -1.0, -2.0 });
doTheTestMod(f16, [4]f16{ 4.0, -4.0, 4.0, -4.0 }, [4]f16{ 1.0, 2.0, 0.5, 3.0 });
doTheTestMod(f32, [4]f32{ 4.0, -4.0, 4.0, -4.0 }, [4]f32{ 1.0, 2.0, 0.5, 3.0 });
doTheTestMod(f64, [4]f64{ 4.0, -4.0, 4.0, -4.0 }, [4]f64{ 1.0, 2.0, 0.5, 3.0 });
doTheTestDiv(i8, [4]i8{ 4, -4, 4, -4 }, [4]i8{ 1, 2, -1, -2 });
doTheTestDiv(i16, [4]i16{ 4, -4, 4, -4 }, [4]i16{ 1, 2, -1, -2 });
doTheTestDiv(i32, [4]i32{ 4, -4, 4, -4 }, [4]i32{ 1, 2, -1, -2 });
doTheTestDiv(i64, [4]i64{ 4, -4, 4, -4 }, [4]i64{ 1, 2, -1, -2 });
doTheTestMod(i8, [4]i8{ 4, -4, 4, -4 }, [4]i8{ 1, 2, 4, 8 });
doTheTestMod(i16, [4]i16{ 4, -4, 4, -4 }, [4]i16{ 1, 2, 4, 8 });
doTheTestMod(i32, [4]i32{ 4, -4, 4, -4 }, [4]i32{ 1, 2, 4, 8 });
doTheTestMod(i64, [4]i64{ 4, -4, 4, -4 }, [4]i64{ 1, 2, 4, 8 });
doTheTestDiv(u8, [4]u8{ 1, 2, 4, 8 }, [4]u8{ 1, 1, 2, 4 });
doTheTestDiv(u16, [4]u16{ 1, 2, 4, 8 }, [4]u16{ 1, 1, 2, 4 });
doTheTestDiv(u32, [4]u32{ 1, 2, 4, 8 }, [4]u32{ 1, 1, 2, 4 });
doTheTestDiv(u64, [4]u64{ 1, 2, 4, 8 }, [4]u64{ 1, 1, 2, 4 });
doTheTestMod(u8, [4]u8{ 1, 2, 4, 8 }, [4]u8{ 1, 1, 2, 4 });
doTheTestMod(u16, [4]u16{ 1, 2, 4, 8 }, [4]u16{ 1, 1, 2, 4 });
doTheTestMod(u32, [4]u32{ 1, 2, 4, 8 }, [4]u32{ 1, 1, 2, 4 });
doTheTestMod(u64, [4]u64{ 1, 2, 4, 8 }, [4]u64{ 1, 1, 2, 4 });
}
};
S.doTheTest();
comptime S.doTheTest();
}