mirror of
https://codeberg.org/ziglang/zig.git
synced 2025-12-06 05:44:20 +00:00
std.random: add weightedIndex function
`weightedIndex` picks from a selection of weighted indices.
This commit is contained in:
parent
0f27836c21
commit
5bb8c03697
2 changed files with 62 additions and 0 deletions
|
|
@ -337,6 +337,42 @@ pub const Random = struct {
|
|||
mem.swap(T, &buf[i], &buf[j]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Randomly selects an index into `proportions`, where the likelihood of each
|
||||
/// index is weighted by that proportion.
|
||||
///
|
||||
/// This is useful for selecting an item from a slice where weights are not equal.
|
||||
/// `T` must be a numeric type capable of holding the sum of `proportions`.
|
||||
pub fn weightedIndex(r: std.rand.Random, comptime T: type, proportions: []T) usize {
|
||||
// This implementation works by summing the proportions and picking a random
|
||||
// point in [0, sum). We then loop over the proportions, accumulating
|
||||
// until our accumulator is greater than the random point.
|
||||
|
||||
var sum: T = 0;
|
||||
for (proportions) |v| {
|
||||
sum += v;
|
||||
}
|
||||
|
||||
const point = if (comptime std.meta.trait.isSignedInt(T))
|
||||
r.intRangeLessThan(T, 0, sum)
|
||||
else if (comptime std.meta.trait.isUnsignedInt(T))
|
||||
r.uintLessThan(T, sum)
|
||||
else if (comptime std.meta.trait.isFloat(T))
|
||||
// take care that imprecision doesn't lead to a value slightly greater than sum
|
||||
std.math.min(r.float(T) * sum, sum - std.math.epsilon(T))
|
||||
else
|
||||
@compileError("weightedIndex does not support proportions of type " ++ @typeName(T));
|
||||
|
||||
std.debug.assert(point < sum);
|
||||
|
||||
var accumulator: T = 0;
|
||||
for (proportions) |p, index| {
|
||||
accumulator += p;
|
||||
if (point < accumulator) return index;
|
||||
}
|
||||
|
||||
unreachable;
|
||||
}
|
||||
};
|
||||
|
||||
/// Convert a random integer 0 <= random_int <= maxValue(T),
|
||||
|
|
|
|||
|
|
@ -445,3 +445,29 @@ test "CSPRNG" {
|
|||
const c = random.int(u64);
|
||||
try expect(a ^ b ^ c != 0);
|
||||
}
|
||||
|
||||
test "Random weightedIndex" {
|
||||
// Make sure weightedIndex works for various integers and floats
|
||||
inline for (.{ u64, i4, f32, f64 }) |T| {
|
||||
var prng = DefaultPrng.init(0);
|
||||
const random = prng.random();
|
||||
|
||||
var proportions = [_]T{ 2, 1, 1, 2 };
|
||||
var counts = [_]f64{ 0, 0, 0, 0 };
|
||||
|
||||
const n_trials: u64 = 10_000;
|
||||
var i: usize = 0;
|
||||
while (i < n_trials) : (i += 1) {
|
||||
const pick = random.weightedIndex(T, &proportions);
|
||||
counts[pick] += 1;
|
||||
}
|
||||
|
||||
// We expect the first and last counts to be roughly 2x the second and third
|
||||
const approxEqRel = std.math.approxEqRel;
|
||||
// Define "roughly" to be within 10%
|
||||
const tolerance = 0.1;
|
||||
try std.testing.expect(approxEqRel(f64, counts[0], counts[1] * 2, tolerance));
|
||||
try std.testing.expect(approxEqRel(f64, counts[1], counts[2], tolerance));
|
||||
try std.testing.expect(approxEqRel(f64, counts[2] * 2, counts[3], tolerance));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue