Merge pull request #37 from eparis/ip-tests
Two New Flag Types - IPNet and Count
diff --git a/count.go b/count.go
new file mode 100644
index 0000000..d061368
--- /dev/null
+++ b/count.go
@@ -0,0 +1,84 @@
+package pflag
+
+import (
+ "fmt"
+ "strconv"
+)
+
+// -- count Value
+type countValue int
+
+func newCountValue(val int, p *int) *countValue {
+ *p = val
+ return (*countValue)(p)
+}
+
+func (i *countValue) Set(s string) error {
+ v, err := strconv.ParseInt(s, 0, 64)
+ // -1 means that no specific value was passed, so increment
+ if v == -1 {
+ *i = countValue(*i + 1)
+ } else {
+ *i = countValue(v)
+ }
+ return err
+}
+
+func (i *countValue) Type() string {
+ return "count"
+}
+
+func (i *countValue) String() string { return fmt.Sprintf("%v", *i) }
+
+func countConv(sval string) (interface{}, error) {
+ i, err := strconv.Atoi(sval)
+ if err != nil {
+ return nil, err
+ }
+ return i, nil
+}
+
+func (f *FlagSet) GetCount(name string) (int, error) {
+ val, err := f.getFlagType(name, "count", countConv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(int), nil
+}
+
+func (f *FlagSet) CountVar(p *int, name string, usage string) {
+ f.CountVarP(p, name, "", usage)
+}
+
+func (f *FlagSet) CountVarP(p *int, name, shorthand string, usage string) {
+ flag := f.VarPF(newCountValue(0, p), name, shorthand, usage)
+ flag.NoOptDefVal = "-1"
+}
+
+func CountVar(p *int, name string, usage string) {
+ CommandLine.CountVar(p, name, usage)
+}
+
+func CountVarP(p *int, name, shorthand string, usage string) {
+ CommandLine.CountVarP(p, name, shorthand, usage)
+}
+
+func (f *FlagSet) Count(name string, usage string) *int {
+ p := new(int)
+ f.CountVarP(p, name, "", usage)
+ return p
+}
+
+func (f *FlagSet) CountP(name, shorthand string, usage string) *int {
+ p := new(int)
+ f.CountVarP(p, name, shorthand, usage)
+ return p
+}
+
+func Count(name string, usage string) *int {
+ return CommandLine.CountP(name, "", usage)
+}
+
+func CountP(name, shorthand string, usage string) *int {
+ return CommandLine.CountP(name, shorthand, usage)
+}
diff --git a/count_test.go b/count_test.go
new file mode 100644
index 0000000..716765c
--- /dev/null
+++ b/count_test.go
@@ -0,0 +1,55 @@
+package pflag
+
+import (
+ "fmt"
+ "os"
+ "testing"
+)
+
+var _ = fmt.Printf
+
+func setUpCount(c *int) *FlagSet {
+ f := NewFlagSet("test", ContinueOnError)
+ f.CountVarP(c, "verbose", "v", "a counter")
+ return f
+}
+
+func TestCount(t *testing.T) {
+ testCases := []struct {
+ input []string
+ success bool
+ expected int
+ }{
+ {[]string{"-vvv"}, true, 3},
+ {[]string{"-v", "-v", "-v"}, true, 3},
+ {[]string{"-v", "--verbose", "-v"}, true, 3},
+ {[]string{"-v=3", "-v"}, true, 4},
+ {[]string{"-v=a"}, false, 0},
+ }
+
+ devnull, _ := os.Open(os.DevNull)
+ os.Stderr = devnull
+ for i := range testCases {
+ var count int
+ f := setUpCount(&count)
+
+ tc := &testCases[i]
+
+ err := f.Parse(tc.input)
+ if err != nil && tc.success == true {
+ t.Errorf("expected success, got %q", err)
+ continue
+ } else if err == nil && tc.success == false {
+ t.Errorf("expected failure, got success")
+ continue
+ } else if tc.success {
+ c, err := f.GetCount("verbose")
+ if err != nil {
+ t.Errorf("Got error trying to fetch the counter flag")
+ }
+ if c != tc.expected {
+ t.Errorf("expected %q, got %q", tc.expected, c)
+ }
+ }
+ }
+}
diff --git a/ip.go b/ip.go
index 746eefd..baa442b 100644
--- a/ip.go
+++ b/ip.go
@@ -3,8 +3,11 @@
import (
"fmt"
"net"
+ "strings"
)
+var _ = strings.TrimSpace
+
// -- net.IP value
type ipValue net.IP
@@ -15,7 +18,7 @@
func (i *ipValue) String() string { return net.IP(*i).String() }
func (i *ipValue) Set(s string) error {
- ip := net.ParseIP(s)
+ ip := net.ParseIP(strings.TrimSpace(s))
if ip == nil {
return fmt.Errorf("failed to parse IP: %q", s)
}
diff --git a/ip_test.go b/ip_test.go
new file mode 100644
index 0000000..1fec50e
--- /dev/null
+++ b/ip_test.go
@@ -0,0 +1,63 @@
+package pflag
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "testing"
+)
+
+func setUpIP(ip *net.IP) *FlagSet {
+ f := NewFlagSet("test", ContinueOnError)
+ f.IPVar(ip, "address", net.ParseIP("0.0.0.0"), "IP Address")
+ return f
+}
+
+func TestIP(t *testing.T) {
+ testCases := []struct {
+ input string
+ success bool
+ expected string
+ }{
+ {"0.0.0.0", true, "0.0.0.0"},
+ {" 0.0.0.0 ", true, "0.0.0.0"},
+ {"1.2.3.4", true, "1.2.3.4"},
+ {"127.0.0.1", true, "127.0.0.1"},
+ {"255.255.255.255", true, "255.255.255.255"},
+ {"", false, ""},
+ {"0", false, ""},
+ {"localhost", false, ""},
+ {"0.0.0", false, ""},
+ {"0.0.0.", false, ""},
+ {"0.0.0.0.", false, ""},
+ {"0.0.0.256", false, ""},
+ {"0 . 0 . 0 . 0", false, ""},
+ }
+
+ devnull, _ := os.Open(os.DevNull)
+ os.Stderr = devnull
+ for i := range testCases {
+ var addr net.IP
+ f := setUpIP(&addr)
+
+ tc := &testCases[i]
+
+ arg := fmt.Sprintf("--address=%s", tc.input)
+ err := f.Parse([]string{arg})
+ if err != nil && tc.success == true {
+ t.Errorf("expected success, got %q", err)
+ continue
+ } else if err == nil && tc.success == false {
+ t.Errorf("expected failure")
+ continue
+ } else if tc.success {
+ ip, err := f.GetIP("address")
+ if err != nil {
+ t.Errorf("Got error trying to fetch the IP flag: %v", err)
+ }
+ if ip.String() != tc.expected {
+ t.Errorf("expected %q, got %q", tc.expected, ip.String())
+ }
+ }
+ }
+}
diff --git a/ipnet.go b/ipnet.go
new file mode 100644
index 0000000..23b7fd3
--- /dev/null
+++ b/ipnet.go
@@ -0,0 +1,100 @@
+package pflag
+
+import (
+ "fmt"
+ "net"
+ "strings"
+)
+
+// IPNet adapts net.IPNet for use as a flag.
+type IPNetValue net.IPNet
+
+func (ipnet IPNetValue) String() string {
+ n := net.IPNet(ipnet)
+ return n.String()
+}
+
+func (ipnet *IPNetValue) Set(value string) error {
+ _, n, err := net.ParseCIDR(strings.TrimSpace(value))
+ if err != nil {
+ return err
+ }
+ *ipnet = IPNetValue(*n)
+ return nil
+}
+
+func (*IPNetValue) Type() string {
+ return "ipNet"
+}
+
+var _ = strings.TrimSpace
+
+func newIPNetValue(val net.IPNet, p *net.IPNet) *IPNetValue {
+ *p = val
+ return (*IPNetValue)(p)
+}
+
+func ipNetConv(sval string) (interface{}, error) {
+ _, n, err := net.ParseCIDR(strings.TrimSpace(sval))
+ if err == nil {
+ return *n, nil
+ }
+ return nil, fmt.Errorf("invalid string being converted to IPNet: %s", sval)
+}
+
+// GetIPNet return the net.IPNet value of a flag with the given name
+func (f *FlagSet) GetIPNet(name string) (net.IPNet, error) {
+ val, err := f.getFlagType(name, "ipNet", ipNetConv)
+ if err != nil {
+ return net.IPNet{}, err
+ }
+ return val.(net.IPNet), nil
+}
+
+// IPNetVar defines an net.IPNet flag with specified name, default value, and usage string.
+// The argument p points to an net.IPNet variable in which to store the value of the flag.
+func (f *FlagSet) IPNetVar(p *net.IPNet, name string, value net.IPNet, usage string) {
+ f.VarP(newIPNetValue(value, p), name, "", usage)
+}
+
+// Like IPNetVar, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) IPNetVarP(p *net.IPNet, name, shorthand string, value net.IPNet, usage string) {
+ f.VarP(newIPNetValue(value, p), name, shorthand, usage)
+}
+
+// IPNetVar defines an net.IPNet flag with specified name, default value, and usage string.
+// The argument p points to an net.IPNet variable in which to store the value of the flag.
+func IPNetVar(p *net.IPNet, name string, value net.IPNet, usage string) {
+ CommandLine.VarP(newIPNetValue(value, p), name, "", usage)
+}
+
+// Like IPNetVar, but accepts a shorthand letter that can be used after a single dash.
+func IPNetVarP(p *net.IPNet, name, shorthand string, value net.IPNet, usage string) {
+ CommandLine.VarP(newIPNetValue(value, p), name, shorthand, usage)
+}
+
+// IPNet defines an net.IPNet flag with specified name, default value, and usage string.
+// The return value is the address of an net.IPNet variable that stores the value of the flag.
+func (f *FlagSet) IPNet(name string, value net.IPNet, usage string) *net.IPNet {
+ p := new(net.IPNet)
+ f.IPNetVarP(p, name, "", value, usage)
+ return p
+}
+
+// Like IPNet, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) IPNetP(name, shorthand string, value net.IPNet, usage string) *net.IPNet {
+ p := new(net.IPNet)
+ f.IPNetVarP(p, name, shorthand, value, usage)
+ return p
+}
+
+// IPNet defines an net.IPNet flag with specified name, default value, and usage string.
+// The return value is the address of an net.IPNet variable that stores the value of the flag.
+func IPNet(name string, value net.IPNet, usage string) *net.IPNet {
+ return CommandLine.IPNetP(name, "", value, usage)
+}
+
+// Like IPNet, but accepts a shorthand letter that can be used after a single dash.
+func IPNetP(name, shorthand string, value net.IPNet, usage string) *net.IPNet {
+ return CommandLine.IPNetP(name, shorthand, value, usage)
+}
diff --git a/ipnet_test.go b/ipnet_test.go
new file mode 100644
index 0000000..335b6fa
--- /dev/null
+++ b/ipnet_test.go
@@ -0,0 +1,70 @@
+package pflag
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "testing"
+)
+
+func setUpIPNet(ip *net.IPNet) *FlagSet {
+ f := NewFlagSet("test", ContinueOnError)
+ _, def, _ := net.ParseCIDR("0.0.0.0/0")
+ f.IPNetVar(ip, "address", *def, "IP Address")
+ return f
+}
+
+func TestIPNet(t *testing.T) {
+ testCases := []struct {
+ input string
+ success bool
+ expected string
+ }{
+ {"0.0.0.0/0", true, "0.0.0.0/0"},
+ {" 0.0.0.0/0 ", true, "0.0.0.0/0"},
+ {"1.2.3.4/8", true, "1.0.0.0/8"},
+ {"127.0.0.1/16", true, "127.0.0.0/16"},
+ {"255.255.255.255/19", true, "255.255.224.0/19"},
+ {"255.255.255.255/32", true, "255.255.255.255/32"},
+ {"", false, ""},
+ {"/0", false, ""},
+ {"0", false, ""},
+ {"0/0", false, ""},
+ {"localhost/0", false, ""},
+ {"0.0.0/4", false, ""},
+ {"0.0.0./8", false, ""},
+ {"0.0.0.0./12", false, ""},
+ {"0.0.0.256/16", false, ""},
+ {"0.0.0.0 /20", false, ""},
+ {"0.0.0.0/ 24", false, ""},
+ {"0 . 0 . 0 . 0 / 28", false, ""},
+ {"0.0.0.0/33", false, ""},
+ }
+
+ devnull, _ := os.Open(os.DevNull)
+ os.Stderr = devnull
+ for i := range testCases {
+ var addr net.IPNet
+ f := setUpIPNet(&addr)
+
+ tc := &testCases[i]
+
+ arg := fmt.Sprintf("--address=%s", tc.input)
+ err := f.Parse([]string{arg})
+ if err != nil && tc.success == true {
+ t.Errorf("expected success, got %q", err)
+ continue
+ } else if err == nil && tc.success == false {
+ t.Errorf("expected failure")
+ continue
+ } else if tc.success {
+ ip, err := f.GetIPNet("address")
+ if err != nil {
+ t.Errorf("Got error trying to fetch the IP flag: %v", err)
+ }
+ if ip.String() != tc.expected {
+ t.Errorf("expected %q, got %q", tc.expected, ip.String())
+ }
+ }
+ }
+}