zig/std/math/fma.zig

162 lines
4.4 KiB
Zig

const std = @import("../index.zig");
const math = std.math;
const assert = std.debug.assert;
pub fn fma(comptime T: type, x: T, y: T, z: T) T {
return switch (T) {
f32 => fma32(x, y, z),
f64 => fma64(x, y ,z),
else => @compileError("fma not implemented for " ++ @typeName(T)),
};
}
fn fma32(x: f32, y: f32, z: f32) f32 {
const xy = f64(x) * y;
const xy_z = xy + z;
const u = @bitCast(u64, xy_z);
const e = (u >> 52) & 0x7FF;
if ((u & 0x1FFFFFFF) != 0x10000000 or e == 0x7FF or xy_z - xy == z) {
return f32(xy_z);
} else {
// TODO: Handle inexact case with double-rounding
return f32(xy_z);
}
}
fn fma64(x: f64, y: f64, z: f64) f64 {
if (!math.isFinite(x) or !math.isFinite(y)) {
return x * y + z;
}
if (!math.isFinite(z)) {
return z;
}
if (x == 0.0 or y == 0.0) {
return x * y + z;
}
if (z == 0.0) {
return x * y;
}
const x1 = math.frexp(x);
var ex = x1.exponent;
var xs = x1.significand;
const x2 = math.frexp(y);
var ey = x2.exponent;
var ys = x2.significand;
const x3 = math.frexp(z);
var ez = x3.exponent;
var zs = x3.significand;
var spread = ex + ey - ez;
if (spread <= 53 * 2) {
zs = math.scalbn(zs, -spread);
} else {
zs = math.copysign(f64, math.f64_min, zs);
}
const xy = dd_mul(xs, ys);
const r = dd_add(xy.hi, zs);
spread = ex + ey;
if (r.hi == 0.0) {
return xy.hi + zs + math.scalbn(xy.lo, spread);
}
const adj = add_adjusted(r.lo, xy.lo);
if (spread + math.ilogb(r.hi) > -1023) {
return math.scalbn(r.hi + adj, spread);
} else {
return add_and_denorm(r.hi, adj, spread);
}
}
const dd = struct { hi: f64, lo: f64, };
fn dd_add(a: f64, b: f64) dd {
var ret: dd = undefined;
ret.hi = a + b;
const s = ret.hi - a;
ret.lo = (a - (ret.hi - s)) + (b - s);
return ret;
}
fn dd_mul(a: f64, b: f64) dd {
var ret: dd = undefined;
const split: f64 = 0x1.0p27 + 1.0;
var p = a * split;
var ha = a - p;
ha += p;
var la = a - ha;
p = b * split;
var hb = b - p;
hb += p;
var lb = b - hb;
p = ha * hb;
var q = ha * lb + la * hb;
ret.hi = p + q;
ret.lo = p - ret.hi + q + la * lb;
return ret;
}
fn add_adjusted(a: f64, b: f64) f64 {
var sum = dd_add(a, b);
if (sum.lo != 0) {
var uhii = @bitCast(u64, sum.hi);
if (uhii & 1 == 0) {
// hibits += copysign(1.0, sum.hi, sum.lo)
const uloi = @bitCast(u64, sum.lo);
uhii += 1 - ((uhii ^ uloi) >> 62);
sum.hi = @bitCast(f64, uhii);
}
}
return sum.hi;
}
fn add_and_denorm(a: f64, b: f64, scale: i32) f64 {
var sum = dd_add(a, b);
if (sum.lo != 0) {
var uhii = @bitCast(u64, sum.hi);
const bits_lost = -i32((uhii >> 52) & 0x7FF) - scale + 1;
if ((bits_lost != 1) == (uhii & 1 != 0)) {
const uloi = @bitCast(u64, sum.lo);
uhii += 1 - (((uhii ^ uloi) >> 62) & 2);
sum.hi = @bitCast(f64, uhii);
}
}
return math.scalbn(sum.hi, scale);
}
test "math.fma" {
assert(fma(f32, 0.0, 1.0, 1.0) == fma32(0.0, 1.0, 1.0));
assert(fma(f64, 0.0, 1.0, 1.0) == fma64(0.0, 1.0, 1.0));
}
test "math.fma32" {
const epsilon = 0.000001;
assert(math.approxEq(f32, fma32(0.0, 5.0, 9.124), 9.124, epsilon));
assert(math.approxEq(f32, fma32(0.2, 5.0, 9.124), 10.124, epsilon));
assert(math.approxEq(f32, fma32(0.8923, 5.0, 9.124), 13.5855, epsilon));
assert(math.approxEq(f32, fma32(1.5, 5.0, 9.124), 16.624, epsilon));
assert(math.approxEq(f32, fma32(37.45, 5.0, 9.124), 196.374004, epsilon));
assert(math.approxEq(f32, fma32(89.123, 5.0, 9.124), 454.739005, epsilon));
assert(math.approxEq(f32, fma32(123123.234375, 5.0, 9.124), 615625.295875, epsilon));
}
test "math.fma64" {
const epsilon = 0.000001;
assert(math.approxEq(f64, fma64(0.0, 5.0, 9.124), 9.124, epsilon));
assert(math.approxEq(f64, fma64(0.2, 5.0, 9.124), 10.124, epsilon));
assert(math.approxEq(f64, fma64(0.8923, 5.0, 9.124), 13.5855, epsilon));
assert(math.approxEq(f64, fma64(1.5, 5.0, 9.124), 16.624, epsilon));
assert(math.approxEq(f64, fma64(37.45, 5.0, 9.124), 196.374, epsilon));
assert(math.approxEq(f64, fma64(89.123, 5.0, 9.124), 454.739, epsilon));
assert(math.approxEq(f64, fma64(123123.234375, 5.0, 9.124), 615625.295875, epsilon));
}