Add a new feature (optional) to configure local IP address to bind to.
diff --git a/scaffold.go b/scaffold.go index 9d90a0d..15b5af7 100644 --- a/scaffold.go +++ b/scaffold.go
@@ -83,6 +83,7 @@ securePort int managementPort int open bool + ipAddr net.IP tracker *requestTracker insecureListener net.Listener secureListener net.Listener @@ -106,17 +107,26 @@ insecurePort: 0, securePort: -1, managementPort: -1, + ipAddr: []byte{0, 0, 0, 0}, open: false, } } /* +SetlocalBindIPAddressV4 seta the IP address (IP V4) for the service to +bind on to listen on. If none set, all IP addesses would be accepted. +*/ +func (s *HTTPScaffold) SetlocalBindIPAddressV4(ip net.IP) { + s.ipAddr = ip +} + +/* SetInsecurePort sets the port number to listen on in regular "HTTP" mode. It may be set to zero, which indicates to listen on an ephemeral port. It must be called before "listen". */ -func (s *HTTPScaffold) SetInsecurePort(ip int) { - s.insecurePort = ip +func (s *HTTPScaffold) SetInsecurePort(port int) { + s.insecurePort = port } /* @@ -126,8 +136,8 @@ Listen if this port is set and if the key and secret files are not also set. */ -func (s *HTTPScaffold) SetSecurePort(ip int) { - s.securePort = ip +func (s *HTTPScaffold) SetSecurePort(port int) { + s.securePort = port } /* @@ -259,6 +269,7 @@ if s.insecurePort >= 0 { il, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: s.ipAddr, Port: s.insecurePort, }) if err != nil { @@ -284,6 +295,7 @@ Certificates: []tls.Certificate{cert}, } sl, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: s.ipAddr, Port: s.securePort, }) if err != nil { @@ -299,6 +311,7 @@ if s.managementPort >= 0 { ml, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: s.ipAddr, Port: s.managementPort, }) if err != nil {
diff --git a/scaffold_test.go b/scaffold_test.go index e46bcad..692602a 100644 --- a/scaffold_test.go +++ b/scaffold_test.go
@@ -6,6 +6,7 @@ "errors" "fmt" "io/ioutil" + "net" "net/http" "strings" "sync/atomic" @@ -26,6 +27,7 @@ var _ = Describe("Scaffold Tests", func() { It("Validate framework", func() { s := CreateHTTPScaffold() + s.SetlocalBindIPAddressV4(GetLocalIP()) stopChan := make(chan error) err := s.Open() Expect(err).Should(Succeed()) @@ -51,6 +53,7 @@ It("Separate management port", func() { s := CreateHTTPScaffold() + s.SetlocalBindIPAddressV4(GetLocalIP()) s.SetManagementPort(0) stopChan := make(chan error) err := s.Open() @@ -209,6 +212,7 @@ } s := CreateHTTPScaffold() + s.SetlocalBindIPAddressV4(GetLocalIP()) s.SetManagementPort(0) s.SetHealthPath("/health") s.SetReadyPath("/ready") @@ -332,6 +336,32 @@ s.Shutdown(shutdownErr) Eventually(stopChan).Should(Receive(Equal(shutdownErr))) }) + + It("DisAllow non-localhost", func() { + s := CreateHTTPScaffold() + s.SetInsecurePort(8181) + s.SetlocalBindIPAddressV4([]byte{127, 0, 0, 1}) + stopChan := make(chan error) + err := s.Open() + Expect(err).Should(Succeed()) + + go func() { + fmt.Fprintf(GinkgoWriter, "Gonna listen on %s\n", s.InsecureAddress()) + stopErr := s.Listen(&testHandler{}) + fmt.Fprintf(GinkgoWriter, "Done listening\n") + stopChan <- stopErr + }() + + Eventually(func() bool { + return testGet(s, "") + }, 5*time.Second).Should(BeTrue()) + _, err = http.Get(fmt.Sprintf("http://%s:%s", GetLocalIPStr(), "8181")) + Expect(err).ShouldNot(Succeed()) + shutdownErr := errors.New("Validate") + s.Shutdown(shutdownErr) + Eventually(stopChan).Should(Receive(Equal(shutdownErr))) + + }) }) func getText(url string) (int, string) { @@ -409,3 +439,33 @@ time.Sleep(delayTime) } } + +func GetLocalIP() []byte { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil + } + for _, address := range addrs { + if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + return ipnet.IP + } + } + } + return nil +} + +func GetLocalIPStr() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "" + } + for _, address := range addrs { + if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + return ipnet.IP.String() + } + } + } + return "" +}