Add an HTTP wrapper that does graceful shutdown.
diff --git a/accept_test.go b/accept_test.go
index 2a80b21..f9cfe75 100644
--- a/accept_test.go
+++ b/accept_test.go
@@ -25,6 +25,12 @@
[]string{"text/plain", "text/xml"})).Should(Equal("text/plain"))
})
+ It("No header, two choices 2", func() {
+ Expect(SelectMediaType(
+ makeRequest(""),
+ []string{"application/json", "text/plain"})).Should(Equal("application/json"))
+ })
+
It("One Header, two choices", func() {
Expect(SelectMediaType(
makeRequest("application/json"),
diff --git a/httpstop.go b/httpstop.go
new file mode 100644
index 0000000..49f014d
--- /dev/null
+++ b/httpstop.go
@@ -0,0 +1 @@
+package goscaffold
diff --git a/scaffold.go b/scaffold.go
new file mode 100644
index 0000000..0bfdc51
--- /dev/null
+++ b/scaffold.go
@@ -0,0 +1,138 @@
+package goscaffold
+
+import (
+ "encoding/json"
+ "net"
+ "net/http"
+ "time"
+)
+
+const (
+ // DefaultGraceTimeout is the default amount of time to wait for a request to complete
+ DefaultGraceTimeout = 30 * time.Second
+)
+
+/*
+An HTTPScaffold provides a set of features on top of a standard HTTP
+listener. It includes an HTTP handler that may be plugged in to any
+standard Go HTTP server. It is intended to be placed before any other
+handlers.
+*/
+type HTTPScaffold struct {
+ insecurePort int
+ tracker *requestTracker
+ insecureListener net.Listener
+}
+
+/*
+CreateHTTPScaffold makes a new scaffold. The default scaffold will
+do nothing.
+*/
+func CreateHTTPScaffold() *HTTPScaffold {
+ return &HTTPScaffold{}
+}
+
+/*
+SetInsecurePort sets the port number to listen on in regular "HTTP" mode.
+It may be set to zero, which indicates to listen on an ephemeral port.
+It must be called before "listen".
+*/
+func (s *HTTPScaffold) SetInsecurePort(ip int) {
+ s.insecurePort = ip
+}
+
+/*
+InsecureAddress returns the actual address (including the port if an
+ephemeral port was used) where we are listening. It must only be
+called after "Listen."
+*/
+func (s *HTTPScaffold) InsecureAddress() string {
+ return s.insecureListener.Addr().String()
+}
+
+/*
+Open opens up the port that was created when the scaffold was set up.
+*/
+func (s *HTTPScaffold) Open() error {
+ s.tracker = startRequestTracker(DefaultGraceTimeout)
+
+ il, err := net.ListenTCP("tcp", &net.TCPAddr{
+ Port: s.insecurePort,
+ })
+ if err != nil {
+ return err
+ }
+ s.insecureListener = il
+ return nil
+}
+
+/*
+Listen should be called instead of using the standard "http" and "net"
+libraries. It will open a port (or ports) and begin listening for
+HTTP traffic. It will block until the server is shut down by
+the various methods in this class.
+It will use the graceful shutdown logic to ensure that once marked down,
+the server will not exit until all the requests have completed,
+or until the shutdown timeout has expired.
+Like http.Serve, this function will block until we are done serving HTTP.
+If "SetInsecurePort" or "SetSecurePort" were not set, then it will listen on
+a dynamic port.
+Listen will block until the server is shutdown using "Shutdown" or one of
+the other shutdown mechanisms. It must not be called until after "Open"
+has been called.
+*/
+func (s *HTTPScaffold) Listen(baseHandler http.Handler) {
+ handler := &httpHandler{
+ s: s,
+ handler: baseHandler,
+ }
+ go http.Serve(s.insecureListener, handler)
+ <-s.tracker.C
+ s.insecureListener.Close()
+}
+
+/*
+Shutdown indicates that the server should stop handling incoming requests
+and exit from the "Serve" call. This may be called automatically by
+calling "CatchSignals," or automatically using this call.
+*/
+func (s *HTTPScaffold) Shutdown(reason error) {
+ s.tracker.shutdown(reason)
+}
+
+/*
+CatchSignals directs the scaffold to catch SIGINT and SIGTERM (the signals
+sent by "Control-C" and "kill" by default) to trigger the markdown
+logic. Using this logic, when these signals are caught, the server will
+catch
+*/
+func (s *HTTPScaffold) CatchSignals() {
+}
+
+type httpHandler struct {
+ s *HTTPScaffold
+ handler http.Handler
+}
+
+func (h *httpHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+ startErr := h.s.tracker.start()
+ if startErr == nil {
+ h.handler.ServeHTTP(resp, req)
+ h.s.tracker.end()
+ } else {
+ mt := SelectMediaType(req, []string{"text/plain", "application/json"})
+ resp.Header().Set("Content-Type", mt)
+ resp.WriteHeader(http.StatusServiceUnavailable)
+ switch mt {
+ case "application/json":
+ re := map[string]string{
+ "error": "Stopping",
+ "message": startErr.Error(),
+ }
+ buf, _ := json.Marshal(&re)
+ resp.Write(buf)
+ default:
+ resp.Write([]byte(startErr.Error()))
+ }
+ }
+}
diff --git a/scaffold_test.go b/scaffold_test.go
new file mode 100644
index 0000000..9422a94
--- /dev/null
+++ b/scaffold_test.go
@@ -0,0 +1,109 @@
+package goscaffold
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "time"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Scaffold Tests", func() {
+ It("Validate framework", func() {
+ s := CreateHTTPScaffold()
+ stopChan := make(chan bool)
+ err := s.Open()
+ Expect(err).Should(Succeed())
+
+ go func() {
+ fmt.Fprintf(GinkgoWriter, "Gonna listen on %s\n", s.InsecureAddress())
+ s.Listen(&testHandler{})
+ fmt.Fprintf(GinkgoWriter, "Done listening\n")
+ stopChan <- true
+ }()
+
+ Eventually(func() bool {
+ return testGet(s, "")
+ }, 5*time.Second).Should(BeTrue())
+ resp, err := http.Get(fmt.Sprintf("http://%s", s.InsecureAddress()))
+ Expect(err).Should(Succeed())
+ Expect(resp.StatusCode).Should(Equal(200))
+ s.Shutdown(errors.New("Validate"))
+ Eventually(stopChan).Should(Receive(BeTrue()))
+ })
+
+ It("Shutdown", func() {
+ s := CreateHTTPScaffold()
+ stopChan := make(chan bool)
+ err := s.Open()
+ Expect(err).Should(Succeed())
+
+ go func() {
+ s.Listen(&testHandler{})
+ stopChan <- true
+ }()
+
+ go func() {
+ resp2, err2 := http.Get(fmt.Sprintf("http://%s?delay=1s", s.InsecureAddress()))
+ Expect(err2).Should(Succeed())
+ Expect(resp2.StatusCode).Should(Equal(200))
+ }()
+
+ // Just make sure server is listening
+ Eventually(func() bool {
+ return testGet(s, "")
+ }, 5*time.Second).Should(BeTrue())
+
+ // Previous call prevents server from exiting
+ Consistently(stopChan, 250*time.Millisecond).ShouldNot(Receive())
+
+ // Tell the server to try and exit
+ s.Shutdown(errors.New("Stop"))
+ // Should take one second -- in the meantime, calls should fail with 503
+ resp, err := http.Get(fmt.Sprintf("http://%s?", s.InsecureAddress()))
+ Expect(err).Should(Succeed())
+ Expect(resp.StatusCode).Should(Equal(503))
+ // But in less than two seconds, server should be down
+ Eventually(stopChan, 2*time.Second).Should(Receive(BeTrue()))
+ // Calls should now fail
+ Eventually(func() bool {
+ return testGet(s, "")
+ }, time.Second).Should(BeFalse())
+ })
+})
+
+func testGet(s *HTTPScaffold, path string) bool {
+ resp, err := http.Get(fmt.Sprintf("http://%s", s.InsecureAddress()))
+ if err != nil {
+ fmt.Fprintf(GinkgoWriter, "Get %s = %s\n", path, err)
+ return false
+ }
+ if resp.StatusCode != 200 {
+ fmt.Fprintf(GinkgoWriter, "Get %s = %d\n", path, resp.StatusCode)
+ return false
+ }
+ return true
+}
+
+type testHandler struct {
+}
+
+func (h *testHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+ var err error
+ var delayTime time.Duration
+
+ delayStr := req.URL.Query().Get("delay")
+ if delayStr != "" {
+ delayTime, err = time.ParseDuration(delayStr)
+ if err != nil {
+ resp.WriteHeader(http.StatusBadRequest)
+ return
+ }
+ }
+
+ if delayTime > 0 {
+ time.Sleep(delayTime)
+ }
+}
diff --git a/tracker.go b/tracker.go
new file mode 100644
index 0000000..0b624ea
--- /dev/null
+++ b/tracker.go
@@ -0,0 +1,130 @@
+package goscaffold
+
+import (
+ "math"
+ "sync/atomic"
+ "time"
+)
+
+/*
+values for the command channel.
+*/
+const (
+ startRequest = iota
+ endRequest
+ shutdown
+)
+
+/*
+The requestTracker keeps track of HTTP requests. In normal operations it
+just counts. Once the server has been marked for shutdown, however, it
+counts down to zero and returns a shutdown indication when that
+happens.
+*/
+type requestTracker struct {
+ // A value will be delivered to this channel when the server can stop.
+ // If "shutdown" is never called then this will never happen.
+ C chan error
+ shutdownWait time.Duration
+ shuttingDown int32
+ shutdownReason *atomic.Value
+ commandChan chan int
+}
+
+/*
+startRequestTracker creates a new tracker. shutdownWait defines the
+maximum amount of time that we should wait for shutdown in case some
+do not complete in a timely way.
+*/
+func startRequestTracker(shutdownWait time.Duration) *requestTracker {
+ rt := &requestTracker{
+ C: make(chan error, 1),
+ commandChan: make(chan int, 100),
+ shutdownWait: shutdownWait,
+ shutdownReason: &atomic.Value{},
+ }
+ go rt.trackerLoop()
+ return rt
+}
+
+/*
+start indicates that a request started. It returns true if the request
+should proceed, and false if the request should fail because the server is
+shutting down.
+*/
+func (t *requestTracker) start() error {
+ sd := atomic.LoadInt32(&t.shuttingDown)
+ if sd != 0 {
+ reason := t.shutdownReason.Load().(*error)
+ if reason == nil {
+ return nil
+ }
+ return *reason
+ }
+ t.commandChan <- startRequest
+ return nil
+}
+
+/*
+end indicates that a request ended. In order for this thing to work, the
+caller needs to ensure that start and end are always paired.
+*/
+func (t *requestTracker) end() {
+ t.commandChan <- endRequest
+}
+
+/*
+shutdown indicates that the tracker should start counting down until
+the number of running requests reaches zero. The "reason" will be returned
+as the result of the "start" call.
+*/
+func (t *requestTracker) shutdown(reason error) {
+ t.shutdownReason.Store(&reason)
+ t.commandChan <- shutdown
+}
+
+func (t *requestTracker) sendStop(sent bool) bool {
+ if !sent {
+ reason := t.shutdownReason.Load().(*error)
+ if reason == nil {
+ return false
+ }
+ t.C <- *reason
+ }
+ return true
+}
+
+/*
+trackerLoop runs all day and manages stuff.
+*/
+func (t *requestTracker) trackerLoop() {
+ activeRequests := 0
+ stopping := false
+ sentStop := false
+ graceTimer := time.NewTimer(time.Duration(math.MaxInt64))
+
+ for !sentStop {
+ select {
+ case cmd := <-t.commandChan:
+ switch cmd {
+ case startRequest:
+ activeRequests++
+ case endRequest:
+ activeRequests--
+ if stopping && activeRequests == 0 {
+ sentStop = t.sendStop(sentStop)
+ }
+ case shutdown:
+ stopping = true
+ atomic.StoreInt32(&t.shuttingDown, 1)
+ if activeRequests <= 0 {
+ sentStop = t.sendStop(sentStop)
+ } else {
+ graceTimer.Reset(t.shutdownWait)
+ }
+ }
+ case <-graceTimer.C:
+ sentStop = t.sendStop(sentStop)
+ }
+ }
+}
diff --git a/tracker_test.go b/tracker_test.go
new file mode 100644
index 0000000..fc3ad0e
--- /dev/null
+++ b/tracker_test.go
@@ -0,0 +1,34 @@
+package goscaffold
+
+import (
+ "errors"
+ "time"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Tracker tests", func() {
+ It("Basic tracker", func() {
+ t := startRequestTracker(10 * time.Second)
+ t.start()
+ Consistently(t.C, 250*time.Millisecond).ShouldNot(Receive())
+ t.shutdown(errors.New("Basic"))
+ Consistently(t.C, 250*time.Millisecond).ShouldNot(Receive())
+ t.end()
+ Eventually(t.C).Should(Receive(MatchError("Basic")))
+ })
+
+ It("Tracker stop idle", func() {
+ t := startRequestTracker(10 * time.Second)
+ t.shutdown(errors.New("Stop"))
+ Eventually(t.C).Should(Receive(MatchError("Stop")))
+ })
+
+ It("Tracker grace timeout", func() {
+ t := startRequestTracker(time.Second)
+ t.start()
+ t.shutdown(errors.New("Stop"))
+ Eventually(t.C, 2*time.Second).Should(Receive(MatchError("Stop")))
+ })
+})