Metadata tracking works properly with embedded structs
diff --git a/mapstructure.go b/mapstructure.go index 70da05b..253859f 100644 --- a/mapstructure.go +++ b/mapstructure.go
@@ -87,6 +87,16 @@ return nil, errors.New("result must be addressable (a pointer)") } + if config.Metadata != nil { + if config.Metadata.Keys == nil { + config.Metadata.Keys = make([]string, 0) + } + + if config.Metadata.Unused == nil { + config.Metadata.Unused = make([]string, 0) + } + } + if config.TagName == "" { config.TagName = "mapstructure" } @@ -327,34 +337,49 @@ dataValKeysUnused[dataValKey.Interface()] = struct{}{} } - errors := make([]string, 0) - valType := val.Type() - for i := 0; i < valType.NumField(); i++ { - fieldType := valType.Field(i) - fieldName := fieldType.Name + // This slice will keep track of all the structs we'll be decoding. + // There can be more than one struct if there are embedded structs + // that are squashed. + structs := make([]reflect.Value, 1, 5) + structs[0] = val - if fieldType.Anonymous { - // We have an embedded field. We "squash" the fields down if - // specified in the tag. - squash := false - tagParts := strings.Split(fieldType.Tag.Get(d.config.TagName), ",") - for _, tag := range tagParts[1:] { - if tag == "squash" { - squash = true - break + // Compile the list of all the fields that we're going to be decoding + // from all the structs. + fields := make(map[*reflect.StructField]reflect.Value) + for len(structs) > 0 { + structVal := structs[0] + structs = structs[1:] + + structType := structVal.Type() + for i := 0; i < structType.NumField(); i++ { + fieldType := structType.Field(i) + + if fieldType.Anonymous { + // We have an embedded field. We "squash" the fields down + // if specified in the tag. + squash := false + tagParts := strings.Split(fieldType.Tag.Get(d.config.TagName), ",") + for _, tag := range tagParts[1:] { + if tag == "squash" { + squash = true + break + } + } + + if squash { + structs = append(structs, val.FieldByName(fieldType.Name)) + continue } } - if squash { - inner := val.FieldByName(fieldName) - err := d.decodeStruct(name, data, inner) - if err != nil { - errors = appendErrors(errors, err) - } - - continue - } + // Normal struct field, store it away + fields[&fieldType] = structVal.Field(i) } + } + + errors := make([]string, 0) + for fieldType, field := range fields { + fieldName := fieldType.Name tagValue := fieldType.Tag.Get(d.config.TagName) if tagValue != "" { @@ -386,7 +411,6 @@ // Delete the key we're using from the unused map so we stop tracking delete(dataValKeysUnused, rawMapKey.Interface()) - field := val.Field(i) if !field.IsValid() { // This should never happen panic("field is not valid")
diff --git a/mapstructure_test.go b/mapstructure_test.go index 12e31c2..5535f8e 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go
@@ -471,6 +471,42 @@ } } +func TestMetadata_Embedded(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var md Metadata + var result EmbeddedSquash + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + expectedKeys := []string{"Vunique", "Vstring"} + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{} + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } +} + func TestNonPtrValue(t *testing.T) { t.Parallel()