properly decode each map key/value, more robust struct decode
diff --git a/mapstructure.go b/mapstructure.go index ec0941e..35b11df 100644 --- a/mapstructure.go +++ b/mapstructure.go
@@ -92,8 +92,27 @@ name, dataValType.Key().Kind()) } - // Just go ahead and set one map to the other... - val.Set(dataVal) + valType := val.Type() + valKeyType := valType.Key() + valElemType := valType.Elem() + + // Make a new map to hold our result + mapType := reflect.MapOf(valKeyType, valElemType) + valMap := reflect.MakeMap(mapType) + + for _, k := range dataVal.MapKeys() { + currentData := dataVal.MapIndex(k).Interface() + currentVal := reflect.Indirect(reflect.New(valElemType)) + + fieldName := fmt.Sprintf("%s[%s]", name, k) + if err := decode(fieldName, currentData, currentVal); err != nil { + return err + } + + valMap.SetMapIndex(k, currentVal) + } + + val.Set(valMap) return nil } @@ -139,32 +158,25 @@ name, dataValType.Key().Kind()) } - // At this point we know that data is a map with string keys, so - // we can properly cast it here. We use the "Interface()" value because - // this gets us the proper interface whether or not data is a pointer - // or not. - m, ok := dataVal.Interface().(map[string]interface{}) - if !ok { - panic("data could not be cast as map[string]interface{}") - } - valType := val.Type() for i := 0; i < valType.NumField(); i++ { fieldType := valType.Field(i) fieldName := fieldType.Name - rawMapVal, ok := m[fieldName] - if !ok { + rawMapVal := dataVal.MapIndex(reflect.ValueOf(fieldName)) + if !rawMapVal.IsValid() { // Do a slower search by iterating over each key and // doing case-insensitive search. - for mK, mV := range m { + for _, dataKeyVal := range dataVal.MapKeys() { + mK := dataKeyVal.Interface().(string) + if strings.EqualFold(mK, fieldName) { - rawMapVal = mV + rawMapVal = dataVal.MapIndex(dataKeyVal) break } } - if rawMapVal == nil { + if !rawMapVal.IsValid() { // There was no matching key in the map for the value in // the struct. Just ignore. continue @@ -178,7 +190,7 @@ } fieldName = fmt.Sprintf("%s.%s", name, fieldName) - if err := decode(fieldName, rawMapVal, field); err != nil { + if err := decode(fieldName, rawMapVal.Interface(), field); err != nil { return err } }
diff --git a/mapstructure_test.go b/mapstructure_test.go index df98537..e913b6c 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go
@@ -11,7 +11,11 @@ type Map struct { Vfoo string - Vother map[string]interface{} + Vother map[string]string +} + +type MapOfStruct struct { + Value map[string]Basic } type Nested struct { @@ -67,8 +71,8 @@ input := map[string]interface{}{ "vfoo": "foo", "vother": map[string]interface{}{ - "foo": 42, - "bar": true, + "foo": "foo", + "bar": "bar", }, } @@ -92,12 +96,47 @@ t.Error("vother should have two items") } - if result.Vother["foo"].(int) != 42 { - t.Errorf("'foo' key should be 42, got: %#v", result.Vother["foo"]) + if result.Vother["foo"] != "foo" { + t.Errorf("'foo' key should be foo, got: %#v", result.Vother["foo"]) } - if result.Vother["bar"].(bool) != true { - t.Errorf("'bar' key should be true, got: %#v", result.Vother["bar"]) + if result.Vother["bar"] != "bar" { + t.Errorf("'bar' key should be bar, got: %#v", result.Vother["bar"]) + } +} + +func TestMapOfStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": map[string]interface{}{ + "foo": map[string]string{"vstring": "one"}, + "bar": map[string]string{"vstring": "two"}, + }, + } + + var result MapOfStruct + err := Decode(input, &result) + if err != nil { + t.Errorf("got an err: %s", err) + t.FailNow() + } + + if result.Value == nil { + t.Error("value should not be nil") + t.FailNow() + } + + if len(result.Value) != 2 { + t.Error("value should have two items") + } + + if result.Value["foo"].Vstring != "one" { + t.Errorf("foo value should be 'one', got: %s", result.Value["foo"].Vstring) + } + + if result.Value["bar"].Vstring != "two" { + t.Errorf("bar value should be 'two', got: %s", result.Value["bar"].Vstring) } }