std.random: add weightedIndex function

`weightedIndex` picks from a selection of weighted indices.
This commit is contained in:
Justin Whear 2022-08-28 04:19:51 -07:00 committed by GitHub
parent 0f27836c21
commit 5bb8c03697
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 0 deletions

View file

@ -337,6 +337,42 @@ pub const Random = struct {
mem.swap(T, &buf[i], &buf[j]);
}
}
/// Randomly selects an index into `proportions`, where the likelihood of each
/// index is weighted by that proportion.
///
/// This is useful for selecting an item from a slice where weights are not equal.
/// `T` must be a numeric type capable of holding the sum of `proportions`.
pub fn weightedIndex(r: std.rand.Random, comptime T: type, proportions: []T) usize {
// This implementation works by summing the proportions and picking a random
// point in [0, sum). We then loop over the proportions, accumulating
// until our accumulator is greater than the random point.
var sum: T = 0;
for (proportions) |v| {
sum += v;
}
const point = if (comptime std.meta.trait.isSignedInt(T))
r.intRangeLessThan(T, 0, sum)
else if (comptime std.meta.trait.isUnsignedInt(T))
r.uintLessThan(T, sum)
else if (comptime std.meta.trait.isFloat(T))
// take care that imprecision doesn't lead to a value slightly greater than sum
std.math.min(r.float(T) * sum, sum - std.math.epsilon(T))
else
@compileError("weightedIndex does not support proportions of type " ++ @typeName(T));
std.debug.assert(point < sum);
var accumulator: T = 0;
for (proportions) |p, index| {
accumulator += p;
if (point < accumulator) return index;
}
unreachable;
}
};
/// Convert a random integer 0 <= random_int <= maxValue(T),

View file

@ -445,3 +445,29 @@ test "CSPRNG" {
const c = random.int(u64);
try expect(a ^ b ^ c != 0);
}
test "Random weightedIndex" {
// Make sure weightedIndex works for various integers and floats
inline for (.{ u64, i4, f32, f64 }) |T| {
var prng = DefaultPrng.init(0);
const random = prng.random();
var proportions = [_]T{ 2, 1, 1, 2 };
var counts = [_]f64{ 0, 0, 0, 0 };
const n_trials: u64 = 10_000;
var i: usize = 0;
while (i < n_trials) : (i += 1) {
const pick = random.weightedIndex(T, &proportions);
counts[pick] += 1;
}
// We expect the first and last counts to be roughly 2x the second and third
const approxEqRel = std.math.approxEqRel;
// Define "roughly" to be within 10%
const tolerance = 0.1;
try std.testing.expect(approxEqRel(f64, counts[0], counts[1] * 2, tolerance));
try std.testing.expect(approxEqRel(f64, counts[1], counts[2], tolerance));
try std.testing.expect(approxEqRel(f64, counts[2] * 2, counts[3], tolerance));
}
}