From 6707e515338aa02dbe928ded1fca16a0154d9426 Mon Sep 17 00:00:00 2001 From: Max Heller Date: Tue, 9 Jun 2026 10:44:54 -0400 Subject: [PATCH] vec: add `extract_if()` Adapts alloc::vec::Vec::extract_if() --- CHANGELOG.md | 1 + src/vec/extract_if.rs | 296 ++++++++++++++++++++++++++++++++++++++++++ src/vec/mod.rs | 89 +++++++++++++ 3 files changed, 386 insertions(+) create mode 100644 src/vec/extract_if.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index caa938303f..4c7b1f488c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] - Added `resize_with` to `Vec` +- Added `Vec::extract_if`. ## [v0.9.3] 2025-04-15 diff --git a/src/vec/extract_if.rs b/src/vec/extract_if.rs new file mode 100644 index 0000000000..5d0faddf72 --- /dev/null +++ b/src/vec/extract_if.rs @@ -0,0 +1,296 @@ +use core::{ + fmt, + ops::{Range, RangeBounds}, + ptr, slice, +}; + +use crate::LenType; + +use super::VecView; + +/// An iterator which uses a closure to determine if an element should be removed. +/// +/// This struct is created by [`Vec::extract_if`]. +/// See its documentation for more. +/// +/// [`Vec::extract_if`]: crate::vec::VecInner::extract_if +/// +/// # Example +/// +/// ``` +/// use heapless::Vec; +/// +/// let mut v = Vec::<_, 4>::from_array([0, 1, 2]); +/// let iter: heapless::vec::ExtractIf<'_, _, _, _> = v.extract_if(.., |x| *x % 2 == 0); +/// ``` +#[must_use = "iterators are lazy and do nothing unless consumed; \ + use `retain_mut` or `extract_if().for_each(drop)` to remove and discard elements"] +pub struct ExtractIf<'a, T, F, LenT: LenType> { + vec: &'a mut VecView, + /// The index of the item that will be inspected by the next call to `next`. + idx: usize, + /// Elements at and beyond this point will be retained. Must be equal or smaller than + /// `old_len`. + end: usize, + /// The number of items that have been drained (removed) thus far. + del: usize, + /// The original length of `vec` prior to draining. + old_len: usize, + /// The filter test predicate. + pred: F, +} + +impl<'a, T, F, LenT: LenType> ExtractIf<'a, T, F, LenT> { + pub(super) fn new>( + vec: &'a mut VecView, + pred: F, + range: R, + ) -> Self { + let old_len = vec.len(); + let Range { start, end } = crate::slice::range(range, ..old_len); + + // Guard against the vec getting leaked (leak amplification) + unsafe { + vec.set_len(0); + } + ExtractIf { + vec, + idx: start, + del: 0, + end, + old_len, + pred, + } + } +} + +impl Iterator for ExtractIf<'_, T, F, LenT> +where + F: FnMut(&mut T) -> bool, +{ + type Item = T; + + fn next(&mut self) -> Option { + while self.idx < self.end { + let i = self.idx; + let buf_ptr = self.vec.as_mut_ptr(); + // SAFETY: + // We know that `i < self.end` from the if guard and that `self.end <= self.old_len` + // from the validity of `Self`. Therefore `i` points to an element within `vec`. + // + // Additionally, the i-th element is valid because each element is visited at most once + // and it is the first time we access vec[i]. + // + // Note: we can't use `vec.get_unchecked_mut(i)` here since the precondition for that + // function is that i < vec.len(), but we've set vec's length to zero. + let cur = unsafe { &mut *buf_ptr.add(i) }; + let drained = (self.pred)(cur); + // Update the index *after* the predicate is called. If the index + // is updated prior and the predicate panics, the element at this + // index would be leaked. + self.idx += 1; + if drained { + self.del += 1; + // SAFETY: We never touch this element again after returning it. + return Some(unsafe { ptr::read(cur) }); + } else if self.del > 0 { + // SAFETY: `self.del` > 0, so the hole slot must not overlap with current element. + // We use copy for move, and never touch this element again. + unsafe { + let hole_slot = buf_ptr.add(i - self.del); + ptr::copy_nonoverlapping(cur, hole_slot, 1); + } + } + } + None + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.end - self.idx)) + } +} + +impl Drop for ExtractIf<'_, T, F, LenT> { + fn drop(&mut self) { + if self.del > 0 { + let ptr = self.vec.as_mut_ptr(); + // SAFETY: Trailing unchecked items must be valid since we never touch them. + unsafe { + ptr::copy( + ptr.cast_const().add(self.idx), + ptr.add(self.idx - self.del), + self.old_len - self.idx, + ); + } + } + // SAFETY: After filling holes, all items are in contiguous memory. + unsafe { + self.vec.set_len(self.old_len - self.del); + } + } +} + +impl fmt::Debug for ExtractIf<'_, T, F, LenT> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // We have to use pointer arithmetic here, + // because the length of `self.vec` is temporarily set to 0. + let start = self.vec.as_ptr(); + + // SAFETY: we always keep first `self.idx - self.del` elements valid. + let retained = unsafe { slice::from_raw_parts(start, self.idx - self.del) }; + + // SAFETY: we have not yet touched elements starting at `self.idx`. + let valid_tail = + unsafe { slice::from_raw_parts(start.add(self.idx), self.old_len - self.idx) }; + + // SAFETY: `end - idx <= old_len - idx`, because `end <= old_len`. Also `idx <= end` by + // invariant. + let (remainder, skipped_tail) = + unsafe { valid_tail.split_at_unchecked(self.end - self.idx) }; + + f.debug_struct("ExtractIf") + .field("retained", &retained) + .field("remainder", &remainder) + .field("skipped_tail", &skipped_tail) + .finish_non_exhaustive() + } +} + +#[cfg(test)] +mod tests { + use super::super::Vec; + + #[test] + fn extract_if_empty() { + let mut vec = Vec::::new(); + + { + let mut iter = vec.extract_if(.., |_| true); + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + assert_eq!(vec.len(), 0); + assert_eq!(vec, []); + } + + #[test] + fn extract_if_zst() { + let mut vec = Vec::<(), 8>::from_array([(), (), (), (), ()]); + let initial_len = vec.len(); + let mut count = 0; + { + let mut iter = vec.extract_if(.., |_| true); + assert_eq!(iter.size_hint(), (0, Some(initial_len))); + while let Some(_) = iter.next() { + count += 1; + assert_eq!(iter.size_hint(), (0, Some(initial_len - count))); + } + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + + assert_eq!(count, initial_len); + assert_eq!(vec.len(), 0); + assert_eq!(vec, []); + } + + #[test] + fn extract_if_false() { + let mut vec = Vec::<_, 16>::from_array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + let initial_len = vec.len(); + let mut count = 0; + { + let mut iter = vec.extract_if(.., |_| false); + assert_eq!(iter.size_hint(), (0, Some(initial_len))); + for _ in iter.by_ref() { + count += 1; + } + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + + assert_eq!(count, 0); + assert_eq!(vec.len(), initial_len); + assert_eq!(vec, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + } + + #[test] + fn extract_if_true() { + let mut vec = Vec::<_, 16>::from_array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + let initial_len = vec.len(); + let mut count = 0; + { + let mut iter = vec.extract_if(.., |_| true); + assert_eq!(iter.size_hint(), (0, Some(initial_len))); + while let Some(_) = iter.next() { + count += 1; + assert_eq!(iter.size_hint(), (0, Some(initial_len - count))); + } + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + + assert_eq!(count, initial_len); + assert_eq!(vec.len(), 0); + assert_eq!(vec, []); + } + + #[test] + fn extract_if_ranges() { + let mut vec = Vec::<_, 16>::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + let mut count = 0; + let it = vec.extract_if(1..=3, |_| { + count += 1; + true + }); + assert_eq!(it.collect::>(), [1, 2, 3]); + assert_eq!(vec, [0, 4, 5, 6, 7, 8, 9, 10]); + assert_eq!(count, 3); + + let it = vec.extract_if(1..=3, |_| false); + assert_eq!(it.collect::>(), []); + assert_eq!(vec, [0, 4, 5, 6, 7, 8, 9, 10]); + } + + #[test] + #[should_panic] + fn extract_if_out_of_bounds() { + let mut vec = Vec::<_, 8>::from_array([0, 1]); + vec.extract_if(5.., |_| true).for_each(drop); + } + + #[test] + fn extract_if_unconsumed() { + let mut vec = Vec::<_, 4>::from_array([1, 2, 3, 4]); + let drain = vec.extract_if(.., |&mut x| x % 2 != 0); + drop(drain); + assert_eq!(vec, [1, 2, 3, 4]); + } + + #[test] + fn extract_if_debug() { + let mut vec = Vec::<_, 8>::from_array([1, 2, 3, 4, 5, 6, 7, 8]); + let mut drain = vec.extract_if(1..5, |&mut x| x % 2 != 0); + assert_eq!( + format!("{drain:?}"), + "ExtractIf { retained: [1], remainder: [2, 3, 4, 5], skipped_tail: [6, 7, 8], .. }" + ); + drain.next().unwrap(); + assert_eq!( + format!("{drain:?}"), + "ExtractIf { retained: [1, 2], remainder: [4, 5], skipped_tail: [6, 7, 8], .. }" + ); + } +} diff --git a/src/vec/mod.rs b/src/vec/mod.rs index 8d0648eb5a..daf314ee9c 100644 --- a/src/vec/mod.rs +++ b/src/vec/mod.rs @@ -21,6 +21,7 @@ use crate::{ }; mod drain; +mod extract_if; mod storage { use core::mem::MaybeUninit; @@ -218,6 +219,7 @@ pub use storage::{OwnedVecStorage, VecStorage, ViewVecStorage}; pub(crate) use storage::VecStorageInner; pub use drain::Drain; +pub use extract_if::ExtractIf; /// Base struct for [`Vec`] and [`VecView`], generic over the [`VecStorage`]. /// @@ -511,6 +513,93 @@ impl + ?Sized> VecInner { } } + /// Creates an iterator which uses a closure to determine if an element in the range should be + /// removed. + /// + /// If the closure returns `true`, the element is removed from the vector + /// and yielded. If the closure returns `false`, or panics, the element + /// remains in the vector and will not be yielded. + /// + /// Only elements that fall in the provided range are considered for extraction, but any + /// elements after the range will still have to be moved if any element has been extracted. + /// + /// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating + /// or the iteration short-circuits, then the remaining elements will be retained. + /// Use `extract_if().for_each(drop)` if you do not need the returned iterator, + /// or [`retain_mut`] with a negated predicate if you also do not need to restrict the range. + /// + /// [`retain_mut`]: VecInner::retain_mut + /// + /// Using this method is equivalent to the following code: + /// + /// ``` + /// use heapless::Vec; + /// + /// # let some_predicate = |x: &mut i32| { *x % 2 == 1 }; + /// # let mut vec = Vec::<_, 8>::from_array([0, 1, 2, 3, 4, 5, 6]); + /// # let mut vec2 = vec.clone(); + /// # let range = 1..5; + /// let mut i = range.start; + /// let end_items = vec.len() - range.end; + /// # let mut extracted = Vec::<_, 8>::new(); + /// + /// while i < vec.len() - end_items { + /// if some_predicate(&mut vec[i]) { + /// let val = vec.remove(i); + /// // your code here + /// # extracted.push(val); + /// } else { + /// i += 1; + /// } + /// } + /// + /// # let extracted2: Vec<_, 8> = vec2.extract_if(range, some_predicate).collect(); + /// # assert_eq!(vec, vec2); + /// # assert_eq!(extracted, extracted2); + /// ``` + /// + /// But `extract_if` is easier to use. `extract_if` is also more efficient, + /// because it can backshift the elements of the array in bulk. + /// + /// The iterator also lets you mutate the value of each element in the + /// closure, regardless of whether you choose to keep or remove it. + /// + /// # Panics + /// + /// If `range` is out of bounds. + /// + /// # Examples + /// + /// Splitting a vector into even and odd values, reusing the original vector: + /// + /// ``` + /// # use heapless::Vec; + /// let mut numbers = Vec::<_, 16>::from_array([1, 2, 3, 4, 5, 6, 8, 9, 11, 13, 14, 15]); + /// + /// let evens = numbers.extract_if(.., |x| *x % 2 == 0).collect::>(); + /// let odds = numbers; + /// + /// assert_eq!(evens, Vec::<_, 16>::from_array([2, 4, 6, 8, 14])); + /// assert_eq!(odds, Vec::<_, 16>::from_array([1, 3, 5, 9, 11, 13, 15])); + /// ``` + /// + /// Using the range argument to only process a part of the vector: + /// + /// ``` + /// # use heapless::Vec; + /// let mut items = Vec::<_, 16>::from_array([0, 0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2]); + /// let ones = items.extract_if(7.., |x| *x == 1).collect::>(); + /// assert_eq!(items, [0, 0, 0, 0, 0, 0, 0, 2, 2, 2]); + /// assert_eq!(ones.len(), 3); + /// ``` + pub fn extract_if(&mut self, range: R, filter: F) -> ExtractIf<'_, T, F, LenT> + where + F: FnMut(&mut T) -> bool, + R: RangeBounds, + { + ExtractIf::new(self.as_mut_view(), filter, range) + } + /// Get a reference to the `Vec`, erasing the `N` const-generic. /// ///