272 lines
6.4 KiB
Rust
272 lines
6.4 KiB
Rust
//! ```rust
|
|
//! use median_accumulator::*;
|
|
//!
|
|
//! let mut acc = vec::MedianAcc::new();
|
|
//!
|
|
//! assert_eq!(acc.get_median(), None);
|
|
//! acc.push(7);
|
|
//! assert_eq!(acc.get_median(), Some(MedianResult::One(7)));
|
|
//! acc.push(5);
|
|
//! assert_eq!(acc.get_median(), Some(MedianResult::Two(5, 7)));
|
|
//! acc.push(7);
|
|
//! assert_eq!(acc.get_median(), Some(MedianResult::One(7)));
|
|
//! ```
|
|
//!
|
|
//! In doc comments, _N_ represents the number of samples, _D_ represents the number of different values taken by the samples.
|
|
|
|
#![cfg_attr(not(feature = "std"), no_std)]
|
|
|
|
mod traits;
|
|
|
|
pub use traits::*;
|
|
|
|
use core::{cmp::Ordering, ops::DerefMut};
|
|
|
|
/// Accumulator for computing median
|
|
#[derive(Clone, Debug, Default)]
|
|
pub struct MedianAcc<
|
|
T: Clone + Ord,
|
|
V: DerefMut<Target = [(T, u32)]> + cc_traits::VecMut<(T, u32)> + InsertIndex,
|
|
> {
|
|
samples: V,
|
|
median_index: Option<usize>,
|
|
median_subindex: u32,
|
|
_t: core::marker::PhantomData<T>,
|
|
}
|
|
|
|
#[cfg(feature = "std")]
|
|
pub mod vec {
|
|
pub type MedianAcc<T> = crate::MedianAcc<T, Vec<(T, u32)>>;
|
|
}
|
|
|
|
/// Computed median
|
|
///
|
|
/// `Two` is when the median is the mean of the two values.
|
|
/// In this case, `result.0 < result.1`.
|
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
|
pub enum MedianResult<T: Clone + Ord> {
|
|
One(T),
|
|
Two(T, T),
|
|
}
|
|
|
|
impl<
|
|
T: Clone + Ord,
|
|
V: DerefMut<Target = [(T, u32)]> + cc_traits::VecMut<(T, u32)> + InsertIndex,
|
|
> MedianAcc<T, V>
|
|
{
|
|
/// Create an empty accumulator
|
|
///
|
|
/// _O(1)_
|
|
///
|
|
/// If using `std::vec::Vec`, does not allocate until the first push.
|
|
pub fn new() -> Self
|
|
where
|
|
V: Default,
|
|
{
|
|
Self {
|
|
samples: Default::default(),
|
|
median_index: None,
|
|
median_subindex: 0,
|
|
_t: Default::default(),
|
|
}
|
|
}
|
|
|
|
/// Create an empty accumulator from an existing (empty) collection
|
|
///
|
|
/// _O(1)_
|
|
///
|
|
/// Useful when using fixed-length collections or to avoid allocations.
|
|
pub fn new_from(collection: V) -> Self {
|
|
assert!(collection.is_empty(), "the collection must be empty");
|
|
|
|
Self {
|
|
samples: collection,
|
|
median_index: None,
|
|
median_subindex: 0,
|
|
_t: Default::default(),
|
|
}
|
|
}
|
|
|
|
/// Push a sample to the accumulator
|
|
///
|
|
/// _O(log(N))_
|
|
pub fn push(&mut self, sample: T) {
|
|
if let Some(median_index) = &mut self.median_index {
|
|
match self
|
|
.samples
|
|
.binary_search_by_key(&sample, |(s, _n)| s.clone())
|
|
{
|
|
Ok(sample_index) => {
|
|
self.samples.get_mut(sample_index).expect("unreachable").1 += 1;
|
|
match sample_index.cmp(median_index) {
|
|
Ordering::Greater => {
|
|
if self.median_subindex
|
|
== self.samples.get(*median_index).expect("unreachable").1 * 2 - 1
|
|
{
|
|
self.median_subindex = 0;
|
|
*median_index += 1;
|
|
} else {
|
|
self.median_subindex += 1;
|
|
}
|
|
}
|
|
Ordering::Equal => {
|
|
self.median_subindex += 1;
|
|
}
|
|
Ordering::Less => {
|
|
if self.median_subindex == 0 {
|
|
*median_index -= 1;
|
|
self.median_subindex =
|
|
self.samples.get(*median_index).expect("unreachable").1 * 2 - 1;
|
|
} else {
|
|
self.median_subindex -= 1;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(sample_index) => {
|
|
self.samples.insert_index(sample_index, (sample, 1));
|
|
if *median_index >= sample_index {
|
|
if self.median_subindex == 0 {
|
|
self.median_subindex =
|
|
self.samples.get(*median_index).expect("unreachable").1 * 2 - 1;
|
|
} else {
|
|
self.median_subindex -= 1;
|
|
*median_index += 1;
|
|
}
|
|
} else if self.median_subindex
|
|
== self.samples.get(*median_index).expect("unreachable").1 * 2 - 1
|
|
{
|
|
self.median_subindex = 0;
|
|
*median_index += 1;
|
|
} else {
|
|
self.median_subindex += 1;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
self.samples.push_back((sample, 1));
|
|
self.median_index = Some(0);
|
|
}
|
|
}
|
|
|
|
/// Get the median (if there is at least one sample)
|
|
///
|
|
/// _O(1)_
|
|
pub fn get_median(&self) -> Option<MedianResult<T>> {
|
|
self.median_index.map(|median_index| {
|
|
let (median_sample, median_n) = self.samples.get(median_index).expect("unreachable");
|
|
if self.median_subindex == median_n * 2 - 1 {
|
|
MedianResult::Two(
|
|
median_sample.clone(),
|
|
self.samples
|
|
.get(median_index + 1)
|
|
.expect("unreachable")
|
|
.0
|
|
.clone(),
|
|
)
|
|
} else {
|
|
MedianResult::One(median_sample.clone())
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Get the number of different values for the samples
|
|
///
|
|
/// _O(1)_
|
|
pub fn count_values(&self) -> usize {
|
|
self.samples.len()
|
|
}
|
|
|
|
/// Get the total number of samples
|
|
///
|
|
/// _O(D)_
|
|
pub fn count_samples(&self) -> usize {
|
|
self.samples
|
|
.iter()
|
|
.fold(0, |count, (_sample, n)| count + *n as usize)
|
|
}
|
|
|
|
/// Clear the data
|
|
pub fn clear(&mut self)
|
|
where
|
|
V: cc_traits::Clear,
|
|
{
|
|
self.samples.clear();
|
|
self.median_index = None;
|
|
self.median_subindex = 0;
|
|
}
|
|
|
|
/// Access the underlying collection
|
|
///
|
|
/// Just in case you need finer allocation management.
|
|
///
|
|
/// # Safety
|
|
/// Leaving the vector in an invalid state may cause invalid result or panic (but no UB).
|
|
pub unsafe fn get_samples_mut(&mut self) -> &mut V {
|
|
&mut self.samples
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
use rand::Rng;
|
|
|
|
#[cfg(feature = "std")]
|
|
fn naive_median<T: Clone + Ord>(samples: &mut [T]) -> Option<MedianResult<T>> {
|
|
if samples.is_empty() {
|
|
None
|
|
} else {
|
|
samples.sort_unstable();
|
|
if samples.len() % 2 == 0 {
|
|
let r2 = samples[samples.len() / 2].clone();
|
|
let r1 = samples[samples.len() / 2 - 1].clone();
|
|
if r1 == r2 {
|
|
Some(MedianResult::One(r1))
|
|
} else {
|
|
Some(MedianResult::Two(r1, r2))
|
|
}
|
|
} else {
|
|
Some(MedianResult::One(samples[samples.len() / 2].clone()))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "std")]
|
|
#[test]
|
|
fn correctness() {
|
|
let mut rng = rand::thread_rng();
|
|
|
|
for _ in 0..100_000 {
|
|
let len: usize = rng.gen_range(0..100);
|
|
let mut samples: Vec<i32> = (0..len).map(|_| rng.gen_range(-100..100)).collect();
|
|
|
|
let mut median = vec::MedianAcc::new();
|
|
for sample in samples.iter() {
|
|
median.push(*sample);
|
|
}
|
|
|
|
assert_eq!(median.get_median(), naive_median(&mut samples));
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "smallvec")]
|
|
#[test]
|
|
fn correctness_smallvec() {
|
|
let mut rng = rand::thread_rng();
|
|
|
|
for _ in 0..100_000 {
|
|
let len: usize = rng.gen_range(0..64);
|
|
let mut samples: Vec<i32> = (0..len).map(|_| rng.gen_range(-100..100)).collect();
|
|
|
|
let mut median = MedianAcc::<i32, smallvec::SmallVec<[(i32, u32); 64]>>::new();
|
|
for sample in samples.iter() {
|
|
median.push(*sample);
|
|
}
|
|
|
|
assert_eq!(median.get_median(), naive_median(&mut samples));
|
|
}
|
|
}
|
|
}
|