properly decode each map key/value, more robust struct decode
diff --git a/mapstructure.go b/mapstructure.go
index ec0941e..35b11df 100644
--- a/mapstructure.go
+++ b/mapstructure.go
@@ -92,8 +92,27 @@
name, dataValType.Key().Kind())
}
- // Just go ahead and set one map to the other...
- val.Set(dataVal)
+ valType := val.Type()
+ valKeyType := valType.Key()
+ valElemType := valType.Elem()
+
+ // Make a new map to hold our result
+ mapType := reflect.MapOf(valKeyType, valElemType)
+ valMap := reflect.MakeMap(mapType)
+
+ for _, k := range dataVal.MapKeys() {
+ currentData := dataVal.MapIndex(k).Interface()
+ currentVal := reflect.Indirect(reflect.New(valElemType))
+
+ fieldName := fmt.Sprintf("%s[%s]", name, k)
+ if err := decode(fieldName, currentData, currentVal); err != nil {
+ return err
+ }
+
+ valMap.SetMapIndex(k, currentVal)
+ }
+
+ val.Set(valMap)
return nil
}
@@ -139,32 +158,25 @@
name, dataValType.Key().Kind())
}
- // At this point we know that data is a map with string keys, so
- // we can properly cast it here. We use the "Interface()" value because
- // this gets us the proper interface whether or not data is a pointer
- // or not.
- m, ok := dataVal.Interface().(map[string]interface{})
- if !ok {
- panic("data could not be cast as map[string]interface{}")
- }
-
valType := val.Type()
for i := 0; i < valType.NumField(); i++ {
fieldType := valType.Field(i)
fieldName := fieldType.Name
- rawMapVal, ok := m[fieldName]
- if !ok {
+ rawMapVal := dataVal.MapIndex(reflect.ValueOf(fieldName))
+ if !rawMapVal.IsValid() {
// Do a slower search by iterating over each key and
// doing case-insensitive search.
- for mK, mV := range m {
+ for _, dataKeyVal := range dataVal.MapKeys() {
+ mK := dataKeyVal.Interface().(string)
+
if strings.EqualFold(mK, fieldName) {
- rawMapVal = mV
+ rawMapVal = dataVal.MapIndex(dataKeyVal)
break
}
}
- if rawMapVal == nil {
+ if !rawMapVal.IsValid() {
// There was no matching key in the map for the value in
// the struct. Just ignore.
continue
@@ -178,7 +190,7 @@
}
fieldName = fmt.Sprintf("%s.%s", name, fieldName)
- if err := decode(fieldName, rawMapVal, field); err != nil {
+ if err := decode(fieldName, rawMapVal.Interface(), field); err != nil {
return err
}
}
diff --git a/mapstructure_test.go b/mapstructure_test.go
index df98537..e913b6c 100644
--- a/mapstructure_test.go
+++ b/mapstructure_test.go
@@ -11,7 +11,11 @@
type Map struct {
Vfoo string
- Vother map[string]interface{}
+ Vother map[string]string
+}
+
+type MapOfStruct struct {
+ Value map[string]Basic
}
type Nested struct {
@@ -67,8 +71,8 @@
input := map[string]interface{}{
"vfoo": "foo",
"vother": map[string]interface{}{
- "foo": 42,
- "bar": true,
+ "foo": "foo",
+ "bar": "bar",
},
}
@@ -92,12 +96,47 @@
t.Error("vother should have two items")
}
- if result.Vother["foo"].(int) != 42 {
- t.Errorf("'foo' key should be 42, got: %#v", result.Vother["foo"])
+ if result.Vother["foo"] != "foo" {
+ t.Errorf("'foo' key should be foo, got: %#v", result.Vother["foo"])
}
- if result.Vother["bar"].(bool) != true {
- t.Errorf("'bar' key should be true, got: %#v", result.Vother["bar"])
+ if result.Vother["bar"] != "bar" {
+ t.Errorf("'bar' key should be bar, got: %#v", result.Vother["bar"])
+ }
+}
+
+func TestMapOfStruct(t *testing.T) {
+ t.Parallel()
+
+ input := map[string]interface{}{
+ "value": map[string]interface{}{
+ "foo": map[string]string{"vstring": "one"},
+ "bar": map[string]string{"vstring": "two"},
+ },
+ }
+
+ var result MapOfStruct
+ err := Decode(input, &result)
+ if err != nil {
+ t.Errorf("got an err: %s", err)
+ t.FailNow()
+ }
+
+ if result.Value == nil {
+ t.Error("value should not be nil")
+ t.FailNow()
+ }
+
+ if len(result.Value) != 2 {
+ t.Error("value should have two items")
+ }
+
+ if result.Value["foo"].Vstring != "one" {
+ t.Errorf("foo value should be 'one', got: %s", result.Value["foo"].Vstring)
+ }
+
+ if result.Value["bar"].Vstring != "two" {
+ t.Errorf("bar value should be 'two', got: %s", result.Value["bar"].Vstring)
}
}