Metadata tracking works properly with embedded structs
diff --git a/mapstructure.go b/mapstructure.go
index 70da05b..253859f 100644
--- a/mapstructure.go
+++ b/mapstructure.go
@@ -87,6 +87,16 @@
return nil, errors.New("result must be addressable (a pointer)")
}
+ if config.Metadata != nil {
+ if config.Metadata.Keys == nil {
+ config.Metadata.Keys = make([]string, 0)
+ }
+
+ if config.Metadata.Unused == nil {
+ config.Metadata.Unused = make([]string, 0)
+ }
+ }
+
if config.TagName == "" {
config.TagName = "mapstructure"
}
@@ -327,34 +337,49 @@
dataValKeysUnused[dataValKey.Interface()] = struct{}{}
}
- errors := make([]string, 0)
- valType := val.Type()
- for i := 0; i < valType.NumField(); i++ {
- fieldType := valType.Field(i)
- fieldName := fieldType.Name
+ // This slice will keep track of all the structs we'll be decoding.
+ // There can be more than one struct if there are embedded structs
+ // that are squashed.
+ structs := make([]reflect.Value, 1, 5)
+ structs[0] = val
- if fieldType.Anonymous {
- // We have an embedded field. We "squash" the fields down if
- // specified in the tag.
- squash := false
- tagParts := strings.Split(fieldType.Tag.Get(d.config.TagName), ",")
- for _, tag := range tagParts[1:] {
- if tag == "squash" {
- squash = true
- break
+ // Compile the list of all the fields that we're going to be decoding
+ // from all the structs.
+ fields := make(map[*reflect.StructField]reflect.Value)
+ for len(structs) > 0 {
+ structVal := structs[0]
+ structs = structs[1:]
+
+ structType := structVal.Type()
+ for i := 0; i < structType.NumField(); i++ {
+ fieldType := structType.Field(i)
+
+ if fieldType.Anonymous {
+ // We have an embedded field. We "squash" the fields down
+ // if specified in the tag.
+ squash := false
+ tagParts := strings.Split(fieldType.Tag.Get(d.config.TagName), ",")
+ for _, tag := range tagParts[1:] {
+ if tag == "squash" {
+ squash = true
+ break
+ }
+ }
+
+ if squash {
+ structs = append(structs, val.FieldByName(fieldType.Name))
+ continue
}
}
- if squash {
- inner := val.FieldByName(fieldName)
- err := d.decodeStruct(name, data, inner)
- if err != nil {
- errors = appendErrors(errors, err)
- }
-
- continue
- }
+ // Normal struct field, store it away
+ fields[&fieldType] = structVal.Field(i)
}
+ }
+
+ errors := make([]string, 0)
+ for fieldType, field := range fields {
+ fieldName := fieldType.Name
tagValue := fieldType.Tag.Get(d.config.TagName)
if tagValue != "" {
@@ -386,7 +411,6 @@
// Delete the key we're using from the unused map so we stop tracking
delete(dataValKeysUnused, rawMapKey.Interface())
- field := val.Field(i)
if !field.IsValid() {
// This should never happen
panic("field is not valid")
diff --git a/mapstructure_test.go b/mapstructure_test.go
index 12e31c2..5535f8e 100644
--- a/mapstructure_test.go
+++ b/mapstructure_test.go
@@ -471,6 +471,42 @@
}
}
+func TestMetadata_Embedded(t *testing.T) {
+ t.Parallel()
+
+ input := map[string]interface{}{
+ "vstring": "foo",
+ "vunique": "bar",
+ }
+
+ var md Metadata
+ var result EmbeddedSquash
+ config := &DecoderConfig{
+ Metadata: &md,
+ Result: &result,
+ }
+
+ decoder, err := NewDecoder(config)
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ err = decoder.Decode(input)
+ if err != nil {
+ t.Fatalf("err: %s", err.Error())
+ }
+
+ expectedKeys := []string{"Vunique", "Vstring"}
+ if !reflect.DeepEqual(md.Keys, expectedKeys) {
+ t.Fatalf("bad keys: %#v", md.Keys)
+ }
+
+ expectedUnused := []string{}
+ if !reflect.DeepEqual(md.Unused, expectedUnused) {
+ t.Fatalf("bad unused: %#v", md.Unused)
+ }
+}
+
func TestNonPtrValue(t *testing.T) {
t.Parallel()