Support decoding nested structs
diff --git a/mapstructure.go b/mapstructure.go index f1a0c80..937c205 100644 --- a/mapstructure.go +++ b/mapstructure.go
@@ -24,8 +24,79 @@ return errors.New("val must be an addressable struct") } - valType := val.Type() + return decode("root", m, val) +} +// Decodes an unknown data type into a specific reflection value. +func decode(name string, data interface{}, val reflect.Value) error { + k := val.Kind() + + // Some shortcuts because we treat all ints and uints the same way + if k >= reflect.Int && k <= reflect.Int64 { + k = reflect.Int + } else if k >= reflect.Uint && k <= reflect.Uint64 { + k = reflect.Uint + } + + switch k { + case reflect.Bool: + fallthrough + case reflect.Int: + fallthrough + case reflect.String: + fallthrough + case reflect.Uint: + return decodeBasic(name, data, val) + case reflect.Struct: + return decodeStruct(name, data, val) + } + + // If we reached this point then we weren't able to decode it + return fmt.Errorf("unsupported type: %s", k) +} + +// This decodes a basic type (bool, int, string, etc.) and sets the +// value to "data" of that type. +func decodeBasic(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.ValueOf(data) + if !dataVal.IsValid() { + // This should never happen because upstream makes sure it is valid + panic("data is invalid") + } + + dataValType := dataVal.Type() + if !dataValType.AssignableTo(val.Type()) { + return fmt.Errorf( + "'%s' expected type '%s', got '%s'", + name, val.Type(), dataValType) + } + + val.Set(dataVal) + return nil +} + +func decodeStruct(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.ValueOf(data) + dataValKind := dataVal.Kind() + if dataValKind != reflect.Map { + return fmt.Errorf("'%s' expected a map, got '%s'", name, dataValKind) + } + + dataValType := dataVal.Type() + if dataValType.Key().Kind() != reflect.String { + return fmt.Errorf( + "'%s' needs a map with string keys, has '%s' keys", + name, dataValType.Key().Kind()) + } + + // At this point we know that data is a map with string keys, so + // we can properly cast it here. + m, ok := data.(map[string]interface{}) + if !ok { + panic("data could not be cast as map[string]interface{}") + } + + valType := val.Type() for i := 0; i < valType.NumField(); i++ { fieldType := valType.Field(i) fieldName := fieldType.Name @@ -54,23 +125,10 @@ panic("field is not valid") } - mapVal := reflect.ValueOf(rawMapVal) - if !mapVal.IsValid() { - // This should never happen because we got the value out - // of the map. - panic("map value is not valid") + fieldName = fmt.Sprintf("%s.%s", name, fieldName) + if err := decode(fieldName, rawMapVal, field); err != nil { + return err } - - mapValType := mapVal.Type() - if !mapValType.AssignableTo(field.Type()) { - // If the value in the map can't be assigned to the field - // in the struct, then this is a problem... - return fmt.Errorf( - "field '%s' expected type '%s', got '%s'", - fieldName, field.Type(), mapValType) - } - - field.Set(mapVal) } return nil
diff --git a/mapstructure_test.go b/mapstructure_test.go index 3ec2a3a..a55953b 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go
@@ -9,6 +9,11 @@ Vextra string } +type Nested struct { + Vfoo string + Vbar Basic +} + func TestBasicTypes(t *testing.T) { t.Parallel() @@ -42,6 +47,46 @@ } } +func TestNestedType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + var result Nested + err := MapToStruct(input, &result) + if err != nil { + t.Errorf("got an err: %s", err.Error()) + t.FailNow() + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } +} + func TestInvalidType(t *testing.T) { t.Parallel() @@ -56,7 +101,7 @@ t.FailNow() } - if err.Error() != "field 'Vstring' expected type 'string', got 'int'" { + if err.Error() != "'root.Vstring' expected type 'string', got 'int'" { t.Errorf("got unexpected error: %s", err) } }