@@ -870,6 +870,53 @@ where
870870 }
871871 }
872872
873+ /// Retains only the elements specified by the predicate until the predicate returns `None`.
874+ ///
875+ /// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until
876+ /// `f(&e)` returns `None`.
877+ ///
878+ /// # Examples
879+ ///
880+ /// ```
881+ /// use hashbrown::{HashTable, DefaultHashBuilder};
882+ /// use std::hash::BuildHasher;
883+ ///
884+ /// let mut table = HashTable::new();
885+ /// let hasher = DefaultHashBuilder::default();
886+ /// let hasher = |val: &_| {
887+ /// use core::hash::Hasher;
888+ /// let mut state = hasher.build_hasher();
889+ /// core::hash::Hash::hash(&val, &mut state);
890+ /// state.finish()
891+ /// };
892+ /// let mut removed = 0;
893+ /// for x in 1..=8 {
894+ /// table.insert_unique(hasher(&x), x, hasher);
895+ /// }
896+ /// table.retain_with_break(|&mut v| if removed < 3 {
897+ /// if v % 2 == 0 {
898+ /// Some(true)
899+ /// } else {
900+ /// removed += 1;
901+ /// Some(false)
902+ /// }
903+ /// } else {
904+ /// None
905+ /// });
906+ /// ```
907+ pub fn retain_with_break ( & mut self , mut f : impl FnMut ( & mut T ) -> Option < bool > ) {
908+ // Here we only use `iter` as a temporary, preventing use-after-free
909+ unsafe {
910+ for item in self . raw . iter ( ) {
911+ match f ( item. as_mut ( ) ) {
912+ Some ( false ) => self . raw . erase ( item) ,
913+ Some ( true ) => continue ,
914+ None => break ,
915+ }
916+ }
917+ }
918+ }
919+
873920 /// Clears the set, returning all elements in an iterator.
874921 ///
875922 /// # Examples
@@ -2372,12 +2419,49 @@ impl<T, F, A: Allocator> FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut
23722419
23732420#[ cfg( test) ]
23742421mod tests {
2422+ use crate :: DefaultHashBuilder ;
2423+
23752424 use super :: HashTable ;
23762425
2426+ use core:: hash:: BuildHasher ;
23772427 #[ test]
23782428 fn test_allocation_info ( ) {
23792429 assert_eq ! ( HashTable :: <( ) >:: new( ) . allocation_size( ) , 0 ) ;
23802430 assert_eq ! ( HashTable :: <u32 >:: new( ) . allocation_size( ) , 0 ) ;
23812431 assert ! ( HashTable :: <u32 >:: with_capacity( 1 ) . allocation_size( ) > core:: mem:: size_of:: <u32 >( ) ) ;
23822432 }
2433+
2434+ #[ test]
2435+ fn test_retain_with_break ( ) {
2436+ let mut table = HashTable :: new ( ) ;
2437+ let hasher = DefaultHashBuilder :: default ( ) ;
2438+ let hasher = |val : & _ | {
2439+ use core:: hash:: Hasher ;
2440+ let mut state = hasher. build_hasher ( ) ;
2441+ core:: hash:: Hash :: hash ( & val, & mut state) ;
2442+ state. finish ( )
2443+ } ;
2444+ for x in 0 ..100 {
2445+ table. insert_unique ( hasher ( & x) , x, hasher) ;
2446+ }
2447+ // looping and removing any value > 50, but stop after 40 iterations
2448+ let mut removed = 0 ;
2449+ table. retain_with_break ( |& mut v| {
2450+ if removed < 40 {
2451+ if v > 50 {
2452+ removed += 1 ;
2453+ Some ( false )
2454+ } else {
2455+ Some ( true )
2456+ }
2457+ } else {
2458+ None
2459+ }
2460+ } ) ;
2461+ assert_eq ! ( table. len( ) , 60 ) ;
2462+ // check nothing up to 50 is removed
2463+ for v in 0 ..=50 {
2464+ assert_eq ! ( table. find( hasher( & v) , |& val| val == v) , Some ( & v) ) ;
2465+ }
2466+ }
23832467}
0 commit comments