Merge pull request #10 from eparis/declare_by_name
Declare Flag{ by name instead of orderdiff --git a/flag.go b/flag.go
index fd03681..33db47c 100644
--- a/flag.go
+++ b/flag.go
@@ -120,6 +120,10 @@
PanicOnError
)
+// normalizedName is a flag name that has been normalized according to rules
+// for the FlagSet (e.g. making '-' and '_' equivalent).
+type normalizedName string
+
// A FlagSet represents a set of defined flags.
type FlagSet struct {
// Usage is the function called when an error occurs while parsing flags.
@@ -127,16 +131,17 @@
// a custom error handler.
Usage func()
- name string
- parsed bool
- actual map[string]*Flag
- formal map[string]*Flag
- shorthands map[byte]*Flag
- args []string // arguments after flags
- exitOnError bool // does the program exit if there's an error?
- errorHandling ErrorHandling
- output io.Writer // nil means stderr; use out() accessor
- interspersed bool // allow interspersed option/non-option args
+ name string
+ parsed bool
+ actual map[normalizedName]*Flag
+ formal map[normalizedName]*Flag
+ shorthands map[byte]*Flag
+ args []string // arguments after flags
+ exitOnError bool // does the program exit if there's an error?
+ errorHandling ErrorHandling
+ output io.Writer // nil means stderr; use out() accessor
+ interspersed bool // allow interspersed option/non-option args
+ wordSeparators []string
}
// A Flag represents the state of a flag.
@@ -159,21 +164,30 @@
}
// sortFlags returns the flags as a slice in lexicographical sorted order.
-func sortFlags(flags map[string]*Flag) []*Flag {
+func sortFlags(flags map[normalizedName]*Flag) []*Flag {
list := make(sort.StringSlice, len(flags))
i := 0
- for _, f := range flags {
- list[i] = f.Name
+ for k := range flags {
+ list[i] = string(k)
i++
}
list.Sort()
result := make([]*Flag, len(list))
for i, name := range list {
- result[i] = flags[name]
+ result[i] = flags[normalizedName(name)]
}
return result
}
+func (f *FlagSet) normalizeFlagName(name string) normalizedName {
+ result := name
+ for _, sep := range f.wordSeparators {
+ result = strings.Replace(result, sep, "-", -1)
+ }
+ // Type convert to indicate normalization has been done.
+ return normalizedName(result)
+}
+
func (f *FlagSet) out() io.Writer {
if f.output == nil {
return os.Stderr
@@ -221,18 +235,24 @@
// Lookup returns the Flag structure of the named flag, returning nil if none exists.
func (f *FlagSet) Lookup(name string) *Flag {
+ return f.lookup(f.normalizeFlagName(name))
+}
+
+// lookup returns the Flag structure of the named flag, returning nil if none exists.
+func (f *FlagSet) lookup(name normalizedName) *Flag {
return f.formal[name]
}
// Lookup returns the Flag structure of the named command-line flag,
// returning nil if none exists.
func Lookup(name string) *Flag {
- return CommandLine.formal[name]
+ return CommandLine.Lookup(name)
}
// Set sets the value of the named flag.
func (f *FlagSet) Set(name, value string) error {
- flag, ok := f.formal[name]
+ normalName := f.normalizeFlagName(name)
+ flag, ok := f.formal[normalName]
if !ok {
return fmt.Errorf("no such flag -%v", name)
}
@@ -241,10 +261,10 @@
return err
}
if f.actual == nil {
- f.actual = make(map[string]*Flag)
+ f.actual = make(map[normalizedName]*Flag)
}
- f.actual[name] = flag
- f.Lookup(name).Changed = true
+ f.actual[normalName] = flag
+ f.lookup(normalName).Changed = true
return nil
}
@@ -370,16 +390,16 @@
}
func (f *FlagSet) AddFlag(flag *Flag) {
- _, alreadythere := f.formal[flag.Name]
+ _, alreadythere := f.formal[f.normalizeFlagName(flag.Name)]
if alreadythere {
msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name)
fmt.Fprintln(f.out(), msg)
panic(msg) // Happens only if flags are declared with identical names
}
if f.formal == nil {
- f.formal = make(map[string]*Flag)
+ f.formal = make(map[normalizedName]*Flag)
}
- f.formal[flag.Name] = flag
+ f.formal[f.normalizeFlagName(flag.Name)] = flag
if len(flag.Shorthand) == 0 {
return
@@ -442,9 +462,9 @@
}
// mark as visited for Visit()
if f.actual == nil {
- f.actual = make(map[string]*Flag)
+ f.actual = make(map[normalizedName]*Flag)
}
- f.actual[flag.Name] = flag
+ f.actual[f.normalizeFlagName(flag.Name)] = flag
flag.Changed = true
return nil
}
@@ -463,7 +483,7 @@
split := strings.SplitN(name, "=", 2)
name = split[0]
m := f.formal
- flag, alreadythere := m[name] // BUG
+ flag, alreadythere := m[f.normalizeFlagName(name)] // BUG
if !alreadythere {
if name == "help" { // special case for nice help message.
f.usage()
@@ -513,7 +533,8 @@
continue
}
if i < len(shorthands)-1 {
- if e := f.setFlag(flag, shorthands[i+1:], s); e != nil {
+ v := strings.TrimPrefix(shorthands[i+1:], "=")
+ if e := f.setFlag(flag, v, s); e != nil {
err = e
return
}
@@ -604,6 +625,19 @@
CommandLine.SetInterspersed(interspersed)
}
+// SetWordSeparators sets a list of strings to be considerered as word
+// separators and normalized for the pruposes of lookups. For example, if this
+// is set to {"-", "_", "."} then --foo_bar, --foo-bar, and --foo.bar are
+// considered equivalent flags. This must be called before flags are parsed,
+// and may only be called once.
+func (f *FlagSet) SetWordSeparators(separators []string) {
+ f.wordSeparators = separators
+ for k, v := range f.formal {
+ delete(f.formal, k)
+ f.formal[f.normalizeFlagName(string(k))] = v
+ }
+}
+
// Parsed returns true if the command-line flags have been parsed.
func Parsed() bool {
return CommandLine.Parsed()
diff --git a/flag_test.go b/flag_test.go
index a33c601..c4055ed 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -185,7 +185,8 @@
boolaFlag := f.BoolP("boola", "a", false, "bool value")
boolbFlag := f.BoolP("boolb", "b", false, "bool2 value")
boolcFlag := f.BoolP("boolc", "c", false, "bool3 value")
- stringFlag := f.StringP("string", "s", "0", "string value")
+ stringaFlag := f.StringP("stringa", "s", "0", "string value")
+ stringzFlag := f.StringP("stringz", "z", "0", "string value")
extra := "interspersed-argument"
notaflag := "--i-look-like-a-flag"
args := []string{
@@ -193,6 +194,7 @@
extra,
"-cs",
"hello",
+ "-z=something",
"--",
notaflag,
}
@@ -212,8 +214,11 @@
if *boolcFlag != true {
t.Error("boolc flag should be true, is ", *boolcFlag)
}
- if *stringFlag != "hello" {
- t.Error("string flag should be `hello`, is ", *stringFlag)
+ if *stringaFlag != "hello" {
+ t.Error("stringa flag should be `hello`, is ", *stringaFlag)
+ }
+ if *stringzFlag != "something" {
+ t.Error("stringz flag should be `something`, is ", *stringzFlag)
}
if len(f.Args()) != 2 {
t.Error("expected one argument, got", len(f.Args()))
@@ -233,6 +238,56 @@
testParse(NewFlagSet("test", ContinueOnError), t)
}
+func testNormalizedNames(args []string, t *testing.T) {
+ f := NewFlagSet("normalized", ContinueOnError)
+ if f.Parsed() {
+ t.Error("f.Parse() = true before Parse")
+ }
+ withDashFlag := f.Bool("with-dash-flag", false, "bool value")
+ // Set this after some flags have been added and before others.
+ f.SetWordSeparators([]string{"-", "_"})
+ withUnderFlag := f.Bool("with_under_flag", false, "bool value")
+ withBothFlag := f.Bool("with-both_flag", false, "bool value")
+ if err := f.Parse(args); err != nil {
+ t.Fatal(err)
+ }
+ if !f.Parsed() {
+ t.Error("f.Parse() = false after Parse")
+ }
+ if *withDashFlag != true {
+ t.Error("withDashFlag flag should be true, is ", *withDashFlag)
+ }
+ if *withUnderFlag != true {
+ t.Error("withUnderFlag flag should be true, is ", *withUnderFlag)
+ }
+ if *withBothFlag != true {
+ t.Error("withBothFlag flag should be true, is ", *withBothFlag)
+ }
+}
+
+func TestNormalizedNames(t *testing.T) {
+ args := []string{
+ "--with-dash-flag",
+ "--with-under-flag",
+ "--with-both-flag",
+ }
+ testNormalizedNames(args, t)
+
+ args = []string{
+ "--with_dash_flag",
+ "--with_under_flag",
+ "--with_both_flag",
+ }
+ testNormalizedNames(args, t)
+
+ args = []string{
+ "--with-dash_flag",
+ "--with-under_flag",
+ "--with-both_flag",
+ }
+ testNormalizedNames(args, t)
+}
+
// Declare a user-defined flag type.
type flagVar []string