Change AppendChain(...Constructor) to Extend(Constructor). Add tests for Extend().
diff --git a/chain.go b/chain.go index f3afddf..d281ad3 100644 --- a/chain.go +++ b/chain.go
@@ -93,14 +93,14 @@ return newChain } -// AppendChain extends a chain, adding the specified chains -// as the last ones in the request flow. +// Extend extends a chain by adding the specified chain +// as the last one in the request flow. // -// AppendChain returns a new chain, leaving the original one untouched. +// Extend returns a new chain, leaving the original one untouched. // // stdChain := alice.New(m1, m2) // ext1Chain := alice.New(m3, m4) -// ext2Chain := stdChain.AppendChain(ext1Chain) +// ext2Chain := stdChain.Extend(ext1Chain) // // requests in stdChain go m1 -> m2 // // requests in ext1Chain go m3 -> m4 // // requests in ext2Chain go m1 -> m2 -> m3 -> m4 @@ -111,13 +111,9 @@ // csrf := nosurf.New(h) // csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail)) // return csrf -// }).AppendChain(aHtmlAfterNosurf) +// }).Extend(aHtmlAfterNosurf) // // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler // // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail -func (c Chain) AppendChain(chains ...Chain) Chain { - newChain := c - for _, ch := range chains { - newChain = newChain.Append(ch.constructors...) - } - return newChain +func (c Chain) Extend(chain Chain) Chain { + return c.Append(chain.constructors...) }
diff --git a/chain_test.go b/chain_test.go index f1cf749..49c0470 100644 --- a/chain_test.go +++ b/chain_test.go
@@ -113,3 +113,32 @@ assert.NotEqual(t, &chain.constructors[0], &newChain.constructors[0]) } + +func TestExtendAddsHandlersCorrectly(t *testing.T) { + chain1 := New(tagMiddleware("t1\n"), tagMiddleware("t2\n")) + chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n")) + newChain := chain1.Extend(chain2) + + assert.Equal(t, len(chain1.constructors), 2) + assert.Equal(t, len(chain2.constructors), 2) + assert.Equal(t, len(newChain.constructors), 4) + + chained := newChain.Then(testApp) + + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + chained.ServeHTTP(w, r) + + assert.Equal(t, w.Body.String(), "t1\nt2\nt3\nt4\napp\n") +} + +func TestExtendRespectsImmutability(t *testing.T) { + chain := New(tagMiddleware("")) + newChain := chain.Extend(New(tagMiddleware(""))) + + assert.NotEqual(t, &chain.constructors[0], &newChain.constructors[0]) +}