Native OPTIONS handling
Fixes #98
diff --git a/router.go b/router.go
index 5d6c81c..f18a678 100644
--- a/router.go
+++ b/router.go
@@ -138,6 +138,10 @@
// handler.
HandleMethodNotAllowed bool
+ // If enabled, the router automatically replies to OPTIONS requests.
+ // Custom OPTIONS handlers take priority over automatic replies.
+ HandleOPTIONS bool
+
// Configurable http.Handler which is called when no matching route is
// found. If it is not set, http.NotFound is used.
NotFound http.Handler
@@ -167,6 +171,7 @@
RedirectTrailingSlash: true,
RedirectFixedPath: true,
HandleMethodNotAllowed: true,
+ HandleOPTIONS: true,
}
}
@@ -288,15 +293,35 @@
return nil, nil, false
}
+func (r *Router) allowed(path, reqMethod string) (allow string) {
+ for method := range r.trees {
+ // Skip the requested method - we already tried this one
+ if method == reqMethod || method == "OPTIONS" {
+ continue
+ }
+
+ handle, _, _ := r.trees[method].getValue(path)
+ if handle != nil {
+ // add request method to list of allowed methods
+ if len(allow) == 0 {
+ allow = method
+ } else {
+ allow += ", " + method
+ }
+ }
+ }
+ return
+}
+
// ServeHTTP makes the router implement the http.Handler interface.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.PanicHandler != nil {
defer r.recv(w, req)
}
- if root := r.trees[req.Method]; root != nil {
- path := req.URL.Path
+ path := req.URL.Path
+ if root := r.trees[req.Method]; root != nil {
if handle, ps, tsr := root.getValue(path); handle != nil {
handle(w, req, ps)
return
@@ -333,37 +358,46 @@
}
}
- // Handle 405
- if r.HandleMethodNotAllowed {
- var allow string
- for method := range r.trees {
- // Skip the requested method - we already tried this one
- if method == req.Method || method == "OPTIONS" {
- continue
- }
+ if req.Method == "OPTIONS" {
+ // Handle OPTIONS requests
+ if r.HandleOPTIONS {
+ var allow string
+ if path == "*" { // server OPTIONS
+ for method := range r.trees {
+ if method == "OPTIONS" {
+ continue
+ }
- handle, _, _ := r.trees[method].getValue(req.URL.Path)
- if handle != nil {
- // add request method to list of allowed methods
- if len(allow) == 0 {
- allow = method
- } else {
- allow += ", " + method
+ // add request method to list of allowed methods
+ if len(allow) == 0 {
+ allow = method
+ } else {
+ allow += ", " + method
+ }
}
+ } else { // path OPTIONS
+ allow = r.allowed(path, req.Method)
+ }
+ if len(allow) > 0 {
+ w.Header().Set("Allow", allow)
+ return
}
}
-
- if len(allow) > 0 {
- w.Header().Set("Allow", allow)
- if r.MethodNotAllowed != nil {
- r.MethodNotAllowed.ServeHTTP(w, req)
- } else {
- http.Error(w,
- http.StatusText(http.StatusMethodNotAllowed),
- http.StatusMethodNotAllowed,
- )
+ } else {
+ // Handle 405
+ if r.HandleMethodNotAllowed {
+ if allow := r.allowed(path, req.Method); len(allow) > 0 {
+ w.Header().Set("Allow", allow)
+ if r.MethodNotAllowed != nil {
+ r.MethodNotAllowed.ServeHTTP(w, req)
+ } else {
+ http.Error(w,
+ http.StatusText(http.StatusMethodNotAllowed),
+ http.StatusMethodNotAllowed,
+ )
+ }
+ return
}
- return
}
}
diff --git a/router_test.go b/router_test.go
index 265bd55..db57740 100644
--- a/router_test.go
+++ b/router_test.go
@@ -216,6 +216,96 @@
}
}
+func TestRouterOPTIONS(t *testing.T) {
+ handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}
+
+ router := New()
+ router.POST("/path", handlerFunc)
+
+ // test not allowed
+ // * (server)
+ r, _ := http.NewRequest("OPTIONS", "*", nil)
+ w := httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+ if !(w.Code == http.StatusOK) {
+ t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
+ } else if allow := w.Header().Get("Allow"); allow != "POST" {
+ t.Error("unexpected Allow header value: " + allow)
+ }
+
+ // path
+ r, _ = http.NewRequest("OPTIONS", "/path", nil)
+ w = httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+ if !(w.Code == http.StatusOK) {
+ t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
+ } else if allow := w.Header().Get("Allow"); allow != "POST" {
+ t.Error("unexpected Allow header value: " + allow)
+ }
+
+ r, _ = http.NewRequest("OPTIONS", "/doesnotexist", nil)
+ w = httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+ if !(w.Code == http.StatusNotFound) {
+ t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
+ }
+
+ // add another method
+ router.GET("/path", handlerFunc)
+
+ // test again
+ // * (server)
+ r, _ = http.NewRequest("OPTIONS", "*", nil)
+ w = httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+ if !(w.Code == http.StatusOK) {
+ t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
+ } else if allow := w.Header().Get("Allow"); allow != "POST, GET" && allow != "GET, POST" {
+ t.Error("unexpected Allow header value: " + allow)
+ }
+
+ // path
+ r, _ = http.NewRequest("OPTIONS", "/path", nil)
+ w = httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+ if !(w.Code == http.StatusOK) {
+ t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
+ } else if allow := w.Header().Get("Allow"); allow != "POST, GET" && allow != "GET, POST" {
+ t.Error("unexpected Allow header value: " + allow)
+ }
+
+ // custom handler
+ var custom bool
+ router.OPTIONS("/path", func(w http.ResponseWriter, r *http.Request, _ Params) {
+ custom = true
+ })
+
+ // test again
+ // * (server)
+ r, _ = http.NewRequest("OPTIONS", "*", nil)
+ w = httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+ if !(w.Code == http.StatusOK) {
+ t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
+ } else if allow := w.Header().Get("Allow"); allow != "POST, GET" && allow != "GET, POST" {
+ t.Error("unexpected Allow header value: " + allow)
+ }
+ if custom {
+ t.Error("custom handler called on *")
+ }
+
+ // path
+ r, _ = http.NewRequest("OPTIONS", "/path", nil)
+ w = httptest.NewRecorder()
+ router.ServeHTTP(w, r)
+ if !(w.Code == http.StatusOK) {
+ t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
+ }
+ if !custom {
+ t.Error("custom handler not called")
+ }
+}
+
func TestRouterNotAllowed(t *testing.T) {
handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}