overhaul api for getting random integers (#1578)
* rand api overhaul * no retry limits. instead documented a recommendation to call int(T) % len directly.master
parent
1c26c2f4d5
commit
e7d9d00ac8
|
@ -5,11 +5,11 @@
|
|||
// ```
|
||||
// var buf: [8]u8 = undefined;
|
||||
// try std.os.getRandomBytes(buf[0..]);
|
||||
// const seed = mem.readInt(buf[0..8], u64, builtin.Endian.Little);
|
||||
// const seed = mem.readIntLE(u64, buf[0..8]);
|
||||
//
|
||||
// var r = DefaultPrng.init(seed);
|
||||
//
|
||||
// const s = r.random.scalar(u64);
|
||||
// const s = r.random.int(u64);
|
||||
// ```
|
||||
//
|
||||
// TODO(tiehuis): Benchmark these against other reference implementations.
|
||||
|
@ -35,60 +35,117 @@ pub const Random = struct {
|
|||
r.fillFn(r, buf);
|
||||
}
|
||||
|
||||
/// Return a random integer/boolean type.
|
||||
pub fn scalar(r: *Random, comptime T: type) T {
|
||||
var rand_bytes: [@sizeOf(T)]u8 = undefined;
|
||||
pub fn boolean(r: *Random) bool {
|
||||
return r.int(u1) != 0;
|
||||
}
|
||||
|
||||
/// Returns a random int `i` such that `0 <= i <= @maxValue(T)`.
|
||||
/// `i` is evenly distributed.
|
||||
pub fn int(r: *Random, comptime T: type) T {
|
||||
const UnsignedT = @IntType(false, T.bit_count);
|
||||
const ByteAlignedT = @IntType(false, @divTrunc(T.bit_count + 7, 8) * 8);
|
||||
|
||||
var rand_bytes: [@sizeOf(ByteAlignedT)]u8 = undefined;
|
||||
r.bytes(rand_bytes[0..]);
|
||||
|
||||
if (T == bool) {
|
||||
return rand_bytes[0] & 0b1 == 0;
|
||||
} else {
|
||||
// NOTE: Cannot @bitCast array to integer type.
|
||||
return mem.readInt(rand_bytes, T, builtin.Endian.Little);
|
||||
// use LE instead of native endian for better portability maybe?
|
||||
// TODO: endian portability is pointless if the underlying prng isn't endian portable.
|
||||
// TODO: document the endian portability of this library.
|
||||
const byte_aligned_result = mem.readIntLE(ByteAlignedT, rand_bytes);
|
||||
const unsigned_result = @truncate(UnsignedT, byte_aligned_result);
|
||||
return @bitCast(T, unsigned_result);
|
||||
}
|
||||
|
||||
/// Returns an evenly distributed random unsigned integer `0 <= i < less_than`.
|
||||
/// This function assumes that the underlying ::fillFn produces evenly distributed values.
|
||||
/// Within this assumption, the runtime of this function is exponentially distributed.
|
||||
/// If ::fillFn were backed by a true random generator,
|
||||
/// the runtime of this function would technically be unbounded.
|
||||
/// However, if ::fillFn is backed by any evenly distributed pseudo random number generator,
|
||||
/// this function is guaranteed to return.
|
||||
/// If you need deterministic runtime bounds, consider instead using `r.int(T) % less_than`,
|
||||
/// which will usually be biased toward smaller values.
|
||||
pub fn uintLessThan(r: *Random, comptime T: type, less_than: T) T {
|
||||
assert(T.is_signed == false);
|
||||
assert(0 < less_than);
|
||||
|
||||
const last_group_size_minus_one: T = @maxValue(T) % less_than;
|
||||
if (last_group_size_minus_one == less_than - 1) {
|
||||
// less_than is a power of two.
|
||||
assert(math.floorPowerOfTwo(T, less_than) == less_than);
|
||||
// There is no retry zone. The optimal retry_zone_start would be @maxValue(T) + 1.
|
||||
return r.int(T) % less_than;
|
||||
}
|
||||
const retry_zone_start = @maxValue(T) - last_group_size_minus_one;
|
||||
|
||||
while (true) {
|
||||
const rand_val = r.int(T);
|
||||
if (rand_val < retry_zone_start) {
|
||||
return rand_val % less_than;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an evenly distributed random unsigned integer `0 <= i <= at_most`.
|
||||
/// See ::uintLessThan, which this function uses in most cases,
|
||||
/// for commentary on the runtime of this function.
|
||||
pub fn uintAtMost(r: *Random, comptime T: type, at_most: T) T {
|
||||
assert(T.is_signed == false);
|
||||
if (at_most == @maxValue(T)) {
|
||||
// have the full range
|
||||
return r.int(T);
|
||||
}
|
||||
return r.uintLessThan(T, at_most + 1);
|
||||
}
|
||||
|
||||
/// Returns an evenly distributed random integer `at_least <= i < less_than`.
|
||||
/// See ::uintLessThan, which this function uses in most cases,
|
||||
/// for commentary on the runtime of this function.
|
||||
pub fn intRangeLessThan(r: *Random, comptime T: type, at_least: T, less_than: T) T {
|
||||
assert(at_least < less_than);
|
||||
if (T.is_signed) {
|
||||
// Two's complement makes this math pretty easy.
|
||||
const UnsignedT = @IntType(false, T.bit_count);
|
||||
const lo = @bitCast(UnsignedT, at_least);
|
||||
const hi = @bitCast(UnsignedT, less_than);
|
||||
const result = lo +% r.uintLessThan(UnsignedT, hi -% lo);
|
||||
return @bitCast(T, result);
|
||||
} else {
|
||||
// The signed implementation would work fine, but we can use stricter arithmetic operators here.
|
||||
return at_least + r.uintLessThan(T, less_than - at_least);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an evenly distributed random integer `at_least <= i <= at_most`.
|
||||
/// See ::uintLessThan, which this function uses in most cases,
|
||||
/// for commentary on the runtime of this function.
|
||||
pub fn intRangeAtMost(r: *Random, comptime T: type, at_least: T, at_most: T) T {
|
||||
assert(at_least <= at_most);
|
||||
if (T.is_signed) {
|
||||
// Two's complement makes this math pretty easy.
|
||||
const UnsignedT = @IntType(false, T.bit_count);
|
||||
const lo = @bitCast(UnsignedT, at_least);
|
||||
const hi = @bitCast(UnsignedT, at_most);
|
||||
const result = lo +% r.uintAtMost(UnsignedT, hi -% lo);
|
||||
return @bitCast(T, result);
|
||||
} else {
|
||||
// The signed implementation would work fine, but we can use stricter arithmetic operators here.
|
||||
return at_least + r.uintAtMost(T, at_most - at_least);
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a random integer/boolean type.
|
||||
/// TODO: deprecated. use ::boolean or ::int instead.
|
||||
pub fn scalar(r: *Random, comptime T: type) T {
|
||||
if (T == bool) return r.boolean();
|
||||
return r.int(T);
|
||||
}
|
||||
|
||||
/// Return a random integer with even distribution between `start`
|
||||
/// inclusive and `end` exclusive. `start` must be less than `end`.
|
||||
/// TODO: deprecated. renamed to ::intRangeLessThan
|
||||
pub fn range(r: *Random, comptime T: type, start: T, end: T) T {
|
||||
assert(start < end);
|
||||
if (T.is_signed) {
|
||||
const uint = @IntType(false, T.bit_count);
|
||||
if (start >= 0 and end >= 0) {
|
||||
return @intCast(T, r.range(uint, @intCast(uint, start), @intCast(uint, end)));
|
||||
} else if (start < 0 and end < 0) {
|
||||
// Can't overflow because the range is over signed ints
|
||||
return math.negateCast(r.range(uint, math.absCast(end), math.absCast(start)) + 1) catch unreachable;
|
||||
} else if (start < 0 and end >= 0) {
|
||||
const end_uint = @intCast(uint, end);
|
||||
const total_range = math.absCast(start) + end_uint;
|
||||
const value = r.range(uint, 0, total_range);
|
||||
const result = if (value < end_uint) x: {
|
||||
break :x @intCast(T, value);
|
||||
} else if (value == end_uint) x: {
|
||||
break :x start;
|
||||
} else x: {
|
||||
// Can't overflow because the range is over signed ints
|
||||
break :x math.negateCast(value - end_uint) catch unreachable;
|
||||
};
|
||||
return result;
|
||||
} else {
|
||||
unreachable;
|
||||
}
|
||||
} else {
|
||||
const total_range = end - start;
|
||||
const leftover = @maxValue(T) % total_range;
|
||||
const upper_bound = @maxValue(T) - leftover;
|
||||
var rand_val_array: [@sizeOf(T)]u8 = undefined;
|
||||
|
||||
while (true) {
|
||||
r.bytes(rand_val_array[0..]);
|
||||
const rand_val = mem.readInt(rand_val_array, T, builtin.Endian.Little);
|
||||
if (rand_val < upper_bound) {
|
||||
return start + (rand_val % total_range);
|
||||
}
|
||||
}
|
||||
}
|
||||
return r.intRangeLessThan(T, start, end);
|
||||
}
|
||||
|
||||
/// Return a floating point value evenly distributed in the range [0, 1).
|
||||
|
@ -97,12 +154,12 @@ pub const Random = struct {
|
|||
// Note: The lowest mantissa bit is always set to 0 so we only use half the available range.
|
||||
switch (T) {
|
||||
f32 => {
|
||||
const s = r.scalar(u32);
|
||||
const s = r.int(u32);
|
||||
const repr = (0x7f << 23) | (s >> 9);
|
||||
return @bitCast(f32, repr) - 1.0;
|
||||
},
|
||||
f64 => {
|
||||
const s = r.scalar(u64);
|
||||
const s = r.int(u64);
|
||||
const repr = (0x3ff << 52) | (s >> 12);
|
||||
return @bitCast(f64, repr) - 1.0;
|
||||
},
|
||||
|
@ -142,12 +199,167 @@ pub const Random = struct {
|
|||
|
||||
var i: usize = 0;
|
||||
while (i < buf.len - 1) : (i += 1) {
|
||||
const j = r.range(usize, i, buf.len);
|
||||
const j = r.intRangeLessThan(usize, i, buf.len);
|
||||
mem.swap(T, &buf[i], &buf[j]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const SequentialPrng = struct {
|
||||
const Self = @This();
|
||||
random: Random,
|
||||
next_value: u8,
|
||||
|
||||
pub fn init() Self {
|
||||
return Self{
|
||||
.random = Random{ .fillFn = fill },
|
||||
.next_value = 0,
|
||||
};
|
||||
}
|
||||
|
||||
fn fill(r: *Random, buf: []u8) void {
|
||||
const self = @fieldParentPtr(Self, "random", r);
|
||||
for (buf) |*b| {
|
||||
b.* = self.next_value;
|
||||
}
|
||||
self.next_value +%= 1;
|
||||
}
|
||||
};
|
||||
|
||||
test "Random int" {
|
||||
testRandomInt();
|
||||
comptime testRandomInt();
|
||||
}
|
||||
fn testRandomInt() void {
|
||||
var r = SequentialPrng.init();
|
||||
|
||||
assert(r.random.int(u0) == 0);
|
||||
|
||||
r.next_value = 0;
|
||||
assert(r.random.int(u1) == 0);
|
||||
assert(r.random.int(u1) == 1);
|
||||
assert(r.random.int(u2) == 2);
|
||||
assert(r.random.int(u2) == 3);
|
||||
assert(r.random.int(u2) == 0);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.int(u8) == 0xff);
|
||||
r.next_value = 0x11;
|
||||
assert(r.random.int(u8) == 0x11);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.int(u32) == 0xffffffff);
|
||||
r.next_value = 0x11;
|
||||
assert(r.random.int(u32) == 0x11111111);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.int(i32) == -1);
|
||||
r.next_value = 0x11;
|
||||
assert(r.random.int(i32) == 0x11111111);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.int(i8) == -1);
|
||||
r.next_value = 0x11;
|
||||
assert(r.random.int(i8) == 0x11);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.int(u33) == 0x1ffffffff);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.int(i1) == -1);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.int(i2) == -1);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.int(i33) == -1);
|
||||
}
|
||||
|
||||
test "Random boolean" {
|
||||
testRandomBoolean();
|
||||
comptime testRandomBoolean();
|
||||
}
|
||||
fn testRandomBoolean() void {
|
||||
var r = SequentialPrng.init();
|
||||
assert(r.random.boolean() == false);
|
||||
assert(r.random.boolean() == true);
|
||||
assert(r.random.boolean() == false);
|
||||
assert(r.random.boolean() == true);
|
||||
}
|
||||
|
||||
test "Random intLessThan" {
|
||||
@setEvalBranchQuota(10000);
|
||||
testRandomIntLessThan();
|
||||
comptime testRandomIntLessThan();
|
||||
}
|
||||
fn testRandomIntLessThan() void {
|
||||
var r = SequentialPrng.init();
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.uintLessThan(u8, 4) == 3);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.uintLessThan(u8, 3) == 0);
|
||||
assert(r.next_value == 1);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeLessThan(u8, 0, 0x80) == 0x7f);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeLessThan(u8, 0x7f, 0xff) == 0xfe);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeLessThan(i8, 0, 0x40) == 0x3f);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeLessThan(i8, -0x40, 0x40) == 0x3f);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeLessThan(i8, -0x80, 0) == -1);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeLessThan(i64, -0x8000000000000000, 0) == -1);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeLessThan(i3, -4, 0) == -1);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeLessThan(i3, -2, 2) == 1);
|
||||
|
||||
// test retrying and eventually getting a good value
|
||||
// start just out of bounds
|
||||
r.next_value = 0x81;
|
||||
assert(r.random.uintLessThan(u8, 0x81) == 0);
|
||||
}
|
||||
|
||||
test "Random intAtMost" {
|
||||
@setEvalBranchQuota(10000);
|
||||
testRandomIntAtMost();
|
||||
comptime testRandomIntAtMost();
|
||||
}
|
||||
fn testRandomIntAtMost() void {
|
||||
var r = SequentialPrng.init();
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.uintAtMost(u8, 3) == 3);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.uintAtMost(u8, 2) == 0);
|
||||
assert(r.next_value == 1);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeAtMost(u8, 0, 0x7f) == 0x7f);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeAtMost(u8, 0x7f, 0xfe) == 0xfe);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeAtMost(i8, 0, 0x3f) == 0x3f);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeAtMost(i8, -0x40, 0x3f) == 0x3f);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeAtMost(i8, -0x80, -1) == -1);
|
||||
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeAtMost(i64, -0x8000000000000000, -1) == -1);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeAtMost(i3, -4, -1) == -1);
|
||||
r.next_value = 0xff;
|
||||
assert(r.random.intRangeAtMost(i3, -2, 1) == 1);
|
||||
|
||||
// test retrying and eventually getting a good value
|
||||
// start just out of bounds
|
||||
r.next_value = 0x81;
|
||||
assert(r.random.uintAtMost(u8, 0x80) == 0);
|
||||
}
|
||||
|
||||
// Generator to extend 64-bit seed values into longer sequences.
|
||||
//
|
||||
// The number of cycles is thus limited to 64-bits regardless of the engine, but this
|
||||
|
@ -622,17 +834,6 @@ test "Random float" {
|
|||
}
|
||||
}
|
||||
|
||||
test "Random scalar" {
|
||||
var prng = DefaultPrng.init(0);
|
||||
const s = prng.random.scalar(u64);
|
||||
}
|
||||
|
||||
test "Random bytes" {
|
||||
var prng = DefaultPrng.init(0);
|
||||
var buf: [2048]u8 = undefined;
|
||||
prng.random.bytes(buf[0..]);
|
||||
}
|
||||
|
||||
test "Random shuffle" {
|
||||
var prng = DefaultPrng.init(0);
|
||||
|
||||
|
@ -664,16 +865,16 @@ test "Random range" {
|
|||
testRange(&prng.random, -4, 3);
|
||||
testRange(&prng.random, -4, -1);
|
||||
testRange(&prng.random, 10, 14);
|
||||
// TODO: test that prng.random.range(1, 1) causes an assertion error
|
||||
testRange(&prng.random, -0x80, 0x7f);
|
||||
}
|
||||
|
||||
fn testRange(r: *Random, start: i32, end: i32) void {
|
||||
const count = @intCast(usize, end - start);
|
||||
var values_buffer = []bool{false} ** 20;
|
||||
fn testRange(r: *Random, start: i8, end: i8) void {
|
||||
const count = @intCast(usize, i32(end) - i32(start));
|
||||
var values_buffer = []bool{false} ** 0x100;
|
||||
const values = values_buffer[0..count];
|
||||
var i: usize = 0;
|
||||
while (i < count) {
|
||||
const value = r.range(i32, start, end);
|
||||
const value: i32 = r.intRangeLessThan(i8, start, end);
|
||||
const index = @intCast(usize, value - start);
|
||||
if (!values[index]) {
|
||||
i += 1;
|
||||
|
|
Loading…
Reference in New Issue