@@ -8,7 +8,10 @@ use core::iter::{Chain, FromIterator, FusedIterator};
88use core:: mem;
99use core:: ops:: { BitAnd , BitOr , BitXor , Sub } ;
1010
11- use super :: map:: { self , ConsumeAllOnDrop , DefaultHashBuilder , DrainFilterInner , HashMap , Keys } ;
11+ use super :: map:: {
12+ self , make_hash, make_insert_hash, ConsumeAllOnDrop , DefaultHashBuilder , DrainFilterInner ,
13+ HashMap , Keys , RawEntryMut ,
14+ } ;
1215use crate :: raw:: { Allocator , Global } ;
1316
1417// Future Optimization (FIXME!)
@@ -953,6 +956,12 @@ where
953956 /// Inserts a value computed from `f` into the set if the given `value` is
954957 /// not present, then returns a reference to the value in the set.
955958 ///
959+ /// # Panics
960+ ///
961+ /// Panics if the value from the function and the provided lookup value
962+ /// are not equivalent or have different hashes. See [`Equivalent`]
963+ /// and [`Hash`] for more information.
964+ ///
956965 /// # Examples
957966 ///
958967 /// ```
@@ -967,20 +976,40 @@ where
967976 /// assert_eq!(value, pet);
968977 /// }
969978 /// assert_eq!(set.len(), 4); // a new "fish" was inserted
979+ /// assert!(set.contains("fish"));
970980 /// ```
971981 #[ cfg_attr( feature = "inline-more" , inline) ]
972982 pub fn get_or_insert_with < Q : ?Sized , F > ( & mut self , value : & Q , f : F ) -> & T
973983 where
974984 Q : Hash + Equivalent < T > ,
975985 F : FnOnce ( & Q ) -> T ,
976986 {
987+ #[ cold]
988+ #[ inline( never) ]
989+ fn assert_failed ( ) {
990+ panic ! (
991+ "the value from the function and the lookup value \
992+ must be equivalent and have the same hash"
993+ ) ;
994+ }
995+
977996 // Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with
978997 // `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`.
979- self . map
980- . raw_entry_mut ( )
981- . from_key ( value)
982- . or_insert_with ( || ( f ( value) , ( ) ) )
983- . 0
998+ let hash = make_hash :: < Q , S > ( & self . map . hash_builder , value) ;
999+ let raw_entry_builder = self . map . raw_entry_mut ( ) ;
1000+ match raw_entry_builder. from_key_hashed_nocheck ( hash, value) {
1001+ RawEntryMut :: Occupied ( entry) => entry. into_key ( ) ,
1002+ RawEntryMut :: Vacant ( entry) => {
1003+ let insert_value = f ( value) ;
1004+ let insert_value_hash = make_insert_hash :: < T , S > ( entry. hasher ( ) , & insert_value) ;
1005+ if !( hash == insert_value_hash && value. equivalent ( & insert_value) ) {
1006+ assert_failed ( ) ;
1007+ }
1008+ entry
1009+ . insert_hashed_nocheck ( insert_value_hash, insert_value, ( ) )
1010+ . 0
1011+ }
1012+ }
9841013 }
9851014
9861015 /// Gets the given value's corresponding entry in the set for in-place manipulation.
@@ -2429,7 +2458,7 @@ fn assert_covariance() {
24292458#[ cfg( test) ]
24302459mod test_set {
24312460 use super :: super :: map:: DefaultHashBuilder ;
2432- use super :: HashSet ;
2461+ use super :: { make_hash , Equivalent , HashSet } ;
24332462 use std:: vec:: Vec ;
24342463
24352464 #[ test]
@@ -2886,4 +2915,88 @@ mod test_set {
28862915 set. insert ( i) ;
28872916 }
28882917 }
2918+
2919+ #[ test]
2920+ fn duplicate_insert ( ) {
2921+ let mut set = HashSet :: new ( ) ;
2922+ set. insert ( 1 ) ;
2923+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2924+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2925+ assert ! ( [ 1 ] . iter( ) . eq( set. iter( ) ) ) ;
2926+ }
2927+
2928+ #[ test]
2929+ #[ allow( clippy:: derived_hash_with_manual_eq) ]
2930+ #[ should_panic]
2931+ fn some_invalid_hash ( ) {
2932+ use core:: hash:: { Hash , Hasher } ;
2933+ #[ derive( Eq , PartialEq ) ]
2934+ struct Invalid {
2935+ count : u32 ,
2936+ }
2937+
2938+ struct InvalidRef {
2939+ count : u32 ,
2940+ }
2941+ impl Equivalent < Invalid > for InvalidRef {
2942+ fn equivalent ( & self , key : & Invalid ) -> bool {
2943+ self . count == key. count
2944+ }
2945+ }
2946+ impl Hash for Invalid {
2947+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2948+ self . count . hash ( state) ;
2949+ }
2950+ }
2951+ impl Hash for InvalidRef {
2952+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2953+ let double = self . count * 2 ;
2954+ double. hash ( state) ;
2955+ }
2956+ }
2957+ let mut set: HashSet < Invalid > = HashSet :: new ( ) ;
2958+ let key = InvalidRef { count : 1 } ;
2959+ let value = Invalid { count : 1 } ;
2960+ if key. equivalent ( & value) {
2961+ set. get_or_insert_with ( & key, |_| value) ;
2962+ }
2963+ }
2964+
2965+ #[ test]
2966+ #[ allow( clippy:: derived_hash_with_manual_eq) ]
2967+ #[ should_panic]
2968+ fn some_invalid_equivalent ( ) {
2969+ use core:: hash:: { Hash , Hasher } ;
2970+ #[ derive( Eq , PartialEq ) ]
2971+ struct Invalid {
2972+ count : u32 ,
2973+ other : u32 ,
2974+ }
2975+
2976+ struct InvalidRef {
2977+ count : u32 ,
2978+ other : u32 ,
2979+ }
2980+ impl Equivalent < Invalid > for InvalidRef {
2981+ fn equivalent ( & self , key : & Invalid ) -> bool {
2982+ self . count == key. count && self . other == key. other
2983+ }
2984+ }
2985+ impl Hash for Invalid {
2986+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2987+ self . count . hash ( state) ;
2988+ }
2989+ }
2990+ impl Hash for InvalidRef {
2991+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2992+ self . count . hash ( state) ;
2993+ }
2994+ }
2995+ let mut set: HashSet < Invalid > = HashSet :: new ( ) ;
2996+ let key = InvalidRef { count : 1 , other : 1 } ;
2997+ let value = Invalid { count : 1 , other : 2 } ;
2998+ if make_hash ( set. hasher ( ) , & key) == make_hash ( set. hasher ( ) , & value) {
2999+ set. get_or_insert_with ( & key, |_| value) ;
3000+ }
3001+ }
28893002}
0 commit comments