diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index c97a567..9d89a17 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -57,37 +57,14 @@ func NewHook(cfg Config) middleware.Hook { closing: make(chan struct{}), } + h.updateKeys() go func() { for { select { case <-h.closing: return case <-time.After(cfg.JWKUpdateInterval): - resp, err := http.Get(cfg.JWKSetURL) - if err != nil { - log.Errorln("failed to fetch JWK Set: " + err.Error()) - continue - } - - parsedJWKs := map[string]gojwk.Key{} - err = json.NewDecoder(resp.Body).Decode(&parsedJWKs) - if err != nil { - resp.Body.Close() - log.Errorln("failed to decode JWK JSON: " + err.Error()) - continue - } - resp.Body.Close() - - keys := map[string]crypto.PublicKey{} - for kid, parsedJWK := range parsedJWKs { - publicKey, err := parsedJWK.DecodePublicKey() - if err != nil { - log.Errorln("failed to decode JWK into public key: " + err.Error()) - continue - } - keys[kid] = publicKey - } - h.publicKeys = keys + h.updateKeys() } } }() @@ -95,6 +72,34 @@ func NewHook(cfg Config) middleware.Hook { return h } +func (h *hook) updateKeys() { + resp, err := http.Get(h.cfg.JWKSetURL) + if err != nil { + log.Errorln("failed to fetch JWK Set: " + err.Error()) + return + } + + parsedJWKs := map[string]gojwk.Key{} + err = json.NewDecoder(resp.Body).Decode(&parsedJWKs) + if err != nil { + resp.Body.Close() + log.Errorln("failed to decode JWK JSON: " + err.Error()) + return + } + resp.Body.Close() + + keys := map[string]crypto.PublicKey{} + for kid, parsedJWK := range parsedJWKs { + publicKey, err := parsedJWK.DecodePublicKey() + if err != nil { + log.Errorln("failed to decode JWK into public key: " + err.Error()) + return + } + keys[kid] = publicKey + } + h.publicKeys = keys +} + func (h *hook) Stop() <-chan error { select { case <-h.closing: