Merge pull request #12 from fabianofranz/master
Add support to short form with equal sign: -p=value
diff --git a/flag.go b/flag.go
index c43edad..7fb0990 100644
--- a/flag.go
+++ b/flag.go
@@ -120,6 +120,10 @@
PanicOnError
)
+// normalizedName is a flag name that has been normalized according to rules
+// for the FlagSet (e.g. making '-' and '_' equivalent).
+type normalizedName string
+
// A FlagSet represents a set of defined flags.
type FlagSet struct {
// Usage is the function called when an error occurs while parsing flags.
@@ -127,16 +131,17 @@
// a custom error handler.
Usage func()
- name string
- parsed bool
- actual map[string]*Flag
- formal map[string]*Flag
- shorthands map[byte]*Flag
- args []string // arguments after flags
- exitOnError bool // does the program exit if there's an error?
- errorHandling ErrorHandling
- output io.Writer // nil means stderr; use out() accessor
- interspersed bool // allow interspersed option/non-option args
+ name string
+ parsed bool
+ actual map[normalizedName]*Flag
+ formal map[normalizedName]*Flag
+ shorthands map[byte]*Flag
+ args []string // arguments after flags
+ exitOnError bool // does the program exit if there's an error?
+ errorHandling ErrorHandling
+ output io.Writer // nil means stderr; use out() accessor
+ interspersed bool // allow interspersed option/non-option args
+ wordSeparators []string
}
// A Flag represents the state of a flag.
@@ -159,21 +164,30 @@
}
// sortFlags returns the flags as a slice in lexicographical sorted order.
-func sortFlags(flags map[string]*Flag) []*Flag {
+func sortFlags(flags map[normalizedName]*Flag) []*Flag {
list := make(sort.StringSlice, len(flags))
i := 0
- for _, f := range flags {
- list[i] = f.Name
+ for k := range flags {
+ list[i] = string(k)
i++
}
list.Sort()
result := make([]*Flag, len(list))
for i, name := range list {
- result[i] = flags[name]
+ result[i] = flags[normalizedName(name)]
}
return result
}
+func (f *FlagSet) normalizeFlagName(name string) normalizedName {
+ result := name
+ for _, sep := range f.wordSeparators {
+ result = strings.Replace(result, sep, "-", -1)
+ }
+ // Type convert to indicate normalization has been done.
+ return normalizedName(result)
+}
+
func (f *FlagSet) out() io.Writer {
if f.output == nil {
return os.Stderr
@@ -221,18 +235,24 @@
// Lookup returns the Flag structure of the named flag, returning nil if none exists.
func (f *FlagSet) Lookup(name string) *Flag {
+ return f.lookup(f.normalizeFlagName(name))
+}
+
+// lookup returns the Flag structure of the named flag, returning nil if none exists.
+func (f *FlagSet) lookup(name normalizedName) *Flag {
return f.formal[name]
}
// Lookup returns the Flag structure of the named command-line flag,
// returning nil if none exists.
func Lookup(name string) *Flag {
- return CommandLine.formal[name]
+ return CommandLine.Lookup(name)
}
// Set sets the value of the named flag.
func (f *FlagSet) Set(name, value string) error {
- flag, ok := f.formal[name]
+ normalName := f.normalizeFlagName(name)
+ flag, ok := f.formal[normalName]
if !ok {
return fmt.Errorf("no such flag -%v", name)
}
@@ -241,10 +261,10 @@
return err
}
if f.actual == nil {
- f.actual = make(map[string]*Flag)
+ f.actual = make(map[normalizedName]*Flag)
}
- f.actual[name] = flag
- f.Lookup(name).Changed = true
+ f.actual[normalName] = flag
+ f.lookup(normalName).Changed = true
return nil
}
@@ -364,16 +384,16 @@
}
func (f *FlagSet) AddFlag(flag *Flag) {
- _, alreadythere := f.formal[flag.Name]
+ _, alreadythere := f.formal[f.normalizeFlagName(flag.Name)]
if alreadythere {
msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name)
fmt.Fprintln(f.out(), msg)
panic(msg) // Happens only if flags are declared with identical names
}
if f.formal == nil {
- f.formal = make(map[string]*Flag)
+ f.formal = make(map[normalizedName]*Flag)
}
- f.formal[flag.Name] = flag
+ f.formal[f.normalizeFlagName(flag.Name)] = flag
if len(flag.Shorthand) == 0 {
return
@@ -436,9 +456,9 @@
}
// mark as visited for Visit()
if f.actual == nil {
- f.actual = make(map[string]*Flag)
+ f.actual = make(map[normalizedName]*Flag)
}
- f.actual[flag.Name] = flag
+ f.actual[f.normalizeFlagName(flag.Name)] = flag
flag.Changed = true
return nil
}
@@ -457,7 +477,7 @@
split := strings.SplitN(name, "=", 2)
name = split[0]
m := f.formal
- flag, alreadythere := m[name] // BUG
+ flag, alreadythere := m[f.normalizeFlagName(name)] // BUG
if !alreadythere {
if name == "help" { // special case for nice help message.
f.usage()
@@ -599,6 +619,19 @@
CommandLine.SetInterspersed(interspersed)
}
+// SetWordSeparators sets a list of strings to be considerered as word
+// separators and normalized for the pruposes of lookups. For example, if this
+// is set to {"-", "_", "."} then --foo_bar, --foo-bar, and --foo.bar are
+// considered equivalent flags. This must be called before flags are parsed,
+// and may only be called once.
+func (f *FlagSet) SetWordSeparators(separators []string) {
+ f.wordSeparators = separators
+ for k, v := range f.formal {
+ delete(f.formal, k)
+ f.formal[f.normalizeFlagName(string(k))] = v
+ }
+}
+
// Parsed returns true if the command-line flags have been parsed.
func Parsed() bool {
return CommandLine.Parsed()
diff --git a/flag_test.go b/flag_test.go
index d6aa308..c4055ed 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -238,6 +238,56 @@
testParse(NewFlagSet("test", ContinueOnError), t)
}
+func testNormalizedNames(args []string, t *testing.T) {
+ f := NewFlagSet("normalized", ContinueOnError)
+ if f.Parsed() {
+ t.Error("f.Parse() = true before Parse")
+ }
+ withDashFlag := f.Bool("with-dash-flag", false, "bool value")
+ // Set this after some flags have been added and before others.
+ f.SetWordSeparators([]string{"-", "_"})
+ withUnderFlag := f.Bool("with_under_flag", false, "bool value")
+ withBothFlag := f.Bool("with-both_flag", false, "bool value")
+ if err := f.Parse(args); err != nil {
+ t.Fatal(err)
+ }
+ if !f.Parsed() {
+ t.Error("f.Parse() = false after Parse")
+ }
+ if *withDashFlag != true {
+ t.Error("withDashFlag flag should be true, is ", *withDashFlag)
+ }
+ if *withUnderFlag != true {
+ t.Error("withUnderFlag flag should be true, is ", *withUnderFlag)
+ }
+ if *withBothFlag != true {
+ t.Error("withBothFlag flag should be true, is ", *withBothFlag)
+ }
+}
+
+func TestNormalizedNames(t *testing.T) {
+ args := []string{
+ "--with-dash-flag",
+ "--with-under-flag",
+ "--with-both-flag",
+ }
+ testNormalizedNames(args, t)
+
+ args = []string{
+ "--with_dash_flag",
+ "--with_under_flag",
+ "--with_both_flag",
+ }
+ testNormalizedNames(args, t)
+
+ args = []string{
+ "--with-dash_flag",
+ "--with-under_flag",
+ "--with-both_flag",
+ }
+ testNormalizedNames(args, t)
+}
+
// Declare a user-defined flag type.
type flagVar []string