Return the error in the handler in case of JWT failure.
diff --git a/oauth.go b/oauth.go
index 350c780..fb34930 100644
--- a/oauth.go
+++ b/oauth.go
@@ -5,6 +5,7 @@
"crypto/rsa"
"encoding/json"
"net/http"
+ "strconv"
"sync"
"time"
@@ -97,49 +98,45 @@
return func(rw http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+ var err2 error = nil
+
+ /* Set the input params in the request if valid */
+ r = SetParamsInRequest(r, ps)
+ /* Set Default as OK */
+ WriteStatusResponse(http.StatusOK, "", r)
+
/* Parse the JWT from the input request */
- jwt, err := jws.ParseJWTFromRequest(r)
- if err != nil {
- WriteErrorResponse(http.StatusBadRequest, err.Error(), rw)
- return
+ jwt, err1 := jws.ParseJWTFromRequest(r)
+ if err1 != nil {
+ WriteStatusResponse(http.StatusBadRequest, err1.Error(), r)
}
/* Get the pulic key from cache */
- pk := a.getPkSafe()
- if pk == nil {
- WriteErrorResponse(http.StatusBadRequest, "Public key not configured. Validation failed.", rw)
- return
+ if err1 == nil {
+ pk := a.getPkSafe()
+ if pk == nil {
+ WriteStatusResponse(http.StatusBadRequest,
+ "Public key not configured. Validation failed.", r)
+ } else {
+ err2 = jwt.Validate(pk, crypto.SigningMethodRS256)
+ if err2 != nil {
+ WriteStatusResponse(http.StatusBadRequest,
+ err2.Error(), r)
+ }
+ }
}
-
- /* Validate the token */
- err = jwt.Validate(pk, crypto.SigningMethodRS256)
- if err != nil {
- WriteErrorResponse(http.StatusBadRequest, err.Error(), rw)
- return
- }
-
- /* Set the input params in the request */
- r = SetParamsInRequest(r, ps)
next.ServeHTTP(rw, r)
}
-
}
/*
-WriteErrorResponse write a non 200 error response
+WriteStatusResponse updates the validation outcome in the header.
*/
-func WriteErrorResponse(statusCode int, message string, w http.ResponseWriter) {
- errors := Errors{message}
- WriteErrorResponses(statusCode, errors, w)
-}
-
-/*
-WriteErrorResponses write our error responses
-*/
-func WriteErrorResponses(statusCode int, errors Errors, w http.ResponseWriter) {
- w.WriteHeader(statusCode)
- w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(errors)
+func WriteStatusResponse(statusCode int, message string, r *http.Request) {
+ r.Header.Set("StatusCode", strconv.Itoa(statusCode))
+ if statusCode != http.StatusOK {
+ r.Header.Set("ErrorMessage", message)
+ }
}
/*
diff --git a/scaffold_test.go b/scaffold_test.go
index b2007b7..d476b6e 100644
--- a/scaffold_test.go
+++ b/scaffold_test.go
@@ -404,17 +404,38 @@
Expect(reqerr).Should(Succeed())
defer resp.Body.Close()
return resp.StatusCode
+ dataValue := resp.Header.Get("Status")
+ Expect(dataValue).To(Equal("200"))
+ return resp.StatusCode
}, 2*time.Second).Should(Equal(200))
- req, err := http.NewRequest("GET",
- "http://"+scaf.InsecureAddress()+"/foobar/xyz/123", nil)
+ })
+ It("SSO handler validation Bad Key", func() {
+ router := httprouter.New()
+ Expect(router).ShouldNot(BeNil())
+ scaf := CreateHTTPScaffold()
+ Expect(scaf).ShouldNot(BeNil())
+ err := scaf.Open()
Expect(err).Should(Succeed())
- req.Header.Set("Authorization", "Bearer DEADBEEF")
- client := &http.Client{}
- resp, err := client.Do(req)
- Expect(err).Should(Succeed())
- defer resp.Body.Close()
- Expect(resp.StatusCode).Should(Equal(400))
+ oauth := scaf.CreateOAuth(validJWTSigner)
+ Expect(oauth).ShouldNot(BeNil())
+ go func() {
+ fmt.Fprintf(GinkgoWriter, "Gonna listen on %s\n", scaf.InsecureAddress())
+ router.GET(oauth.SSOHandler("/foobar/:param1/:param2", buslogicHandlerFail1))
+ scaf.Listen(router)
+ }()
+
+ Eventually(func() int {
+ req, err := http.NewRequest("GET",
+ "http://"+scaf.InsecureAddress()+"/foobar/xyz/123", nil)
+ Expect(err).Should(Succeed())
+ req.Header.Set("Authorization", "Bearer DEADBEEF")
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ Expect(err).Should(Succeed())
+ defer resp.Body.Close()
+ return resp.StatusCode
+ }, 2*time.Second).Should(Equal(200))
})
It("SSO handler validation bad public key", func() {
@@ -428,7 +449,7 @@
Expect(oauth).ShouldNot(BeNil())
go func() {
fmt.Fprintf(GinkgoWriter, "Gonna listen on %s\n", scaf.InsecureAddress())
- router.GET(oauth.SSOHandler("/foobar/:param1/:param2", buslogicHandler))
+ router.GET(oauth.SSOHandler("/foobar/:param1/:param2", buslogicHandlerFail2))
scaf.Listen(router)
}()
@@ -440,7 +461,7 @@
resp, err := client.Do(req)
Expect(err).Should(Succeed())
defer resp.Body.Close()
- Expect(resp.StatusCode).Should(Equal(400))
+ Expect(resp.StatusCode).Should(Equal(200))
})
It("Get stack trace", func() {
@@ -468,6 +489,27 @@
})
func buslogicHandler(w http.ResponseWriter, r *http.Request) {
+ Expect(r.Header.Get("StatusCode")).Should(Equal("200"))
+ p := FetchParams(r)
+ cid := p.ByName("param1")
+ Expect(cid).To(Equal("xyz"))
+ cid = p.ByName("param2")
+ Expect(cid).To(Equal("123"))
+}
+
+func buslogicHandlerFail1(w http.ResponseWriter, r *http.Request) {
+ Expect(r.Header.Get("StatusCode")).Should(Equal("400"))
+ Expect(r.Header.Get("ErrorMessage")).Should(Equal("not a compact JWS"))
+ p := FetchParams(r)
+ cid := p.ByName("param1")
+ Expect(cid).To(Equal("xyz"))
+ cid = p.ByName("param2")
+ Expect(cid).To(Equal("123"))
+}
+
+func buslogicHandlerFail2(w http.ResponseWriter, r *http.Request) {
+ Expect(r.Header.Get("StatusCode")).Should(Equal("400"))
+ Expect(r.Header.Get("ErrorMessage")).Should(Equal("Public key not configured. Validation failed."))
p := FetchParams(r)
cid := p.ByName("param1")
Expect(cid).To(Equal("xyz"))