Merge pull request #15 from eparis/generic-flag-normalization-thockin

Generic flag normalization
diff --git a/flag.go b/flag.go
index ac14b81..55594df 100644
--- a/flag.go
+++ b/flag.go
@@ -120,9 +120,9 @@
 	PanicOnError
 )
 
-// normalizedName is a flag name that has been normalized according to rules
+// NormalizedName is a flag name that has been normalized according to rules
 // for the FlagSet (e.g. making '-' and '_' equivalent).
-type normalizedName string
+type NormalizedName string
 
 // A FlagSet represents a set of defined flags.
 type FlagSet struct {
@@ -131,17 +131,17 @@
 	// a custom error handler.
 	Usage func()
 
-	name           string
-	parsed         bool
-	actual         map[normalizedName]*Flag
-	formal         map[normalizedName]*Flag
-	shorthands     map[byte]*Flag
-	args           []string // arguments after flags
-	exitOnError    bool     // does the program exit if there's an error?
-	errorHandling  ErrorHandling
-	output         io.Writer // nil means stderr; use out() accessor
-	interspersed   bool      // allow interspersed option/non-option args
-	wordSeparators []string
+	name              string
+	parsed            bool
+	actual            map[NormalizedName]*Flag
+	formal            map[NormalizedName]*Flag
+	shorthands        map[byte]*Flag
+	args              []string // arguments after flags
+	exitOnError       bool     // does the program exit if there's an error?
+	errorHandling     ErrorHandling
+	output            io.Writer // nil means stderr; use out() accessor
+	interspersed      bool      // allow interspersed option/non-option args
+	normalizeNameFunc func(f *FlagSet, name string) NormalizedName
 }
 
 // A Flag represents the state of a flag.
@@ -165,7 +165,7 @@
 }
 
 // sortFlags returns the flags as a slice in lexicographical sorted order.
-func sortFlags(flags map[normalizedName]*Flag) []*Flag {
+func sortFlags(flags map[NormalizedName]*Flag) []*Flag {
 	list := make(sort.StringSlice, len(flags))
 	i := 0
 	for k := range flags {
@@ -175,18 +175,29 @@
 	list.Sort()
 	result := make([]*Flag, len(list))
 	for i, name := range list {
-		result[i] = flags[normalizedName(name)]
+		result[i] = flags[NormalizedName(name)]
 	}
 	return result
 }
 
-func (f *FlagSet) normalizeFlagName(name string) normalizedName {
-	result := name
-	for _, sep := range f.wordSeparators {
-		result = strings.Replace(result, sep, "-", -1)
+func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) {
+	f.normalizeNameFunc = n
+	for k, v := range f.formal {
+		delete(f.formal, k)
+		f.formal[f.normalizeFlagName(string(k))] = v
 	}
-	// Type convert to indicate normalization has been done.
-	return normalizedName(result)
+}
+
+func (f *FlagSet) GetNormalizeFunc() func(f *FlagSet, name string) NormalizedName {
+	if f.normalizeNameFunc != nil {
+		return f.normalizeNameFunc
+	}
+	return func(f *FlagSet, name string) NormalizedName { return NormalizedName(name) }
+}
+
+func (f *FlagSet) normalizeFlagName(name string) NormalizedName {
+	n := f.GetNormalizeFunc()
+	return n(f, name)
 }
 
 func (f *FlagSet) out() io.Writer {
@@ -240,7 +251,7 @@
 }
 
 // lookup returns the Flag structure of the named flag, returning nil if none exists.
-func (f *FlagSet) lookup(name normalizedName) *Flag {
+func (f *FlagSet) lookup(name NormalizedName) *Flag {
 	return f.formal[name]
 }
 
@@ -272,7 +283,7 @@
 		return err
 	}
 	if f.actual == nil {
-		f.actual = make(map[normalizedName]*Flag)
+		f.actual = make(map[NormalizedName]*Flag)
 	}
 	f.actual[normalName] = flag
 	flag.Changed = true
@@ -417,7 +428,7 @@
 		panic(msg) // Happens only if flags are declared with identical names
 	}
 	if f.formal == nil {
-		f.formal = make(map[normalizedName]*Flag)
+		f.formal = make(map[NormalizedName]*Flag)
 	}
 	f.formal[f.normalizeFlagName(flag.Name)] = flag
 
@@ -482,7 +493,7 @@
 	}
 	// mark as visited for Visit()
 	if f.actual == nil {
-		f.actual = make(map[normalizedName]*Flag)
+		f.actual = make(map[NormalizedName]*Flag)
 	}
 	f.actual[f.normalizeFlagName(flag.Name)] = flag
 	flag.Changed = true
@@ -648,19 +659,6 @@
 	CommandLine.SetInterspersed(interspersed)
 }
 
-// SetWordSeparators sets a list of strings to be considerered as word
-// separators and normalized for the pruposes of lookups.  For example, if this
-// is set to {"-", "_", "."} then --foo_bar, --foo-bar, and --foo.bar are
-// considered equivalent flags.  This must be called before flags are parsed,
-// and may only be called once.
-func (f *FlagSet) SetWordSeparators(separators []string) {
-	f.wordSeparators = separators
-	for k, v := range f.formal {
-		delete(f.formal, k)
-		f.formal[f.normalizeFlagName(string(k))] = v
-	}
-}
-
 // Parsed returns true if the command-line flags have been parsed.
 func Parsed() bool {
 	return CommandLine.Parsed()
diff --git a/flag_test.go b/flag_test.go
index a1478e2..f552a2f 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -239,14 +239,29 @@
 	testParse(NewFlagSet("test", ContinueOnError), t)
 }
 
-func testNormalizedNames(args []string, t *testing.T) {
+func replaceSeparators(name string, from []string, to string) string {
+	result := name
+	for _, sep := range from {
+		result = strings.Replace(result, sep, to, -1)
+	}
+	// Type convert to indicate normalization has been done.
+	return result
+}
+
+func wordSepNormalizeFunc(f *FlagSet, name string) NormalizedName {
+	seps := []string{"-", "_"}
+	name = replaceSeparators(name, seps, ".")
+	return NormalizedName(name)
+}
+
+func testWordSepNormalizedNames(args []string, t *testing.T) {
 	f := NewFlagSet("normalized", ContinueOnError)
 	if f.Parsed() {
 		t.Error("f.Parse() = true before Parse")
 	}
 	withDashFlag := f.Bool("with-dash-flag", false, "bool value")
 	// Set this after some flags have been added and before others.
-	f.SetWordSeparators([]string{"-", "_"})
+	f.SetNormalizeFunc(wordSepNormalizeFunc)
 	withUnderFlag := f.Bool("with_under_flag", false, "bool value")
 	withBothFlag := f.Bool("with-both_flag", false, "bool value")
 	if err := f.Parse(args); err != nil {
@@ -266,27 +281,66 @@
 	}
 }
 
-func TestNormalizedNames(t *testing.T) {
+func TestWordSepNormalizedNames(t *testing.T) {
 	args := []string{
 		"--with-dash-flag",
 		"--with-under-flag",
 		"--with-both-flag",
 	}
-	testNormalizedNames(args, t)
+	testWordSepNormalizedNames(args, t)
 
 	args = []string{
 		"--with_dash_flag",
 		"--with_under_flag",
 		"--with_both_flag",
 	}
-	testNormalizedNames(args, t)
+	testWordSepNormalizedNames(args, t)
 
 	args = []string{
 		"--with-dash_flag",
 		"--with-under_flag",
 		"--with-both_flag",
 	}
-	testNormalizedNames(args, t)
+	testWordSepNormalizedNames(args, t)
+}
+
+func aliasAndWordSepFlagNames(f *FlagSet, name string) NormalizedName {
+	seps := []string{"-", "_"}
+
+	oldName := replaceSeparators("old-valid_flag", seps, ".")
+	newName := replaceSeparators("valid-flag", seps, ".")
+
+	name = replaceSeparators(name, seps, ".")
+	switch name {
+	case oldName:
+		name = newName
+		break
+	}
+
+	return NormalizedName(name)
+}
+
+func TestCustomNormalizedNames(t *testing.T) {
+	f := NewFlagSet("normalized", ContinueOnError)
+	if f.Parsed() {
+		t.Error("f.Parse() = true before Parse")
+	}
+
+	validFlag := f.Bool("valid-flag", false, "bool value")
+	f.SetNormalizeFunc(aliasAndWordSepFlagNames)
+	someOtherFlag := f.Bool("some-other-flag", false, "bool value")
+
+	args := []string{"--old_valid_flag", "--some-other_flag"}
+	if err := f.Parse(args); err != nil {
+		t.Fatal(err)
+	}
+
+	if *validFlag != true {
+		t.Errorf("validFlag is %v even though we set the alias --old_valid_falg", *validFlag)
+	}
+	if *someOtherFlag != true {
+		t.Error("someOtherFlag should be true, is ", *someOtherFlag)
+	}
 }
 
 // Declare a user-defined flag type.
@@ -503,7 +557,7 @@
 func TestDeprecatedFlagUsageNormalized(t *testing.T) {
 	f := NewFlagSet("bob", ContinueOnError)
 	f.Bool("bad-double_flag", true, "always true")
-	f.SetWordSeparators([]string{"-", "_"})
+	f.SetNormalizeFunc(wordSepNormalizeFunc)
 	usageMsg := "use --good-flag instead"
 	f.MarkDeprecated("bad_double-flag", usageMsg)