zig/std/rand.zig

191 lines
5.6 KiB
Zig
Raw Normal View History

const assert = @import("debug.zig").assert;
2016-07-28 20:14:57 -07:00
const rand_test = @import("rand_test.zig");
const mem = @import("mem.zig");
pub const MT19937_32 = MersenneTwister(
u32, 624, 397, 31,
0x9908B0DF,
11, 0xFFFFFFFF,
7, 0x9D2C5680,
15, 0xEFC60000,
18, 1812433253);
pub const MT19937_64 = MersenneTwister(
u64, 312, 156, 31,
0xB5026F5AA96619E9,
29, 0x5555555555555555,
17, 0x71D67FFFEDA60000,
37, 0xFFF7EEE000000000,
43, 6364136223846793005);
2016-01-02 23:30:41 -08:00
/// Use `init` to initialize this state.
2016-12-18 16:40:26 -08:00
pub const Rand = struct {
2016-08-16 22:42:50 -07:00
const Rng = if (@sizeOf(usize) >= 8) MT19937_64 else MT19937_32;
rng: Rng,
2016-01-02 23:30:41 -08:00
/// Initialize random state with the given seed.
pub fn init(seed: usize) -> Rand {
Rand {
.rng = Rng.init(seed),
}
}
/// Get an integer or boolean with random bits.
pub fn scalar(r: &Rand, comptime T: type) -> T {
if (T == usize) {
return r.rng.get();
} else if (T == bool) {
return (r.rng.get() & 0b1) == 0;
} else {
2016-08-16 22:42:50 -07:00
var result: [@sizeOf(T)]u8 = undefined;
r.fillBytes(result[0...]);
return mem.readInt(result, T, false);
2016-01-02 23:30:41 -08:00
}
}
/// Fill `buf` with randomness.
2016-08-16 22:42:50 -07:00
pub fn fillBytes(r: &Rand, buf: []u8) {
var bytes_left = buf.len;
2016-08-16 22:42:50 -07:00
while (bytes_left >= @sizeOf(usize)) {
mem.writeInt(buf[buf.len - bytes_left...], r.rng.get(), false);
2016-08-16 22:42:50 -07:00
bytes_left -= @sizeOf(usize);
}
2016-01-02 23:30:41 -08:00
if (bytes_left > 0) {
var rand_val_array: [@sizeOf(usize)]u8 = undefined;
mem.writeInt(rand_val_array[0...], r.rng.get(), false);
2016-01-02 23:30:41 -08:00
while (bytes_left > 0) {
2016-08-16 22:42:50 -07:00
buf[buf.len - bytes_left] = rand_val_array[@sizeOf(usize) - bytes_left];
2016-01-02 23:30:41 -08:00
bytes_left -= 1;
}
}
}
/// Get a random unsigned integer with even distribution between `start`
/// inclusive and `end` exclusive.
// TODO support signed integers and then rename to "range"
pub fn rangeUnsigned(r: &Rand, comptime T: type, start: T, end: T) -> T {
2016-01-02 23:30:41 -08:00
const range = end - start;
2016-08-16 22:42:50 -07:00
const leftover = @maxValue(T) % range;
const upper_bound = @maxValue(T) - leftover;
var rand_val_array: [@sizeOf(T)]u8 = undefined;
2016-01-02 23:30:41 -08:00
while (true) {
r.fillBytes(rand_val_array[0...]);
const rand_val = mem.readInt(rand_val_array, T, false);
2016-01-02 23:30:41 -08:00
if (rand_val < upper_bound) {
return start + (rand_val % range);
}
}
}
/// Get a floating point value in the range 0.0..1.0.
pub fn float(r: &Rand, comptime T: type) -> T {
// TODO Implement this way instead:
2016-08-16 22:42:50 -07:00
// const int = @int_type(false, @sizeOf(T) * 8);
// const mask = ((1 << @float_mantissa_bit_count(T)) - 1);
// const rand_bits = r.rng.scalar(int) & mask;
// return @float_compose(T, false, 0, rand_bits) - 1.0
const int_type = @IntType(false, @sizeOf(T) * 8);
const precision = if (T == f32) {
16777216
} else if (T == f64) {
9007199254740992
} else {
@compileError("unknown floating point type")
};
2016-08-16 22:42:50 -07:00
return T(r.rangeUnsigned(int_type, 0, precision)) / T(precision);
2016-02-05 03:30:19 -08:00
}
2016-12-18 16:40:26 -08:00
};
fn MersenneTwister(
comptime int: type, comptime n: usize, comptime m: usize, comptime r: int,
comptime a: int,
comptime u: int, comptime d: int,
comptime s: int, comptime b: int,
comptime t: int, comptime c: int,
comptime l: int, comptime f: int) -> type
{
2016-12-18 16:40:26 -08:00
struct {
const Self = this;
2016-12-18 16:40:26 -08:00
array: [n]int,
index: usize,
pub fn init(seed: int) -> Self {
var mt = Self {
.array = undefined,
.index = n,
};
2016-12-18 16:40:26 -08:00
var prev_value = seed;
mt.array[0] = prev_value;
var i: usize = 1;
while (i < n; i += 1) {
2016-12-18 16:40:26 -08:00
prev_value = int(i) +% f *% (prev_value ^ (prev_value >> (int.bit_count - 2)));
mt.array[i] = prev_value;
}
return mt;
2016-12-18 16:40:26 -08:00
}
2016-12-18 16:40:26 -08:00
pub fn get(mt: &Self) -> int {
const mag01 = []int{0, a};
const LM: int = (1 << r) - 1;
const UM = ~LM;
2016-12-18 16:40:26 -08:00
if (mt.index >= n) {
var i: usize = 0;
2016-12-18 16:40:26 -08:00
while (i < n - m; i += 1) {
const x = (mt.array[i] & UM) | (mt.array[i + 1] & LM);
mt.array[i] = mt.array[i + m] ^ (x >> 1) ^ mag01[x & 0x1];
}
2016-12-18 16:40:26 -08:00
while (i < n - 1; i += 1) {
const x = (mt.array[i] & UM) | (mt.array[i + 1] & LM);
mt.array[i] = mt.array[i + m - n] ^ (x >> 1) ^ mag01[x & 0x1];
2016-12-18 16:40:26 -08:00
}
const x = (mt.array[i] & UM) | (mt.array[0] & LM);
mt.array[i] = mt.array[m - 1] ^ (x >> 1) ^ mag01[x & 0x1];
2016-12-18 16:40:26 -08:00
mt.index = 0;
}
2016-01-02 23:30:41 -08:00
2016-12-18 16:40:26 -08:00
var x = mt.array[mt.index];
mt.index += 1;
2016-12-18 16:40:26 -08:00
x ^= ((x >> u) & d);
x ^= ((x <<% s) & b);
x ^= ((x <<% t) & c);
x ^= (x >> l);
2016-12-18 16:40:26 -08:00
return x;
}
}
2016-01-24 18:27:12 -08:00
}
2016-02-05 03:30:19 -08:00
test "rand float 32" {
var r = Rand.init(42);
var i: usize = 0;
while (i < 1000; i += 1) {
const val = r.float(f32);
2016-07-28 20:14:57 -07:00
assert(val >= 0.0);
assert(val < 1.0);
}
2016-02-05 03:30:19 -08:00
}
2016-07-28 20:14:57 -07:00
test "testMT19937_64" {
var rng = MT19937_64.init(rand_test.mt64_seed);
2016-07-28 20:14:57 -07:00
for (rand_test.mt64_data) |value| {
assert(value == rng.get());
}
}
test "testMT19937_32" {
var rng = MT19937_32.init(rand_test.mt32_seed);
2016-07-28 20:14:57 -07:00
for (rand_test.mt32_data) |value| {
assert(value == rng.get());
}
}