Skip to content
This repository was archived by the owner on Jun 12, 2024. It is now read-only.

Commit db1a120

Browse files
authored
feat(CON-4069): add in-memory cache for JWT public keys (#184)
Now it's possible to maintain an in-memory cache of multiple public keys in order to use them for JWT signature validation.
1 parent b2e2dc0 commit db1a120

File tree

2 files changed

+713
-0
lines changed

2 files changed

+713
-0
lines changed

pkg/tokens/public_keys.go

Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
package tokens
2+
3+
import (
4+
"context"
5+
"crypto/rsa"
6+
"crypto/sha512"
7+
"encoding/hex"
8+
"io/fs"
9+
"os"
10+
"path"
11+
"strings"
12+
"sync"
13+
"time"
14+
15+
"github.com/golang-jwt/jwt/v4"
16+
"github.com/pkg/errors"
17+
"github.com/sirupsen/logrus"
18+
)
19+
20+
var (
21+
// ErrKeyNotFound occurs when the key function cannot find a key in the cache
22+
ErrKeyNotFound = errors.New("specified key not found")
23+
// ErrMalformedKeyID occurs when the `KeyIDHeaderName` value in JWT header is absent or has a wrong type
24+
ErrMalformedKeyID = errors.New("malformed key ID in the JWT header")
25+
// ErrUnsupportedSigningMethod occurs when a JWT header specifies an unsupported signing method
26+
ErrUnsupportedSigningMethod = errors.New("signing method is not supported")
27+
)
28+
29+
const (
30+
// KeyIDHeaderName is the expected header name in a JWT token
31+
KeyIDHeaderName = "kid"
32+
)
33+
34+
type keyEntry struct {
35+
// Filename is the filename the key was loaded from
36+
Filename string
37+
// ModTime is the last modification timstamp of the file
38+
ModTime time.Time
39+
// Size is the file size
40+
Size int64
41+
// Hash is hex encoded SHA512/256 hash from the file content
42+
Hash string
43+
// Key is an RSA public key ready to be used for JWT signature validation
44+
Key *rsa.PublicKey
45+
}
46+
47+
// PublicKeyMap defines operations on the map of public keys used for JWT validation
48+
type PublicKeyMap interface {
49+
// MaintainCache runs a synchronization loop that reads the public keys directory
50+
// and refreshes the in-memory cache for quick access.
51+
MaintainCache(ctx context.Context, interval time.Duration) error
52+
// KeyFunction is a key function that can be used in the JWT library
53+
KeyFunction(token *jwt.Token) (interface{}, error)
54+
}
55+
56+
// NewPublicKeyMapWithFS returns a public key map for a given directory path in the given FS
57+
func NewPublicKeyMapWithFS(fileSys fs.FS, directoryPath string) (PublicKeyMap, error) {
58+
m := &publicKeyMap{
59+
rw: &sync.RWMutex{},
60+
directoryPath: directoryPath,
61+
fileSys: fileSys,
62+
}
63+
return m, m.init()
64+
}
65+
66+
// NewPublicKeyMap returns a public key map for a given directory path
67+
func NewPublicKeyMap(directoryPath string) (PublicKeyMap, error) {
68+
return NewPublicKeyMapWithFS(os.DirFS("/"), strings.TrimLeft(directoryPath, "/"))
69+
}
70+
71+
type publicKeyMap struct {
72+
keysByHashes map[string]*keyEntry
73+
fileSys fs.FS
74+
directoryPath string
75+
rw *sync.RWMutex
76+
}
77+
78+
func (m *publicKeyMap) MaintainCache(ctx context.Context, interval time.Duration) (err error) {
79+
ticker := time.NewTicker(interval)
80+
defer ticker.Stop()
81+
82+
for range ticker.C {
83+
// the only error we should expect is the context cancellation
84+
// the rest of the errors are just logged
85+
err = m.sync(ctx)
86+
if err != nil {
87+
return err
88+
}
89+
}
90+
91+
return nil
92+
}
93+
94+
func (m *publicKeyMap) lookup(hash string) (entry *keyEntry, ok bool) {
95+
m.rw.RLock()
96+
entry, ok = m.keysByHashes[hash]
97+
m.rw.RUnlock()
98+
return entry, ok
99+
}
100+
101+
func (m *publicKeyMap) KeyFunction(token *jwt.Token) (interface{}, error) {
102+
_, ok := token.Method.(*jwt.SigningMethodRSA)
103+
if !ok {
104+
return nil, errors.Wrapf(
105+
ErrUnsupportedSigningMethod,
106+
"signing method: %v",
107+
token.Header["alg"],
108+
)
109+
}
110+
kid, ok := token.Header[KeyIDHeaderName].(string)
111+
if !ok {
112+
return nil, errors.Wrapf(ErrMalformedKeyID, "%s=%+v", KeyIDHeaderName, token.Header[KeyIDHeaderName])
113+
}
114+
115+
entry, ok := m.lookup(kid)
116+
if !ok {
117+
return nil, errors.Wrapf(ErrKeyNotFound, "%s=%s", KeyIDHeaderName, kid)
118+
}
119+
120+
return entry.Key, nil
121+
}
122+
123+
func (m *publicKeyMap) init() (err error) {
124+
files, err := fs.ReadDir(m.fileSys, m.directoryPath)
125+
if err != nil {
126+
return errors.Wrapf(err, "failed to read directory with keys: %s", m.directoryPath)
127+
}
128+
129+
keysByHashes := make(map[string]*keyEntry, len(files))
130+
131+
for _, file := range files {
132+
if !file.Type().IsRegular() {
133+
continue
134+
}
135+
key, err := m.fileToKeyEntry(file)
136+
if err != nil {
137+
logrus.
138+
WithField("file", file.Name()).
139+
WithError(err).
140+
Warn("failed to read the public key, skipped")
141+
continue
142+
}
143+
144+
keysByHashes[key.Hash] = key
145+
}
146+
147+
keysAdded := len(keysByHashes)
148+
149+
m.rw.Lock()
150+
m.keysByHashes = keysByHashes
151+
m.rw.Unlock()
152+
153+
logrus.
154+
WithField("added", keysAdded).
155+
Debug("public keys cache has been initialized.")
156+
157+
return nil
158+
}
159+
160+
func (m *publicKeyMap) clear() {
161+
m.rw.RLock()
162+
m.keysByHashes = map[string]*keyEntry{}
163+
m.rw.RUnlock()
164+
}
165+
166+
func (m *publicKeyMap) sync(ctx context.Context) (err error) {
167+
// sync algorithm:
168+
169+
// 1. ReadLock
170+
// 2. Clone `keysByHashes` into `currentKeys` which is map of keys but their filenames
171+
// 3. Unlock
172+
// 4. Initialize `deletes := map[string]struct{}` of known filenames (keys of `currentKeys` map)
173+
// 5. Define counters `keysAdded`, `keysKept` and `keysUpdated`
174+
// 6. Read files (only first level) in the given directory using fs.ReadDir (not recursive) on each file:
175+
// 1. delete(deletes, filename) – we mark a seen file, no need to delete it
176+
// 2. for a file that has a known filename, matching modtime and size do nothing and continue to the next file. We don't check the hashes, it's too expensive to do for each file
177+
// 3. for a known filename but not matching properties we try load a public key and store it in `currentKeys`, increment `keysUpdated`
178+
// 4. for a new file try to load a public key and compute file's hash, store into `currentKeys`, increment `keysAdded`
179+
// 7. Delete from `currentKeys` those files that are left in `deletes` set.
180+
// 8. Build an updated `currentKeyHashes := map[string]*KeyEntry` a map of hashes to key entries.
181+
// 9. WriteLock
182+
// 10. Replace `keysByHashes` with `currentKeyHashes`
183+
// 11. Unlock
184+
185+
err = ctx.Err()
186+
if err != nil {
187+
return err
188+
}
189+
190+
logrus.Debug("updating the public keys cache...")
191+
192+
logrus.
193+
WithField("path", m.directoryPath).
194+
Debug("reading the keys directory...")
195+
196+
// first try if even can read the directory
197+
files, err := fs.ReadDir(m.fileSys, m.directoryPath)
198+
if err != nil {
199+
logrus.
200+
WithField("path", m.directoryPath).
201+
WithError(err).
202+
Error("failed to read directory with keys, clearing the cache...")
203+
204+
m.clear()
205+
206+
logrus.
207+
WithField("path", m.directoryPath).
208+
Debug("cache cleared.")
209+
210+
return nil
211+
}
212+
213+
if len(files) == 0 {
214+
logrus.
215+
WithField("path", m.directoryPath).
216+
Debug("no keys have been found, clearing the cache...")
217+
218+
m.clear()
219+
220+
logrus.
221+
WithField("path", m.directoryPath).
222+
Debug("cache cleared.")
223+
224+
return nil
225+
}
226+
227+
logrus.
228+
WithField("path", m.directoryPath).
229+
WithField("file_count", len(files)).
230+
Debug("key files have been found.")
231+
232+
err = ctx.Err()
233+
if err != nil {
234+
return err
235+
}
236+
237+
// we keep the lock time very short, and we don't expect too many keys
238+
239+
m.rw.RLock()
240+
241+
// Clone `keysByHashes` into `currentKeys` which is map of keys but their filenames
242+
// Initialize `deletes := map[string]struct{}` of known filenames
243+
currentKeys := make(map[string]*keyEntry, len(m.keysByHashes))
244+
deletes := make(map[string]struct{}, len(m.keysByHashes))
245+
for _, entry := range m.keysByHashes {
246+
currentKeys[entry.Filename] = entry
247+
deletes[entry.Filename] = struct{}{}
248+
}
249+
250+
m.rw.RUnlock()
251+
252+
var (
253+
keysAdded, keysUpdated, keysKept int
254+
)
255+
256+
for _, file := range files {
257+
err = ctx.Err()
258+
if err != nil {
259+
return err
260+
}
261+
if !file.Type().IsRegular() {
262+
continue
263+
}
264+
265+
filename := file.Name()
266+
267+
knownKey, exists := currentKeys[filename]
268+
269+
// for a file that has a known filename, matching modtime and size
270+
// do nothing and continue to the next file.
271+
// We don't check the hashes, it's too expensive to do for each file
272+
if exists {
273+
fileInfo, err := file.Info()
274+
if err != nil {
275+
logrus.
276+
WithField("file", file.Name()).
277+
WithError(err).
278+
Warn("failed to compare file change, skipped")
279+
continue
280+
}
281+
282+
if fileInfo.ModTime() == knownKey.ModTime && fileInfo.Size() == knownKey.Size {
283+
// mark the key as valid, so it's not deleted later
284+
delete(deletes, filename)
285+
continue
286+
}
287+
}
288+
289+
key, err := m.fileToKeyEntry(file)
290+
if err != nil {
291+
logrus.
292+
WithField("file", file.Name()).
293+
WithError(err).
294+
Warn("failed to read the public key, skipped")
295+
continue
296+
}
297+
currentKeys[filename] = key
298+
if exists {
299+
keysUpdated++
300+
} else {
301+
keysAdded++
302+
}
303+
// mark the key as valid, so it's not deleted later
304+
delete(deletes, filename)
305+
}
306+
307+
if keysAdded == 0 && keysUpdated == 0 && len(deletes) == 0 {
308+
logrus.
309+
WithField("key_count", len(currentKeys)).
310+
Debug("no change detected, keeping the current public keys cache")
311+
return nil
312+
}
313+
314+
// delete from `currentKeys` those files that are left in `deletes` set.
315+
for filename := range deletes {
316+
delete(currentKeys, filename)
317+
}
318+
keysKept = len(currentKeys) - keysAdded - keysUpdated - len(deletes)
319+
320+
// build an updated map of hashes to key entries.
321+
currentKeyHashes := make(map[string]*keyEntry, len(currentKeys))
322+
for _, key := range currentKeys {
323+
currentKeyHashes[key.Hash] = key
324+
}
325+
326+
m.rw.Lock()
327+
m.keysByHashes = currentKeyHashes
328+
m.rw.Unlock()
329+
330+
logrus.
331+
WithField("added", keysAdded).
332+
WithField("updated", keysUpdated).
333+
WithField("deleted", len(deletes)).
334+
WithField("kept", keysKept).
335+
Debug("public keys cache has been updated.")
336+
337+
return nil
338+
}
339+
340+
func (m *publicKeyMap) fileToKeyEntry(file os.DirEntry) (key *keyEntry, err error) {
341+
fileInfo, err := file.Info()
342+
if err != nil {
343+
return nil, err
344+
}
345+
346+
filename := file.Name()
347+
fullPath := path.Join(m.directoryPath, filename)
348+
349+
bytes, err := fs.ReadFile(m.fileSys, fullPath)
350+
if err != nil {
351+
return nil, err
352+
}
353+
354+
rsaKey, err := jwt.ParseRSAPublicKeyFromPEM(bytes)
355+
if err != nil {
356+
return nil, err
357+
}
358+
359+
hashBytes := sha512.Sum512_256(bytes)
360+
hash := hex.EncodeToString(hashBytes[:])
361+
362+
return &keyEntry{
363+
Filename: filename,
364+
ModTime: fileInfo.ModTime(),
365+
Size: fileInfo.Size(),
366+
Hash: hash,
367+
Key: rsaKey,
368+
}, nil
369+
}

0 commit comments

Comments
 (0)