Chaining functionality
diff --git a/chain.go b/chain.go index 29aeee1..5083c14 100644 --- a/chain.go +++ b/chain.go
@@ -18,3 +18,25 @@ return c } + +// Chains the middleware and returns the final http.Handler +// New(m1, m2, m3).Then(h) +// is equivalent to: +// m1(m2(m3(h))) +// When the request comes in, it will be passed to m1, then m2, then m3 +// and finally, the given handler +// (assuming every middleware calls the following one) +func (c Chain) Then(h http.Handler) http.Handler { + var final http.Handler + if h != nil { + final = h + } else { + final = http.DefaultServeMux + } + + for i := len(c.constructors) - 1; i >= 0; i-- { + final = c.constructors[i](final) + } + + return final +}
diff --git a/chain_test.go b/chain_test.go index bbea276..43fc2f3 100644 --- a/chain_test.go +++ b/chain_test.go
@@ -2,11 +2,24 @@ import ( "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" ) +// A constructor for middleware +// that writes its own "tag" into the RW and does nothing else. +// Useful in checking if a chain is behaving in the right order. +func tagMiddleware(tag string) Constructor { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(tag)) + h.ServeHTTP(w, r) + }) + } +} + // Tests creating a new chain func TestNew(t *testing.T) { c1 := func(h http.Handler) http.Handler { @@ -22,3 +35,24 @@ assert.Equal(t, chain.constructors[0], slice[0]) assert.Equal(t, chain.constructors[1], slice[1]) } + +func TestThen(t *testing.T) { + t1 := tagMiddleware("t1\n") + t2 := tagMiddleware("t2\n") + t3 := tagMiddleware("t3\n") + app := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("app\n")) + }) + + chained := New(t1, t2, t3).Then(app) + + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + chained.ServeHTTP(w, r) + + assert.Equal(t, w.Body.String(), "t1\nt2\nt3\napp\n") +}