Merge pull request #37 from eparis/ip-tests

Two New Flag Types  -  IPNet and Count
diff --git a/count.go b/count.go
new file mode 100644
index 0000000..d061368
--- /dev/null
+++ b/count.go
@@ -0,0 +1,84 @@
+package pflag
+
+import (
+	"fmt"
+	"strconv"
+)
+
+// -- count Value
+type countValue int
+
+func newCountValue(val int, p *int) *countValue {
+	*p = val
+	return (*countValue)(p)
+}
+
+func (i *countValue) Set(s string) error {
+	v, err := strconv.ParseInt(s, 0, 64)
+	// -1 means that no specific value was passed, so increment
+	if v == -1 {
+		*i = countValue(*i + 1)
+	} else {
+		*i = countValue(v)
+	}
+	return err
+}
+
+func (i *countValue) Type() string {
+	return "count"
+}
+
+func (i *countValue) String() string { return fmt.Sprintf("%v", *i) }
+
+func countConv(sval string) (interface{}, error) {
+	i, err := strconv.Atoi(sval)
+	if err != nil {
+		return nil, err
+	}
+	return i, nil
+}
+
+func (f *FlagSet) GetCount(name string) (int, error) {
+	val, err := f.getFlagType(name, "count", countConv)
+	if err != nil {
+		return 0, err
+	}
+	return val.(int), nil
+}
+
+func (f *FlagSet) CountVar(p *int, name string, usage string) {
+	f.CountVarP(p, name, "", usage)
+}
+
+func (f *FlagSet) CountVarP(p *int, name, shorthand string, usage string) {
+	flag := f.VarPF(newCountValue(0, p), name, shorthand, usage)
+	flag.NoOptDefVal = "-1"
+}
+
+func CountVar(p *int, name string, usage string) {
+	CommandLine.CountVar(p, name, usage)
+}
+
+func CountVarP(p *int, name, shorthand string, usage string) {
+	CommandLine.CountVarP(p, name, shorthand, usage)
+}
+
+func (f *FlagSet) Count(name string, usage string) *int {
+	p := new(int)
+	f.CountVarP(p, name, "", usage)
+	return p
+}
+
+func (f *FlagSet) CountP(name, shorthand string, usage string) *int {
+	p := new(int)
+	f.CountVarP(p, name, shorthand, usage)
+	return p
+}
+
+func Count(name string, usage string) *int {
+	return CommandLine.CountP(name, "", usage)
+}
+
+func CountP(name, shorthand string, usage string) *int {
+	return CommandLine.CountP(name, shorthand, usage)
+}
diff --git a/count_test.go b/count_test.go
new file mode 100644
index 0000000..716765c
--- /dev/null
+++ b/count_test.go
@@ -0,0 +1,55 @@
+package pflag
+
+import (
+	"fmt"
+	"os"
+	"testing"
+)
+
+var _ = fmt.Printf
+
+func setUpCount(c *int) *FlagSet {
+	f := NewFlagSet("test", ContinueOnError)
+	f.CountVarP(c, "verbose", "v", "a counter")
+	return f
+}
+
+func TestCount(t *testing.T) {
+	testCases := []struct {
+		input    []string
+		success  bool
+		expected int
+	}{
+		{[]string{"-vvv"}, true, 3},
+		{[]string{"-v", "-v", "-v"}, true, 3},
+		{[]string{"-v", "--verbose", "-v"}, true, 3},
+		{[]string{"-v=3", "-v"}, true, 4},
+		{[]string{"-v=a"}, false, 0},
+	}
+
+	devnull, _ := os.Open(os.DevNull)
+	os.Stderr = devnull
+	for i := range testCases {
+		var count int
+		f := setUpCount(&count)
+
+		tc := &testCases[i]
+
+		err := f.Parse(tc.input)
+		if err != nil && tc.success == true {
+			t.Errorf("expected success, got %q", err)
+			continue
+		} else if err == nil && tc.success == false {
+			t.Errorf("expected failure, got success")
+			continue
+		} else if tc.success {
+			c, err := f.GetCount("verbose")
+			if err != nil {
+				t.Errorf("Got error trying to fetch the counter flag")
+			}
+			if c != tc.expected {
+				t.Errorf("expected %q, got %q", tc.expected, c)
+			}
+		}
+	}
+}
diff --git a/ip.go b/ip.go
index 746eefd..baa442b 100644
--- a/ip.go
+++ b/ip.go
@@ -3,8 +3,11 @@
 import (
 	"fmt"
 	"net"
+	"strings"
 )
 
+var _ = strings.TrimSpace
+
 // -- net.IP value
 type ipValue net.IP
 
@@ -15,7 +18,7 @@
 
 func (i *ipValue) String() string { return net.IP(*i).String() }
 func (i *ipValue) Set(s string) error {
-	ip := net.ParseIP(s)
+	ip := net.ParseIP(strings.TrimSpace(s))
 	if ip == nil {
 		return fmt.Errorf("failed to parse IP: %q", s)
 	}
diff --git a/ip_test.go b/ip_test.go
new file mode 100644
index 0000000..1fec50e
--- /dev/null
+++ b/ip_test.go
@@ -0,0 +1,63 @@
+package pflag
+
+import (
+	"fmt"
+	"net"
+	"os"
+	"testing"
+)
+
+func setUpIP(ip *net.IP) *FlagSet {
+	f := NewFlagSet("test", ContinueOnError)
+	f.IPVar(ip, "address", net.ParseIP("0.0.0.0"), "IP Address")
+	return f
+}
+
+func TestIP(t *testing.T) {
+	testCases := []struct {
+		input    string
+		success  bool
+		expected string
+	}{
+		{"0.0.0.0", true, "0.0.0.0"},
+		{" 0.0.0.0 ", true, "0.0.0.0"},
+		{"1.2.3.4", true, "1.2.3.4"},
+		{"127.0.0.1", true, "127.0.0.1"},
+		{"255.255.255.255", true, "255.255.255.255"},
+		{"", false, ""},
+		{"0", false, ""},
+		{"localhost", false, ""},
+		{"0.0.0", false, ""},
+		{"0.0.0.", false, ""},
+		{"0.0.0.0.", false, ""},
+		{"0.0.0.256", false, ""},
+		{"0 . 0 . 0 . 0", false, ""},
+	}
+
+	devnull, _ := os.Open(os.DevNull)
+	os.Stderr = devnull
+	for i := range testCases {
+		var addr net.IP
+		f := setUpIP(&addr)
+
+		tc := &testCases[i]
+
+		arg := fmt.Sprintf("--address=%s", tc.input)
+		err := f.Parse([]string{arg})
+		if err != nil && tc.success == true {
+			t.Errorf("expected success, got %q", err)
+			continue
+		} else if err == nil && tc.success == false {
+			t.Errorf("expected failure")
+			continue
+		} else if tc.success {
+			ip, err := f.GetIP("address")
+			if err != nil {
+				t.Errorf("Got error trying to fetch the IP flag: %v", err)
+			}
+			if ip.String() != tc.expected {
+				t.Errorf("expected %q, got %q", tc.expected, ip.String())
+			}
+		}
+	}
+}
diff --git a/ipnet.go b/ipnet.go
new file mode 100644
index 0000000..23b7fd3
--- /dev/null
+++ b/ipnet.go
@@ -0,0 +1,100 @@
+package pflag
+
+import (
+	"fmt"
+	"net"
+	"strings"
+)
+
+// IPNet adapts net.IPNet for use as a flag.
+type IPNetValue net.IPNet
+
+func (ipnet IPNetValue) String() string {
+	n := net.IPNet(ipnet)
+	return n.String()
+}
+
+func (ipnet *IPNetValue) Set(value string) error {
+	_, n, err := net.ParseCIDR(strings.TrimSpace(value))
+	if err != nil {
+		return err
+	}
+	*ipnet = IPNetValue(*n)
+	return nil
+}
+
+func (*IPNetValue) Type() string {
+	return "ipNet"
+}
+
+var _ = strings.TrimSpace
+
+func newIPNetValue(val net.IPNet, p *net.IPNet) *IPNetValue {
+	*p = val
+	return (*IPNetValue)(p)
+}
+
+func ipNetConv(sval string) (interface{}, error) {
+	_, n, err := net.ParseCIDR(strings.TrimSpace(sval))
+	if err == nil {
+		return *n, nil
+	}
+	return nil, fmt.Errorf("invalid string being converted to IPNet: %s", sval)
+}
+
+// GetIPNet return the net.IPNet value of a flag with the given name
+func (f *FlagSet) GetIPNet(name string) (net.IPNet, error) {
+	val, err := f.getFlagType(name, "ipNet", ipNetConv)
+	if err != nil {
+		return net.IPNet{}, err
+	}
+	return val.(net.IPNet), nil
+}
+
+// IPNetVar defines an net.IPNet flag with specified name, default value, and usage string.
+// The argument p points to an net.IPNet variable in which to store the value of the flag.
+func (f *FlagSet) IPNetVar(p *net.IPNet, name string, value net.IPNet, usage string) {
+	f.VarP(newIPNetValue(value, p), name, "", usage)
+}
+
+// Like IPNetVar, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) IPNetVarP(p *net.IPNet, name, shorthand string, value net.IPNet, usage string) {
+	f.VarP(newIPNetValue(value, p), name, shorthand, usage)
+}
+
+// IPNetVar defines an net.IPNet flag with specified name, default value, and usage string.
+// The argument p points to an net.IPNet variable in which to store the value of the flag.
+func IPNetVar(p *net.IPNet, name string, value net.IPNet, usage string) {
+	CommandLine.VarP(newIPNetValue(value, p), name, "", usage)
+}
+
+// Like IPNetVar, but accepts a shorthand letter that can be used after a single dash.
+func IPNetVarP(p *net.IPNet, name, shorthand string, value net.IPNet, usage string) {
+	CommandLine.VarP(newIPNetValue(value, p), name, shorthand, usage)
+}
+
+// IPNet defines an net.IPNet flag with specified name, default value, and usage string.
+// The return value is the address of an net.IPNet variable that stores the value of the flag.
+func (f *FlagSet) IPNet(name string, value net.IPNet, usage string) *net.IPNet {
+	p := new(net.IPNet)
+	f.IPNetVarP(p, name, "", value, usage)
+	return p
+}
+
+// Like IPNet, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) IPNetP(name, shorthand string, value net.IPNet, usage string) *net.IPNet {
+	p := new(net.IPNet)
+	f.IPNetVarP(p, name, shorthand, value, usage)
+	return p
+}
+
+// IPNet defines an net.IPNet flag with specified name, default value, and usage string.
+// The return value is the address of an net.IPNet variable that stores the value of the flag.
+func IPNet(name string, value net.IPNet, usage string) *net.IPNet {
+	return CommandLine.IPNetP(name, "", value, usage)
+}
+
+// Like IPNet, but accepts a shorthand letter that can be used after a single dash.
+func IPNetP(name, shorthand string, value net.IPNet, usage string) *net.IPNet {
+	return CommandLine.IPNetP(name, shorthand, value, usage)
+}
diff --git a/ipnet_test.go b/ipnet_test.go
new file mode 100644
index 0000000..335b6fa
--- /dev/null
+++ b/ipnet_test.go
@@ -0,0 +1,70 @@
+package pflag
+
+import (
+	"fmt"
+	"net"
+	"os"
+	"testing"
+)
+
+func setUpIPNet(ip *net.IPNet) *FlagSet {
+	f := NewFlagSet("test", ContinueOnError)
+	_, def, _ := net.ParseCIDR("0.0.0.0/0")
+	f.IPNetVar(ip, "address", *def, "IP Address")
+	return f
+}
+
+func TestIPNet(t *testing.T) {
+	testCases := []struct {
+		input    string
+		success  bool
+		expected string
+	}{
+		{"0.0.0.0/0", true, "0.0.0.0/0"},
+		{" 0.0.0.0/0 ", true, "0.0.0.0/0"},
+		{"1.2.3.4/8", true, "1.0.0.0/8"},
+		{"127.0.0.1/16", true, "127.0.0.0/16"},
+		{"255.255.255.255/19", true, "255.255.224.0/19"},
+		{"255.255.255.255/32", true, "255.255.255.255/32"},
+		{"", false, ""},
+		{"/0", false, ""},
+		{"0", false, ""},
+		{"0/0", false, ""},
+		{"localhost/0", false, ""},
+		{"0.0.0/4", false, ""},
+		{"0.0.0./8", false, ""},
+		{"0.0.0.0./12", false, ""},
+		{"0.0.0.256/16", false, ""},
+		{"0.0.0.0 /20", false, ""},
+		{"0.0.0.0/ 24", false, ""},
+		{"0 . 0 . 0 . 0 / 28", false, ""},
+		{"0.0.0.0/33", false, ""},
+	}
+
+	devnull, _ := os.Open(os.DevNull)
+	os.Stderr = devnull
+	for i := range testCases {
+		var addr net.IPNet
+		f := setUpIPNet(&addr)
+
+		tc := &testCases[i]
+
+		arg := fmt.Sprintf("--address=%s", tc.input)
+		err := f.Parse([]string{arg})
+		if err != nil && tc.success == true {
+			t.Errorf("expected success, got %q", err)
+			continue
+		} else if err == nil && tc.success == false {
+			t.Errorf("expected failure")
+			continue
+		} else if tc.success {
+			ip, err := f.GetIPNet("address")
+			if err != nil {
+				t.Errorf("Got error trying to fetch the IP flag: %v", err)
+			}
+			if ip.String() != tc.expected {
+				t.Errorf("expected %q, got %q", tc.expected, ip.String())
+			}
+		}
+	}
+}