Cache public key, so we dont hit server each time.
diff --git a/oauth.go b/oauth.go index c55beb7..5ce1cbf 100644 --- a/oauth.go +++ b/oauth.go
@@ -2,13 +2,20 @@ import ( "context" + "crypto/rsa" "encoding/json" "github.com/SermoDigital/jose/crypto" "github.com/SermoDigital/jose/jws" - "github.com/SermoDigital/jose/jwt" "github.com/julienschmidt/httprouter" "github.com/justinas/alice" "net/http" + "sync" + "time" +) + +var ( + gPkey *rsa.PublicKey + rwMutex sync.RWMutex ) const params = "params" @@ -33,8 +40,6 @@ key for verifying the JWT token */ type OAuth struct { - keyURL string - client *http.Client } /* @@ -47,6 +52,29 @@ } /* +CreateOAuth is a constructor that creates OAuth for OAuthService +interface. OAuthService interface offers method:- +(1) SSOHandler(): Offers the user to attach http handler for JWT +verification. +*/ +func (s *HTTPScaffold) CreateOAuth(keyURL string) OAuthService { + + pk, err := getPubicKey(keyURL) + if err != nil { + panic("Unable to retreive Public Key") + } + setPkSafe(pk) + /* + Routine that will fetch & update the public keys in the global + variable periodically + */ + updatePulicKeysPeriodic(keyURL) + + return &OAuth{} + +} + +/* SetParamsInRequest Sets the params and its values in the request */ func SetParamsInRequest(r *http.Request, ps httprouter.Params) *http.Request { @@ -85,7 +113,7 @@ } /* Validate the JWT */ - err = a.Validate(jwt) + err = jwt.Validate(getPkSafe(), crypto.SigningMethodRS256) if err != nil { WriteErrorResponse(http.StatusBadRequest, err.Error(), rw) return @@ -99,34 +127,6 @@ } /* -ValidateKey validate the jwt and return an error if it fails -*/ -func (a *OAuth) Validate(jwt jwt.JWT) error { - - r, err := a.client.Get(a.keyURL) - - if err != nil { - return err - } - - defer r.Body.Close() - ssoKey := &ssoKey{} - err = json.NewDecoder(r.Body).Decode(ssoKey) - if err != nil { - return err - } - - /* Retrieve the Public Key */ - publieKey, err := crypto.ParseRSAPublicKeyFromPEM([]byte(ssoKey.Value)) - if err != nil { - return err - } - - /* Return the status of validation */ - return jwt.Validate(publieKey, crypto.SigningMethodRS256) -} - -/* WriteErrorResponse write a non 200 error response */ func WriteErrorResponse(statusCode int, message string, w http.ResponseWriter) { @@ -142,3 +142,72 @@ w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(errors) } + +/* +updatePulicKeysPeriodic updates the cache periodically (every day) +*/ +func updatePulicKeysPeriodic(keyURL string) { + + ticker := time.NewTicker(24 * 3600 * time.Second) + quit := make(chan struct{}) + go func() { + for { + select { + case <-ticker.C: + pk, err := getPubicKey(keyURL) + if err == nil { + setPkSafe(pk) + } + case <-quit: + ticker.Stop() + return + } + } + }() +} + +/* +getPubicKey: Loads the Public key in to memory and returns it. +*/ +func getPubicKey(keyURL string) (*rsa.PublicKey, error) { + + client := &http.Client{} + r, err := client.Get(keyURL) + if err != nil { + return nil, err + } + + defer r.Body.Close() + ssoKey := &ssoKey{} + err = json.NewDecoder(r.Body).Decode(ssoKey) + if err != nil { + return nil, err + } + + /* Retrieve the Public Key */ + publicKey, err := crypto.ParseRSAPublicKeyFromPEM([]byte(ssoKey.Value)) + if err != nil { + return nil, err + } + return publicKey, nil + +} + +/* +setPkSafe Safely stores the Public Key (via a Write Lock) +*/ +func setPkSafe(pk *rsa.PublicKey) { + rwMutex.Lock() + gPkey = pk + rwMutex.Unlock() +} + +/* +getPkSafe returns the stored key (via a read lock) +*/ +func getPkSafe() *rsa.PublicKey { + rwMutex.RLock() + pk := gPkey + rwMutex.RUnlock() + return pk +}
diff --git a/scaffold.go b/scaffold.go index 55013d4..9ab8d2f 100644 --- a/scaffold.go +++ b/scaffold.go
@@ -95,8 +95,8 @@ markdownPath string markdownMethod string markdownHandler MarkdownHandler - keyFile string certFile string + keyFile string } /*