Merge pull request #56 from eparis/len-at-dash
Store the length of Args when a -- is found in commandline
diff --git a/flag.go b/flag.go
index 66ef096..57d9339 100644
--- a/flag.go
+++ b/flag.go
@@ -140,6 +140,7 @@
formal map[NormalizedName]*Flag
shorthands map[byte]*Flag
args []string // arguments after flags
+ argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no --
exitOnError bool // does the program exit if there's an error?
errorHandling ErrorHandling
output io.Writer // nil means stderr; use out() accessor
@@ -292,6 +293,13 @@
return result, nil
}
+// ArgsLenAtDash will return the length of f.Args at the moment when a -- was
+// found during arg parsing. This allows your program to know which args were
+// before the -- and which came after.
+func (f *FlagSet) ArgsLenAtDash() int {
+ return f.argsLenAtDash
+}
+
// MarkDeprecated indicated that a flag is deprecated in your program. It will
// continue to function but will not show up in help or usage messages. Using
// this flag will also print the given usageMessage.
@@ -740,6 +748,7 @@
if s[1] == '-' {
if len(s) == 2 { // "--" terminates the flags
+ f.argsLenAtDash = len(f.args)
f.args = append(f.args, args...)
break
}
@@ -806,6 +815,7 @@
f := &FlagSet{
name: name,
errorHandling: errorHandling,
+ argsLenAtDash: -1,
interspersed: true,
}
return f
@@ -822,4 +832,5 @@
func (f *FlagSet) Init(name string, errorHandling ErrorHandling) {
f.name = name
f.errorHandling = errorHandling
+ f.argsLenAtDash = -1
}
diff --git a/flag_test.go b/flag_test.go
index 1b4005c..e17b2aa 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -387,6 +387,9 @@
} else if f.Args()[1] != notaflag {
t.Errorf("expected argument %q got %q", notaflag, f.Args()[1])
}
+ if f.ArgsLenAtDash() != 1 {
+ t.Errorf("expected argsLenAtDash %d got %d", f.ArgsLenAtDash(), 1)
+ }
}
func TestParse(t *testing.T) {
@@ -424,6 +427,9 @@
if f.Changed("invalid") {
t.Errorf("--invalid was changed!")
}
+ if f.ArgsLenAtDash() != -1 {
+ t.Errorf("Expected argsLenAtDash: %d but got %d", -1, f.ArgsLenAtDash())
+ }
}
func replaceSeparators(name string, from []string, to string) string {
@@ -713,6 +719,9 @@
if f.Args()[1] != arg2 {
t.Errorf("expected argument %q got %q", arg2, f.Args()[1])
}
+ if f.ArgsLenAtDash() != 0 {
+ t.Errorf("expected argsLenAtDash %d got %d", 0, f.ArgsLenAtDash())
+ }
}
func TestDeprecatedFlagInDocs(t *testing.T) {