@@ -46,14 +46,14 @@ type RemoteKeySet struct {
4646 inflight * inflight
4747
4848 // A set of cached keys.
49- cachedKeys [ ]jose.JSONWebKey
49+ cachedKeys map [ string ]jose.JSONWebKey
5050}
5151
5252// inflight is used to wait on some in-flight request from multiple goroutines.
5353type inflight struct {
5454 doneCh chan struct {}
5555
56- keys [ ]jose.JSONWebKey
56+ keys map [ string ]jose.JSONWebKey
5757 err error
5858}
5959
@@ -70,14 +70,14 @@ func (i *inflight) wait() <-chan struct{} {
7070// done can only be called by a single goroutine. It records the result of the
7171// inflight request and signals other goroutines that the result is safe to
7272// inspect.
73- func (i * inflight ) done (keys [ ]jose.JSONWebKey , err error ) {
73+ func (i * inflight ) done (keys map [ string ]jose.JSONWebKey , err error ) {
7474 i .keys = keys
7575 i .err = err
7676 close (i .doneCh )
7777}
7878
7979// result cannot be called until the wait() channel has returned a value.
80- func (i * inflight ) result () ([ ]jose.JSONWebKey , error ) {
80+ func (i * inflight ) result () (map [ string ]jose.JSONWebKey , error ) {
8181 return i .keys , i .err
8282}
8383
@@ -102,43 +102,53 @@ func (r *RemoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) (
102102 break
103103 }
104104
105- keys := r .keysFromCache ()
106- for _ , key := range keys {
107- if keyID == "" || key .KeyID == keyID {
108- if payload , err := jws .Verify (& key ); err == nil {
109- return payload , nil
110- }
111- }
105+ if payload , ok := r .verifyWithKey (keyID , jws ); ok {
106+ return payload , nil
112107 }
113-
114108 // If the kid doesn't match, check for new keys from the remote. This is the
115109 // strategy recommended by the spec.
116110 //
117111 // https://openid.net/specs/openid-connect-core-1_0.html#RotateSigKeys
118- keys , err := r .keysFromRemote (ctx )
112+ _ , err := r .keysFromRemote (ctx )
119113 if err != nil {
120114 return nil , fmt .Errorf ("fetching keys %v" , err )
121115 }
122116
123- for _ , key := range keys {
124- if keyID == "" || key .KeyID == keyID {
117+ if payload , ok := r .verifyWithKey (keyID , jws ); ok {
118+ return payload , nil
119+ }
120+
121+ return nil , errors .New ("failed to verify id token signature" )
122+ }
123+
124+ // verifyWithKey attempts to verify the jws using the key with keyID from the cache
125+ // if keyID is the empty string, it tries each key in the cache
126+ func (r * RemoteKeySet ) verifyWithKey (keyID string , jws * jose.JSONWebSignature ) (payload []byte , ok bool ) {
127+ if keyID == "" {
128+ for _ , key := range r .keysFromCache () {
125129 if payload , err := jws .Verify (& key ); err == nil {
126- return payload , nil
130+ return payload , true
131+ }
132+ }
133+ } else {
134+ if key , ok := r .keysFromCache ()[keyID ]; ok {
135+ if payload , err := jws .Verify (& key ); err == nil {
136+ return payload , true
127137 }
128138 }
129139 }
130- return nil , errors . New ( "failed to verify id token signature" )
140+ return nil , false
131141}
132142
133- func (r * RemoteKeySet ) keysFromCache () (keys [ ]jose.JSONWebKey ) {
143+ func (r * RemoteKeySet ) keysFromCache () (keys map [ string ]jose.JSONWebKey ) {
134144 r .mu .Lock ()
135145 defer r .mu .Unlock ()
136146 return r .cachedKeys
137147}
138148
139149// keysFromRemote syncs the key set from the remote set, records the values in the
140150// cache, and returns the key set.
141- func (r * RemoteKeySet ) keysFromRemote (ctx context.Context ) ([ ]jose.JSONWebKey , error ) {
151+ func (r * RemoteKeySet ) keysFromRemote (ctx context.Context ) (map [ string ]jose.JSONWebKey , error ) {
142152 // Need to lock to inspect the inflight request field.
143153 r .mu .Lock ()
144154 // If there's not a current inflight request, create one.
@@ -178,7 +188,7 @@ func (r *RemoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, e
178188 }
179189}
180190
181- func (r * RemoteKeySet ) updateKeys () ([ ]jose.JSONWebKey , error ) {
191+ func (r * RemoteKeySet ) updateKeys () (map [ string ]jose.JSONWebKey , error ) {
182192 req , err := http .NewRequest ("GET" , r .jwksURL , nil )
183193 if err != nil {
184194 return nil , fmt .Errorf ("oidc: can't create request: %v" , err )
@@ -204,5 +214,9 @@ func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) {
204214 if err != nil {
205215 return nil , fmt .Errorf ("oidc: failed to decode keys: %v %s" , err , body )
206216 }
207- return keySet .Keys , nil
217+ keys := make (map [string ]jose.JSONWebKey )
218+ for _ , key := range keySet .Keys {
219+ keys [key .KeyID ] = key
220+ }
221+ return keys , nil
208222}
0 commit comments