Native OPTIONS handling Fixes #98
diff --git a/router.go b/router.go index 5d6c81c..f18a678 100644 --- a/router.go +++ b/router.go
@@ -138,6 +138,10 @@ // handler. HandleMethodNotAllowed bool + // If enabled, the router automatically replies to OPTIONS requests. + // Custom OPTIONS handlers take priority over automatic replies. + HandleOPTIONS bool + // Configurable http.Handler which is called when no matching route is // found. If it is not set, http.NotFound is used. NotFound http.Handler @@ -167,6 +171,7 @@ RedirectTrailingSlash: true, RedirectFixedPath: true, HandleMethodNotAllowed: true, + HandleOPTIONS: true, } } @@ -288,15 +293,35 @@ return nil, nil, false } +func (r *Router) allowed(path, reqMethod string) (allow string) { + for method := range r.trees { + // Skip the requested method - we already tried this one + if method == reqMethod || method == "OPTIONS" { + continue + } + + handle, _, _ := r.trees[method].getValue(path) + if handle != nil { + // add request method to list of allowed methods + if len(allow) == 0 { + allow = method + } else { + allow += ", " + method + } + } + } + return +} + // ServeHTTP makes the router implement the http.Handler interface. func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { if r.PanicHandler != nil { defer r.recv(w, req) } - if root := r.trees[req.Method]; root != nil { - path := req.URL.Path + path := req.URL.Path + if root := r.trees[req.Method]; root != nil { if handle, ps, tsr := root.getValue(path); handle != nil { handle(w, req, ps) return @@ -333,37 +358,46 @@ } } - // Handle 405 - if r.HandleMethodNotAllowed { - var allow string - for method := range r.trees { - // Skip the requested method - we already tried this one - if method == req.Method || method == "OPTIONS" { - continue - } + if req.Method == "OPTIONS" { + // Handle OPTIONS requests + if r.HandleOPTIONS { + var allow string + if path == "*" { // server OPTIONS + for method := range r.trees { + if method == "OPTIONS" { + continue + } - handle, _, _ := r.trees[method].getValue(req.URL.Path) - if handle != nil { - // add request method to list of allowed methods - if len(allow) == 0 { - allow = method - } else { - allow += ", " + method + // add request method to list of allowed methods + if len(allow) == 0 { + allow = method + } else { + allow += ", " + method + } } + } else { // path OPTIONS + allow = r.allowed(path, req.Method) + } + if len(allow) > 0 { + w.Header().Set("Allow", allow) + return } } - - if len(allow) > 0 { - w.Header().Set("Allow", allow) - if r.MethodNotAllowed != nil { - r.MethodNotAllowed.ServeHTTP(w, req) - } else { - http.Error(w, - http.StatusText(http.StatusMethodNotAllowed), - http.StatusMethodNotAllowed, - ) + } else { + // Handle 405 + if r.HandleMethodNotAllowed { + if allow := r.allowed(path, req.Method); len(allow) > 0 { + w.Header().Set("Allow", allow) + if r.MethodNotAllowed != nil { + r.MethodNotAllowed.ServeHTTP(w, req) + } else { + http.Error(w, + http.StatusText(http.StatusMethodNotAllowed), + http.StatusMethodNotAllowed, + ) + } + return } - return } }
diff --git a/router_test.go b/router_test.go index 265bd55..db57740 100644 --- a/router_test.go +++ b/router_test.go
@@ -216,6 +216,96 @@ } } +func TestRouterOPTIONS(t *testing.T) { + handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {} + + router := New() + router.POST("/path", handlerFunc) + + // test not allowed + // * (server) + r, _ := http.NewRequest("OPTIONS", "*", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, r) + if !(w.Code == http.StatusOK) { + t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header()) + } else if allow := w.Header().Get("Allow"); allow != "POST" { + t.Error("unexpected Allow header value: " + allow) + } + + // path + r, _ = http.NewRequest("OPTIONS", "/path", nil) + w = httptest.NewRecorder() + router.ServeHTTP(w, r) + if !(w.Code == http.StatusOK) { + t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header()) + } else if allow := w.Header().Get("Allow"); allow != "POST" { + t.Error("unexpected Allow header value: " + allow) + } + + r, _ = http.NewRequest("OPTIONS", "/doesnotexist", nil) + w = httptest.NewRecorder() + router.ServeHTTP(w, r) + if !(w.Code == http.StatusNotFound) { + t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header()) + } + + // add another method + router.GET("/path", handlerFunc) + + // test again + // * (server) + r, _ = http.NewRequest("OPTIONS", "*", nil) + w = httptest.NewRecorder() + router.ServeHTTP(w, r) + if !(w.Code == http.StatusOK) { + t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header()) + } else if allow := w.Header().Get("Allow"); allow != "POST, GET" && allow != "GET, POST" { + t.Error("unexpected Allow header value: " + allow) + } + + // path + r, _ = http.NewRequest("OPTIONS", "/path", nil) + w = httptest.NewRecorder() + router.ServeHTTP(w, r) + if !(w.Code == http.StatusOK) { + t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header()) + } else if allow := w.Header().Get("Allow"); allow != "POST, GET" && allow != "GET, POST" { + t.Error("unexpected Allow header value: " + allow) + } + + // custom handler + var custom bool + router.OPTIONS("/path", func(w http.ResponseWriter, r *http.Request, _ Params) { + custom = true + }) + + // test again + // * (server) + r, _ = http.NewRequest("OPTIONS", "*", nil) + w = httptest.NewRecorder() + router.ServeHTTP(w, r) + if !(w.Code == http.StatusOK) { + t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header()) + } else if allow := w.Header().Get("Allow"); allow != "POST, GET" && allow != "GET, POST" { + t.Error("unexpected Allow header value: " + allow) + } + if custom { + t.Error("custom handler called on *") + } + + // path + r, _ = http.NewRequest("OPTIONS", "/path", nil) + w = httptest.NewRecorder() + router.ServeHTTP(w, r) + if !(w.Code == http.StatusOK) { + t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header()) + } + if !custom { + t.Error("custom handler not called") + } +} + func TestRouterNotAllowed(t *testing.T) { handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}