diff --git a/CMakeLists.txt b/CMakeLists.txt index bf90a7ef4..ea21f6fc7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -515,6 +515,7 @@ set(ZIG_STD_FILES "os/windows/util.zig" "os/zen.zig" "rand/index.zig" + "rand/ziggurat.zig" "sort.zig" "special/bootstrap.zig" "special/bootstrap_lib.zig" diff --git a/std/rand/index.zig b/std/rand/index.zig index 6a746fce9..bd6209009 100644 --- a/std/rand/index.zig +++ b/std/rand/index.zig @@ -19,6 +19,7 @@ const builtin = @import("builtin"); const assert = std.debug.assert; const mem = std.mem; const math = std.math; +const ziggurat = @import("ziggurat.zig"); // When you need fast unbiased random numbers pub const DefaultPrng = Xoroshiro128; @@ -109,15 +110,28 @@ pub const Random = struct { } } - /// Return a floating point value normally distributed in the range [0, 1]. + /// Return a floating point value normally distributed with mean = 0, stddev = 1. + /// + /// To use different parameters, use: floatNorm(...) * desiredStddev + desiredMean. pub fn floatNorm(r: &Random, comptime T: type) T { - // TODO(tiehuis): See https://www.doornik.com/research/ziggurat.pdf - @compileError("floatNorm is unimplemented"); + const value = ziggurat.next_f64(r, ziggurat.NormDist); + switch (T) { + f32 => return f32(value), + f64 => return value, + else => @compileError("unknown floating point type"), + } } - /// Return a exponentially distributed float between (0, @maxValue(f64)) + /// Return an exponentially distributed float with a rate parameter of 1. + /// + /// To use a different rate parameter, use: floatExp(...) / desiredRate. pub fn floatExp(r: &Random, comptime T: type) T { - @compileError("floatExp is unimplemented"); + const value = ziggurat.next_f64(r, ziggurat.ExpDist); + switch (T) { + f32 => return f32(value), + f64 => return value, + else => @compileError("unknown floating point type"), + } } /// Shuffle a slice into a random order. diff --git a/std/rand/ziggurat.zig b/std/rand/ziggurat.zig new file mode 100644 index 000000000..7790b71d2 --- /dev/null +++ b/std/rand/ziggurat.zig @@ -0,0 +1,146 @@ +// Implements ZIGNOR [1]. +// +// [1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to Generate Normal Random Samples*] +// (https://www.doornik.com/research/ziggurat.pdf). Nuffield College, Oxford. +// +// rust/rand used as a reference; +// +// NOTE: This seems interesting but reference code is a bit hard to grok: +// https://sbarral.github.io/etf. + +const std = @import("../index.zig"); +const math = std.math; +const Random = std.rand.Random; + +pub fn next_f64(random: &Random, comptime tables: &const ZigTable) f64 { + while (true) { + // We manually construct a float from parts as we can avoid an extra random lookup here by + // using the unused exponent for the lookup table entry. + const bits = random.scalar(u64); + const i = usize(bits & 0xff); + + const u = blk: { + if (tables.is_symmetric) { + // Generate a value in the range [2, 4) and scale into [-1, 1) + const repr = ((0x3ff + 1) << 52) | (bits >> 12); + break :blk @bitCast(f64, repr) - 3.0; + } else { + // Generate a value in the range [1, 2) and scale into (0, 1) + const repr = (0x3ff << 52) | (bits >> 12); + break :blk @bitCast(f64, repr) - (1.0 - math.f64_epsilon / 2.0); + } + }; + + const x = u * tables.x[i]; + const test_x = if (tables.is_symmetric) math.fabs(x) else x; + + // equivalent to |u| < tables.x[i+1] / tables.x[i] (or u < tables.x[i+1] / tables.x[i]) + if (test_x < tables.x[i + 1]) { + return x; + } + + if (i == 0) { + return tables.zero_case(random, u); + } + + // equivalent to f1 + DRanU() * (f0 - f1) < 1 + if (tables.f[i + 1] + (tables.f[i] - tables.f[i + 1]) * random.float(f64) < tables.pdf(x)) { + return x; + } + } +} + +pub const ZigTable = struct { + r: f64, + x: [257]f64, + f: [257]f64, + + // probability density function used as a fallback + pdf: fn(f64) f64, + // whether the distribution is symmetric + is_symmetric: bool, + // fallback calculation in the case we are in the 0 block + zero_case: fn(&Random, f64) f64, +}; + +// zigNorInit +fn ZigTableGen(comptime is_symmetric: bool, comptime r: f64, comptime v: f64, comptime f: fn(f64) f64, + comptime f_inv: fn(f64) f64, comptime zero_case: fn(&Random, f64) f64) ZigTable { + var tables: ZigTable = undefined; + + tables.is_symmetric = is_symmetric; + tables.r = r; + tables.pdf = f; + tables.zero_case = zero_case; + + tables.x[0] = v / f(r); + tables.x[1] = r; + + for (tables.x[2..256]) |*entry, i| { + const last = tables.x[2 + i - 1]; + *entry = f_inv(v / last + f(last)); + } + tables.x[256] = 0; + + for (tables.f[0..]) |*entry, i| { + *entry = f(tables.x[i]); + } + + return tables; +} + +// N(0, 1) +pub const NormDist = blk: { + @setEvalBranchQuota(30000); + break :blk ZigTableGen(true, norm_r, norm_v, norm_f, norm_f_inv, norm_zero_case); +}; + +const norm_r = 3.6541528853610088; +const norm_v = 0.00492867323399; + +fn norm_f(x: f64) f64 { return math.exp(-x * x / 2.0); } +fn norm_f_inv(y: f64) f64 { return math.sqrt(-2.0 * math.ln(y)); } +fn norm_zero_case(random: &Random, u: f64) f64 { + var x: f64 = 1; + var y: f64 = 0; + + while (-2.0 * y < x * x) { + x = math.ln(random.float(f64)) / norm_r; + y = math.ln(random.float(f64)); + } + + if (u < 0) { + return x - norm_r; + } else { + return norm_r - x; + } +} + +test "ziggurant normal dist sanity" { + var prng = std.rand.DefaultPrng.init(0); + var i: usize = 0; + while (i < 1000) : (i += 1) { + _ = prng.random.floatNorm(f64); + } +} + +// Exp(1) +pub const ExpDist = blk: { + @setEvalBranchQuota(30000); + break :blk ZigTableGen(false, exp_r, exp_v, exp_f, exp_f_inv, exp_zero_case); +}; + +const exp_r = 7.69711747013104972; +const exp_v = 0.0039496598225815571993; + +fn exp_f(x: f64) f64 { return math.exp(-x); } +fn exp_f_inv(y: f64) f64 { return -math.ln(y); } +fn exp_zero_case(random: &Random, _: f64) f64 { return exp_r - math.ln(random.float(f64)); } + +test "ziggurant exp dist sanity" { + var prng = std.rand.DefaultPrng.init(0); + var i: usize = 0; + while (i < 1000) : (i += 1) { + _ = prng.random.floatExp(f64); + } +}