Support decoding nested structs
diff --git a/mapstructure.go b/mapstructure.go
index f1a0c80..937c205 100644
--- a/mapstructure.go
+++ b/mapstructure.go
@@ -24,8 +24,79 @@
return errors.New("val must be an addressable struct")
}
- valType := val.Type()
+ return decode("root", m, val)
+}
+// Decodes an unknown data type into a specific reflection value.
+func decode(name string, data interface{}, val reflect.Value) error {
+ k := val.Kind()
+
+ // Some shortcuts because we treat all ints and uints the same way
+ if k >= reflect.Int && k <= reflect.Int64 {
+ k = reflect.Int
+ } else if k >= reflect.Uint && k <= reflect.Uint64 {
+ k = reflect.Uint
+ }
+
+ switch k {
+ case reflect.Bool:
+ fallthrough
+ case reflect.Int:
+ fallthrough
+ case reflect.String:
+ fallthrough
+ case reflect.Uint:
+ return decodeBasic(name, data, val)
+ case reflect.Struct:
+ return decodeStruct(name, data, val)
+ }
+
+ // If we reached this point then we weren't able to decode it
+ return fmt.Errorf("unsupported type: %s", k)
+}
+
+// This decodes a basic type (bool, int, string, etc.) and sets the
+// value to "data" of that type.
+func decodeBasic(name string, data interface{}, val reflect.Value) error {
+ dataVal := reflect.ValueOf(data)
+ if !dataVal.IsValid() {
+ // This should never happen because upstream makes sure it is valid
+ panic("data is invalid")
+ }
+
+ dataValType := dataVal.Type()
+ if !dataValType.AssignableTo(val.Type()) {
+ return fmt.Errorf(
+ "'%s' expected type '%s', got '%s'",
+ name, val.Type(), dataValType)
+ }
+
+ val.Set(dataVal)
+ return nil
+}
+
+func decodeStruct(name string, data interface{}, val reflect.Value) error {
+ dataVal := reflect.ValueOf(data)
+ dataValKind := dataVal.Kind()
+ if dataValKind != reflect.Map {
+ return fmt.Errorf("'%s' expected a map, got '%s'", name, dataValKind)
+ }
+
+ dataValType := dataVal.Type()
+ if dataValType.Key().Kind() != reflect.String {
+ return fmt.Errorf(
+ "'%s' needs a map with string keys, has '%s' keys",
+ name, dataValType.Key().Kind())
+ }
+
+ // At this point we know that data is a map with string keys, so
+ // we can properly cast it here.
+ m, ok := data.(map[string]interface{})
+ if !ok {
+ panic("data could not be cast as map[string]interface{}")
+ }
+
+ valType := val.Type()
for i := 0; i < valType.NumField(); i++ {
fieldType := valType.Field(i)
fieldName := fieldType.Name
@@ -54,23 +125,10 @@
panic("field is not valid")
}
- mapVal := reflect.ValueOf(rawMapVal)
- if !mapVal.IsValid() {
- // This should never happen because we got the value out
- // of the map.
- panic("map value is not valid")
+ fieldName = fmt.Sprintf("%s.%s", name, fieldName)
+ if err := decode(fieldName, rawMapVal, field); err != nil {
+ return err
}
-
- mapValType := mapVal.Type()
- if !mapValType.AssignableTo(field.Type()) {
- // If the value in the map can't be assigned to the field
- // in the struct, then this is a problem...
- return fmt.Errorf(
- "field '%s' expected type '%s', got '%s'",
- fieldName, field.Type(), mapValType)
- }
-
- field.Set(mapVal)
}
return nil
diff --git a/mapstructure_test.go b/mapstructure_test.go
index 3ec2a3a..a55953b 100644
--- a/mapstructure_test.go
+++ b/mapstructure_test.go
@@ -9,6 +9,11 @@
Vextra string
}
+type Nested struct {
+ Vfoo string
+ Vbar Basic
+}
+
func TestBasicTypes(t *testing.T) {
t.Parallel()
@@ -42,6 +47,46 @@
}
}
+func TestNestedType(t *testing.T) {
+ t.Parallel()
+
+ input := map[string]interface{}{
+ "vfoo": "foo",
+ "vbar": map[string]interface{}{
+ "vstring": "foo",
+ "vint": 42,
+ "vbool": true,
+ },
+ }
+
+ var result Nested
+ err := MapToStruct(input, &result)
+ if err != nil {
+ t.Errorf("got an err: %s", err.Error())
+ t.FailNow()
+ }
+
+ if result.Vfoo != "foo" {
+ t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo)
+ }
+
+ if result.Vbar.Vstring != "foo" {
+ t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring)
+ }
+
+ if result.Vbar.Vint != 42 {
+ t.Errorf("vint value should be 42: %#v", result.Vbar.Vint)
+ }
+
+ if result.Vbar.Vbool != true {
+ t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool)
+ }
+
+ if result.Vbar.Vextra != "" {
+ t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra)
+ }
+}
+
func TestInvalidType(t *testing.T) {
t.Parallel()
@@ -56,7 +101,7 @@
t.FailNow()
}
- if err.Error() != "field 'Vstring' expected type 'string', got 'int'" {
+ if err.Error() != "'root.Vstring' expected type 'string', got 'int'" {
t.Errorf("got unexpected error: %s", err)
}
}