Allow chaining of any http.Handler, not just http.HandlerFunc.
diff --git a/README.md b/README.md index 9875c70..785a108 100644 --- a/README.md +++ b/README.md
@@ -297,7 +297,7 @@ **NOTE: It might be required to set [Router.HandleMethodNotAllowed](http://godoc.org/github.com/julienschmidt/httprouter#Router.HandleMethodNotAllowed) to `false` to avoid problems.** -You can use another [http.HandlerFunc](http://golang.org/pkg/net/http/#HandlerFunc), for example another router, to handle requests which could not be matched by this router by using the [Router.NotFound](http://godoc.org/github.com/julienschmidt/httprouter#Router.NotFound) handler. This allows chaining. +You can use another [http.Handler](http://golang.org/pkg/net/http/#Handler), for example another router, to handle requests which could not be matched by this router by using the [Router.NotFound](http://godoc.org/github.com/julienschmidt/httprouter#Router.NotFound) handler. This allows chaining. ### Static files The `NotFound` handler can for example be used to serve static files from the root path `/` (like an index.html file along with other assets):
diff --git a/router.go b/router.go index 155b871..83e86fd 100644 --- a/router.go +++ b/router.go
@@ -138,14 +138,14 @@ // handler. HandleMethodNotAllowed bool - // Configurable http.HandlerFunc which is called when no matching route is + // Configurable http.Handler which is called when no matching route is // found. If it is not set, http.NotFound is used. - NotFound http.HandlerFunc + NotFound http.Handler - // Configurable http.HandlerFunc which is called when a request + // Configurable http.Handler which is called when a request // cannot be routed and HandleMethodNotAllowed is true. // If it is not set, http.Error with http.StatusMethodNotAllowed is used. - MethodNotAllowed http.HandlerFunc + MethodNotAllowed http.Handler // Function to handle panics recovered from http handlers. // It should be used to generate a error page and return the http error code @@ -342,7 +342,7 @@ handle, _, _ := r.trees[method].getValue(req.URL.Path) if handle != nil { if r.MethodNotAllowed != nil { - r.MethodNotAllowed(w, req) + r.MethodNotAllowed.ServeHTTP(w, req) } else { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), @@ -356,7 +356,7 @@ // Handle 404 if r.NotFound != nil { - r.NotFound(w, req) + r.NotFound.ServeHTTP(w, req) } else { http.NotFound(w, req) }
diff --git a/router_test.go b/router_test.go index 9dc6296..e3141bd 100644 --- a/router_test.go +++ b/router_test.go
@@ -174,6 +174,48 @@ } } +func TestRouterChaining(t *testing.T) { + router1 := New() + router2 := New() + router1.NotFound = router2 + + fooHit := false + router1.POST("/foo", func(w http.ResponseWriter, req *http.Request, _ Params) { + fooHit = true + w.WriteHeader(http.StatusOK) + }) + + barHit := false + router2.POST("/bar", func(w http.ResponseWriter, req *http.Request, _ Params) { + barHit = true + w.WriteHeader(http.StatusOK) + }) + + r, _ := http.NewRequest("POST", "/foo", nil) + w := httptest.NewRecorder() + router1.ServeHTTP(w, r) + if !(w.Code == http.StatusOK && fooHit) { + t.Errorf("Regular routing failed with router chaining.") + t.FailNow() + } + + r, _ = http.NewRequest("POST", "/bar", nil) + w = httptest.NewRecorder() + router1.ServeHTTP(w, r) + if !(w.Code == http.StatusOK && barHit) { + t.Errorf("Chained routing failed with router chaining.") + t.FailNow() + } + + r, _ = http.NewRequest("POST", "/qax", nil) + w = httptest.NewRecorder() + router1.ServeHTTP(w, r) + if !(w.Code == http.StatusNotFound) { + t.Errorf("NotFound behavior failed with router chaining.") + t.FailNow() + } +} + func TestRouterNotAllowed(t *testing.T) { handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {} @@ -190,10 +232,10 @@ w = httptest.NewRecorder() responseText := "custom method" - router.MethodNotAllowed = func(w http.ResponseWriter, req *http.Request) { + router.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusTeapot) w.Write([]byte(responseText)) - } + }) router.ServeHTTP(w, r) if got := w.Body.String(); !(got == responseText) { t.Errorf("unexpected response got %q want %q", got, responseText) @@ -237,10 +279,10 @@ // Test custom not found handler var notFound bool - router.NotFound = func(rw http.ResponseWriter, r *http.Request) { + router.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(404) notFound = true - } + }) r, _ := http.NewRequest("GET", "/nope", nil) w := httptest.NewRecorder() router.ServeHTTP(w, r)