Return the error in the handler in case of JWT failure.
diff --git a/oauth.go b/oauth.go index 350c780..fb34930 100644 --- a/oauth.go +++ b/oauth.go
@@ -5,6 +5,7 @@ "crypto/rsa" "encoding/json" "net/http" + "strconv" "sync" "time" @@ -97,49 +98,45 @@ return func(rw http.ResponseWriter, r *http.Request, ps httprouter.Params) { + var err2 error = nil + + /* Set the input params in the request if valid */ + r = SetParamsInRequest(r, ps) + /* Set Default as OK */ + WriteStatusResponse(http.StatusOK, "", r) + /* Parse the JWT from the input request */ - jwt, err := jws.ParseJWTFromRequest(r) - if err != nil { - WriteErrorResponse(http.StatusBadRequest, err.Error(), rw) - return + jwt, err1 := jws.ParseJWTFromRequest(r) + if err1 != nil { + WriteStatusResponse(http.StatusBadRequest, err1.Error(), r) } /* Get the pulic key from cache */ - pk := a.getPkSafe() - if pk == nil { - WriteErrorResponse(http.StatusBadRequest, "Public key not configured. Validation failed.", rw) - return + if err1 == nil { + pk := a.getPkSafe() + if pk == nil { + WriteStatusResponse(http.StatusBadRequest, + "Public key not configured. Validation failed.", r) + } else { + err2 = jwt.Validate(pk, crypto.SigningMethodRS256) + if err2 != nil { + WriteStatusResponse(http.StatusBadRequest, + err2.Error(), r) + } + } } - - /* Validate the token */ - err = jwt.Validate(pk, crypto.SigningMethodRS256) - if err != nil { - WriteErrorResponse(http.StatusBadRequest, err.Error(), rw) - return - } - - /* Set the input params in the request */ - r = SetParamsInRequest(r, ps) next.ServeHTTP(rw, r) } - } /* -WriteErrorResponse write a non 200 error response +WriteStatusResponse updates the validation outcome in the header. */ -func WriteErrorResponse(statusCode int, message string, w http.ResponseWriter) { - errors := Errors{message} - WriteErrorResponses(statusCode, errors, w) -} - -/* -WriteErrorResponses write our error responses -*/ -func WriteErrorResponses(statusCode int, errors Errors, w http.ResponseWriter) { - w.WriteHeader(statusCode) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(errors) +func WriteStatusResponse(statusCode int, message string, r *http.Request) { + r.Header.Set("StatusCode", strconv.Itoa(statusCode)) + if statusCode != http.StatusOK { + r.Header.Set("ErrorMessage", message) + } } /*
diff --git a/scaffold_test.go b/scaffold_test.go index b2007b7..d476b6e 100644 --- a/scaffold_test.go +++ b/scaffold_test.go
@@ -404,17 +404,38 @@ Expect(reqerr).Should(Succeed()) defer resp.Body.Close() return resp.StatusCode + dataValue := resp.Header.Get("Status") + Expect(dataValue).To(Equal("200")) + return resp.StatusCode }, 2*time.Second).Should(Equal(200)) - req, err := http.NewRequest("GET", - "http://"+scaf.InsecureAddress()+"/foobar/xyz/123", nil) + }) + It("SSO handler validation Bad Key", func() { + router := httprouter.New() + Expect(router).ShouldNot(BeNil()) + scaf := CreateHTTPScaffold() + Expect(scaf).ShouldNot(BeNil()) + err := scaf.Open() Expect(err).Should(Succeed()) - req.Header.Set("Authorization", "Bearer DEADBEEF") - client := &http.Client{} - resp, err := client.Do(req) - Expect(err).Should(Succeed()) - defer resp.Body.Close() - Expect(resp.StatusCode).Should(Equal(400)) + oauth := scaf.CreateOAuth(validJWTSigner) + Expect(oauth).ShouldNot(BeNil()) + go func() { + fmt.Fprintf(GinkgoWriter, "Gonna listen on %s\n", scaf.InsecureAddress()) + router.GET(oauth.SSOHandler("/foobar/:param1/:param2", buslogicHandlerFail1)) + scaf.Listen(router) + }() + + Eventually(func() int { + req, err := http.NewRequest("GET", + "http://"+scaf.InsecureAddress()+"/foobar/xyz/123", nil) + Expect(err).Should(Succeed()) + req.Header.Set("Authorization", "Bearer DEADBEEF") + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).Should(Succeed()) + defer resp.Body.Close() + return resp.StatusCode + }, 2*time.Second).Should(Equal(200)) }) It("SSO handler validation bad public key", func() { @@ -428,7 +449,7 @@ Expect(oauth).ShouldNot(BeNil()) go func() { fmt.Fprintf(GinkgoWriter, "Gonna listen on %s\n", scaf.InsecureAddress()) - router.GET(oauth.SSOHandler("/foobar/:param1/:param2", buslogicHandler)) + router.GET(oauth.SSOHandler("/foobar/:param1/:param2", buslogicHandlerFail2)) scaf.Listen(router) }() @@ -440,7 +461,7 @@ resp, err := client.Do(req) Expect(err).Should(Succeed()) defer resp.Body.Close() - Expect(resp.StatusCode).Should(Equal(400)) + Expect(resp.StatusCode).Should(Equal(200)) }) It("Get stack trace", func() { @@ -468,6 +489,27 @@ }) func buslogicHandler(w http.ResponseWriter, r *http.Request) { + Expect(r.Header.Get("StatusCode")).Should(Equal("200")) + p := FetchParams(r) + cid := p.ByName("param1") + Expect(cid).To(Equal("xyz")) + cid = p.ByName("param2") + Expect(cid).To(Equal("123")) +} + +func buslogicHandlerFail1(w http.ResponseWriter, r *http.Request) { + Expect(r.Header.Get("StatusCode")).Should(Equal("400")) + Expect(r.Header.Get("ErrorMessage")).Should(Equal("not a compact JWS")) + p := FetchParams(r) + cid := p.ByName("param1") + Expect(cid).To(Equal("xyz")) + cid = p.ByName("param2") + Expect(cid).To(Equal("123")) +} + +func buslogicHandlerFail2(w http.ResponseWriter, r *http.Request) { + Expect(r.Header.Get("StatusCode")).Should(Equal("400")) + Expect(r.Header.Get("ErrorMessage")).Should(Equal("Public key not configured. Validation failed.")) p := FetchParams(r) cid := p.ByName("param1") Expect(cid).To(Equal("xyz"))