Merge pull request #18 from agutl/master
Add Extend()
diff --git a/chain.go b/chain.go
index 6ae8c26..d281ad3 100644
--- a/chain.go
+++ b/chain.go
@@ -92,3 +92,28 @@
newChain := New(newCons...)
return newChain
}
+
+// Extend extends a chain by adding the specified chain
+// as the last one in the request flow.
+//
+// Extend returns a new chain, leaving the original one untouched.
+//
+// stdChain := alice.New(m1, m2)
+// ext1Chain := alice.New(m3, m4)
+// ext2Chain := stdChain.Extend(ext1Chain)
+// // requests in stdChain go m1 -> m2
+// // requests in ext1Chain go m3 -> m4
+// // requests in ext2Chain go m1 -> m2 -> m3 -> m4
+//
+// Another example:
+// aHtmlAfterNosurf := alice.New(m2)
+// aHtml := alice.New(m1, func(h http.Handler) http.Handler {
+// csrf := nosurf.New(h)
+// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail))
+// return csrf
+// }).Extend(aHtmlAfterNosurf)
+// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler
+// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail
+func (c Chain) Extend(chain Chain) Chain {
+ return c.Append(chain.constructors...)
+}
diff --git a/chain_test.go b/chain_test.go
index f1cf749..49c0470 100644
--- a/chain_test.go
+++ b/chain_test.go
@@ -113,3 +113,32 @@
assert.NotEqual(t, &chain.constructors[0], &newChain.constructors[0])
}
+
+func TestExtendAddsHandlersCorrectly(t *testing.T) {
+ chain1 := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
+ chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
+ newChain := chain1.Extend(chain2)
+
+ assert.Equal(t, len(chain1.constructors), 2)
+ assert.Equal(t, len(chain2.constructors), 2)
+ assert.Equal(t, len(newChain.constructors), 4)
+
+ chained := newChain.Then(testApp)
+
+ w := httptest.NewRecorder()
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ chained.ServeHTTP(w, r)
+
+ assert.Equal(t, w.Body.String(), "t1\nt2\nt3\nt4\napp\n")
+}
+
+func TestExtendRespectsImmutability(t *testing.T) {
+ chain := New(tagMiddleware(""))
+ newChain := chain.Extend(New(tagMiddleware("")))
+
+ assert.NotEqual(t, &chain.constructors[0], &newChain.constructors[0])
+}