diff --git a/cmd/chihaya/config.go b/cmd/chihaya/config.go index 12f3d71..5c7c896 100644 --- a/cmd/chihaya/config.go +++ b/cmd/chihaya/config.go @@ -78,7 +78,11 @@ func (cfg ConfigFile) CreateHooks() (preHooks, postHooks []middleware.Hook, err if err != nil { return nil, nil, errors.New("invalid JWT middleware config: " + err.Error()) } - preHooks = append(preHooks, jwt.NewHook(jwtCfg)) + hook, err := jwt.NewHook(jwtCfg) + if err != nil { + return nil, nil, errors.New("invalid JWT middleware config: " + err.Error()) + } + preHooks = append(preHooks, hook) case "client approval": var caCfg clientapproval.Config err := yaml.Unmarshal(cfgBytes, &caCfg) diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index c97a567..47b6eec 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -50,49 +50,60 @@ type hook struct { } // NewHook returns an instance of the JWT middleware. -func NewHook(cfg Config) middleware.Hook { +func NewHook(cfg Config) (middleware.Hook, error) { h := &hook{ cfg: cfg, publicKeys: map[string]crypto.PublicKey{}, closing: make(chan struct{}), } + err := h.updateKeys() + if err != nil { + return nil, errors.New("failed to update initial JWK Set: " + err.Error()) + } + go func() { for { select { case <-h.closing: return case <-time.After(cfg.JWKUpdateInterval): - resp, err := http.Get(cfg.JWKSetURL) - if err != nil { - log.Errorln("failed to fetch JWK Set: " + err.Error()) - continue - } - - parsedJWKs := map[string]gojwk.Key{} - err = json.NewDecoder(resp.Body).Decode(&parsedJWKs) - if err != nil { - resp.Body.Close() - log.Errorln("failed to decode JWK JSON: " + err.Error()) - continue - } - resp.Body.Close() - - keys := map[string]crypto.PublicKey{} - for kid, parsedJWK := range parsedJWKs { - publicKey, err := parsedJWK.DecodePublicKey() - if err != nil { - log.Errorln("failed to decode JWK into public key: " + err.Error()) - continue - } - keys[kid] = publicKey - } - h.publicKeys = keys + h.updateKeys() } } }() - return h + return h, nil +} + +func (h *hook) updateKeys() error { + resp, err := http.Get(h.cfg.JWKSetURL) + if err != nil { + log.Errorln("failed to fetch JWK Set: " + err.Error()) + return err + } + + parsedJWKs := map[string]gojwk.Key{} + err = json.NewDecoder(resp.Body).Decode(&parsedJWKs) + if err != nil { + resp.Body.Close() + log.Errorln("failed to decode JWK JSON: " + err.Error()) + return err + } + resp.Body.Close() + + keys := map[string]crypto.PublicKey{} + for kid, parsedJWK := range parsedJWKs { + publicKey, err := parsedJWK.DecodePublicKey() + if err != nil { + log.Errorln("failed to decode JWK into public key: " + err.Error()) + return err + } + keys[kid] = publicKey + } + h.publicKeys = keys + + return nil } func (h *hook) Stop() <-chan error {