Allow chaining of any http.Handler, not just http.HandlerFunc.
diff --git a/README.md b/README.md
index 9875c70..785a108 100644
--- a/README.md
+++ b/README.md
@@ -297,7 +297,7 @@
**NOTE: It might be required to set [Router.HandleMethodNotAllowed](http://godoc.org/github.com/julienschmidt/httprouter#Router.HandleMethodNotAllowed) to `false` to avoid problems.**
-You can use another [http.HandlerFunc](http://golang.org/pkg/net/http/#HandlerFunc), for example another router, to handle requests which could not be matched by this router by using the [Router.NotFound](http://godoc.org/github.com/julienschmidt/httprouter#Router.NotFound) handler. This allows chaining.
+You can use another [http.Handler](http://golang.org/pkg/net/http/#Handler), for example another router, to handle requests which could not be matched by this router by using the [Router.NotFound](http://godoc.org/github.com/julienschmidt/httprouter#Router.NotFound) handler. This allows chaining.
### Static files
The `NotFound` handler can for example be used to serve static files from the root path `/` (like an index.html file along with other assets):
diff --git a/router.go b/router.go
index 155b871..83e86fd 100644
--- a/router.go
+++ b/router.go
@@ -138,14 +138,14 @@
// handler.
HandleMethodNotAllowed bool
- // Configurable http.HandlerFunc which is called when no matching route is
+ // Configurable http.Handler which is called when no matching route is
// found. If it is not set, http.NotFound is used.
- NotFound http.HandlerFunc
+ NotFound http.Handler
- // Configurable http.HandlerFunc which is called when a request
+ // Configurable http.Handler which is called when a request
// cannot be routed and HandleMethodNotAllowed is true.
// If it is not set, http.Error with http.StatusMethodNotAllowed is used.
- MethodNotAllowed http.HandlerFunc
+ MethodNotAllowed http.Handler
// Function to handle panics recovered from http handlers.
// It should be used to generate a error page and return the http error code
@@ -342,7 +342,7 @@
handle, _, _ := r.trees[method].getValue(req.URL.Path)
if handle != nil {
if r.MethodNotAllowed != nil {
- r.MethodNotAllowed(w, req)
+ r.MethodNotAllowed.ServeHTTP(w, req)
} else {
http.Error(w,
http.StatusText(http.StatusMethodNotAllowed),
@@ -356,7 +356,7 @@
// Handle 404
if r.NotFound != nil {
- r.NotFound(w, req)
+ r.NotFound.ServeHTTP(w, req)
} else {
http.NotFound(w, req)
}
diff --git a/router_test.go b/router_test.go
index 9dc6296..e3141bd 100644
--- a/router_test.go
+++ b/router_test.go
@@ -174,6 +174,48 @@
}
}
+func TestRouterChaining(t *testing.T) {
+ router1 := New()
+ router2 := New()
+ router1.NotFound = router2
+
+ fooHit := false
+ router1.POST("/foo", func(w http.ResponseWriter, req *http.Request, _ Params) {
+ fooHit = true
+ w.WriteHeader(http.StatusOK)
+ })
+
+ barHit := false
+ router2.POST("/bar", func(w http.ResponseWriter, req *http.Request, _ Params) {
+ barHit = true
+ w.WriteHeader(http.StatusOK)
+ })
+
+ r, _ := http.NewRequest("POST", "/foo", nil)
+ w := httptest.NewRecorder()
+ router1.ServeHTTP(w, r)
+ if !(w.Code == http.StatusOK && fooHit) {
+ t.Errorf("Regular routing failed with router chaining.")
+ t.FailNow()
+ }
+
+ r, _ = http.NewRequest("POST", "/bar", nil)
+ w = httptest.NewRecorder()
+ router1.ServeHTTP(w, r)
+ if !(w.Code == http.StatusOK && barHit) {
+ t.Errorf("Chained routing failed with router chaining.")
+ t.FailNow()
+ }
+
+ r, _ = http.NewRequest("POST", "/qax", nil)
+ w = httptest.NewRecorder()
+ router1.ServeHTTP(w, r)
+ if !(w.Code == http.StatusNotFound) {
+ t.Errorf("NotFound behavior failed with router chaining.")
+ t.FailNow()
+ }
+}
+
func TestRouterNotAllowed(t *testing.T) {
handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}
@@ -190,10 +232,10 @@
w = httptest.NewRecorder()
responseText := "custom method"
- router.MethodNotAllowed = func(w http.ResponseWriter, req *http.Request) {
+ router.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusTeapot)
w.Write([]byte(responseText))
- }
+ })
router.ServeHTTP(w, r)
if got := w.Body.String(); !(got == responseText) {
t.Errorf("unexpected response got %q want %q", got, responseText)
@@ -237,10 +279,10 @@
// Test custom not found handler
var notFound bool
- router.NotFound = func(rw http.ResponseWriter, r *http.Request) {
+ router.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(404)
notFound = true
- }
+ })
r, _ := http.NewRequest("GET", "/nope", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)