median-accumulator/src/lib.rs

210 lines
5 KiB
Rust
Raw Normal View History

2022-09-20 17:02:56 +00:00
//! ```rust
//! use median_accumulator::*;
//!
//! let mut acc = MedianAcc::new();
//!
//! assert_eq!(acc.get_median(), None);
//! acc.push(7);
//! assert_eq!(acc.get_median(), Some(MedianResult::One(7)));
//! acc.push(5);
//! assert_eq!(acc.get_median(), Some(MedianResult::Two(5, 7)));
//! acc.push(7);
//! assert_eq!(acc.get_median(), Some(MedianResult::One(7)));
//! ```
//!
//! In doc comments, _N_ represents the number of samples, _D_ represents the number of different values taken by the samples.
use std::cmp::Ordering;
/// Accumulator for computing median
#[derive(Clone, Debug, Default)]
pub struct MedianAcc<T: Clone + Ord> {
samples: Vec<(T, u32)>,
median_index: Option<usize>,
median_subindex: u32,
}
/// Computed median
///
/// `Two` is when the median is the mean of the two values.
/// In this case, `result.0 < result.1`.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum MedianResult<T: Clone + Ord> {
One(T),
Two(T, T),
}
impl<T: Clone + Ord> MedianAcc<T> {
/// Create an empty accumulator
///
/// _O(1)_
///
/// Does not allocate until the first push.
pub fn new() -> Self {
Self {
samples: Vec::new(),
median_index: None,
median_subindex: 0,
}
}
/// Push a sample to the accumulator
///
/// _O(log(N))_
pub fn push(&mut self, sample: T) {
if let Some(median_index) = &mut self.median_index {
match self
.samples
.binary_search_by_key(&sample, |(s, _n)| s.clone())
{
Ok(sample_index) => {
self.samples.get_mut(sample_index).expect("unreachable").1 += 1;
match sample_index.cmp(median_index) {
Ordering::Greater => {
if self.median_subindex
== self.samples.get(*median_index).expect("unreachable").1 * 2 - 1
{
self.median_subindex = 0;
*median_index += 1;
} else {
self.median_subindex += 1;
}
}
Ordering::Equal => {
self.median_subindex += 1;
}
Ordering::Less => {
if self.median_subindex == 0 {
*median_index -= 1;
self.median_subindex =
self.samples.get(*median_index).expect("unreachable").1 * 2 - 1;
} else {
self.median_subindex -= 1;
}
}
}
}
Err(sample_index) => {
self.samples.insert(sample_index, (sample, 1));
if *median_index >= sample_index {
if self.median_subindex == 0 {
self.median_subindex =
self.samples.get(*median_index).expect("unreachable").1 * 2 - 1;
} else {
self.median_subindex -= 1;
*median_index += 1;
}
} else if self.median_subindex
== self.samples.get(*median_index).expect("unreachable").1 * 2 - 1
{
self.median_subindex = 0;
*median_index += 1;
} else {
self.median_subindex += 1;
}
}
}
} else {
self.samples.push((sample, 1));
self.median_index = Some(0);
}
}
/// Get the median (if there is at least one sample)
///
/// _O(1)_
pub fn get_median(&self) -> Option<MedianResult<T>> {
self.median_index.map(|median_index| {
let (median_sample, median_n) = self.samples.get(median_index).expect("unreachable");
if self.median_subindex == median_n * 2 - 1 {
MedianResult::Two(
median_sample.clone(),
self.samples
.get(median_index + 1)
.expect("unreachable")
.0
.clone(),
)
} else {
MedianResult::One(median_sample.clone())
}
})
}
/// Get the number of different values for the samples
///
/// _O(1)_
pub fn count_values(&self) -> usize {
self.samples.len()
}
/// Get the total number of samples
///
/// _O(D)_
pub fn count_samples(&self) -> usize {
self.samples
.iter()
.fold(0, |count, (_sample, n)| count + *n as usize)
}
/// Clear the data
pub fn clear(&mut self) {
self.samples.clear();
self.median_index = None;
self.median_subindex = 0;
}
/// Access the underlying vec
///
/// Just in case you need finer allocation management.
///
/// # Safety
/// Leaving the vector in an invalid state may cause invalid result or panic (but no UB).
pub unsafe fn get_samples_mut(&mut self) -> &mut Vec<(T, u32)> {
&mut self.samples
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
fn naive_median<T: Clone + Ord>(samples: &mut [T]) -> Option<MedianResult<T>> {
if samples.is_empty() {
None
} else {
samples.sort_unstable();
if samples.len() % 2 == 0 {
let r2 = samples[samples.len() / 2].clone();
let r1 = samples[samples.len() / 2 - 1].clone();
if r1 == r2 {
Some(MedianResult::One(r1))
} else {
Some(MedianResult::Two(r1, r2))
}
} else {
Some(MedianResult::One(samples[samples.len() / 2].clone()))
}
}
}
#[test]
fn correctness() {
let mut rng = rand::thread_rng();
for _ in 0..100_000 {
let len: usize = rng.gen_range(0..100);
let mut samples: Vec<i32> = (0..len).map(|_| rng.gen_range(-100..100)).collect();
let mut median = MedianAcc::new();
for sample in samples.iter() {
median.push(*sample);
}
assert_eq!(median.get_median(), naive_median(&mut samples));
}
}
}