Chaining functionality
diff --git a/chain.go b/chain.go
index 29aeee1..5083c14 100644
--- a/chain.go
+++ b/chain.go
@@ -18,3 +18,25 @@
return c
}
+
+// Chains the middleware and returns the final http.Handler
+// New(m1, m2, m3).Then(h)
+// is equivalent to:
+// m1(m2(m3(h)))
+// When the request comes in, it will be passed to m1, then m2, then m3
+// and finally, the given handler
+// (assuming every middleware calls the following one)
+func (c Chain) Then(h http.Handler) http.Handler {
+ var final http.Handler
+ if h != nil {
+ final = h
+ } else {
+ final = http.DefaultServeMux
+ }
+
+ for i := len(c.constructors) - 1; i >= 0; i-- {
+ final = c.constructors[i](final)
+ }
+
+ return final
+}
diff --git a/chain_test.go b/chain_test.go
index bbea276..43fc2f3 100644
--- a/chain_test.go
+++ b/chain_test.go
@@ -2,11 +2,24 @@
import (
"net/http"
+ "net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
+// A constructor for middleware
+// that writes its own "tag" into the RW and does nothing else.
+// Useful in checking if a chain is behaving in the right order.
+func tagMiddleware(tag string) Constructor {
+ return func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(tag))
+ h.ServeHTTP(w, r)
+ })
+ }
+}
+
// Tests creating a new chain
func TestNew(t *testing.T) {
c1 := func(h http.Handler) http.Handler {
@@ -22,3 +35,24 @@
assert.Equal(t, chain.constructors[0], slice[0])
assert.Equal(t, chain.constructors[1], slice[1])
}
+
+func TestThen(t *testing.T) {
+ t1 := tagMiddleware("t1\n")
+ t2 := tagMiddleware("t2\n")
+ t3 := tagMiddleware("t3\n")
+ app := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("app\n"))
+ })
+
+ chained := New(t1, t2, t3).Then(app)
+
+ w := httptest.NewRecorder()
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ chained.ServeHTTP(w, r)
+
+ assert.Equal(t, w.Body.String(), "t1\nt2\nt3\napp\n")
+}