stage1: add support for @mulAdd fused-multiply-add for floats and vectors of floats

Not all of the softfloat library is being built....

Vector support is very buggy at the moment, but should work when the bugs are fixed.
(as I had the same code working with another vector function, that hasn't been merged yet).
master
Shawn Landden 2019-06-18 17:28:49 -05:00
parent bbfb53d524
commit fce2d2d18b
8 changed files with 292 additions and 7 deletions

View File

@ -389,6 +389,8 @@ set(EMBEDDED_SOFTFLOAT_SOURCES
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/s_subMagsF32.c"
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/s_subMagsF64.c"
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/s_tryPropagateNaNF128M.c"
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f16_mulAdd.c"
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f128M_mulAdd.c"
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/softfloat_state.c"
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/ui32_to_f128M.c"
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/ui64_to_f128M.c"

View File

@ -6259,6 +6259,13 @@ comptime {
This function is only valid within function scope.
</p>
{#header_close#}
{#header_open|@mulAdd#}
<pre>{#syntax#}@mulAdd(comptime T: type, a: T, b: T, c: T) T{#endsyntax#}</pre>
<p>
Fused multiply add (for floats), similar to {#syntax#}(a * b) + c{#endsyntax#}, except
only rounds once, and is thus more accurate.
</p>
{#header_close#}
{#header_open|@byteSwap#}

View File

@ -1406,6 +1406,7 @@ enum BuiltinFnId {
BuiltinFnIdSubWithOverflow,
BuiltinFnIdMulWithOverflow,
BuiltinFnIdShlWithOverflow,
BuiltinFnIdMulAdd,
BuiltinFnIdCInclude,
BuiltinFnIdCDefine,
BuiltinFnIdCUndef,
@ -1554,6 +1555,7 @@ enum ZigLLVMFnId {
ZigLLVMFnIdClz,
ZigLLVMFnIdPopCount,
ZigLLVMFnIdOverflowArithmetic,
ZigLLVMFnIdFMA,
ZigLLVMFnIdFloor,
ZigLLVMFnIdCeil,
ZigLLVMFnIdSqrt,
@ -1584,6 +1586,7 @@ struct ZigLLVMFnKey {
} pop_count;
struct {
uint32_t bit_count;
uint32_t vector_len; // 0 means not a vector
} floating;
struct {
AddSubMul add_sub_mul;
@ -2235,6 +2238,7 @@ enum IrInstructionId {
IrInstructionIdHandle,
IrInstructionIdAlignOf,
IrInstructionIdOverflowOp,
IrInstructionIdMulAdd,
IrInstructionIdTestErr,
IrInstructionIdUnwrapErrCode,
IrInstructionIdUnwrapErrPayload,
@ -3038,6 +3042,15 @@ struct IrInstructionOverflowOp {
ZigType *result_ptr_type;
};
struct IrInstructionMulAdd {
IrInstruction base;
IrInstruction *type_value;
IrInstruction *op1;
IrInstruction *op2;
IrInstruction *op3;
};
struct IrInstructionAlignOf {
IrInstruction base;

View File

@ -5737,11 +5737,11 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) {
case ZigLLVMFnIdPopCount:
return (uint32_t)(x.data.clz.bit_count) * (uint32_t)101195049;
case ZigLLVMFnIdFloor:
return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1899859168;
case ZigLLVMFnIdCeil:
return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1953839089;
case ZigLLVMFnIdSqrt:
return (uint32_t)(x.data.floating.bit_count) * (uint32_t)2225366385;
case ZigLLVMFnIdFMA:
return (uint32_t)(x.data.floating.bit_count) * ((uint32_t)x.id + 1025) +
(uint32_t)(x.data.floating.vector_len) * (((uint32_t)x.id << 5) + 1025);
case ZigLLVMFnIdBswap:
return (uint32_t)(x.data.bswap.bit_count) * (uint32_t)3661994335;
case ZigLLVMFnIdBitReverse:
@ -5772,6 +5772,7 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) {
case ZigLLVMFnIdFloor:
case ZigLLVMFnIdCeil:
case ZigLLVMFnIdSqrt:
case ZigLLVMFnIdFMA:
return a.data.floating.bit_count == b.data.floating.bit_count;
case ZigLLVMFnIdOverflowArithmetic:
return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) &&

View File

@ -807,31 +807,51 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *operand_type, AddSu
}
static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id) {
assert(type_entry->id == ZigTypeIdFloat);
assert(type_entry->id == ZigTypeIdFloat ||
type_entry->id == ZigTypeIdVector);
bool is_vector = (type_entry->id == ZigTypeIdVector);
ZigType *float_type = is_vector ? type_entry->data.vector.elem_type : type_entry;
ZigLLVMFnKey key = {};
key.id = fn_id;
key.data.floating.bit_count = (uint32_t)type_entry->data.floating.bit_count;
key.data.floating.bit_count = (uint32_t)float_type->data.floating.bit_count;
key.data.floating.vector_len = is_vector ? (uint32_t)type_entry->data.vector.len : 0;
auto existing_entry = g->llvm_fn_table.maybe_get(key);
if (existing_entry)
return existing_entry->value;
const char *name;
uint32_t num_args;
if (fn_id == ZigLLVMFnIdFloor) {
name = "floor";
num_args = 1;
} else if (fn_id == ZigLLVMFnIdCeil) {
name = "ceil";
num_args = 1;
} else if (fn_id == ZigLLVMFnIdSqrt) {
name = "sqrt";
num_args = 1;
} else if (fn_id == ZigLLVMFnIdFMA) {
name = "fma";
num_args = 3;
} else {
zig_unreachable();
}
char fn_name[64];
sprintf(fn_name, "llvm.%s.f%" ZIG_PRI_usize "", name, type_entry->data.floating.bit_count);
if (is_vector)
sprintf(fn_name, "llvm.%s.v%" PRIu32 "f%" PRIu32, name, key.data.floating.vector_len, key.data.floating.bit_count);
else
sprintf(fn_name, "llvm.%s.f%" PRIu32, name, key.data.floating.bit_count);
LLVMTypeRef float_type_ref = get_llvm_type(g, type_entry);
LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, &float_type_ref, 1, false);
LLVMTypeRef return_elem_types[3] = {
float_type_ref,
float_type_ref,
float_type_ref,
};
LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, return_elem_types, num_args, false);
LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type);
assert(LLVMGetIntrinsicID(fn_val));
@ -5437,6 +5457,21 @@ static LLVMValueRef ir_render_sqrt(CodeGen *g, IrExecutable *executable, IrInstr
return LLVMBuildCall(g->builder, fn_val, &op, 1, "");
}
static LLVMValueRef ir_render_mul_add(CodeGen *g, IrExecutable *executable, IrInstructionMulAdd *instruction) {
LLVMValueRef op1 = ir_llvm_value(g, instruction->op1);
LLVMValueRef op2 = ir_llvm_value(g, instruction->op2);
LLVMValueRef op3 = ir_llvm_value(g, instruction->op3);
assert(instruction->base.value.type->id == ZigTypeIdFloat ||
instruction->base.value.type->id == ZigTypeIdVector);
LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA);
LLVMValueRef args[3] = {
op1,
op2,
op3,
};
return LLVMBuildCall(g->builder, fn_val, args, 3, "");
}
static LLVMValueRef ir_render_bswap(CodeGen *g, IrExecutable *executable, IrInstructionBswap *instruction) {
LLVMValueRef op = ir_llvm_value(g, instruction->op);
ZigType *int_type = instruction->base.value.type;
@ -5781,6 +5816,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
return ir_render_mark_err_ret_trace_ptr(g, executable, (IrInstructionMarkErrRetTracePtr *)instruction);
case IrInstructionIdSqrt:
return ir_render_sqrt(g, executable, (IrInstructionSqrt *)instruction);
case IrInstructionIdMulAdd:
return ir_render_mul_add(g, executable, (IrInstructionMulAdd *)instruction);
case IrInstructionIdArrayToVector:
return ir_render_array_to_vector(g, executable, (IrInstructionArrayToVector *)instruction);
case IrInstructionIdVectorToArray:
@ -7398,6 +7435,7 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdRem, "rem", 2);
create_builtin_fn(g, BuiltinFnIdMod, "mod", 2);
create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2);
create_builtin_fn(g, BuiltinFnIdMulAdd, "mulAdd", 4);
create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX);
create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX);
create_builtin_fn(g, BuiltinFnIdNewStackCall, "newStackCall", SIZE_MAX);

View File

@ -747,6 +747,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionTestErr *) {
return IrInstructionIdTestErr;
}
static constexpr IrInstructionId ir_instruction_id(IrInstructionMulAdd *) {
return IrInstructionIdMulAdd;
}
static constexpr IrInstructionId ir_instruction_id(IrInstructionUnwrapErrCode *) {
return IrInstructionIdUnwrapErrCode;
}
@ -2308,6 +2312,22 @@ static IrInstruction *ir_build_overflow_op(IrBuilder *irb, Scope *scope, AstNode
return &instruction->base;
}
static IrInstruction *ir_build_mul_add(IrBuilder *irb, Scope *scope, AstNode *source_node,
IrInstruction *type_value, IrInstruction *op1, IrInstruction *op2, IrInstruction *op3) {
IrInstructionMulAdd *instruction = ir_build_instruction<IrInstructionMulAdd>(irb, scope, source_node);
instruction->type_value = type_value;
instruction->op1 = op1;
instruction->op2 = op2;
instruction->op3 = op3;
ir_ref_instruction(type_value, irb->current_basic_block);
ir_ref_instruction(op1, irb->current_basic_block);
ir_ref_instruction(op2, irb->current_basic_block);
ir_ref_instruction(op3, irb->current_basic_block);
return &instruction->base;
}
static IrInstruction *ir_build_align_of(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type_value) {
IrInstructionAlignOf *instruction = ir_build_instruction<IrInstructionAlignOf>(irb, scope, source_node);
instruction->type_value = type_value;
@ -4028,6 +4048,33 @@ static IrInstruction *ir_gen_overflow_op(IrBuilder *irb, Scope *scope, AstNode *
return ir_build_overflow_op(irb, scope, node, op, type_value, op1, op2, result_ptr, nullptr);
}
static IrInstruction *ir_gen_mul_add(IrBuilder *irb, Scope *scope, AstNode *node) {
assert(node->type == NodeTypeFnCallExpr);
AstNode *type_node = node->data.fn_call_expr.params.at(0);
AstNode *op1_node = node->data.fn_call_expr.params.at(1);
AstNode *op2_node = node->data.fn_call_expr.params.at(2);
AstNode *op3_node = node->data.fn_call_expr.params.at(3);
IrInstruction *type_value = ir_gen_node(irb, type_node, scope);
if (type_value == irb->codegen->invalid_instruction)
return irb->codegen->invalid_instruction;
IrInstruction *op1 = ir_gen_node(irb, op1_node, scope);
if (op1 == irb->codegen->invalid_instruction)
return irb->codegen->invalid_instruction;
IrInstruction *op2 = ir_gen_node(irb, op2_node, scope);
if (op2 == irb->codegen->invalid_instruction)
return irb->codegen->invalid_instruction;
IrInstruction *op3 = ir_gen_node(irb, op3_node, scope);
if (op3 == irb->codegen->invalid_instruction)
return irb->codegen->invalid_instruction;
return ir_build_mul_add(irb, scope, node, type_value, op1, op2, op3);
}
static IrInstruction *ir_gen_this(IrBuilder *irb, Scope *orig_scope, AstNode *node) {
for (Scope *it_scope = orig_scope; it_scope != nullptr; it_scope = it_scope->parent) {
if (it_scope->id == ScopeIdDecls) {
@ -4687,6 +4734,8 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
return ir_lval_wrap(irb, scope, ir_gen_overflow_op(irb, scope, node, IrOverflowOpMul), lval);
case BuiltinFnIdShlWithOverflow:
return ir_lval_wrap(irb, scope, ir_gen_overflow_op(irb, scope, node, IrOverflowOpShl), lval);
case BuiltinFnIdMulAdd:
return ir_lval_wrap(irb, scope, ir_gen_mul_add(irb, scope, node), lval);
case BuiltinFnIdTypeName:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
@ -21185,6 +21234,125 @@ static IrInstruction *ir_analyze_instruction_overflow_op(IrAnalyze *ira, IrInstr
return result;
}
static void ir_eval_mul_add(IrAnalyze *ira, IrInstructionMulAdd *source_instr, ZigType *float_type,
ConstExprValue *op1, ConstExprValue *op2, ConstExprValue *op3, ConstExprValue *out_val) {
if (float_type->id == ZigTypeIdComptimeFloat) {
f128M_mulAdd(&out_val->data.x_bigfloat.value, &op1->data.x_bigfloat.value, &op2->data.x_bigfloat.value,
&op3->data.x_bigfloat.value);
} else if (float_type->id == ZigTypeIdFloat) {
switch (float_type->data.floating.bit_count) {
case 16:
out_val->data.x_f16 = f16_mulAdd(op1->data.x_f16, op2->data.x_f16, op3->data.x_f16);
break;
case 32:
out_val->data.x_f32 = fmaf(op1->data.x_f32, op2->data.x_f32, op3->data.x_f32);
break;
case 64:
out_val->data.x_f64 = fma(op1->data.x_f64, op2->data.x_f64, op3->data.x_f64);
break;
case 128:
f128M_mulAdd(&op1->data.x_f128, &op2->data.x_f128, &op3->data.x_f128, &out_val->data.x_f128);
break;
default:
zig_unreachable();
}
} else {
zig_unreachable();
}
}
static IrInstruction *ir_analyze_instruction_mul_add(IrAnalyze *ira, IrInstructionMulAdd *instruction) {
IrInstruction *type_value = instruction->type_value->child;
if (type_is_invalid(type_value->value.type))
return ira->codegen->invalid_instruction;
ZigType *expr_type = ir_resolve_type(ira, type_value);
if (type_is_invalid(expr_type))
return ira->codegen->invalid_instruction;
// Only allow float types, and vectors of floats.
ZigType *float_type = (expr_type->id == ZigTypeIdVector) ? expr_type->data.vector.elem_type : expr_type;
if (float_type->id != ZigTypeIdFloat) {
ir_add_error(ira, type_value,
buf_sprintf("expected float or vector of float type, found '%s'", buf_ptr(&float_type->name)));
return ira->codegen->invalid_instruction;
}
IrInstruction *op1 = instruction->op1->child;
if (type_is_invalid(op1->value.type))
return ira->codegen->invalid_instruction;
IrInstruction *casted_op1 = ir_implicit_cast(ira, op1, expr_type);
if (type_is_invalid(casted_op1->value.type))
return ira->codegen->invalid_instruction;
IrInstruction *op2 = instruction->op2->child;
if (type_is_invalid(op2->value.type))
return ira->codegen->invalid_instruction;
IrInstruction *casted_op2 = ir_implicit_cast(ira, op2, expr_type);
if (type_is_invalid(casted_op2->value.type))
return ira->codegen->invalid_instruction;
IrInstruction *op3 = instruction->op3->child;
if (type_is_invalid(op3->value.type))
return ira->codegen->invalid_instruction;
IrInstruction *casted_op3 = ir_implicit_cast(ira, op3, expr_type);
if (type_is_invalid(casted_op3->value.type))
return ira->codegen->invalid_instruction;
if (instr_is_comptime(casted_op1) &&
instr_is_comptime(casted_op2) &&
instr_is_comptime(casted_op3)) {
ConstExprValue *op1_const = ir_resolve_const(ira, casted_op1, UndefBad);
if (!op1_const)
return ira->codegen->invalid_instruction;
ConstExprValue *op2_const = ir_resolve_const(ira, casted_op2, UndefBad);
if (!op2_const)
return ira->codegen->invalid_instruction;
ConstExprValue *op3_const = ir_resolve_const(ira, casted_op3, UndefBad);
if (!op3_const)
return ira->codegen->invalid_instruction;
IrInstruction *result = ir_const(ira, &instruction->base, expr_type);
ConstExprValue *out_val = &result->value;
if (expr_type->id == ZigTypeIdVector) {
expand_undef_array(ira->codegen, op1_const);
expand_undef_array(ira->codegen, op2_const);
expand_undef_array(ira->codegen, op3_const);
out_val->special = ConstValSpecialUndef;
expand_undef_array(ira->codegen, out_val);
size_t len = expr_type->data.vector.len;
for (size_t i = 0; i < len; i += 1) {
ConstExprValue *float_operand_op1 = &op1_const->data.x_array.data.s_none.elements[i];
ConstExprValue *float_operand_op2 = &op2_const->data.x_array.data.s_none.elements[i];
ConstExprValue *float_operand_op3 = &op3_const->data.x_array.data.s_none.elements[i];
ConstExprValue *float_out_val = &out_val->data.x_array.data.s_none.elements[i];
assert(float_operand_op1->type == float_type);
assert(float_operand_op2->type == float_type);
assert(float_operand_op3->type == float_type);
assert(float_out_val->type == float_type);
ir_eval_mul_add(ira, instruction, float_type,
op1_const, op2_const, op3_const, float_out_val);
float_out_val->type = float_type;
}
out_val->type = expr_type;
out_val->special = ConstValSpecialStatic;
} else {
ir_eval_mul_add(ira, instruction, float_type, op1_const, op2_const, op3_const, out_val);
}
return result;
}
IrInstruction *result = ir_build_mul_add(&ira->new_irb,
instruction->base.scope, instruction->base.source_node,
type_value, casted_op1, casted_op2, casted_op3);
result->value.type = expr_type;
return result;
}
static IrInstruction *ir_analyze_instruction_test_err(IrAnalyze *ira, IrInstructionTestErr *instruction) {
IrInstruction *value = instruction->value->child;
if (type_is_invalid(value->value.type))
@ -23596,6 +23764,8 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio
return ir_analyze_instruction_mark_err_ret_trace_ptr(ira, (IrInstructionMarkErrRetTracePtr *)instruction);
case IrInstructionIdSqrt:
return ir_analyze_instruction_sqrt(ira, (IrInstructionSqrt *)instruction);
case IrInstructionIdMulAdd:
return ir_analyze_instruction_mul_add(ira, (IrInstructionMulAdd *)instruction);
case IrInstructionIdIntToErr:
return ir_analyze_instruction_int_to_err(ira, (IrInstructionIntToErr *)instruction);
case IrInstructionIdErrToInt:
@ -23835,6 +24005,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
case IrInstructionIdCoroPromise:
case IrInstructionIdPromiseResultType:
case IrInstructionIdSqrt:
case IrInstructionIdMulAdd:
case IrInstructionIdAtomicLoad:
case IrInstructionIdIntCast:
case IrInstructionIdFloatCast:

View File

@ -1439,6 +1439,22 @@ static void ir_print_sqrt(IrPrint *irp, IrInstructionSqrt *instruction) {
fprintf(irp->f, ")");
}
static void ir_print_mul_add(IrPrint *irp, IrInstructionMulAdd *instruction) {
fprintf(irp->f, "@mulAdd(");
if (instruction->type_value != nullptr) {
ir_print_other_instruction(irp, instruction->type_value);
} else {
fprintf(irp->f, "null");
}
fprintf(irp->f, ",");
ir_print_other_instruction(irp, instruction->op1);
fprintf(irp->f, ",");
ir_print_other_instruction(irp, instruction->op2);
fprintf(irp->f, ",");
ir_print_other_instruction(irp, instruction->op3);
fprintf(irp->f, ")");
}
static void ir_print_decl_var_gen(IrPrint *irp, IrInstructionDeclVarGen *decl_var_instruction) {
ZigVar *var = decl_var_instruction->var;
const char *var_or_const = decl_var_instruction->var->gen_is_const ? "const" : "var";
@ -1905,6 +1921,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
case IrInstructionIdSqrt:
ir_print_sqrt(irp, (IrInstructionSqrt *)instruction);
break;
case IrInstructionIdMulAdd:
ir_print_mul_add(irp, (IrInstructionMulAdd *)instruction);
break;
case IrInstructionIdAtomicLoad:
ir_print_atomic_load(irp, (IrInstructionAtomicLoad *)instruction);
break;

View File

@ -0,0 +1,34 @@
const expect = @import("std").testing.expect;
test "@mulAdd" {
comptime testMulAdd();
testMulAdd();
}
fn testMulAdd() void {
{
var a: f16 = 5.5;
var b: f16 = 2.5;
var c: f16 = 6.25;
expect(@mulAdd(f16, a, b, c) == 20);
}
{
var a: f32 = 5.5;
var b: f32 = 2.5;
var c: f32 = 6.25;
expect(@mulAdd(f32, a, b, c) == 20);
}
{
var a: f64 = 5.5;
var b: f64 = 2.5;
var c: f64 = 6.25;
expect(@mulAdd(f64, a, b, c) == 20);
}
// Awaits implementation in libm.zig
//{
// var a: f16 = 5.5;
// var b: f128 = 2.5;
// var c: f128 = 6.25;
// expect(@mulAdd(f128, a, b, c) == 20);
//}
}