@@ -8,7 +8,10 @@ use core::iter::{Chain, FromIterator, FusedIterator};
88use core:: mem;
99use core:: ops:: { BitAnd , BitOr , BitXor , Sub } ;
1010
11- use super :: map:: { self , ConsumeAllOnDrop , DefaultHashBuilder , DrainFilterInner , HashMap , Keys } ;
11+ use super :: map:: {
12+ self , make_hash, make_insert_hash, ConsumeAllOnDrop , DefaultHashBuilder , DrainFilterInner ,
13+ HashMap , Keys , RawEntryMut ,
14+ } ;
1215use crate :: raw:: { Allocator , Global } ;
1316
1417// Future Optimization (FIXME!)
@@ -953,6 +956,12 @@ where
953956 /// Inserts a value computed from `f` into the set if the given `value` is
954957 /// not present, then returns a reference to the value in the set.
955958 ///
959+ /// # Panics
960+ ///
961+ /// Panics if the value from the function and the provided lookup value
962+ /// are not equivalent or have different hashes. See [`Equivalent`]
963+ /// and [`Hash`] for more information.
964+ ///
956965 /// # Examples
957966 ///
958967 /// ```
@@ -967,20 +976,40 @@ where
967976 /// assert_eq!(value, pet);
968977 /// }
969978 /// assert_eq!(set.len(), 4); // a new "fish" was inserted
979+ /// assert!(set.contains("fish"));
970980 /// ```
971981 #[ cfg_attr( feature = "inline-more" , inline) ]
972982 pub fn get_or_insert_with < Q : ?Sized , F > ( & mut self , value : & Q , f : F ) -> & T
973983 where
974984 Q : Hash + Equivalent < T > ,
975985 F : FnOnce ( & Q ) -> T ,
976986 {
987+ #[ cold]
988+ #[ inline( never) ]
989+ fn assert_failed ( ) {
990+ panic ! (
991+ "the value from the function and the lookup value \
992+ must be equivalent and have the same hash"
993+ ) ;
994+ }
995+
977996 // Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with
978997 // `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`.
979- self . map
980- . raw_entry_mut ( )
981- . from_key ( value)
982- . or_insert_with ( || ( f ( value) , ( ) ) )
983- . 0
998+ let hash = make_hash :: < Q , S > ( & self . map . hash_builder , value) ;
999+ let raw_entry_builder = self . map . raw_entry_mut ( ) ;
1000+ match raw_entry_builder. from_key_hashed_nocheck ( hash, value) {
1001+ RawEntryMut :: Occupied ( entry) => entry. into_key ( ) ,
1002+ RawEntryMut :: Vacant ( entry) => {
1003+ let insert_value = f ( value) ;
1004+ let insert_value_hash = make_insert_hash :: < T , S > ( entry. hasher ( ) , & insert_value) ;
1005+ if !( hash == insert_value_hash && value. equivalent ( & insert_value) ) {
1006+ assert_failed ( ) ;
1007+ }
1008+ entry
1009+ . insert_hashed_nocheck ( insert_value_hash, insert_value, ( ) )
1010+ . 0
1011+ }
1012+ }
9841013 }
9851014
9861015 /// Gets the given value's corresponding entry in the set for in-place manipulation.
@@ -2429,7 +2458,7 @@ fn assert_covariance() {
24292458#[ cfg( test) ]
24302459mod test_set {
24312460 use super :: super :: map:: DefaultHashBuilder ;
2432- use super :: HashSet ;
2461+ use super :: { make_hash , Equivalent , HashSet } ;
24332462 use std:: vec:: Vec ;
24342463
24352464 #[ test]
@@ -2886,4 +2915,100 @@ mod test_set {
28862915 set. insert ( i) ;
28872916 }
28882917 }
2918+
2919+ #[ test]
2920+ fn duplicate_insert ( ) {
2921+ let mut set = HashSet :: new ( ) ;
2922+ set. insert ( 1 ) ;
2923+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2924+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2925+ assert ! ( [ 1 ] . iter( ) . eq( set. iter( ) ) ) ;
2926+ }
2927+
2928+ #[ test]
2929+ #[ should_panic]
2930+ fn some_invalid_hash ( ) {
2931+ use core:: hash:: { Hash , Hasher } ;
2932+ struct Invalid {
2933+ count : u32 ,
2934+ }
2935+
2936+ struct InvalidRef {
2937+ count : u32 ,
2938+ }
2939+
2940+ impl PartialEq for Invalid {
2941+ fn eq ( & self , other : & Self ) -> bool {
2942+ self . count == other. count
2943+ }
2944+ }
2945+ impl Eq for Invalid { }
2946+
2947+ impl Equivalent < Invalid > for InvalidRef {
2948+ fn equivalent ( & self , key : & Invalid ) -> bool {
2949+ self . count == key. count
2950+ }
2951+ }
2952+ impl Hash for Invalid {
2953+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2954+ self . count . hash ( state) ;
2955+ }
2956+ }
2957+ impl Hash for InvalidRef {
2958+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2959+ let double = self . count * 2 ;
2960+ double. hash ( state) ;
2961+ }
2962+ }
2963+ let mut set: HashSet < Invalid > = HashSet :: new ( ) ;
2964+ let key = InvalidRef { count : 1 } ;
2965+ let value = Invalid { count : 1 } ;
2966+ if key. equivalent ( & value) {
2967+ set. get_or_insert_with ( & key, |_| value) ;
2968+ }
2969+ }
2970+
2971+ #[ test]
2972+ #[ should_panic]
2973+ fn some_invalid_equivalent ( ) {
2974+ use core:: hash:: { Hash , Hasher } ;
2975+ struct Invalid {
2976+ count : u32 ,
2977+ other : u32 ,
2978+ }
2979+
2980+ struct InvalidRef {
2981+ count : u32 ,
2982+ other : u32 ,
2983+ }
2984+
2985+ impl PartialEq for Invalid {
2986+ fn eq ( & self , other : & Self ) -> bool {
2987+ self . count == other. count && self . other == other. other
2988+ }
2989+ }
2990+ impl Eq for Invalid { }
2991+
2992+ impl Equivalent < Invalid > for InvalidRef {
2993+ fn equivalent ( & self , key : & Invalid ) -> bool {
2994+ self . count == key. count && self . other == key. other
2995+ }
2996+ }
2997+ impl Hash for Invalid {
2998+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2999+ self . count . hash ( state) ;
3000+ }
3001+ }
3002+ impl Hash for InvalidRef {
3003+ fn hash < H : Hasher > ( & self , state : & mut H ) {
3004+ self . count . hash ( state) ;
3005+ }
3006+ }
3007+ let mut set: HashSet < Invalid > = HashSet :: new ( ) ;
3008+ let key = InvalidRef { count : 1 , other : 1 } ;
3009+ let value = Invalid { count : 1 , other : 2 } ;
3010+ if make_hash ( set. hasher ( ) , & key) == make_hash ( set. hasher ( ) , & value) {
3011+ set. get_or_insert_with ( & key, |_| value) ;
3012+ }
3013+ }
28893014}
0 commit comments