Add new net.IPNet flag type
diff --git a/ipnet.go b/ipnet.go
new file mode 100644
index 0000000..23b7fd3
--- /dev/null
+++ b/ipnet.go
@@ -0,0 +1,100 @@
+package pflag
+
+import (
+ "fmt"
+ "net"
+ "strings"
+)
+
+// IPNet adapts net.IPNet for use as a flag.
+type IPNetValue net.IPNet
+
+func (ipnet IPNetValue) String() string {
+ n := net.IPNet(ipnet)
+ return n.String()
+}
+
+func (ipnet *IPNetValue) Set(value string) error {
+ _, n, err := net.ParseCIDR(strings.TrimSpace(value))
+ if err != nil {
+ return err
+ }
+ *ipnet = IPNetValue(*n)
+ return nil
+}
+
+func (*IPNetValue) Type() string {
+ return "ipNet"
+}
+
+var _ = strings.TrimSpace
+
+func newIPNetValue(val net.IPNet, p *net.IPNet) *IPNetValue {
+ *p = val
+ return (*IPNetValue)(p)
+}
+
+func ipNetConv(sval string) (interface{}, error) {
+ _, n, err := net.ParseCIDR(strings.TrimSpace(sval))
+ if err == nil {
+ return *n, nil
+ }
+ return nil, fmt.Errorf("invalid string being converted to IPNet: %s", sval)
+}
+
+// GetIPNet return the net.IPNet value of a flag with the given name
+func (f *FlagSet) GetIPNet(name string) (net.IPNet, error) {
+ val, err := f.getFlagType(name, "ipNet", ipNetConv)
+ if err != nil {
+ return net.IPNet{}, err
+ }
+ return val.(net.IPNet), nil
+}
+
+// IPNetVar defines an net.IPNet flag with specified name, default value, and usage string.
+// The argument p points to an net.IPNet variable in which to store the value of the flag.
+func (f *FlagSet) IPNetVar(p *net.IPNet, name string, value net.IPNet, usage string) {
+ f.VarP(newIPNetValue(value, p), name, "", usage)
+}
+
+// Like IPNetVar, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) IPNetVarP(p *net.IPNet, name, shorthand string, value net.IPNet, usage string) {
+ f.VarP(newIPNetValue(value, p), name, shorthand, usage)
+}
+
+// IPNetVar defines an net.IPNet flag with specified name, default value, and usage string.
+// The argument p points to an net.IPNet variable in which to store the value of the flag.
+func IPNetVar(p *net.IPNet, name string, value net.IPNet, usage string) {
+ CommandLine.VarP(newIPNetValue(value, p), name, "", usage)
+}
+
+// Like IPNetVar, but accepts a shorthand letter that can be used after a single dash.
+func IPNetVarP(p *net.IPNet, name, shorthand string, value net.IPNet, usage string) {
+ CommandLine.VarP(newIPNetValue(value, p), name, shorthand, usage)
+}
+
+// IPNet defines an net.IPNet flag with specified name, default value, and usage string.
+// The return value is the address of an net.IPNet variable that stores the value of the flag.
+func (f *FlagSet) IPNet(name string, value net.IPNet, usage string) *net.IPNet {
+ p := new(net.IPNet)
+ f.IPNetVarP(p, name, "", value, usage)
+ return p
+}
+
+// Like IPNet, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) IPNetP(name, shorthand string, value net.IPNet, usage string) *net.IPNet {
+ p := new(net.IPNet)
+ f.IPNetVarP(p, name, shorthand, value, usage)
+ return p
+}
+
+// IPNet defines an net.IPNet flag with specified name, default value, and usage string.
+// The return value is the address of an net.IPNet variable that stores the value of the flag.
+func IPNet(name string, value net.IPNet, usage string) *net.IPNet {
+ return CommandLine.IPNetP(name, "", value, usage)
+}
+
+// Like IPNet, but accepts a shorthand letter that can be used after a single dash.
+func IPNetP(name, shorthand string, value net.IPNet, usage string) *net.IPNet {
+ return CommandLine.IPNetP(name, shorthand, value, usage)
+}
diff --git a/ipnet_test.go b/ipnet_test.go
new file mode 100644
index 0000000..335b6fa
--- /dev/null
+++ b/ipnet_test.go
@@ -0,0 +1,70 @@
+package pflag
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "testing"
+)
+
+func setUpIPNet(ip *net.IPNet) *FlagSet {
+ f := NewFlagSet("test", ContinueOnError)
+ _, def, _ := net.ParseCIDR("0.0.0.0/0")
+ f.IPNetVar(ip, "address", *def, "IP Address")
+ return f
+}
+
+func TestIPNet(t *testing.T) {
+ testCases := []struct {
+ input string
+ success bool
+ expected string
+ }{
+ {"0.0.0.0/0", true, "0.0.0.0/0"},
+ {" 0.0.0.0/0 ", true, "0.0.0.0/0"},
+ {"1.2.3.4/8", true, "1.0.0.0/8"},
+ {"127.0.0.1/16", true, "127.0.0.0/16"},
+ {"255.255.255.255/19", true, "255.255.224.0/19"},
+ {"255.255.255.255/32", true, "255.255.255.255/32"},
+ {"", false, ""},
+ {"/0", false, ""},
+ {"0", false, ""},
+ {"0/0", false, ""},
+ {"localhost/0", false, ""},
+ {"0.0.0/4", false, ""},
+ {"0.0.0./8", false, ""},
+ {"0.0.0.0./12", false, ""},
+ {"0.0.0.256/16", false, ""},
+ {"0.0.0.0 /20", false, ""},
+ {"0.0.0.0/ 24", false, ""},
+ {"0 . 0 . 0 . 0 / 28", false, ""},
+ {"0.0.0.0/33", false, ""},
+ }
+
+ devnull, _ := os.Open(os.DevNull)
+ os.Stderr = devnull
+ for i := range testCases {
+ var addr net.IPNet
+ f := setUpIPNet(&addr)
+
+ tc := &testCases[i]
+
+ arg := fmt.Sprintf("--address=%s", tc.input)
+ err := f.Parse([]string{arg})
+ if err != nil && tc.success == true {
+ t.Errorf("expected success, got %q", err)
+ continue
+ } else if err == nil && tc.success == false {
+ t.Errorf("expected failure")
+ continue
+ } else if tc.success {
+ ip, err := f.GetIPNet("address")
+ if err != nil {
+ t.Errorf("Got error trying to fetch the IP flag: %v", err)
+ }
+ if ip.String() != tc.expected {
+ t.Errorf("expected %q, got %q", tc.expected, ip.String())
+ }
+ }
+ }
+}