Restore performance for the simple case
```
BenchmarkGetBool-4 1021 481 -52.89%
BenchmarkGet-4 879 403 -54.15%
BenchmarkGetBoolFromMap-4 6.56 6.40 -2.44%
benchmark old allocs new allocs delta
BenchmarkGetBool-4 6 4 -33.33%
BenchmarkGet-4 6 4 -33.33%
BenchmarkGetBoolFromMap-4 0 0 +0.00%
benchmark old bytes new bytes delta
BenchmarkGetBool-4 113 49 -56.64%
BenchmarkGet-4 112 48 -57.14%
BenchmarkGetBoolFromMap-4 0 0 +0.00%
```
Fixes #249
Fixes https://github.com/spf13/hugo/issues/2536
diff --git a/viper.go b/viper.go
index 8f27849..a03540c 100644
--- a/viper.go
+++ b/viper.go
@@ -399,17 +399,42 @@
return false
}
+// searchMapForKey may end up traversing the map if the key references a nested
+// item (foo.bar), but will use a fast path for the common case.
+// Note: This assumes that the key given is already lowercase.
+func (v *Viper) searchMapForKey(source map[string]interface{}, lcaseKey string) interface{} {
+ if !strings.Contains(lcaseKey, v.keyDelim) {
+ v, ok := source[lcaseKey]
+ if ok {
+ return v
+ }
+ return nil
+ }
+
+ path := strings.Split(lcaseKey, v.keyDelim)
+ return v.searchMap(source, path)
+}
+
// searchMap recursively searches for a value for path in source map.
// Returns nil if not found.
+// Note: This assumes that the path entries are lower cased.
func (v *Viper) searchMap(source map[string]interface{}, path []string) interface{} {
if len(path) == 0 {
return source
}
+ // Fast path
+ if len(path) == 1 {
+ if v, ok := source[path[0]]; ok {
+ return v
+ }
+ return nil
+ }
+
var ok bool
var next interface{}
for k, v := range source {
- if strings.ToLower(k) == strings.ToLower(path[0]) {
+ if k == path[0] {
ok = true
next = v
break
@@ -594,8 +619,8 @@
valType := val
if v.typeByDefValue {
- path := strings.Split(lcaseKey, v.keyDelim)
- defVal := v.searchMap(v.defaults, path)
+ // TODO(bep) this branch isn't covered by a single test.
+ defVal := v.searchMapForKey(v.defaults, lcaseKey)
if defVal != nil {
valType = defVal
}
@@ -841,32 +866,39 @@
// Viper will check in the following order:
// flag, env, config file, key/value store, default.
// Viper will check to see if an alias exists first.
-func (v *Viper) find(key string) interface{} {
- var val interface{}
- var exists bool
+// Note: this assumes a lower-cased key given.
+func (v *Viper) find(lcaseKey string) interface{} {
+
+ var (
+ val interface{}
+ exists bool
+ path = strings.Split(lcaseKey, v.keyDelim)
+ nested = len(path) > 1
+ )
// compute the path through the nested maps to the nested value
- path := strings.Split(key, v.keyDelim)
- if shadow := v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)); shadow != "" {
+ if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" {
return nil
}
// if the requested key is an alias, then return the proper key
- key = v.realKey(key)
- // re-compute the path
- path = strings.Split(key, v.keyDelim)
+ lcaseKey = v.realKey(lcaseKey)
// Set() override first
- val = v.searchMap(v.override, path)
+ val = v.searchMapForKey(v.override, lcaseKey)
if val != nil {
return val
}
- if shadow := v.isPathShadowedInDeepMap(path, v.override); shadow != "" {
+
+ path = strings.Split(lcaseKey, v.keyDelim)
+ nested = len(path) > 1
+
+ if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
return nil
}
// PFlag override next
- flag, exists := v.pflags[key]
+ flag, exists := v.pflags[lcaseKey]
if exists && flag.HasChanged() {
switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64":
@@ -880,7 +912,8 @@
return flag.ValueString()
}
}
- if shadow := v.isPathShadowedInFlatMap(path, v.pflags); shadow != "" {
+
+ if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
return nil
}
@@ -888,14 +921,14 @@
if v.automaticEnvApplied {
// even if it hasn't been registered, if automaticEnv is used,
// check any Get request
- if val = v.getEnv(v.mergeWithEnvPrefix(key)); val != "" {
+ if val = v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); val != "" {
return val
}
- if shadow := v.isPathShadowedInAutoEnv(path); shadow != "" {
+ if nested && v.isPathShadowedInAutoEnv(path) != "" {
return nil
}
}
- envkey, exists := v.env[key]
+ envkey, exists := v.env[lcaseKey]
if exists {
if val = v.getEnv(envkey); val != "" {
return val
@@ -934,7 +967,7 @@
// last chance: if no other value is returned and a flag does exist for the value,
// get the flag's value even if the flag's value has not changed
- if flag, exists := v.pflags[key]; exists {
+ if flag, exists := v.pflags[lcaseKey]; exists {
switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64":
return cast.ToInt(flag.ValueString())
diff --git a/viper_test.go b/viper_test.go
index c04b5f8..963c3ce 100644
--- a/viper_test.go
+++ b/viper_test.go
@@ -18,6 +18,8 @@
"testing"
"time"
+ "github.com/spf13/cast"
+
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
)
@@ -131,12 +133,18 @@
unmarshalReader(remote, v.kvstore)
}
-func initYAML() {
+func initConfig(typ, config string) {
Reset()
- SetConfigType("yaml")
- r := bytes.NewReader(yamlExample)
+ SetConfigType(typ)
+ r := strings.NewReader(config)
- unmarshalReader(r, v.config)
+ if err := unmarshalReader(r, v.config); err != nil {
+ panic(err)
+ }
+}
+
+func initYAML() {
+ initConfig("yaml", string(yamlExample))
}
func initJSON() {
@@ -435,13 +443,8 @@
assert.Equal(t, all, AllSettings())
}
-func TestCaseInSensitive(t *testing.T) {
- assert.Equal(t, true, Get("hacker"))
- Set("Title", "Checking Case")
- assert.Equal(t, "Checking Case", Get("tItle"))
-}
-
func TestAliasesOfAliases(t *testing.T) {
+ Set("Title", "Checking Case")
RegisterAlias("Foo", "Bar")
RegisterAlias("Bar", "Title")
assert.Equal(t, "Checking Case", Get("FOO"))
@@ -538,7 +541,6 @@
}
func TestBoundCaseSensitivity(t *testing.T) {
-
assert.Equal(t, "brown", Get("eyes"))
BindEnv("eYEs", "TURTLE_EYES")
@@ -917,8 +919,19 @@
}
func TestShadowedNestedValue(t *testing.T) {
+
+ config := `name: steve
+clothing:
+ jacket: leather
+ trousers: denim
+ pants:
+ size: large
+`
+ initConfig("yaml", config)
+
+ assert.Equal(t, "steve", GetString("name"))
+
polyester := "polyester"
- initYAML()
SetDefault("clothing.shirt", polyester)
SetDefault("clothing.jacket.price", 100)
@@ -942,16 +955,63 @@
assert.Equal(t, expected, actual)
}
-func TestGetBool(t *testing.T) {
- key := "BooleanKey"
- v = New()
- v.Set(key, true)
- if !v.GetBool(key) {
- t.Fatal("GetBool returned false")
+func TestCaseInSensitive(t *testing.T) {
+ for _, config := range []struct {
+ typ string
+ content string
+ }{
+ {"yaml", `
+aBcD: 1
+eF:
+ gH: 2
+ iJk: 3
+ Lm:
+ nO: 4
+ P:
+ Q: 5
+ R: 6
+`},
+ {"json", `{
+ "aBcD": 1,
+ "eF": {
+ "iJk": 3,
+ "Lm": {
+ "P": {
+ "Q": 5,
+ "R": 6
+ },
+ "nO": 4
+ },
+ "gH": 2
+ }
+}`},
+ {"toml", `aBcD = 1
+[eF]
+gH = 2
+iJk = 3
+[eF.Lm]
+nO = 4
+[eF.Lm.P]
+Q = 5
+R = 6
+`},
+ } {
+ doTestCaseInSensitive(t, config.typ, config.content)
}
- if v.GetBool("NotFound") {
- t.Fatal("GetBool returned true")
- }
+}
+
+func doTestCaseInSensitive(t *testing.T, typ, config string) {
+ initConfig(typ, config)
+ Set("RfD", true)
+ assert.Equal(t, true, Get("rfd"))
+ assert.Equal(t, true, Get("rFD"))
+ assert.Equal(t, 1, cast.ToInt(Get("abcd")))
+ assert.Equal(t, 1, cast.ToInt(Get("Abcd")))
+ assert.Equal(t, 2, cast.ToInt(Get("ef.gh")))
+ assert.Equal(t, 3, cast.ToInt(Get("ef.ijk")))
+ assert.Equal(t, 4, cast.ToInt(Get("ef.lm.no")))
+ assert.Equal(t, 5, cast.ToInt(Get("ef.lm.p.q")))
+
}
func BenchmarkGetBool(b *testing.B) {