std.rand: Refactor Random interface

These changes have been made to resolve issue #10037. The `Random`
interface was implemented in such a way that causes significant slowdown
when calling the `fill` function of the rng used.

The `Random` interface is no longer stored in a field of the rng, and is
instead returned by the child function `random()` of the rng. This
avoids the performance issues caused by the interface.
This commit is contained in:
Ominitay 2021-10-27 15:53:29 +01:00 committed by Andrew Kelley
parent 9024f27d8f
commit c1a5ff34f3
18 changed files with 278 additions and 231 deletions

View file

@ -242,10 +242,11 @@ test "std.atomic.Queue" {
fn startPuts(ctx: *Context) u8 { fn startPuts(ctx: *Context) u8 {
var put_count: usize = puts_per_thread; var put_count: usize = puts_per_thread;
var r = std.rand.DefaultPrng.init(0xdeadbeef); var prng = std.rand.DefaultPrng.init(0xdeadbeef);
const random = prng.random();
while (put_count != 0) : (put_count -= 1) { while (put_count != 0) : (put_count -= 1) {
std.time.sleep(1); // let the os scheduler be our fuzz std.time.sleep(1); // let the os scheduler be our fuzz
const x = @bitCast(i32, r.random.int(u32)); const x = @bitCast(i32, random.int(u32));
const node = ctx.allocator.create(Queue(i32).Node) catch unreachable; const node = ctx.allocator.create(Queue(i32).Node) catch unreachable;
node.* = .{ node.* = .{
.prev = undefined, .prev = undefined,

View file

@ -147,10 +147,11 @@ test "std.atomic.stack" {
fn startPuts(ctx: *Context) u8 { fn startPuts(ctx: *Context) u8 {
var put_count: usize = puts_per_thread; var put_count: usize = puts_per_thread;
var r = std.rand.DefaultPrng.init(0xdeadbeef); var prng = std.rand.DefaultPrng.init(0xdeadbeef);
const random = prng.random();
while (put_count != 0) : (put_count -= 1) { while (put_count != 0) : (put_count -= 1) {
std.time.sleep(1); // let the os scheduler be our fuzz std.time.sleep(1); // let the os scheduler be our fuzz
const x = @bitCast(i32, r.random.int(u32)); const x = @bitCast(i32, random.int(u32));
const node = ctx.allocator.create(Stack(i32).Node) catch unreachable; const node = ctx.allocator.create(Stack(i32).Node) catch unreachable;
node.* = Stack(i32).Node{ node.* = Stack(i32).Node{
.next = undefined, .next = undefined,

View file

@ -11,6 +11,7 @@ const KiB = 1024;
const MiB = 1024 * KiB; const MiB = 1024 * KiB;
var prng = std.rand.DefaultPrng.init(0); var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();
const Crypto = struct { const Crypto = struct {
ty: type, ty: type,
@ -34,7 +35,7 @@ pub fn benchmarkHash(comptime Hash: anytype, comptime bytes: comptime_int) !u64
var h = Hash.init(.{}); var h = Hash.init(.{});
var block: [Hash.digest_length]u8 = undefined; var block: [Hash.digest_length]u8 = undefined;
prng.random.bytes(block[0..]); random.bytes(block[0..]);
var offset: usize = 0; var offset: usize = 0;
var timer = try Timer.start(); var timer = try Timer.start();
@ -66,11 +67,11 @@ const macs = [_]Crypto{
pub fn benchmarkMac(comptime Mac: anytype, comptime bytes: comptime_int) !u64 { pub fn benchmarkMac(comptime Mac: anytype, comptime bytes: comptime_int) !u64 {
var in: [512 * KiB]u8 = undefined; var in: [512 * KiB]u8 = undefined;
prng.random.bytes(in[0..]); random.bytes(in[0..]);
const key_length = if (Mac.key_length == 0) 32 else Mac.key_length; const key_length = if (Mac.key_length == 0) 32 else Mac.key_length;
var key: [key_length]u8 = undefined; var key: [key_length]u8 = undefined;
prng.random.bytes(key[0..]); random.bytes(key[0..]);
var mac: [Mac.mac_length]u8 = undefined; var mac: [Mac.mac_length]u8 = undefined;
var offset: usize = 0; var offset: usize = 0;
@ -94,10 +95,10 @@ pub fn benchmarkKeyExchange(comptime DhKeyExchange: anytype, comptime exchange_c
std.debug.assert(DhKeyExchange.shared_length >= DhKeyExchange.secret_length); std.debug.assert(DhKeyExchange.shared_length >= DhKeyExchange.secret_length);
var secret: [DhKeyExchange.shared_length]u8 = undefined; var secret: [DhKeyExchange.shared_length]u8 = undefined;
prng.random.bytes(secret[0..]); random.bytes(secret[0..]);
var public: [DhKeyExchange.shared_length]u8 = undefined; var public: [DhKeyExchange.shared_length]u8 = undefined;
prng.random.bytes(public[0..]); random.bytes(public[0..]);
var timer = try Timer.start(); var timer = try Timer.start();
const start = timer.lap(); const start = timer.lap();
@ -211,15 +212,15 @@ const aeads = [_]Crypto{
pub fn benchmarkAead(comptime Aead: anytype, comptime bytes: comptime_int) !u64 { pub fn benchmarkAead(comptime Aead: anytype, comptime bytes: comptime_int) !u64 {
var in: [512 * KiB]u8 = undefined; var in: [512 * KiB]u8 = undefined;
prng.random.bytes(in[0..]); random.bytes(in[0..]);
var tag: [Aead.tag_length]u8 = undefined; var tag: [Aead.tag_length]u8 = undefined;
var key: [Aead.key_length]u8 = undefined; var key: [Aead.key_length]u8 = undefined;
prng.random.bytes(key[0..]); random.bytes(key[0..]);
var nonce: [Aead.nonce_length]u8 = undefined; var nonce: [Aead.nonce_length]u8 = undefined;
prng.random.bytes(nonce[0..]); random.bytes(nonce[0..]);
var offset: usize = 0; var offset: usize = 0;
var timer = try Timer.start(); var timer = try Timer.start();
@ -244,7 +245,7 @@ const aes = [_]Crypto{
pub fn benchmarkAes(comptime Aes: anytype, comptime count: comptime_int) !u64 { pub fn benchmarkAes(comptime Aes: anytype, comptime count: comptime_int) !u64 {
var key: [Aes.key_bits / 8]u8 = undefined; var key: [Aes.key_bits / 8]u8 = undefined;
prng.random.bytes(key[0..]); random.bytes(key[0..]);
const ctx = Aes.initEnc(key); const ctx = Aes.initEnc(key);
var in = [_]u8{0} ** 16; var in = [_]u8{0} ** 16;
@ -273,7 +274,7 @@ const aes8 = [_]Crypto{
pub fn benchmarkAes8(comptime Aes: anytype, comptime count: comptime_int) !u64 { pub fn benchmarkAes8(comptime Aes: anytype, comptime count: comptime_int) !u64 {
var key: [Aes.key_bits / 8]u8 = undefined; var key: [Aes.key_bits / 8]u8 = undefined;
prng.random.bytes(key[0..]); random.bytes(key[0..]);
const ctx = Aes.initEnc(key); const ctx = Aes.initEnc(key);
var in = [_]u8{0} ** (8 * 16); var in = [_]u8{0} ** (8 * 16);

View file

@ -11,7 +11,10 @@ const os = std.os;
/// We use this as a layer of indirection because global const pointers cannot /// We use this as a layer of indirection because global const pointers cannot
/// point to thread-local variables. /// point to thread-local variables.
pub var interface = std.rand.Random{ .fillFn = tlsCsprngFill }; pub const interface = std.rand.Random{
.ptr = undefined,
.fillFn = tlsCsprngFill,
};
const os_has_fork = switch (builtin.os.tag) { const os_has_fork = switch (builtin.os.tag) {
.dragonfly, .dragonfly,
@ -55,7 +58,7 @@ var install_atfork_handler = std.once(struct {
threadlocal var wipe_mem: []align(mem.page_size) u8 = &[_]u8{}; threadlocal var wipe_mem: []align(mem.page_size) u8 = &[_]u8{};
fn tlsCsprngFill(_: *const std.rand.Random, buffer: []u8) void { fn tlsCsprngFill(_: *c_void, buffer: []u8) void {
if (builtin.link_libc and @hasDecl(std.c, "arc4random_buf")) { if (builtin.link_libc and @hasDecl(std.c, "arc4random_buf")) {
// arc4random is already a thread-local CSPRNG. // arc4random is already a thread-local CSPRNG.
return std.c.arc4random_buf(buffer.ptr, buffer.len); return std.c.arc4random_buf(buffer.ptr, buffer.len);

View file

@ -11,6 +11,7 @@ const MiB = 1024 * KiB;
const GiB = 1024 * MiB; const GiB = 1024 * MiB;
var prng = std.rand.DefaultPrng.init(0); var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();
const Hash = struct { const Hash = struct {
ty: type, ty: type,
@ -88,7 +89,7 @@ pub fn benchmarkHash(comptime H: anytype, bytes: usize) !Result {
}; };
var block: [block_size]u8 = undefined; var block: [block_size]u8 = undefined;
prng.random.bytes(block[0..]); random.bytes(block[0..]);
var offset: usize = 0; var offset: usize = 0;
var timer = try Timer.start(); var timer = try Timer.start();
@ -110,7 +111,7 @@ pub fn benchmarkHash(comptime H: anytype, bytes: usize) !Result {
pub fn benchmarkHashSmallKeys(comptime H: anytype, key_size: usize, bytes: usize) !Result { pub fn benchmarkHashSmallKeys(comptime H: anytype, key_size: usize, bytes: usize) !Result {
const key_count = bytes / key_size; const key_count = bytes / key_size;
var block: [block_size]u8 = undefined; var block: [block_size]u8 = undefined;
prng.random.bytes(block[0..]); random.bytes(block[0..]);
var i: usize = 0; var i: usize = 0;
var timer = try Timer.start(); var timer = try Timer.start();

View file

@ -1795,10 +1795,11 @@ test "std.hash_map put and remove loop in random order" {
while (i < size) : (i += 1) { while (i < size) : (i += 1) {
try keys.append(i); try keys.append(i);
} }
var rng = std.rand.DefaultPrng.init(0); var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();
while (i < iterations) : (i += 1) { while (i < iterations) : (i += 1) {
std.rand.Random.shuffle(&rng.random, u32, keys.items); random.shuffle(u32, keys.items);
for (keys.items) |key| { for (keys.items) |key| {
try map.put(key, key); try map.put(key, key);
@ -1826,14 +1827,15 @@ test "std.hash_map remove one million elements in random order" {
keys.append(i) catch unreachable; keys.append(i) catch unreachable;
} }
var rng = std.rand.DefaultPrng.init(0); var prng = std.rand.DefaultPrng.init(0);
std.rand.Random.shuffle(&rng.random, u32, keys.items); const random = prng.random();
random.shuffle(u32, keys.items);
for (keys.items) |key| { for (keys.items) |key| {
map.put(key, key) catch unreachable; map.put(key, key) catch unreachable;
} }
std.rand.Random.shuffle(&rng.random, u32, keys.items); random.shuffle(u32, keys.items);
i = 0; i = 0;
while (i < n) : (i += 1) { while (i < n) : (i += 1) {
const key = keys.items[i]; const key = keys.items[i];

View file

@ -20,7 +20,8 @@ test "write a file, read it, then delete it" {
var data: [1024]u8 = undefined; var data: [1024]u8 = undefined;
var prng = DefaultPrng.init(1234); var prng = DefaultPrng.init(1234);
prng.random.bytes(data[0..]); const random = prng.random();
random.bytes(data[0..]);
const tmp_file_name = "temp_test_file.txt"; const tmp_file_name = "temp_test_file.txt";
{ {
var file = try tmp.dir.createFile(tmp_file_name, .{}); var file = try tmp.dir.createFile(tmp_file_name, .{});

View file

@ -589,9 +589,10 @@ test "big.rational set/to Float round-trip" {
var a = try Rational.init(testing.allocator); var a = try Rational.init(testing.allocator);
defer a.deinit(); defer a.deinit();
var prng = std.rand.DefaultPrng.init(0x5EED); var prng = std.rand.DefaultPrng.init(0x5EED);
const random = prng.random();
var i: usize = 0; var i: usize = 0;
while (i < 512) : (i += 1) { while (i < 512) : (i += 1) {
const r = prng.random.float(f64); const r = random.float(f64);
try a.setFloat(f64, r); try a.setFloat(f64, r);
try testing.expect((try a.toFloat(f64)) == r); try testing.expect((try a.toFloat(f64)) == r);
} }

View file

@ -850,17 +850,18 @@ test "std.PriorityDequeue: shrinkAndFree" {
test "std.PriorityDequeue: fuzz testing min" { test "std.PriorityDequeue: fuzz testing min" {
var prng = std.rand.DefaultPrng.init(0x12345678); var prng = std.rand.DefaultPrng.init(0x12345678);
const random = prng.random();
const test_case_count = 100; const test_case_count = 100;
const queue_size = 1_000; const queue_size = 1_000;
var i: usize = 0; var i: usize = 0;
while (i < test_case_count) : (i += 1) { while (i < test_case_count) : (i += 1) {
try fuzzTestMin(&prng.random, queue_size); try fuzzTestMin(random, queue_size);
} }
} }
fn fuzzTestMin(rng: *std.rand.Random, comptime queue_size: usize) !void { fn fuzzTestMin(rng: std.rand.Random, comptime queue_size: usize) !void {
const allocator = testing.allocator; const allocator = testing.allocator;
const items = try generateRandomSlice(allocator, rng, queue_size); const items = try generateRandomSlice(allocator, rng, queue_size);
@ -878,17 +879,18 @@ fn fuzzTestMin(rng: *std.rand.Random, comptime queue_size: usize) !void {
test "std.PriorityDequeue: fuzz testing max" { test "std.PriorityDequeue: fuzz testing max" {
var prng = std.rand.DefaultPrng.init(0x87654321); var prng = std.rand.DefaultPrng.init(0x87654321);
const random = prng.random();
const test_case_count = 100; const test_case_count = 100;
const queue_size = 1_000; const queue_size = 1_000;
var i: usize = 0; var i: usize = 0;
while (i < test_case_count) : (i += 1) { while (i < test_case_count) : (i += 1) {
try fuzzTestMax(&prng.random, queue_size); try fuzzTestMax(random, queue_size);
} }
} }
fn fuzzTestMax(rng: *std.rand.Random, queue_size: usize) !void { fn fuzzTestMax(rng: std.rand.Random, queue_size: usize) !void {
const allocator = testing.allocator; const allocator = testing.allocator;
const items = try generateRandomSlice(allocator, rng, queue_size); const items = try generateRandomSlice(allocator, rng, queue_size);
@ -906,17 +908,18 @@ fn fuzzTestMax(rng: *std.rand.Random, queue_size: usize) !void {
test "std.PriorityDequeue: fuzz testing min and max" { test "std.PriorityDequeue: fuzz testing min and max" {
var prng = std.rand.DefaultPrng.init(0x87654321); var prng = std.rand.DefaultPrng.init(0x87654321);
const random = prng.random();
const test_case_count = 100; const test_case_count = 100;
const queue_size = 1_000; const queue_size = 1_000;
var i: usize = 0; var i: usize = 0;
while (i < test_case_count) : (i += 1) { while (i < test_case_count) : (i += 1) {
try fuzzTestMinMax(&prng.random, queue_size); try fuzzTestMinMax(random, queue_size);
} }
} }
fn fuzzTestMinMax(rng: *std.rand.Random, queue_size: usize) !void { fn fuzzTestMinMax(rng: std.rand.Random, queue_size: usize) !void {
const allocator = testing.allocator; const allocator = testing.allocator;
const items = try generateRandomSlice(allocator, rng, queue_size); const items = try generateRandomSlice(allocator, rng, queue_size);
@ -943,7 +946,7 @@ fn fuzzTestMinMax(rng: *std.rand.Random, queue_size: usize) !void {
} }
} }
fn generateRandomSlice(allocator: *std.mem.Allocator, rng: *std.rand.Random, size: usize) ![]u32 { fn generateRandomSlice(allocator: *std.mem.Allocator, rng: std.rand.Random, size: usize) ![]u32 {
var array = std.ArrayList(u32).init(allocator); var array = std.ArrayList(u32).init(allocator);
try array.ensureTotalCapacity(size); try array.ensureTotalCapacity(size);

View file

@ -29,19 +29,40 @@ pub const Xoshiro256 = @import("rand/Xoshiro256.zig");
pub const Sfc64 = @import("rand/Sfc64.zig"); pub const Sfc64 = @import("rand/Sfc64.zig");
pub const Random = struct { pub const Random = struct {
fillFn: fn (r: *Random, buf: []u8) void, ptr: *c_void,
fillFn: fn (ptr: *c_void, buf: []u8) void,
/// Read random bytes into the specified buffer until full. pub fn init(pointer: anytype) Random {
pub fn bytes(r: *Random, buf: []u8) void { const Ptr = @TypeOf(pointer);
r.fillFn(r, buf); assert(@typeInfo(Ptr) == .Pointer); // Must be a pointer
assert(@typeInfo(Ptr).Pointer.size == .One); // Must be a single-item pointer
assert(@typeInfo(@typeInfo(Ptr).Pointer.child) == .Struct); // Must point to a struct
assert(std.meta.trait.hasFn("fill")(@typeInfo(Ptr).Pointer.child)); // Struct must provide the `fill` function
const gen = struct {
fn fill(ptr: *c_void, buf: []u8) void {
const alignment = @typeInfo(Ptr).Pointer.alignment;
const self = @ptrCast(Ptr, @alignCast(alignment, ptr));
self.fill(buf);
}
};
return .{
.ptr = pointer,
.fillFn = gen.fill,
};
} }
pub fn boolean(r: *Random) bool { /// Read random bytes into the specified buffer until full.
pub fn bytes(r: Random, buf: []u8) void {
r.fillFn(r.ptr, buf);
}
pub fn boolean(r: Random) bool {
return r.int(u1) != 0; return r.int(u1) != 0;
} }
/// Returns a random value from an enum, evenly distributed. /// Returns a random value from an enum, evenly distributed.
pub fn enumValue(r: *Random, comptime EnumType: type) EnumType { pub fn enumValue(r: Random, comptime EnumType: type) EnumType {
if (comptime !std.meta.trait.is(.Enum)(EnumType)) { if (comptime !std.meta.trait.is(.Enum)(EnumType)) {
@compileError("Random.enumValue requires an enum type, not a " ++ @typeName(EnumType)); @compileError("Random.enumValue requires an enum type, not a " ++ @typeName(EnumType));
} }
@ -55,7 +76,7 @@ pub const Random = struct {
/// Returns a random int `i` such that `minInt(T) <= i <= maxInt(T)`. /// Returns a random int `i` such that `minInt(T) <= i <= maxInt(T)`.
/// `i` is evenly distributed. /// `i` is evenly distributed.
pub fn int(r: *Random, comptime T: type) T { pub fn int(r: Random, comptime T: type) T {
const bits = @typeInfo(T).Int.bits; const bits = @typeInfo(T).Int.bits;
const UnsignedT = std.meta.Int(.unsigned, bits); const UnsignedT = std.meta.Int(.unsigned, bits);
const ByteAlignedT = std.meta.Int(.unsigned, @divTrunc(bits + 7, 8) * 8); const ByteAlignedT = std.meta.Int(.unsigned, @divTrunc(bits + 7, 8) * 8);
@ -73,7 +94,7 @@ pub const Random = struct {
/// Constant-time implementation off `uintLessThan`. /// Constant-time implementation off `uintLessThan`.
/// The results of this function may be biased. /// The results of this function may be biased.
pub fn uintLessThanBiased(r: *Random, comptime T: type, less_than: T) T { pub fn uintLessThanBiased(r: Random, comptime T: type, less_than: T) T {
comptime assert(@typeInfo(T).Int.signedness == .unsigned); comptime assert(@typeInfo(T).Int.signedness == .unsigned);
const bits = @typeInfo(T).Int.bits; const bits = @typeInfo(T).Int.bits;
comptime assert(bits <= 64); // TODO: workaround: LLVM ERROR: Unsupported library call operation! comptime assert(bits <= 64); // TODO: workaround: LLVM ERROR: Unsupported library call operation!
@ -93,7 +114,7 @@ pub const Random = struct {
/// However, if `fillFn` is backed by any evenly distributed pseudo random number generator, /// However, if `fillFn` is backed by any evenly distributed pseudo random number generator,
/// this function is guaranteed to return. /// this function is guaranteed to return.
/// If you need deterministic runtime bounds, use `uintLessThanBiased`. /// If you need deterministic runtime bounds, use `uintLessThanBiased`.
pub fn uintLessThan(r: *Random, comptime T: type, less_than: T) T { pub fn uintLessThan(r: Random, comptime T: type, less_than: T) T {
comptime assert(@typeInfo(T).Int.signedness == .unsigned); comptime assert(@typeInfo(T).Int.signedness == .unsigned);
const bits = @typeInfo(T).Int.bits; const bits = @typeInfo(T).Int.bits;
comptime assert(bits <= 64); // TODO: workaround: LLVM ERROR: Unsupported library call operation! comptime assert(bits <= 64); // TODO: workaround: LLVM ERROR: Unsupported library call operation!
@ -130,7 +151,7 @@ pub const Random = struct {
/// Constant-time implementation off `uintAtMost`. /// Constant-time implementation off `uintAtMost`.
/// The results of this function may be biased. /// The results of this function may be biased.
pub fn uintAtMostBiased(r: *Random, comptime T: type, at_most: T) T { pub fn uintAtMostBiased(r: Random, comptime T: type, at_most: T) T {
assert(@typeInfo(T).Int.signedness == .unsigned); assert(@typeInfo(T).Int.signedness == .unsigned);
if (at_most == maxInt(T)) { if (at_most == maxInt(T)) {
// have the full range // have the full range
@ -142,7 +163,7 @@ pub const Random = struct {
/// Returns an evenly distributed random unsigned integer `0 <= i <= at_most`. /// Returns an evenly distributed random unsigned integer `0 <= i <= at_most`.
/// See `uintLessThan`, which this function uses in most cases, /// See `uintLessThan`, which this function uses in most cases,
/// for commentary on the runtime of this function. /// for commentary on the runtime of this function.
pub fn uintAtMost(r: *Random, comptime T: type, at_most: T) T { pub fn uintAtMost(r: Random, comptime T: type, at_most: T) T {
assert(@typeInfo(T).Int.signedness == .unsigned); assert(@typeInfo(T).Int.signedness == .unsigned);
if (at_most == maxInt(T)) { if (at_most == maxInt(T)) {
// have the full range // have the full range
@ -153,7 +174,7 @@ pub const Random = struct {
/// Constant-time implementation off `intRangeLessThan`. /// Constant-time implementation off `intRangeLessThan`.
/// The results of this function may be biased. /// The results of this function may be biased.
pub fn intRangeLessThanBiased(r: *Random, comptime T: type, at_least: T, less_than: T) T { pub fn intRangeLessThanBiased(r: Random, comptime T: type, at_least: T, less_than: T) T {
assert(at_least < less_than); assert(at_least < less_than);
const info = @typeInfo(T).Int; const info = @typeInfo(T).Int;
if (info.signedness == .signed) { if (info.signedness == .signed) {
@ -172,7 +193,7 @@ pub const Random = struct {
/// Returns an evenly distributed random integer `at_least <= i < less_than`. /// Returns an evenly distributed random integer `at_least <= i < less_than`.
/// See `uintLessThan`, which this function uses in most cases, /// See `uintLessThan`, which this function uses in most cases,
/// for commentary on the runtime of this function. /// for commentary on the runtime of this function.
pub fn intRangeLessThan(r: *Random, comptime T: type, at_least: T, less_than: T) T { pub fn intRangeLessThan(r: Random, comptime T: type, at_least: T, less_than: T) T {
assert(at_least < less_than); assert(at_least < less_than);
const info = @typeInfo(T).Int; const info = @typeInfo(T).Int;
if (info.signedness == .signed) { if (info.signedness == .signed) {
@ -190,7 +211,7 @@ pub const Random = struct {
/// Constant-time implementation off `intRangeAtMostBiased`. /// Constant-time implementation off `intRangeAtMostBiased`.
/// The results of this function may be biased. /// The results of this function may be biased.
pub fn intRangeAtMostBiased(r: *Random, comptime T: type, at_least: T, at_most: T) T { pub fn intRangeAtMostBiased(r: Random, comptime T: type, at_least: T, at_most: T) T {
assert(at_least <= at_most); assert(at_least <= at_most);
const info = @typeInfo(T).Int; const info = @typeInfo(T).Int;
if (info.signedness == .signed) { if (info.signedness == .signed) {
@ -209,7 +230,7 @@ pub const Random = struct {
/// Returns an evenly distributed random integer `at_least <= i <= at_most`. /// Returns an evenly distributed random integer `at_least <= i <= at_most`.
/// See `uintLessThan`, which this function uses in most cases, /// See `uintLessThan`, which this function uses in most cases,
/// for commentary on the runtime of this function. /// for commentary on the runtime of this function.
pub fn intRangeAtMost(r: *Random, comptime T: type, at_least: T, at_most: T) T { pub fn intRangeAtMost(r: Random, comptime T: type, at_least: T, at_most: T) T {
assert(at_least <= at_most); assert(at_least <= at_most);
const info = @typeInfo(T).Int; const info = @typeInfo(T).Int;
if (info.signedness == .signed) { if (info.signedness == .signed) {
@ -230,7 +251,7 @@ pub const Random = struct {
pub const range = @compileError("deprecated; use intRangeLessThan()"); pub const range = @compileError("deprecated; use intRangeLessThan()");
/// Return a floating point value evenly distributed in the range [0, 1). /// Return a floating point value evenly distributed in the range [0, 1).
pub fn float(r: *Random, comptime T: type) T { pub fn float(r: Random, comptime T: type) T {
// Generate a uniform value between [1, 2) and scale down to [0, 1). // Generate a uniform value between [1, 2) and scale down to [0, 1).
// Note: The lowest mantissa bit is always set to 0 so we only use half the available range. // Note: The lowest mantissa bit is always set to 0 so we only use half the available range.
switch (T) { switch (T) {
@ -251,7 +272,7 @@ pub const Random = struct {
/// Return a floating point value normally distributed with mean = 0, stddev = 1. /// Return a floating point value normally distributed with mean = 0, stddev = 1.
/// ///
/// To use different parameters, use: floatNorm(...) * desiredStddev + desiredMean. /// To use different parameters, use: floatNorm(...) * desiredStddev + desiredMean.
pub fn floatNorm(r: *Random, comptime T: type) T { pub fn floatNorm(r: Random, comptime T: type) T {
const value = ziggurat.next_f64(r, ziggurat.NormDist); const value = ziggurat.next_f64(r, ziggurat.NormDist);
switch (T) { switch (T) {
f32 => return @floatCast(f32, value), f32 => return @floatCast(f32, value),
@ -263,7 +284,7 @@ pub const Random = struct {
/// Return an exponentially distributed float with a rate parameter of 1. /// Return an exponentially distributed float with a rate parameter of 1.
/// ///
/// To use a different rate parameter, use: floatExp(...) / desiredRate. /// To use a different rate parameter, use: floatExp(...) / desiredRate.
pub fn floatExp(r: *Random, comptime T: type) T { pub fn floatExp(r: Random, comptime T: type) T {
const value = ziggurat.next_f64(r, ziggurat.ExpDist); const value = ziggurat.next_f64(r, ziggurat.ExpDist);
switch (T) { switch (T) {
f32 => return @floatCast(f32, value), f32 => return @floatCast(f32, value),
@ -273,7 +294,7 @@ pub const Random = struct {
} }
/// Shuffle a slice into a random order. /// Shuffle a slice into a random order.
pub fn shuffle(r: *Random, comptime T: type, buf: []T) void { pub fn shuffle(r: Random, comptime T: type, buf: []T) void {
if (buf.len < 2) { if (buf.len < 2) {
return; return;
} }
@ -303,18 +324,19 @@ pub fn limitRangeBiased(comptime T: type, random_int: T, less_than: T) T {
const SequentialPrng = struct { const SequentialPrng = struct {
const Self = @This(); const Self = @This();
random: Random,
next_value: u8, next_value: u8,
pub fn init() Self { pub fn init() Self {
return Self{ return Self{
.random = Random{ .fillFn = fill },
.next_value = 0, .next_value = 0,
}; };
} }
fn fill(r: *Random, buf: []u8) void { pub fn random(self: *Self) Random {
const self = @fieldParentPtr(Self, "random", r); return Random.init(self);
}
pub fn fill(self: *Self, buf: []u8) void {
for (buf) |*b| { for (buf) |*b| {
b.* = self.next_value; b.* = self.next_value;
} }
@ -327,45 +349,46 @@ test "Random int" {
comptime try testRandomInt(); comptime try testRandomInt();
} }
fn testRandomInt() !void { fn testRandomInt() !void {
var r = SequentialPrng.init(); var rng = SequentialPrng.init();
const random = rng.random();
try expect(r.random.int(u0) == 0); try expect(random.int(u0) == 0);
r.next_value = 0; rng.next_value = 0;
try expect(r.random.int(u1) == 0); try expect(random.int(u1) == 0);
try expect(r.random.int(u1) == 1); try expect(random.int(u1) == 1);
try expect(r.random.int(u2) == 2); try expect(random.int(u2) == 2);
try expect(r.random.int(u2) == 3); try expect(random.int(u2) == 3);
try expect(r.random.int(u2) == 0); try expect(random.int(u2) == 0);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.int(u8) == 0xff); try expect(random.int(u8) == 0xff);
r.next_value = 0x11; rng.next_value = 0x11;
try expect(r.random.int(u8) == 0x11); try expect(random.int(u8) == 0x11);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.int(u32) == 0xffffffff); try expect(random.int(u32) == 0xffffffff);
r.next_value = 0x11; rng.next_value = 0x11;
try expect(r.random.int(u32) == 0x11111111); try expect(random.int(u32) == 0x11111111);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.int(i32) == -1); try expect(random.int(i32) == -1);
r.next_value = 0x11; rng.next_value = 0x11;
try expect(r.random.int(i32) == 0x11111111); try expect(random.int(i32) == 0x11111111);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.int(i8) == -1); try expect(random.int(i8) == -1);
r.next_value = 0x11; rng.next_value = 0x11;
try expect(r.random.int(i8) == 0x11); try expect(random.int(i8) == 0x11);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.int(u33) == 0x1ffffffff); try expect(random.int(u33) == 0x1ffffffff);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.int(i1) == -1); try expect(random.int(i1) == -1);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.int(i2) == -1); try expect(random.int(i2) == -1);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.int(i33) == -1); try expect(random.int(i33) == -1);
} }
test "Random boolean" { test "Random boolean" {
@ -373,11 +396,13 @@ test "Random boolean" {
comptime try testRandomBoolean(); comptime try testRandomBoolean();
} }
fn testRandomBoolean() !void { fn testRandomBoolean() !void {
var r = SequentialPrng.init(); var rng = SequentialPrng.init();
try expect(r.random.boolean() == false); const random = rng.random();
try expect(r.random.boolean() == true);
try expect(r.random.boolean() == false); try expect(random.boolean() == false);
try expect(r.random.boolean() == true); try expect(random.boolean() == true);
try expect(random.boolean() == false);
try expect(random.boolean() == true);
} }
test "Random enum" { test "Random enum" {
@ -390,11 +415,12 @@ fn testRandomEnumValue() !void {
Second, Second,
Third, Third,
}; };
var r = SequentialPrng.init(); var rng = SequentialPrng.init();
r.next_value = 0; const random = rng.random();
try expect(r.random.enumValue(TestEnum) == TestEnum.First); rng.next_value = 0;
try expect(r.random.enumValue(TestEnum) == TestEnum.First); try expect(random.enumValue(TestEnum) == TestEnum.First);
try expect(r.random.enumValue(TestEnum) == TestEnum.First); try expect(random.enumValue(TestEnum) == TestEnum.First);
try expect(random.enumValue(TestEnum) == TestEnum.First);
} }
test "Random intLessThan" { test "Random intLessThan" {
@ -403,38 +429,40 @@ test "Random intLessThan" {
comptime try testRandomIntLessThan(); comptime try testRandomIntLessThan();
} }
fn testRandomIntLessThan() !void { fn testRandomIntLessThan() !void {
var r = SequentialPrng.init(); var rng = SequentialPrng.init();
r.next_value = 0xff; const random = rng.random();
try expect(r.random.uintLessThan(u8, 4) == 3);
try expect(r.next_value == 0);
try expect(r.random.uintLessThan(u8, 4) == 0);
try expect(r.next_value == 1);
r.next_value = 0; rng.next_value = 0xff;
try expect(r.random.uintLessThan(u64, 32) == 0); try expect(random.uintLessThan(u8, 4) == 3);
try expect(rng.next_value == 0);
try expect(random.uintLessThan(u8, 4) == 0);
try expect(rng.next_value == 1);
rng.next_value = 0;
try expect(random.uintLessThan(u64, 32) == 0);
// trigger the bias rejection code path // trigger the bias rejection code path
r.next_value = 0; rng.next_value = 0;
try expect(r.random.uintLessThan(u8, 3) == 0); try expect(random.uintLessThan(u8, 3) == 0);
// verify we incremented twice // verify we incremented twice
try expect(r.next_value == 2); try expect(rng.next_value == 2);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeLessThan(u8, 0, 0x80) == 0x7f); try expect(random.intRangeLessThan(u8, 0, 0x80) == 0x7f);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeLessThan(u8, 0x7f, 0xff) == 0xfe); try expect(random.intRangeLessThan(u8, 0x7f, 0xff) == 0xfe);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeLessThan(i8, 0, 0x40) == 0x3f); try expect(random.intRangeLessThan(i8, 0, 0x40) == 0x3f);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeLessThan(i8, -0x40, 0x40) == 0x3f); try expect(random.intRangeLessThan(i8, -0x40, 0x40) == 0x3f);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeLessThan(i8, -0x80, 0) == -1); try expect(random.intRangeLessThan(i8, -0x80, 0) == -1);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeLessThan(i3, -4, 0) == -1); try expect(random.intRangeLessThan(i3, -4, 0) == -1);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeLessThan(i3, -2, 2) == 1); try expect(random.intRangeLessThan(i3, -2, 2) == 1);
} }
test "Random intAtMost" { test "Random intAtMost" {
@ -443,67 +471,70 @@ test "Random intAtMost" {
comptime try testRandomIntAtMost(); comptime try testRandomIntAtMost();
} }
fn testRandomIntAtMost() !void { fn testRandomIntAtMost() !void {
var r = SequentialPrng.init(); var rng = SequentialPrng.init();
r.next_value = 0xff; const random = rng.random();
try expect(r.random.uintAtMost(u8, 3) == 3);
try expect(r.next_value == 0); rng.next_value = 0xff;
try expect(r.random.uintAtMost(u8, 3) == 0); try expect(random.uintAtMost(u8, 3) == 3);
try expect(rng.next_value == 0);
try expect(random.uintAtMost(u8, 3) == 0);
// trigger the bias rejection code path // trigger the bias rejection code path
r.next_value = 0; rng.next_value = 0;
try expect(r.random.uintAtMost(u8, 2) == 0); try expect(random.uintAtMost(u8, 2) == 0);
// verify we incremented twice // verify we incremented twice
try expect(r.next_value == 2); try expect(rng.next_value == 2);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeAtMost(u8, 0, 0x7f) == 0x7f); try expect(random.intRangeAtMost(u8, 0, 0x7f) == 0x7f);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeAtMost(u8, 0x7f, 0xfe) == 0xfe); try expect(random.intRangeAtMost(u8, 0x7f, 0xfe) == 0xfe);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeAtMost(i8, 0, 0x3f) == 0x3f); try expect(random.intRangeAtMost(i8, 0, 0x3f) == 0x3f);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeAtMost(i8, -0x40, 0x3f) == 0x3f); try expect(random.intRangeAtMost(i8, -0x40, 0x3f) == 0x3f);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeAtMost(i8, -0x80, -1) == -1); try expect(random.intRangeAtMost(i8, -0x80, -1) == -1);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeAtMost(i3, -4, -1) == -1); try expect(random.intRangeAtMost(i3, -4, -1) == -1);
r.next_value = 0xff; rng.next_value = 0xff;
try expect(r.random.intRangeAtMost(i3, -2, 1) == 1); try expect(random.intRangeAtMost(i3, -2, 1) == 1);
try expect(r.random.uintAtMost(u0, 0) == 0); try expect(random.uintAtMost(u0, 0) == 0);
} }
test "Random Biased" { test "Random Biased" {
var r = DefaultPrng.init(0); var prng = DefaultPrng.init(0);
const random = prng.random();
// Not thoroughly checking the logic here. // Not thoroughly checking the logic here.
// Just want to execute all the paths with different types. // Just want to execute all the paths with different types.
try expect(r.random.uintLessThanBiased(u1, 1) == 0); try expect(random.uintLessThanBiased(u1, 1) == 0);
try expect(r.random.uintLessThanBiased(u32, 10) < 10); try expect(random.uintLessThanBiased(u32, 10) < 10);
try expect(r.random.uintLessThanBiased(u64, 20) < 20); try expect(random.uintLessThanBiased(u64, 20) < 20);
try expect(r.random.uintAtMostBiased(u0, 0) == 0); try expect(random.uintAtMostBiased(u0, 0) == 0);
try expect(r.random.uintAtMostBiased(u1, 0) <= 0); try expect(random.uintAtMostBiased(u1, 0) <= 0);
try expect(r.random.uintAtMostBiased(u32, 10) <= 10); try expect(random.uintAtMostBiased(u32, 10) <= 10);
try expect(r.random.uintAtMostBiased(u64, 20) <= 20); try expect(random.uintAtMostBiased(u64, 20) <= 20);
try expect(r.random.intRangeLessThanBiased(u1, 0, 1) == 0); try expect(random.intRangeLessThanBiased(u1, 0, 1) == 0);
try expect(r.random.intRangeLessThanBiased(i1, -1, 0) == -1); try expect(random.intRangeLessThanBiased(i1, -1, 0) == -1);
try expect(r.random.intRangeLessThanBiased(u32, 10, 20) >= 10); try expect(random.intRangeLessThanBiased(u32, 10, 20) >= 10);
try expect(r.random.intRangeLessThanBiased(i32, 10, 20) >= 10); try expect(random.intRangeLessThanBiased(i32, 10, 20) >= 10);
try expect(r.random.intRangeLessThanBiased(u64, 20, 40) >= 20); try expect(random.intRangeLessThanBiased(u64, 20, 40) >= 20);
try expect(r.random.intRangeLessThanBiased(i64, 20, 40) >= 20); try expect(random.intRangeLessThanBiased(i64, 20, 40) >= 20);
// uncomment for broken module error: // uncomment for broken module error:
//expect(r.random.intRangeAtMostBiased(u0, 0, 0) == 0); //expect(random.intRangeAtMostBiased(u0, 0, 0) == 0);
try expect(r.random.intRangeAtMostBiased(u1, 0, 1) >= 0); try expect(random.intRangeAtMostBiased(u1, 0, 1) >= 0);
try expect(r.random.intRangeAtMostBiased(i1, -1, 0) >= -1); try expect(random.intRangeAtMostBiased(i1, -1, 0) >= -1);
try expect(r.random.intRangeAtMostBiased(u32, 10, 20) >= 10); try expect(random.intRangeAtMostBiased(u32, 10, 20) >= 10);
try expect(r.random.intRangeAtMostBiased(i32, 10, 20) >= 10); try expect(random.intRangeAtMostBiased(i32, 10, 20) >= 10);
try expect(r.random.intRangeAtMostBiased(u64, 20, 40) >= 20); try expect(random.intRangeAtMostBiased(u64, 20, 40) >= 20);
try expect(r.random.intRangeAtMostBiased(i64, 20, 40) >= 20); try expect(random.intRangeAtMostBiased(i64, 20, 40) >= 20);
} }
// Generator to extend 64-bit seed values into longer sequences. // Generator to extend 64-bit seed values into longer sequences.
@ -547,14 +578,15 @@ test "splitmix64 sequence" {
// Actual Random helper function tests, pcg engine is assumed correct. // Actual Random helper function tests, pcg engine is assumed correct.
test "Random float" { test "Random float" {
var prng = DefaultPrng.init(0); var prng = DefaultPrng.init(0);
const random = prng.random();
var i: usize = 0; var i: usize = 0;
while (i < 1000) : (i += 1) { while (i < 1000) : (i += 1) {
const val1 = prng.random.float(f32); const val1 = random.float(f32);
try expect(val1 >= 0.0); try expect(val1 >= 0.0);
try expect(val1 < 1.0); try expect(val1 < 1.0);
const val2 = prng.random.float(f64); const val2 = random.float(f64);
try expect(val2 >= 0.0); try expect(val2 >= 0.0);
try expect(val2 < 1.0); try expect(val2 < 1.0);
} }
@ -562,13 +594,14 @@ test "Random float" {
test "Random shuffle" { test "Random shuffle" {
var prng = DefaultPrng.init(0); var prng = DefaultPrng.init(0);
const random = prng.random();
var seq = [_]u8{ 0, 1, 2, 3, 4 }; var seq = [_]u8{ 0, 1, 2, 3, 4 };
var seen = [_]bool{false} ** 5; var seen = [_]bool{false} ** 5;
var i: usize = 0; var i: usize = 0;
while (i < 1000) : (i += 1) { while (i < 1000) : (i += 1) {
prng.random.shuffle(u8, seq[0..]); random.shuffle(u8, seq[0..]);
seen[seq[0]] = true; seen[seq[0]] = true;
try expect(sumArray(seq[0..]) == 10); try expect(sumArray(seq[0..]) == 10);
} }
@ -588,17 +621,19 @@ fn sumArray(s: []const u8) u32 {
test "Random range" { test "Random range" {
var prng = DefaultPrng.init(0); var prng = DefaultPrng.init(0);
try testRange(&prng.random, -4, 3); const random = prng.random();
try testRange(&prng.random, -4, -1);
try testRange(&prng.random, 10, 14); try testRange(random, -4, 3);
try testRange(&prng.random, -0x80, 0x7f); try testRange(random, -4, -1);
try testRange(random, 10, 14);
try testRange(random, -0x80, 0x7f);
} }
fn testRange(r: *Random, start: i8, end: i8) !void { fn testRange(r: Random, start: i8, end: i8) !void {
try testRangeBias(r, start, end, true); try testRangeBias(r, start, end, true);
try testRangeBias(r, start, end, false); try testRangeBias(r, start, end, false);
} }
fn testRangeBias(r: *Random, start: i8, end: i8, biased: bool) !void { fn testRangeBias(r: Random, start: i8, end: i8, biased: bool) !void {
const count = @intCast(usize, @as(i32, end) - @as(i32, start)); const count = @intCast(usize, @as(i32, end) - @as(i32, start));
var values_buffer = [_]bool{false} ** 0x100; var values_buffer = [_]bool{false} ** 0x100;
const values = values_buffer[0..count]; const values = values_buffer[0..count];
@ -617,9 +652,10 @@ test "CSPRNG" {
var secret_seed: [DefaultCsprng.secret_seed_length]u8 = undefined; var secret_seed: [DefaultCsprng.secret_seed_length]u8 = undefined;
std.crypto.random.bytes(&secret_seed); std.crypto.random.bytes(&secret_seed);
var csprng = DefaultCsprng.init(secret_seed); var csprng = DefaultCsprng.init(secret_seed);
const a = csprng.random.int(u64); const random = csprng.random();
const b = csprng.random.int(u64); const a = random.int(u64);
const c = csprng.random.int(u64); const b = random.int(u64);
const c = random.int(u64);
try expect(a ^ b ^ c != 0); try expect(a ^ b ^ c != 0);
} }

View file

@ -5,7 +5,6 @@ const Random = std.rand.Random;
const mem = std.mem; const mem = std.mem;
const Gimli = @This(); const Gimli = @This();
random: Random,
state: std.crypto.core.Gimli, state: std.crypto.core.Gimli,
pub const secret_seed_length = 32; pub const secret_seed_length = 32;
@ -16,15 +15,16 @@ pub fn init(secret_seed: [secret_seed_length]u8) Gimli {
mem.copy(u8, initial_state[0..secret_seed_length], &secret_seed); mem.copy(u8, initial_state[0..secret_seed_length], &secret_seed);
mem.set(u8, initial_state[secret_seed_length..], 0); mem.set(u8, initial_state[secret_seed_length..], 0);
var self = Gimli{ var self = Gimli{
.random = Random{ .fillFn = fill },
.state = std.crypto.core.Gimli.init(initial_state), .state = std.crypto.core.Gimli.init(initial_state),
}; };
return self; return self;
} }
fn fill(r: *Random, buf: []u8) void { pub fn random(self: *Gimli) Random {
const self = @fieldParentPtr(Gimli, "random", r); return Random.init(self);
}
pub fn fill(self: *Gimli, buf: []u8) void {
if (buf.len != 0) { if (buf.len != 0) {
self.state.squeeze(buf); self.state.squeeze(buf);
} else { } else {

View file

@ -8,8 +8,6 @@ const Random = std.rand.Random;
const mem = std.mem; const mem = std.mem;
const Isaac64 = @This(); const Isaac64 = @This();
random: Random,
r: [256]u64, r: [256]u64,
m: [256]u64, m: [256]u64,
a: u64, a: u64,
@ -19,7 +17,6 @@ i: usize,
pub fn init(init_s: u64) Isaac64 { pub fn init(init_s: u64) Isaac64 {
var isaac = Isaac64{ var isaac = Isaac64{
.random = Random{ .fillFn = fill },
.r = undefined, .r = undefined,
.m = undefined, .m = undefined,
.a = undefined, .a = undefined,
@ -33,6 +30,10 @@ pub fn init(init_s: u64) Isaac64 {
return isaac; return isaac;
} }
pub fn random(self: *Isaac64) Random {
return Random.init(self);
}
fn step(self: *Isaac64, mix: u64, base: usize, comptime m1: usize, comptime m2: usize) void { fn step(self: *Isaac64, mix: u64, base: usize, comptime m1: usize, comptime m2: usize) void {
const x = self.m[base + m1]; const x = self.m[base + m1];
self.a = mix +% self.m[base + m2]; self.a = mix +% self.m[base + m2];
@ -149,9 +150,7 @@ fn seed(self: *Isaac64, init_s: u64, comptime rounds: usize) void {
self.i = self.r.len; // trigger refill on first value self.i = self.r.len; // trigger refill on first value
} }
fn fill(r: *Random, buf: []u8) void { pub fn fill(self: *Isaac64, buf: []u8) void {
const self = @fieldParentPtr(Isaac64, "random", r);
var i: usize = 0; var i: usize = 0;
const aligned_len = buf.len - (buf.len & 7); const aligned_len = buf.len - (buf.len & 7);
@ -230,7 +229,7 @@ test "isaac64 fill" {
var buf0: [8]u8 = undefined; var buf0: [8]u8 = undefined;
var buf1: [7]u8 = undefined; var buf1: [7]u8 = undefined;
std.mem.writeIntLittle(u64, &buf0, s); std.mem.writeIntLittle(u64, &buf0, s);
Isaac64.fill(&r.random, &buf1); r.fill(&buf1);
try std.testing.expect(std.mem.eql(u8, buf0[0..7], buf1[0..])); try std.testing.expect(std.mem.eql(u8, buf0[0..7], buf1[0..]));
} }
} }

View file

@ -8,14 +8,11 @@ const Pcg = @This();
const default_multiplier = 6364136223846793005; const default_multiplier = 6364136223846793005;
random: Random,
s: u64, s: u64,
i: u64, i: u64,
pub fn init(init_s: u64) Pcg { pub fn init(init_s: u64) Pcg {
var pcg = Pcg{ var pcg = Pcg{
.random = Random{ .fillFn = fill },
.s = undefined, .s = undefined,
.i = undefined, .i = undefined,
}; };
@ -24,6 +21,10 @@ pub fn init(init_s: u64) Pcg {
return pcg; return pcg;
} }
pub fn random(self: *Pcg) Random {
return Random.init(self);
}
fn next(self: *Pcg) u32 { fn next(self: *Pcg) u32 {
const l = self.s; const l = self.s;
self.s = l *% default_multiplier +% (self.i | 1); self.s = l *% default_multiplier +% (self.i | 1);
@ -48,9 +49,7 @@ fn seedTwo(self: *Pcg, init_s: u64, init_i: u64) void {
self.s = self.s *% default_multiplier +% self.i; self.s = self.s *% default_multiplier +% self.i;
} }
fn fill(r: *Random, buf: []u8) void { pub fn fill(self: *Pcg, buf: []u8) void {
const self = @fieldParentPtr(Pcg, "random", r);
var i: usize = 0; var i: usize = 0;
const aligned_len = buf.len - (buf.len & 7); const aligned_len = buf.len - (buf.len & 7);
@ -113,7 +112,7 @@ test "pcg fill" {
var buf0: [4]u8 = undefined; var buf0: [4]u8 = undefined;
var buf1: [3]u8 = undefined; var buf1: [3]u8 = undefined;
std.mem.writeIntLittle(u32, &buf0, s); std.mem.writeIntLittle(u32, &buf0, s);
Pcg.fill(&r.random, &buf1); r.fill(&buf1);
try std.testing.expect(std.mem.eql(u8, buf0[0..3], buf1[0..])); try std.testing.expect(std.mem.eql(u8, buf0[0..3], buf1[0..]));
} }
} }

View file

@ -7,8 +7,6 @@ const Random = std.rand.Random;
const math = std.math; const math = std.math;
const Sfc64 = @This(); const Sfc64 = @This();
random: Random,
a: u64 = undefined, a: u64 = undefined,
b: u64 = undefined, b: u64 = undefined,
c: u64 = undefined, c: u64 = undefined,
@ -19,14 +17,16 @@ const RightShift = 11;
const LeftShift = 3; const LeftShift = 3;
pub fn init(init_s: u64) Sfc64 { pub fn init(init_s: u64) Sfc64 {
var x = Sfc64{ var x = Sfc64{};
.random = Random{ .fillFn = fill },
};
x.seed(init_s); x.seed(init_s);
return x; return x;
} }
pub fn random(self: *Sfc64) Random {
return Random.init(self);
}
fn next(self: *Sfc64) u64 { fn next(self: *Sfc64) u64 {
const tmp = self.a +% self.b +% self.counter; const tmp = self.a +% self.b +% self.counter;
self.counter += 1; self.counter += 1;
@ -47,9 +47,7 @@ fn seed(self: *Sfc64, init_s: u64) void {
} }
} }
fn fill(r: *Random, buf: []u8) void { pub fn fill(self: *Sfc64, buf: []u8) void {
const self = @fieldParentPtr(Sfc64, "random", r);
var i: usize = 0; var i: usize = 0;
const aligned_len = buf.len - (buf.len & 7); const aligned_len = buf.len - (buf.len & 7);
@ -128,7 +126,7 @@ test "Sfc64 fill" {
var buf0: [8]u8 = undefined; var buf0: [8]u8 = undefined;
var buf1: [7]u8 = undefined; var buf1: [7]u8 = undefined;
std.mem.writeIntLittle(u64, &buf0, s); std.mem.writeIntLittle(u64, &buf0, s);
Sfc64.fill(&r.random, &buf1); r.fill(&buf1);
try std.testing.expect(std.mem.eql(u8, buf0[0..7], buf1[0..])); try std.testing.expect(std.mem.eql(u8, buf0[0..7], buf1[0..]));
} }
} }

View file

@ -7,20 +7,19 @@ const Random = std.rand.Random;
const math = std.math; const math = std.math;
const Xoroshiro128 = @This(); const Xoroshiro128 = @This();
random: Random,
s: [2]u64, s: [2]u64,
pub fn init(init_s: u64) Xoroshiro128 { pub fn init(init_s: u64) Xoroshiro128 {
var x = Xoroshiro128{ var x = Xoroshiro128{ .s = undefined };
.random = Random{ .fillFn = fill },
.s = undefined,
};
x.seed(init_s); x.seed(init_s);
return x; return x;
} }
pub fn random(self: *Xoroshiro128) Random {
return Random.init(self);
}
fn next(self: *Xoroshiro128) u64 { fn next(self: *Xoroshiro128) u64 {
const s0 = self.s[0]; const s0 = self.s[0];
var s1 = self.s[1]; var s1 = self.s[1];
@ -66,9 +65,7 @@ pub fn seed(self: *Xoroshiro128, init_s: u64) void {
self.s[1] = gen.next(); self.s[1] = gen.next();
} }
fn fill(r: *Random, buf: []u8) void { pub fn fill(self: *Xoroshiro128, buf: []u8) void {
const self = @fieldParentPtr(Xoroshiro128, "random", r);
var i: usize = 0; var i: usize = 0;
const aligned_len = buf.len - (buf.len & 7); const aligned_len = buf.len - (buf.len & 7);
@ -144,7 +141,7 @@ test "xoroshiro fill" {
var buf0: [8]u8 = undefined; var buf0: [8]u8 = undefined;
var buf1: [7]u8 = undefined; var buf1: [7]u8 = undefined;
std.mem.writeIntLittle(u64, &buf0, s); std.mem.writeIntLittle(u64, &buf0, s);
Xoroshiro128.fill(&r.random, &buf1); r.fill(&buf1);
try std.testing.expect(std.mem.eql(u8, buf0[0..7], buf1[0..])); try std.testing.expect(std.mem.eql(u8, buf0[0..7], buf1[0..]));
} }
} }

View file

@ -7,13 +7,10 @@ const Random = std.rand.Random;
const math = std.math; const math = std.math;
const Xoshiro256 = @This(); const Xoshiro256 = @This();
random: Random,
s: [4]u64, s: [4]u64,
pub fn init(init_s: u64) Xoshiro256 { pub fn init(init_s: u64) Xoshiro256 {
var x = Xoshiro256{ var x = Xoshiro256{
.random = Random{ .fillFn = fill },
.s = undefined, .s = undefined,
}; };
@ -21,6 +18,10 @@ pub fn init(init_s: u64) Xoshiro256 {
return x; return x;
} }
pub fn random(self: *Xoshiro256) Random {
return Random.init(self);
}
fn next(self: *Xoshiro256) u64 { fn next(self: *Xoshiro256) u64 {
const r = math.rotl(u64, self.s[0] +% self.s[3], 23) +% self.s[0]; const r = math.rotl(u64, self.s[0] +% self.s[3], 23) +% self.s[0];
@ -64,9 +65,7 @@ pub fn seed(self: *Xoshiro256, init_s: u64) void {
self.s[3] = gen.next(); self.s[3] = gen.next();
} }
fn fill(r: *Random, buf: []u8) void { pub fn fill(self: *Xoshiro256, buf: []u8) void {
const self = @fieldParentPtr(Xoshiro256, "random", r);
var i: usize = 0; var i: usize = 0;
const aligned_len = buf.len - (buf.len & 7); const aligned_len = buf.len - (buf.len & 7);
@ -138,7 +137,7 @@ test "xoroshiro fill" {
var buf0: [8]u8 = undefined; var buf0: [8]u8 = undefined;
var buf1: [7]u8 = undefined; var buf1: [7]u8 = undefined;
std.mem.writeIntLittle(u64, &buf0, s); std.mem.writeIntLittle(u64, &buf0, s);
Xoshiro256.fill(&r.random, &buf1); r.fill(&buf1);
try std.testing.expect(std.mem.eql(u8, buf0[0..7], buf1[0..])); try std.testing.expect(std.mem.eql(u8, buf0[0..7], buf1[0..]));
} }
} }

View file

@ -13,7 +13,7 @@ const builtin = @import("builtin");
const math = std.math; const math = std.math;
const Random = std.rand.Random; const Random = std.rand.Random;
pub fn next_f64(random: *Random, comptime tables: ZigTable) f64 { pub fn next_f64(random: Random, comptime tables: ZigTable) f64 {
while (true) { while (true) {
// We manually construct a float from parts as we can avoid an extra random lookup here by // We manually construct a float from parts as we can avoid an extra random lookup here by
// using the unused exponent for the lookup table entry. // using the unused exponent for the lookup table entry.
@ -61,7 +61,7 @@ pub const ZigTable = struct {
// whether the distribution is symmetric // whether the distribution is symmetric
is_symmetric: bool, is_symmetric: bool,
// fallback calculation in the case we are in the 0 block // fallback calculation in the case we are in the 0 block
zero_case: fn (*Random, f64) f64, zero_case: fn (Random, f64) f64,
}; };
// zigNorInit // zigNorInit
@ -71,7 +71,7 @@ fn ZigTableGen(
comptime v: f64, comptime v: f64,
comptime f: fn (f64) f64, comptime f: fn (f64) f64,
comptime f_inv: fn (f64) f64, comptime f_inv: fn (f64) f64,
comptime zero_case: fn (*Random, f64) f64, comptime zero_case: fn (Random, f64) f64,
) ZigTable { ) ZigTable {
var tables: ZigTable = undefined; var tables: ZigTable = undefined;
@ -111,7 +111,7 @@ fn norm_f(x: f64) f64 {
fn norm_f_inv(y: f64) f64 { fn norm_f_inv(y: f64) f64 {
return math.sqrt(-2.0 * math.ln(y)); return math.sqrt(-2.0 * math.ln(y));
} }
fn norm_zero_case(random: *Random, u: f64) f64 { fn norm_zero_case(random: Random, u: f64) f64 {
var x: f64 = 1; var x: f64 = 1;
var y: f64 = 0; var y: f64 = 0;
@ -133,9 +133,11 @@ test "normal dist sanity" {
if (please_windows_dont_oom) return error.SkipZigTest; if (please_windows_dont_oom) return error.SkipZigTest;
var prng = std.rand.DefaultPrng.init(0); var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();
var i: usize = 0; var i: usize = 0;
while (i < 1000) : (i += 1) { while (i < 1000) : (i += 1) {
_ = prng.random.floatNorm(f64); _ = random.floatNorm(f64);
} }
} }
@ -154,7 +156,7 @@ fn exp_f(x: f64) f64 {
fn exp_f_inv(y: f64) f64 { fn exp_f_inv(y: f64) f64 {
return -math.ln(y); return -math.ln(y);
} }
fn exp_zero_case(random: *Random, _: f64) f64 { fn exp_zero_case(random: Random, _: f64) f64 {
return exp_r - math.ln(random.float(f64)); return exp_r - math.ln(random.float(f64));
} }
@ -162,9 +164,11 @@ test "exp dist sanity" {
if (please_windows_dont_oom) return error.SkipZigTest; if (please_windows_dont_oom) return error.SkipZigTest;
var prng = std.rand.DefaultPrng.init(0); var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();
var i: usize = 0; var i: usize = 0;
while (i < 1000) : (i += 1) { while (i < 1000) : (i += 1) {
_ = prng.random.floatExp(f64); _ = random.floatExp(f64);
} }
} }

View file

@ -1328,16 +1328,17 @@ test "another sort case" {
test "sort fuzz testing" { test "sort fuzz testing" {
var prng = std.rand.DefaultPrng.init(0x12345678); var prng = std.rand.DefaultPrng.init(0x12345678);
const random = prng.random();
const test_case_count = 10; const test_case_count = 10;
var i: usize = 0; var i: usize = 0;
while (i < test_case_count) : (i += 1) { while (i < test_case_count) : (i += 1) {
try fuzzTest(&prng.random); try fuzzTest(random);
} }
} }
var fixed_buffer_mem: [100 * 1024]u8 = undefined; var fixed_buffer_mem: [100 * 1024]u8 = undefined;
fn fuzzTest(rng: *std.rand.Random) !void { fn fuzzTest(rng: std.rand.Random) !void {
const array_size = rng.intRangeLessThan(usize, 0, 1000); const array_size = rng.intRangeLessThan(usize, 0, 1000);
var array = try testing.allocator.alloc(IdAndValue, array_size); var array = try testing.allocator.alloc(IdAndValue, array_size);
defer testing.allocator.free(array); defer testing.allocator.free(array);