diff --git a/src/query/users/src/jwt/authenticator.rs b/src/query/users/src/jwt/authenticator.rs index c3802f6e97831..245ffec3650b3 100644 --- a/src/query/users/src/jwt/authenticator.rs +++ b/src/query/users/src/jwt/authenticator.rs @@ -115,7 +115,15 @@ impl JwtAuthenticator { ) -> Result> { let metadata = Token::decode_metadata(token); let key_id = metadata.map_or(None, |e| e.key_id().map(|s| s.to_string())); - let pub_key = key_store.get_key(key_id).await?; + let pub_key = match key_store.get_key(&key_id).await? { + None => { + return Err(ErrorCode::AuthenticateFailure(format!( + "key id {} not found in jwk store", + key_id.unwrap_or_default() + ))); + } + Some(pk) => pk, + }; let r = match &pub_key { PubKey::RSA256(pk) => pk.verify_token::(token, None), PubKey::ES256(pk) => pk.verify_token::(token, None), @@ -128,6 +136,7 @@ impl JwtAuthenticator { Some(_) => Ok(c), } } + #[async_backtrace::framed] pub async fn parse_jwt_claims(&self, token: &str) -> Result> { let mut combined_code = ErrorCode::AuthenticateFailure( diff --git a/src/query/users/src/jwt/jwk.rs b/src/query/users/src/jwt/jwk.rs index 00495492ef9ff..97cd6a09e994e 100644 --- a/src/query/users/src/jwt/jwk.rs +++ b/src/query/users/src/jwt/jwk.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::collections::VecDeque; use std::sync::Arc; use std::time::Duration; use std::time::Instant; @@ -95,23 +96,34 @@ pub struct JwkKeys { pub keys: Vec, } +/// [`JwkKeyStore`] is a store for JWKS keys, it will cache the keys for a while and refresh the +/// keys periodically. When the keys are refreshed, the older keys will still be kept for a while. +/// +/// When the keys rotated in the client side first, the server will respond a 401 Authorization Failure +/// error, as the key is not found in the cache. We'll try to refresh the keys and try again. pub struct JwkKeyStore { - pub(crate) url: String, - cached_keys: Arc>>, - pub(crate) last_refreshed_at: RwLock>, - pub(crate) refresh_interval: Duration, - pub(crate) refresh_timeout: Duration, - pub(crate) load_keys_func: Option HashMap + Send + Sync>>, + url: String, + recent_cached_maps: Arc>>>, + last_refreshed_time: RwLock>, + last_retry_time: RwLock>, + max_recent_cached_maps: usize, + refresh_interval: Duration, + refresh_timeout: Duration, + retry_interval: Duration, + load_keys_func: Option HashMap + Send + Sync>>, } impl JwkKeyStore { pub fn new(url: String) -> Self { Self { url, - cached_keys: Arc::new(RwLock::new(HashMap::new())), + recent_cached_maps: Arc::new(RwLock::new(VecDeque::new())), + max_recent_cached_maps: 2, refresh_interval: Duration::from_secs(JWKS_REFRESH_INTERVAL), refresh_timeout: Duration::from_secs(JWKS_REFRESH_TIMEOUT), - last_refreshed_at: RwLock::new(None), + retry_interval: Duration::from_secs(2), + last_refreshed_time: RwLock::new(None), + last_retry_time: RwLock::new(None), load_keys_func: None, } } @@ -135,6 +147,16 @@ impl JwkKeyStore { self } + pub fn with_max_recent_cached_maps(mut self, max: usize) -> Self { + self.max_recent_cached_maps = max; + self + } + + pub fn with_retry_interval(mut self, interval: u64) -> Self { + self.retry_interval = Duration::from_secs(interval); + self + } + pub fn url(&self) -> String { self.url.clone() } @@ -153,9 +175,11 @@ impl JwkKeyStore { .map_err(|e| { ErrorCode::InvalidConfig(format!("Failed to create jwks client: {}", e)) })?; - let response = client.get(&self.url).send().await.map_err(|e| { - ErrorCode::AuthenticateFailure(format!("Could not download JWKS: {}", e)) - })?; + let response = client + .get(&self.url) + .send() + .await + .map_err(|e| ErrorCode::Internal(format!("Could not download JWKS: {}", e)))?; let jwk_keys: JwkKeys = response .json() .await @@ -168,16 +192,22 @@ impl JwkKeyStore { } #[async_backtrace::framed] - async fn load_keys_with_cache(&self, force: bool) -> Result> { + async fn maybe_refresh_cached_keys(&self, force: bool) -> Result<()> { let need_reload = force - || match *self.last_refreshed_at.read() { + || match *self.last_refreshed_time.read() { None => true, Some(last_refreshed_at) => last_refreshed_at.elapsed() > self.refresh_interval, }; - let old_keys = self.cached_keys.read().clone(); + let old_keys = self + .recent_cached_maps + .read() + .iter() + .last() + .cloned() + .unwrap_or(HashMap::new()); if !need_reload { - return Ok(old_keys); + return Ok(()); } // if got network issues on loading JWKS, fallback to the cached keys if available @@ -186,45 +216,81 @@ impl JwkKeyStore { Err(err) => { warn!("Failed to load JWKS: {}", err); if !old_keys.is_empty() { - return Ok(old_keys); + return Ok(()); } return Err(err.add_message("failed to load JWKS keys, and no available fallback")); } }; - // the JWKS keys are not always changes, but when it changed, we can have a log about this. - if !new_keys.keys().eq(old_keys.keys()) { - info!("JWKS keys changed."); + // if the new keys are empty, skip save it to the cache + if new_keys.is_empty() { + warn!("got empty JWKS keys, skip"); + return Ok(()); } - *self.cached_keys.write() = new_keys.clone(); - self.last_refreshed_at.write().replace(Instant::now()); - Ok(new_keys) + + // only update the cache when the keys are changed + if new_keys.keys().eq(old_keys.keys()) { + return Ok(()); + } + info!("JWKS keys changed."); + + // append the new keys to the end of recent_cached_maps + { + let mut recent_cached_maps = self.recent_cached_maps.write(); + recent_cached_maps.push_back(new_keys); + if recent_cached_maps.len() > self.max_recent_cached_maps { + recent_cached_maps.pop_front(); + } + } + self.last_refreshed_time.write().replace(Instant::now()); + Ok(()) } #[async_backtrace::framed] - pub async fn get_key(&self, key_id: Option) -> Result { - let keys = self.load_keys_with_cache(false).await?; + pub async fn get_key(&self, key_id: &Option) -> Result> { + self.maybe_refresh_cached_keys(false).await?; // if the key_id is not set, and there is only one key in the store, return it - let key_id = match key_id { - Some(key_id) => key_id, + let key_id = match &key_id { + Some(key_id) => key_id.clone(), None => { - if keys.len() != 1 { + let cached_maps = self.recent_cached_maps.read(); + let first_key = cached_maps + .iter() + .last() + .and_then(|keys| keys.iter().next()); + if let Some((_, pub_key)) = first_key { + return Ok(Some(pub_key.clone())); + } else { return Err(ErrorCode::AuthenticateFailure( "must specify key_id for jwt when multi keys exists ", )); - } else { - return Ok((keys.iter().next().unwrap().1).clone()); } } }; - match keys.get(&key_id) { - None => Err(ErrorCode::AuthenticateFailure(format!( - "key id {} not found in jwk store", - key_id - ))), - Some(key) => Ok(key.clone()), + // if the key is not found, try to refresh the keys and try again. this refresh only + // happens once within retry interval (default 2s). + for _ in 0..2 { + for keys_map in self.recent_cached_maps.read().iter().rev() { + if let Some(key) = keys_map.get(&key_id) { + return Ok(Some(key.clone())); + } + } + let need_retry = match *self.last_retry_time.read() { + None => true, + Some(last_retry_time) => last_retry_time.elapsed() > self.retry_interval, + }; + if need_retry { + warn!( + "key id {} not found in jwk store, try to peek the latest keys", + key_id + ); + self.maybe_refresh_cached_keys(true).await?; + *self.last_retry_time.write() = Some(Instant::now()); + } } + + Ok(None) } } diff --git a/src/query/users/tests/it/jwt/jwk.rs b/src/query/users/tests/it/jwt/jwk.rs new file mode 100644 index 0000000000000..bcd2e43337ce5 --- /dev/null +++ b/src/query/users/tests/it/jwt/jwk.rs @@ -0,0 +1,99 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use databend_common_base::base::tokio; +use databend_common_exception::Result; +use databend_common_users::JwkKeyStore; +use databend_common_users::PubKey; +use jwt_simple::prelude::*; +use parking_lot::Mutex; + +struct MockJwksLoader { + keys: Mutex>, +} + +impl MockJwksLoader { + fn new() -> Self { + Self { + keys: Mutex::new(HashMap::new()), + } + } + + fn reset_keys(&self, key_names: &[&'static str]) { + let mut keys = self.keys.lock(); + keys.clear(); + for key_name in key_names { + keys.insert( + key_name.to_string(), + PubKey::RSA256(Box::new(RS256KeyPair::generate(2048).unwrap().public_key())), + ); + } + } + + fn load_keys(&self) -> HashMap { + self.keys.lock().clone() + } +} + +#[tokio::test] +async fn test_jwk_store_with_random_keys() -> Result<()> { + let mock_jwks_loader = Arc::new(MockJwksLoader::new()); + let mock_jwks_loader_cloned = mock_jwks_loader.clone(); + let jwk_store = JwkKeyStore::new("jwks_key".to_string()) + .with_load_keys_func(Arc::new(move || mock_jwks_loader_cloned.load_keys())) + .with_max_recent_cached_maps(2) + .with_retry_interval(0); + + mock_jwks_loader.reset_keys(&["key1", "key2"]); + let key = jwk_store.get_key(&None).await?; + assert!(key.is_some()); + let key = jwk_store.get_key(&Some("key1".to_string())).await?; + assert!(key.is_some()); + let key = jwk_store.get_key(&Some("key3".to_string())).await?; + assert!(key.is_none()); + + mock_jwks_loader.reset_keys(&["key3", "key4"]); + let key = jwk_store.get_key(&Some("key3".to_string())).await?; + assert!(key.is_some()); + let key = jwk_store.get_key(&Some("key4".to_string())).await?; + assert!(key.is_some()); + let key = jwk_store.get_key(&Some("key5".to_string())).await?; + assert!(key.is_none()); + let key = jwk_store.get_key(&Some("key1".to_string())).await?; + assert!(key.is_some()); + + mock_jwks_loader.reset_keys(&["key5", "key6"]); + let key = jwk_store.get_key(&Some("key5".to_string())).await?; + assert!(key.is_some()); + let key = jwk_store.get_key(&Some("key6".to_string())).await?; + assert!(key.is_some()); + let key = jwk_store.get_key(&Some("key1".to_string())).await?; + assert!(key.is_none()); + Ok(()) +} + +#[tokio::test] +async fn test_jwk_store_with_random_keys_and_long_retry_interval() -> Result<()> { + let mock_jwks_loader = Arc::new(MockJwksLoader::new()); + let mock_jwks_loader_cloned = mock_jwks_loader.clone(); + let jwk_store = JwkKeyStore::new("jwks_key".to_string()) + .with_load_keys_func(Arc::new(move || mock_jwks_loader_cloned.load_keys())) + .with_max_recent_cached_maps(2) + .with_retry_interval(3600); + + mock_jwks_loader.reset_keys(&["key1", "key2"]); + let key = jwk_store.get_key(&None).await?; + assert!(key.is_some()); + let key = jwk_store.get_key(&Some("key1".to_string())).await?; + assert!(key.is_some()); + let key = jwk_store.get_key(&Some("key3".to_string())).await?; + assert!(key.is_none()); + + mock_jwks_loader.reset_keys(&["key3", "key4"]); + let key = jwk_store.get_key(&Some("key3".to_string())).await?; + assert!(key.is_none()); + let key = jwk_store.get_key(&Some("key4".to_string())).await?; + assert!(key.is_none()); + + Ok(()) +} diff --git a/src/query/users/tests/it/jwt/mod.rs b/src/query/users/tests/it/jwt/mod.rs index d722ef05cd4c9..fba02366edb59 100644 --- a/src/query/users/tests/it/jwt/mod.rs +++ b/src/query/users/tests/it/jwt/mod.rs @@ -13,3 +13,4 @@ // limitations under the License. mod authenticator; +mod jwk;