Copy and insensitivise maps in Set
Fixes #261
Closes #265
diff --git a/util.go b/util.go
index aec8a99..3ebada9 100644
--- a/util.go
+++ b/util.go
@@ -39,6 +39,39 @@
return fmt.Sprintf("While parsing config: %s", pe.err.Error())
}
+// toCaseInsensitiveValue checks if the value is a map;
+// if so, create a copy and lower-case the keys recursively.
+func toCaseInsensitiveValue(value interface{}) interface{} {
+ switch v := value.(type) {
+ case map[interface{}]interface{}:
+ value = copyAndInsensitiviseMap(cast.ToStringMap(v))
+ case map[string]interface{}:
+ value = copyAndInsensitiviseMap(v)
+ }
+
+ return value
+}
+
+// copyAndInsensitiviseMap behaves like insensitiviseMap, but creates a copy of
+// any map it makes case insensitive.
+func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} {
+ nm := make(map[string]interface{})
+
+ for key, val := range m {
+ lkey := strings.ToLower(key)
+ switch v := val.(type) {
+ case map[interface{}]interface{}:
+ nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v))
+ case map[string]interface{}:
+ nm[lkey] = copyAndInsensitiviseMap(v)
+ default:
+ nm[lkey] = v
+ }
+ }
+
+ return nm
+}
+
func insensitiviseMap(m map[string]interface{}) {
for key, val := range m {
switch val.(type) {
diff --git a/util_test.go b/util_test.go
new file mode 100644
index 0000000..5949e09
--- /dev/null
+++ b/util_test.go
@@ -0,0 +1,55 @@
+// Copyright © 2016 Steve Francia <spf@spf13.com>.
+//
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+// Viper is a application configuration system.
+// It believes that applications can be configured a variety of ways
+// via flags, ENVIRONMENT variables, configuration files retrieved
+// from the file system, or a remote key/value store.
+
+package viper
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestCopyAndInsensitiviseMap(t *testing.T) {
+
+ var (
+ given = map[string]interface{}{
+ "Foo": 32,
+ "Bar": map[interface{}]interface {
+ }{
+ "ABc": "A",
+ "cDE": "B"},
+ }
+ expected = map[string]interface{}{
+ "foo": 32,
+ "bar": map[string]interface {
+ }{
+ "abc": "A",
+ "cde": "B"},
+ }
+ )
+
+ got := copyAndInsensitiviseMap(given)
+
+ if !reflect.DeepEqual(got, expected) {
+ t.Fatalf("Got %q\nexpected\n%q", got, expected)
+ }
+
+ if _, ok := given["foo"]; ok {
+ t.Fatal("Input map changed")
+ }
+
+ if _, ok := given["bar"]; ok {
+ t.Fatal("Input map changed")
+ }
+
+ m := given["Bar"].(map[interface{}]interface{})
+ if _, ok := m["ABc"]; !ok {
+ t.Fatal("Input map changed")
+ }
+}
diff --git a/viper.go b/viper.go
index dba229b..de92df3 100644
--- a/viper.go
+++ b/viper.go
@@ -1042,6 +1042,7 @@
func (v *Viper) SetDefault(key string, value interface{}) {
// If alias passed in, then set the proper default
key = v.realKey(strings.ToLower(key))
+ value = toCaseInsensitiveValue(value)
path := strings.Split(key, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1])
@@ -1058,6 +1059,7 @@
func (v *Viper) Set(key string, value interface{}) {
// If alias passed in, then set the proper override
key = v.realKey(strings.ToLower(key))
+ value = toCaseInsensitiveValue(value)
path := strings.Split(key, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1])
diff --git a/viper_test.go b/viper_test.go
index 83dbac4..60e75a3 100644
--- a/viper_test.go
+++ b/viper_test.go
@@ -978,7 +978,7 @@
assert.Equal(t, expected, actual)
}
-func TestCaseInSensitive(t *testing.T) {
+func TestCaseInsensitive(t *testing.T) {
for _, config := range []struct {
typ string
content string
@@ -1019,11 +1019,70 @@
R = 6
`},
} {
- doTestCaseInSensitive(t, config.typ, config.content)
+ doTestCaseInsensitive(t, config.typ, config.content)
}
}
-func doTestCaseInSensitive(t *testing.T, typ, config string) {
+func TestCaseInsensitiveSet(t *testing.T) {
+ Reset()
+ m1 := map[string]interface{}{
+ "Foo": 32,
+ "Bar": map[interface{}]interface {
+ }{
+ "ABc": "A",
+ "cDE": "B"},
+ }
+
+ m2 := map[string]interface{}{
+ "Foo": 52,
+ "Bar": map[interface{}]interface {
+ }{
+ "bCd": "A",
+ "eFG": "B"},
+ }
+
+ Set("Given1", m1)
+ Set("Number1", 42)
+
+ SetDefault("Given2", m2)
+ SetDefault("Number2", 52)
+
+ // Verify SetDefault
+ if v := Get("number2"); v != 52 {
+ t.Fatalf("Expected 52 got %q", v)
+ }
+
+ if v := Get("given2.foo"); v != 52 {
+ t.Fatalf("Expected 52 got %q", v)
+ }
+
+ if v := Get("given2.bar.bcd"); v != "A" {
+ t.Fatalf("Expected A got %q", v)
+ }
+
+ if _, ok := m2["Foo"]; !ok {
+ t.Fatal("Input map changed")
+ }
+
+ // Verify Set
+ if v := Get("number1"); v != 42 {
+ t.Fatalf("Expected 42 got %q", v)
+ }
+
+ if v := Get("given1.foo"); v != 32 {
+ t.Fatalf("Expected 32 got %q", v)
+ }
+
+ if v := Get("given1.bar.abc"); v != "A" {
+ t.Fatalf("Expected A got %q", v)
+ }
+
+ if _, ok := m1["Foo"]; !ok {
+ t.Fatal("Input map changed")
+ }
+}
+
+func doTestCaseInsensitive(t *testing.T, typ, config string) {
initConfig(typ, config)
Set("RfD", true)
assert.Equal(t, true, Get("rfd"))