Skip to content

Commit 7d902d7

Browse files
authored
feat(auth): auto retry on jwks key not found (#17410)
* refactor: not expose fields in JwkKeyStore * feat: add auto retry about jwks * refactor: the loop * record the recent jwks keys * save * fix typo * prepare to add test * add ut * add tests * limit the lock scope
1 parent f56350e commit 7d902d7

File tree

4 files changed

+211
-36
lines changed

4 files changed

+211
-36
lines changed

src/query/users/src/jwt/authenticator.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,15 @@ impl JwtAuthenticator {
115115
) -> Result<JWTClaims<CustomClaims>> {
116116
let metadata = Token::decode_metadata(token);
117117
let key_id = metadata.map_or(None, |e| e.key_id().map(|s| s.to_string()));
118-
let pub_key = key_store.get_key(key_id).await?;
118+
let pub_key = match key_store.get_key(&key_id).await? {
119+
None => {
120+
return Err(ErrorCode::AuthenticateFailure(format!(
121+
"key id {} not found in jwk store",
122+
key_id.unwrap_or_default()
123+
)));
124+
}
125+
Some(pk) => pk,
126+
};
119127
let r = match &pub_key {
120128
PubKey::RSA256(pk) => pk.verify_token::<CustomClaims>(token, None),
121129
PubKey::ES256(pk) => pk.verify_token::<CustomClaims>(token, None),
@@ -128,6 +136,7 @@ impl JwtAuthenticator {
128136
Some(_) => Ok(c),
129137
}
130138
}
139+
131140
#[async_backtrace::framed]
132141
pub async fn parse_jwt_claims(&self, token: &str) -> Result<JWTClaims<CustomClaims>> {
133142
let mut combined_code = ErrorCode::AuthenticateFailure(

src/query/users/src/jwt/jwk.rs

Lines changed: 101 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
use std::collections::HashMap;
16+
use std::collections::VecDeque;
1617
use std::sync::Arc;
1718
use std::time::Duration;
1819
use std::time::Instant;
@@ -95,23 +96,34 @@ pub struct JwkKeys {
9596
pub keys: Vec<JwkKey>,
9697
}
9798

99+
/// [`JwkKeyStore`] is a store for JWKS keys, it will cache the keys for a while and refresh the
100+
/// keys periodically. When the keys are refreshed, the older keys will still be kept for a while.
101+
///
102+
/// When the keys rotated in the client side first, the server will respond a 401 Authorization Failure
103+
/// error, as the key is not found in the cache. We'll try to refresh the keys and try again.
98104
pub struct JwkKeyStore {
99-
pub(crate) url: String,
100-
cached_keys: Arc<RwLock<HashMap<String, PubKey>>>,
101-
pub(crate) last_refreshed_at: RwLock<Option<Instant>>,
102-
pub(crate) refresh_interval: Duration,
103-
pub(crate) refresh_timeout: Duration,
104-
pub(crate) load_keys_func: Option<Arc<dyn Fn() -> HashMap<String, PubKey> + Send + Sync>>,
105+
url: String,
106+
recent_cached_maps: Arc<RwLock<VecDeque<HashMap<String, PubKey>>>>,
107+
last_refreshed_time: RwLock<Option<Instant>>,
108+
last_retry_time: RwLock<Option<Instant>>,
109+
max_recent_cached_maps: usize,
110+
refresh_interval: Duration,
111+
refresh_timeout: Duration,
112+
retry_interval: Duration,
113+
load_keys_func: Option<Arc<dyn Fn() -> HashMap<String, PubKey> + Send + Sync>>,
105114
}
106115

107116
impl JwkKeyStore {
108117
pub fn new(url: String) -> Self {
109118
Self {
110119
url,
111-
cached_keys: Arc::new(RwLock::new(HashMap::new())),
120+
recent_cached_maps: Arc::new(RwLock::new(VecDeque::new())),
121+
max_recent_cached_maps: 2,
112122
refresh_interval: Duration::from_secs(JWKS_REFRESH_INTERVAL),
113123
refresh_timeout: Duration::from_secs(JWKS_REFRESH_TIMEOUT),
114-
last_refreshed_at: RwLock::new(None),
124+
retry_interval: Duration::from_secs(2),
125+
last_refreshed_time: RwLock::new(None),
126+
last_retry_time: RwLock::new(None),
115127
load_keys_func: None,
116128
}
117129
}
@@ -135,6 +147,16 @@ impl JwkKeyStore {
135147
self
136148
}
137149

150+
pub fn with_max_recent_cached_maps(mut self, max: usize) -> Self {
151+
self.max_recent_cached_maps = max;
152+
self
153+
}
154+
155+
pub fn with_retry_interval(mut self, interval: u64) -> Self {
156+
self.retry_interval = Duration::from_secs(interval);
157+
self
158+
}
159+
138160
pub fn url(&self) -> String {
139161
self.url.clone()
140162
}
@@ -153,9 +175,11 @@ impl JwkKeyStore {
153175
.map_err(|e| {
154176
ErrorCode::InvalidConfig(format!("Failed to create jwks client: {}", e))
155177
})?;
156-
let response = client.get(&self.url).send().await.map_err(|e| {
157-
ErrorCode::AuthenticateFailure(format!("Could not download JWKS: {}", e))
158-
})?;
178+
let response = client
179+
.get(&self.url)
180+
.send()
181+
.await
182+
.map_err(|e| ErrorCode::Internal(format!("Could not download JWKS: {}", e)))?;
159183
let jwk_keys: JwkKeys = response
160184
.json()
161185
.await
@@ -168,16 +192,22 @@ impl JwkKeyStore {
168192
}
169193

170194
#[async_backtrace::framed]
171-
async fn load_keys_with_cache(&self, force: bool) -> Result<HashMap<String, PubKey>> {
195+
async fn maybe_refresh_cached_keys(&self, force: bool) -> Result<()> {
172196
let need_reload = force
173-
|| match *self.last_refreshed_at.read() {
197+
|| match *self.last_refreshed_time.read() {
174198
None => true,
175199
Some(last_refreshed_at) => last_refreshed_at.elapsed() > self.refresh_interval,
176200
};
177201

178-
let old_keys = self.cached_keys.read().clone();
202+
let old_keys = self
203+
.recent_cached_maps
204+
.read()
205+
.iter()
206+
.last()
207+
.cloned()
208+
.unwrap_or(HashMap::new());
179209
if !need_reload {
180-
return Ok(old_keys);
210+
return Ok(());
181211
}
182212

183213
// if got network issues on loading JWKS, fallback to the cached keys if available
@@ -186,45 +216,81 @@ impl JwkKeyStore {
186216
Err(err) => {
187217
warn!("Failed to load JWKS: {}", err);
188218
if !old_keys.is_empty() {
189-
return Ok(old_keys);
219+
return Ok(());
190220
}
191221
return Err(err.add_message("failed to load JWKS keys, and no available fallback"));
192222
}
193223
};
194224

195-
// the JWKS keys are not always changes, but when it changed, we can have a log about this.
196-
if !new_keys.keys().eq(old_keys.keys()) {
197-
info!("JWKS keys changed.");
225+
// if the new keys are empty, skip save it to the cache
226+
if new_keys.is_empty() {
227+
warn!("got empty JWKS keys, skip");
228+
return Ok(());
198229
}
199-
*self.cached_keys.write() = new_keys.clone();
200-
self.last_refreshed_at.write().replace(Instant::now());
201-
Ok(new_keys)
230+
231+
// only update the cache when the keys are changed
232+
if new_keys.keys().eq(old_keys.keys()) {
233+
return Ok(());
234+
}
235+
info!("JWKS keys changed.");
236+
237+
// append the new keys to the end of recent_cached_maps
238+
{
239+
let mut recent_cached_maps = self.recent_cached_maps.write();
240+
recent_cached_maps.push_back(new_keys);
241+
if recent_cached_maps.len() > self.max_recent_cached_maps {
242+
recent_cached_maps.pop_front();
243+
}
244+
}
245+
self.last_refreshed_time.write().replace(Instant::now());
246+
Ok(())
202247
}
203248

204249
#[async_backtrace::framed]
205-
pub async fn get_key(&self, key_id: Option<String>) -> Result<PubKey> {
206-
let keys = self.load_keys_with_cache(false).await?;
250+
pub async fn get_key(&self, key_id: &Option<String>) -> Result<Option<PubKey>> {
251+
self.maybe_refresh_cached_keys(false).await?;
207252

208253
// if the key_id is not set, and there is only one key in the store, return it
209-
let key_id = match key_id {
210-
Some(key_id) => key_id,
254+
let key_id = match &key_id {
255+
Some(key_id) => key_id.clone(),
211256
None => {
212-
if keys.len() != 1 {
257+
let cached_maps = self.recent_cached_maps.read();
258+
let first_key = cached_maps
259+
.iter()
260+
.last()
261+
.and_then(|keys| keys.iter().next());
262+
if let Some((_, pub_key)) = first_key {
263+
return Ok(Some(pub_key.clone()));
264+
} else {
213265
return Err(ErrorCode::AuthenticateFailure(
214266
"must specify key_id for jwt when multi keys exists ",
215267
));
216-
} else {
217-
return Ok((keys.iter().next().unwrap().1).clone());
218268
}
219269
}
220270
};
221271

222-
match keys.get(&key_id) {
223-
None => Err(ErrorCode::AuthenticateFailure(format!(
224-
"key id {} not found in jwk store",
225-
key_id
226-
))),
227-
Some(key) => Ok(key.clone()),
272+
// if the key is not found, try to refresh the keys and try again. this refresh only
273+
// happens once within retry interval (default 2s).
274+
for _ in 0..2 {
275+
for keys_map in self.recent_cached_maps.read().iter().rev() {
276+
if let Some(key) = keys_map.get(&key_id) {
277+
return Ok(Some(key.clone()));
278+
}
279+
}
280+
let need_retry = match *self.last_retry_time.read() {
281+
None => true,
282+
Some(last_retry_time) => last_retry_time.elapsed() > self.retry_interval,
283+
};
284+
if need_retry {
285+
warn!(
286+
"key id {} not found in jwk store, try to peek the latest keys",
287+
key_id
288+
);
289+
self.maybe_refresh_cached_keys(true).await?;
290+
*self.last_retry_time.write() = Some(Instant::now());
291+
}
228292
}
293+
294+
Ok(None)
229295
}
230296
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
use std::collections::HashMap;
2+
use std::sync::Arc;
3+
4+
use databend_common_base::base::tokio;
5+
use databend_common_exception::Result;
6+
use databend_common_users::JwkKeyStore;
7+
use databend_common_users::PubKey;
8+
use jwt_simple::prelude::*;
9+
use parking_lot::Mutex;
10+
11+
struct MockJwksLoader {
12+
keys: Mutex<HashMap<String, PubKey>>,
13+
}
14+
15+
impl MockJwksLoader {
16+
fn new() -> Self {
17+
Self {
18+
keys: Mutex::new(HashMap::new()),
19+
}
20+
}
21+
22+
fn reset_keys(&self, key_names: &[&'static str]) {
23+
let mut keys = self.keys.lock();
24+
keys.clear();
25+
for key_name in key_names {
26+
keys.insert(
27+
key_name.to_string(),
28+
PubKey::RSA256(Box::new(RS256KeyPair::generate(2048).unwrap().public_key())),
29+
);
30+
}
31+
}
32+
33+
fn load_keys(&self) -> HashMap<String, PubKey> {
34+
self.keys.lock().clone()
35+
}
36+
}
37+
38+
#[tokio::test]
39+
async fn test_jwk_store_with_random_keys() -> Result<()> {
40+
let mock_jwks_loader = Arc::new(MockJwksLoader::new());
41+
let mock_jwks_loader_cloned = mock_jwks_loader.clone();
42+
let jwk_store = JwkKeyStore::new("jwks_key".to_string())
43+
.with_load_keys_func(Arc::new(move || mock_jwks_loader_cloned.load_keys()))
44+
.with_max_recent_cached_maps(2)
45+
.with_retry_interval(0);
46+
47+
mock_jwks_loader.reset_keys(&["key1", "key2"]);
48+
let key = jwk_store.get_key(&None).await?;
49+
assert!(key.is_some());
50+
let key = jwk_store.get_key(&Some("key1".to_string())).await?;
51+
assert!(key.is_some());
52+
let key = jwk_store.get_key(&Some("key3".to_string())).await?;
53+
assert!(key.is_none());
54+
55+
mock_jwks_loader.reset_keys(&["key3", "key4"]);
56+
let key = jwk_store.get_key(&Some("key3".to_string())).await?;
57+
assert!(key.is_some());
58+
let key = jwk_store.get_key(&Some("key4".to_string())).await?;
59+
assert!(key.is_some());
60+
let key = jwk_store.get_key(&Some("key5".to_string())).await?;
61+
assert!(key.is_none());
62+
let key = jwk_store.get_key(&Some("key1".to_string())).await?;
63+
assert!(key.is_some());
64+
65+
mock_jwks_loader.reset_keys(&["key5", "key6"]);
66+
let key = jwk_store.get_key(&Some("key5".to_string())).await?;
67+
assert!(key.is_some());
68+
let key = jwk_store.get_key(&Some("key6".to_string())).await?;
69+
assert!(key.is_some());
70+
let key = jwk_store.get_key(&Some("key1".to_string())).await?;
71+
assert!(key.is_none());
72+
Ok(())
73+
}
74+
75+
#[tokio::test]
76+
async fn test_jwk_store_with_random_keys_and_long_retry_interval() -> Result<()> {
77+
let mock_jwks_loader = Arc::new(MockJwksLoader::new());
78+
let mock_jwks_loader_cloned = mock_jwks_loader.clone();
79+
let jwk_store = JwkKeyStore::new("jwks_key".to_string())
80+
.with_load_keys_func(Arc::new(move || mock_jwks_loader_cloned.load_keys()))
81+
.with_max_recent_cached_maps(2)
82+
.with_retry_interval(3600);
83+
84+
mock_jwks_loader.reset_keys(&["key1", "key2"]);
85+
let key = jwk_store.get_key(&None).await?;
86+
assert!(key.is_some());
87+
let key = jwk_store.get_key(&Some("key1".to_string())).await?;
88+
assert!(key.is_some());
89+
let key = jwk_store.get_key(&Some("key3".to_string())).await?;
90+
assert!(key.is_none());
91+
92+
mock_jwks_loader.reset_keys(&["key3", "key4"]);
93+
let key = jwk_store.get_key(&Some("key3".to_string())).await?;
94+
assert!(key.is_none());
95+
let key = jwk_store.get_key(&Some("key4".to_string())).await?;
96+
assert!(key.is_none());
97+
98+
Ok(())
99+
}

src/query/users/tests/it/jwt/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
// limitations under the License.
1414

1515
mod authenticator;
16+
mod jwk;

0 commit comments

Comments
 (0)