Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/query/users/src/jwt/authenticator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,15 @@ impl JwtAuthenticator {
) -> Result<JWTClaims<CustomClaims>> {
let metadata = Token::decode_metadata(token);
let key_id = metadata.map_or(None, |e| e.key_id().map(|s| s.to_string()));
let pub_key = key_store.get_key(key_id).await?;
let pub_key = match key_store.get_key(&key_id).await? {
None => {
return Err(ErrorCode::AuthenticateFailure(format!(
"key id {} not found in jwk store",
key_id.unwrap_or_default()
)));
}
Some(pk) => pk,
};
let r = match &pub_key {
PubKey::RSA256(pk) => pk.verify_token::<CustomClaims>(token, None),
PubKey::ES256(pk) => pk.verify_token::<CustomClaims>(token, None),
Expand All @@ -128,6 +136,7 @@ impl JwtAuthenticator {
Some(_) => Ok(c),
}
}

#[async_backtrace::framed]
pub async fn parse_jwt_claims(&self, token: &str) -> Result<JWTClaims<CustomClaims>> {
let mut combined_code = ErrorCode::AuthenticateFailure(
Expand Down
136 changes: 101 additions & 35 deletions src/query/users/src/jwt/jwk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
Expand Down Expand Up @@ -95,23 +96,34 @@ pub struct JwkKeys {
pub keys: Vec<JwkKey>,
}

/// [`JwkKeyStore`] is a store for JWKS keys, it will cache the keys for a while and refresh the
/// keys periodically. When the keys are refreshed, the older keys will still be kept for a while.
///
/// When the keys rotated in the client side first, the server will respond a 401 Authorization Failure
/// error, as the key is not found in the cache. We'll try to refresh the keys and try again.
pub struct JwkKeyStore {
pub(crate) url: String,
cached_keys: Arc<RwLock<HashMap<String, PubKey>>>,
pub(crate) last_refreshed_at: RwLock<Option<Instant>>,
pub(crate) refresh_interval: Duration,
pub(crate) refresh_timeout: Duration,
pub(crate) load_keys_func: Option<Arc<dyn Fn() -> HashMap<String, PubKey> + Send + Sync>>,
url: String,
recent_cached_maps: Arc<RwLock<VecDeque<HashMap<String, PubKey>>>>,
last_refreshed_time: RwLock<Option<Instant>>,
last_retry_time: RwLock<Option<Instant>>,
max_recent_cached_maps: usize,
refresh_interval: Duration,
refresh_timeout: Duration,
retry_interval: Duration,
load_keys_func: Option<Arc<dyn Fn() -> HashMap<String, PubKey> + Send + Sync>>,
}

impl JwkKeyStore {
pub fn new(url: String) -> Self {
Self {
url,
cached_keys: Arc::new(RwLock::new(HashMap::new())),
recent_cached_maps: Arc::new(RwLock::new(VecDeque::new())),
max_recent_cached_maps: 2,
refresh_interval: Duration::from_secs(JWKS_REFRESH_INTERVAL),
refresh_timeout: Duration::from_secs(JWKS_REFRESH_TIMEOUT),
last_refreshed_at: RwLock::new(None),
retry_interval: Duration::from_secs(2),
last_refreshed_time: RwLock::new(None),
last_retry_time: RwLock::new(None),
load_keys_func: None,
}
}
Expand All @@ -135,6 +147,16 @@ impl JwkKeyStore {
self
}

pub fn with_max_recent_cached_maps(mut self, max: usize) -> Self {
self.max_recent_cached_maps = max;
self
}

pub fn with_retry_interval(mut self, interval: u64) -> Self {
self.retry_interval = Duration::from_secs(interval);
self
}

pub fn url(&self) -> String {
self.url.clone()
}
Expand All @@ -153,9 +175,11 @@ impl JwkKeyStore {
.map_err(|e| {
ErrorCode::InvalidConfig(format!("Failed to create jwks client: {}", e))
})?;
let response = client.get(&self.url).send().await.map_err(|e| {
ErrorCode::AuthenticateFailure(format!("Could not download JWKS: {}", e))
})?;
let response = client
.get(&self.url)
.send()
.await
.map_err(|e| ErrorCode::Internal(format!("Could not download JWKS: {}", e)))?;
let jwk_keys: JwkKeys = response
.json()
.await
Expand All @@ -168,16 +192,22 @@ impl JwkKeyStore {
}

#[async_backtrace::framed]
async fn load_keys_with_cache(&self, force: bool) -> Result<HashMap<String, PubKey>> {
async fn maybe_refresh_cached_keys(&self, force: bool) -> Result<()> {
let need_reload = force
|| match *self.last_refreshed_at.read() {
|| match *self.last_refreshed_time.read() {
None => true,
Some(last_refreshed_at) => last_refreshed_at.elapsed() > self.refresh_interval,
};

let old_keys = self.cached_keys.read().clone();
let old_keys = self
.recent_cached_maps
.read()
.iter()
.last()
.cloned()
.unwrap_or(HashMap::new());
if !need_reload {
return Ok(old_keys);
return Ok(());
}

// if got network issues on loading JWKS, fallback to the cached keys if available
Expand All @@ -186,45 +216,81 @@ impl JwkKeyStore {
Err(err) => {
warn!("Failed to load JWKS: {}", err);
if !old_keys.is_empty() {
return Ok(old_keys);
return Ok(());
}
return Err(err.add_message("failed to load JWKS keys, and no available fallback"));
}
};

// the JWKS keys are not always changes, but when it changed, we can have a log about this.
if !new_keys.keys().eq(old_keys.keys()) {
info!("JWKS keys changed.");
// if the new keys are empty, skip save it to the cache
if new_keys.is_empty() {
warn!("got empty JWKS keys, skip");
return Ok(());
}
*self.cached_keys.write() = new_keys.clone();
self.last_refreshed_at.write().replace(Instant::now());
Ok(new_keys)

// only update the cache when the keys are changed
if new_keys.keys().eq(old_keys.keys()) {
return Ok(());
}
info!("JWKS keys changed.");

// append the new keys to the end of recent_cached_maps
{
let mut recent_cached_maps = self.recent_cached_maps.write();
recent_cached_maps.push_back(new_keys);
if recent_cached_maps.len() > self.max_recent_cached_maps {
recent_cached_maps.pop_front();
}
}
self.last_refreshed_time.write().replace(Instant::now());
Ok(())
}

#[async_backtrace::framed]
pub async fn get_key(&self, key_id: Option<String>) -> Result<PubKey> {
let keys = self.load_keys_with_cache(false).await?;
pub async fn get_key(&self, key_id: &Option<String>) -> Result<Option<PubKey>> {
self.maybe_refresh_cached_keys(false).await?;

// if the key_id is not set, and there is only one key in the store, return it
let key_id = match key_id {
Some(key_id) => key_id,
let key_id = match &key_id {
Some(key_id) => key_id.clone(),
None => {
if keys.len() != 1 {
let cached_maps = self.recent_cached_maps.read();
let first_key = cached_maps
.iter()
.last()
.and_then(|keys| keys.iter().next());
if let Some((_, pub_key)) = first_key {
return Ok(Some(pub_key.clone()));
} else {
return Err(ErrorCode::AuthenticateFailure(
"must specify key_id for jwt when multi keys exists ",
));
} else {
return Ok((keys.iter().next().unwrap().1).clone());
}
}
};

match keys.get(&key_id) {
None => Err(ErrorCode::AuthenticateFailure(format!(
"key id {} not found in jwk store",
key_id
))),
Some(key) => Ok(key.clone()),
// if the key is not found, try to refresh the keys and try again. this refresh only
// happens once within retry interval (default 2s).
for _ in 0..2 {
for keys_map in self.recent_cached_maps.read().iter().rev() {
if let Some(key) = keys_map.get(&key_id) {
return Ok(Some(key.clone()));
}
}
let need_retry = match *self.last_retry_time.read() {
None => true,
Some(last_retry_time) => last_retry_time.elapsed() > self.retry_interval,
};
if need_retry {
warn!(
"key id {} not found in jwk store, try to peek the latest keys",
key_id
);
self.maybe_refresh_cached_keys(true).await?;
*self.last_retry_time.write() = Some(Instant::now());
}
}

Ok(None)
}
}
99 changes: 99 additions & 0 deletions src/query/users/tests/it/jwt/jwk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use std::collections::HashMap;
use std::sync::Arc;

use databend_common_base::base::tokio;
use databend_common_exception::Result;
use databend_common_users::JwkKeyStore;
use databend_common_users::PubKey;
use jwt_simple::prelude::*;
use parking_lot::Mutex;

struct MockJwksLoader {
keys: Mutex<HashMap<String, PubKey>>,
}

impl MockJwksLoader {
fn new() -> Self {
Self {
keys: Mutex::new(HashMap::new()),
}
}

fn reset_keys(&self, key_names: &[&'static str]) {
let mut keys = self.keys.lock();
keys.clear();
for key_name in key_names {
keys.insert(
key_name.to_string(),
PubKey::RSA256(Box::new(RS256KeyPair::generate(2048).unwrap().public_key())),
);
}
}

fn load_keys(&self) -> HashMap<String, PubKey> {
self.keys.lock().clone()
}
}

#[tokio::test]
async fn test_jwk_store_with_random_keys() -> Result<()> {
let mock_jwks_loader = Arc::new(MockJwksLoader::new());
let mock_jwks_loader_cloned = mock_jwks_loader.clone();
let jwk_store = JwkKeyStore::new("jwks_key".to_string())
.with_load_keys_func(Arc::new(move || mock_jwks_loader_cloned.load_keys()))
.with_max_recent_cached_maps(2)
.with_retry_interval(0);

mock_jwks_loader.reset_keys(&["key1", "key2"]);
let key = jwk_store.get_key(&None).await?;
assert!(key.is_some());
let key = jwk_store.get_key(&Some("key1".to_string())).await?;
assert!(key.is_some());
let key = jwk_store.get_key(&Some("key3".to_string())).await?;
assert!(key.is_none());

mock_jwks_loader.reset_keys(&["key3", "key4"]);
let key = jwk_store.get_key(&Some("key3".to_string())).await?;
assert!(key.is_some());
let key = jwk_store.get_key(&Some("key4".to_string())).await?;
assert!(key.is_some());
let key = jwk_store.get_key(&Some("key5".to_string())).await?;
assert!(key.is_none());
let key = jwk_store.get_key(&Some("key1".to_string())).await?;
assert!(key.is_some());

mock_jwks_loader.reset_keys(&["key5", "key6"]);
let key = jwk_store.get_key(&Some("key5".to_string())).await?;
assert!(key.is_some());
let key = jwk_store.get_key(&Some("key6".to_string())).await?;
assert!(key.is_some());
let key = jwk_store.get_key(&Some("key1".to_string())).await?;
assert!(key.is_none());
Ok(())
}

#[tokio::test]
async fn test_jwk_store_with_random_keys_and_long_retry_interval() -> Result<()> {
let mock_jwks_loader = Arc::new(MockJwksLoader::new());
let mock_jwks_loader_cloned = mock_jwks_loader.clone();
let jwk_store = JwkKeyStore::new("jwks_key".to_string())
.with_load_keys_func(Arc::new(move || mock_jwks_loader_cloned.load_keys()))
.with_max_recent_cached_maps(2)
.with_retry_interval(3600);

mock_jwks_loader.reset_keys(&["key1", "key2"]);
let key = jwk_store.get_key(&None).await?;
assert!(key.is_some());
let key = jwk_store.get_key(&Some("key1".to_string())).await?;
assert!(key.is_some());
let key = jwk_store.get_key(&Some("key3".to_string())).await?;
assert!(key.is_none());

mock_jwks_loader.reset_keys(&["key3", "key4"]);
let key = jwk_store.get_key(&Some("key3".to_string())).await?;
assert!(key.is_none());
let key = jwk_store.get_key(&Some("key4".to_string())).await?;
assert!(key.is_none());

Ok(())
}
1 change: 1 addition & 0 deletions src/query/users/tests/it/jwt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
// limitations under the License.

mod authenticator;
mod jwk;