1313// limitations under the License.
1414
1515use std:: collections:: HashMap ;
16+ use std:: collections:: VecDeque ;
1617use std:: sync:: Arc ;
1718use std:: time:: Duration ;
1819use std:: time:: Instant ;
@@ -95,23 +96,34 @@ pub struct JwkKeys {
9596 pub keys : Vec < JwkKey > ,
9697}
9798
99+ /// [`JwkKeyStore`] is a store for JWKS keys, it will cache the keys for a while and refresh the
100+ /// keys periodically. When the keys are refreshed, the older keys will still be kept for a while.
101+ ///
102+ /// When the keys rotated in the client side first, the server will respond a 401 Authorization Failure
103+ /// error, as the key is not found in the cache. We'll try to refresh the keys and try again.
98104pub struct JwkKeyStore {
99- pub ( crate ) url : String ,
100- cached_keys : Arc < RwLock < HashMap < String , PubKey > > > ,
101- pub ( crate ) last_refreshed_at : RwLock < Option < Instant > > ,
102- pub ( crate ) refresh_interval : Duration ,
103- pub ( crate ) refresh_timeout : Duration ,
104- pub ( crate ) load_keys_func : Option < Arc < dyn Fn ( ) -> HashMap < String , PubKey > + Send + Sync > > ,
105+ url : String ,
106+ recent_cached_maps : Arc < RwLock < VecDeque < HashMap < String , PubKey > > > > ,
107+ last_refreshed_time : RwLock < Option < Instant > > ,
108+ last_retry_time : RwLock < Option < Instant > > ,
109+ max_recent_cached_maps : usize ,
110+ refresh_interval : Duration ,
111+ refresh_timeout : Duration ,
112+ retry_interval : Duration ,
113+ load_keys_func : Option < Arc < dyn Fn ( ) -> HashMap < String , PubKey > + Send + Sync > > ,
105114}
106115
107116impl JwkKeyStore {
108117 pub fn new ( url : String ) -> Self {
109118 Self {
110119 url,
111- cached_keys : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
120+ recent_cached_maps : Arc :: new ( RwLock :: new ( VecDeque :: new ( ) ) ) ,
121+ max_recent_cached_maps : 2 ,
112122 refresh_interval : Duration :: from_secs ( JWKS_REFRESH_INTERVAL ) ,
113123 refresh_timeout : Duration :: from_secs ( JWKS_REFRESH_TIMEOUT ) ,
114- last_refreshed_at : RwLock :: new ( None ) ,
124+ retry_interval : Duration :: from_secs ( 2 ) ,
125+ last_refreshed_time : RwLock :: new ( None ) ,
126+ last_retry_time : RwLock :: new ( None ) ,
115127 load_keys_func : None ,
116128 }
117129 }
@@ -135,6 +147,16 @@ impl JwkKeyStore {
135147 self
136148 }
137149
150+ pub fn with_max_recent_cached_maps ( mut self , max : usize ) -> Self {
151+ self . max_recent_cached_maps = max;
152+ self
153+ }
154+
155+ pub fn with_retry_interval ( mut self , interval : u64 ) -> Self {
156+ self . retry_interval = Duration :: from_secs ( interval) ;
157+ self
158+ }
159+
138160 pub fn url ( & self ) -> String {
139161 self . url . clone ( )
140162 }
@@ -153,9 +175,11 @@ impl JwkKeyStore {
153175 . map_err ( |e| {
154176 ErrorCode :: InvalidConfig ( format ! ( "Failed to create jwks client: {}" , e) )
155177 } ) ?;
156- let response = client. get ( & self . url ) . send ( ) . await . map_err ( |e| {
157- ErrorCode :: AuthenticateFailure ( format ! ( "Could not download JWKS: {}" , e) )
158- } ) ?;
178+ let response = client
179+ . get ( & self . url )
180+ . send ( )
181+ . await
182+ . map_err ( |e| ErrorCode :: Internal ( format ! ( "Could not download JWKS: {}" , e) ) ) ?;
159183 let jwk_keys: JwkKeys = response
160184 . json ( )
161185 . await
@@ -168,16 +192,22 @@ impl JwkKeyStore {
168192 }
169193
170194 #[ async_backtrace:: framed]
171- async fn load_keys_with_cache ( & self , force : bool ) -> Result < HashMap < String , PubKey > > {
195+ async fn maybe_refresh_cached_keys ( & self , force : bool ) -> Result < ( ) > {
172196 let need_reload = force
173- || match * self . last_refreshed_at . read ( ) {
197+ || match * self . last_refreshed_time . read ( ) {
174198 None => true ,
175199 Some ( last_refreshed_at) => last_refreshed_at. elapsed ( ) > self . refresh_interval ,
176200 } ;
177201
178- let old_keys = self . cached_keys . read ( ) . clone ( ) ;
202+ let old_keys = self
203+ . recent_cached_maps
204+ . read ( )
205+ . iter ( )
206+ . last ( )
207+ . cloned ( )
208+ . unwrap_or ( HashMap :: new ( ) ) ;
179209 if !need_reload {
180- return Ok ( old_keys ) ;
210+ return Ok ( ( ) ) ;
181211 }
182212
183213 // if got network issues on loading JWKS, fallback to the cached keys if available
@@ -186,45 +216,81 @@ impl JwkKeyStore {
186216 Err ( err) => {
187217 warn ! ( "Failed to load JWKS: {}" , err) ;
188218 if !old_keys. is_empty ( ) {
189- return Ok ( old_keys ) ;
219+ return Ok ( ( ) ) ;
190220 }
191221 return Err ( err. add_message ( "failed to load JWKS keys, and no available fallback" ) ) ;
192222 }
193223 } ;
194224
195- // the JWKS keys are not always changes, but when it changed, we can have a log about this.
196- if !new_keys. keys ( ) . eq ( old_keys. keys ( ) ) {
197- info ! ( "JWKS keys changed." ) ;
225+ // if the new keys are empty, skip save it to the cache
226+ if new_keys. is_empty ( ) {
227+ warn ! ( "got empty JWKS keys, skip" ) ;
228+ return Ok ( ( ) ) ;
198229 }
199- * self . cached_keys . write ( ) = new_keys. clone ( ) ;
200- self . last_refreshed_at . write ( ) . replace ( Instant :: now ( ) ) ;
201- Ok ( new_keys)
230+
231+ // only update the cache when the keys are changed
232+ if new_keys. keys ( ) . eq ( old_keys. keys ( ) ) {
233+ return Ok ( ( ) ) ;
234+ }
235+ info ! ( "JWKS keys changed." ) ;
236+
237+ // append the new keys to the end of recent_cached_maps
238+ {
239+ let mut recent_cached_maps = self . recent_cached_maps . write ( ) ;
240+ recent_cached_maps. push_back ( new_keys) ;
241+ if recent_cached_maps. len ( ) > self . max_recent_cached_maps {
242+ recent_cached_maps. pop_front ( ) ;
243+ }
244+ }
245+ self . last_refreshed_time . write ( ) . replace ( Instant :: now ( ) ) ;
246+ Ok ( ( ) )
202247 }
203248
204249 #[ async_backtrace:: framed]
205- pub async fn get_key ( & self , key_id : Option < String > ) -> Result < PubKey > {
206- let keys = self . load_keys_with_cache ( false ) . await ?;
250+ pub async fn get_key ( & self , key_id : & Option < String > ) -> Result < Option < PubKey > > {
251+ self . maybe_refresh_cached_keys ( false ) . await ?;
207252
208253 // if the key_id is not set, and there is only one key in the store, return it
209- let key_id = match key_id {
210- Some ( key_id) => key_id,
254+ let key_id = match & key_id {
255+ Some ( key_id) => key_id. clone ( ) ,
211256 None => {
212- if keys. len ( ) != 1 {
257+ let cached_maps = self . recent_cached_maps . read ( ) ;
258+ let first_key = cached_maps
259+ . iter ( )
260+ . last ( )
261+ . and_then ( |keys| keys. iter ( ) . next ( ) ) ;
262+ if let Some ( ( _, pub_key) ) = first_key {
263+ return Ok ( Some ( pub_key. clone ( ) ) ) ;
264+ } else {
213265 return Err ( ErrorCode :: AuthenticateFailure (
214266 "must specify key_id for jwt when multi keys exists " ,
215267 ) ) ;
216- } else {
217- return Ok ( ( keys. iter ( ) . next ( ) . unwrap ( ) . 1 ) . clone ( ) ) ;
218268 }
219269 }
220270 } ;
221271
222- match keys. get ( & key_id) {
223- None => Err ( ErrorCode :: AuthenticateFailure ( format ! (
224- "key id {} not found in jwk store" ,
225- key_id
226- ) ) ) ,
227- Some ( key) => Ok ( key. clone ( ) ) ,
272+ // if the key is not found, try to refresh the keys and try again. this refresh only
273+ // happens once within retry interval (default 2s).
274+ for _ in 0 ..2 {
275+ for keys_map in self . recent_cached_maps . read ( ) . iter ( ) . rev ( ) {
276+ if let Some ( key) = keys_map. get ( & key_id) {
277+ return Ok ( Some ( key. clone ( ) ) ) ;
278+ }
279+ }
280+ let need_retry = match * self . last_retry_time . read ( ) {
281+ None => true ,
282+ Some ( last_retry_time) => last_retry_time. elapsed ( ) > self . retry_interval ,
283+ } ;
284+ if need_retry {
285+ warn ! (
286+ "key id {} not found in jwk store, try to peek the latest keys" ,
287+ key_id
288+ ) ;
289+ self . maybe_refresh_cached_keys ( true ) . await ?;
290+ * self . last_retry_time . write ( ) = Some ( Instant :: now ( ) ) ;
291+ }
228292 }
293+
294+ Ok ( None )
229295 }
230296}
0 commit comments