ComposeDecodeHookFunc
diff --git a/decode_hooks.go b/decode_hooks.go index 1321ba1..f9ae144 100644 --- a/decode_hooks.go +++ b/decode_hooks.go
@@ -5,6 +5,31 @@ "strings" ) +// ComposeDecodeHookFunc creates a single DecodeHookFunc that +// automatically composes multiple DecodeHookFuncs. +// +// The composed funcs are called in order, with the result of the +// previous transformation. +func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc { + return func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + var err error + for _, f1 := range fs { + data, err = f1(f, t, data) + if err != nil { + return nil, err + } + + // Modify the from kind to be correct with the new data + f = getKind(reflect.ValueOf(data)) + } + + return data, nil + } +} + // StringToSliceHookFunc returns a DecodeHookFunc that converts // string to []string by splitting on the given sep. func StringToSliceHookFunc(sep string) DecodeHookFunc {
diff --git a/decode_hooks_test.go b/decode_hooks_test.go index 2e22e22..6a226a7 100644 --- a/decode_hooks_test.go +++ b/decode_hooks_test.go
@@ -1,10 +1,83 @@ package mapstructure import ( + "errors" "reflect" "testing" ) +func TestComposeDecodeHookFunc(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "foo", nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := ComposeDecodeHookFunc(f1, f2) + + result, err := f(reflect.String, reflect.Slice, "") + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "foobar" { + t.Fatalf("bad: %#v", result) + } +} + +func TestComposeDecodeHookFunc_err(t *testing.T) { + f1 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) { + return nil, errors.New("foo") + } + + f2 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) { + panic("NOPE") + } + + f := ComposeDecodeHookFunc(f1, f2) + + _, err := f(reflect.String, reflect.Slice, 42) + if err.Error() != "foo" { + t.Fatalf("bad: %s", err) + } +} + +func TestComposeDecodeHookFunc_kinds(t *testing.T) { + var f2From reflect.Kind + + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return int(42), nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + f2From = f + return data, nil + } + + f := ComposeDecodeHookFunc(f1, f2) + + _, err := f(reflect.String, reflect.Slice, "") + if err != nil { + t.Fatalf("bad: %s", err) + } + if f2From != reflect.Int { + t.Fatalf("bad: %#v", f2From) + } +} + func TestStringToSliceHookFunc(t *testing.T) { f := StringToSliceHookFunc(",")
diff --git a/mapstructure.go b/mapstructure.go index 66c6042..0c23d8b 100644 --- a/mapstructure.go +++ b/mapstructure.go
@@ -163,14 +163,14 @@ if d.config.DecodeHook != nil { // We have a DecodeHook, so let's pre-process the data. var err error - data, err = d.config.DecodeHook(d.getKind(dataVal), d.getKind(val), data) + data, err = d.config.DecodeHook(getKind(dataVal), getKind(val), data) if err != nil { return err } } var err error - dataKind := d.getKind(val) + dataKind := getKind(val) switch dataKind { case reflect.Bool: err = d.decodeBool(name, data, val) @@ -206,21 +206,6 @@ return err } -func (d *Decoder) getKind(val reflect.Value) reflect.Kind { - kind := val.Kind() - - switch { - case kind >= reflect.Int && kind <= reflect.Int64: - return reflect.Int - case kind >= reflect.Uint && kind <= reflect.Uint64: - return reflect.Uint - case kind >= reflect.Float32 && kind <= reflect.Float64: - return reflect.Float32 - default: - return kind - } -} - // This decodes a basic type (bool, int, string, etc.) and sets the // value to "data" of that type. func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value) error { @@ -238,7 +223,7 @@ func (d *Decoder) decodeString(name string, data interface{}, val reflect.Value) error { dataVal := reflect.ValueOf(data) - dataKind := d.getKind(dataVal) + dataKind := getKind(dataVal) converted := true switch { @@ -280,7 +265,7 @@ func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) error { dataVal := reflect.ValueOf(data) - dataKind := d.getKind(dataVal) + dataKind := getKind(dataVal) switch { case dataKind == reflect.Int: @@ -313,7 +298,7 @@ func (d *Decoder) decodeUint(name string, data interface{}, val reflect.Value) error { dataVal := reflect.ValueOf(data) - dataKind := d.getKind(dataVal) + dataKind := getKind(dataVal) switch { case dataKind == reflect.Int: @@ -346,7 +331,7 @@ func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) error { dataVal := reflect.ValueOf(data) - dataKind := d.getKind(dataVal) + dataKind := getKind(dataVal) switch { case dataKind == reflect.Bool: @@ -377,7 +362,7 @@ func (d *Decoder) decodeFloat(name string, data interface{}, val reflect.Value) error { dataVal := reflect.ValueOf(data) - dataKind := d.getKind(dataVal) + dataKind := getKind(dataVal) switch { case dataKind == reflect.Int: @@ -685,3 +670,18 @@ return nil } + +func getKind(val reflect.Value) reflect.Kind { + kind := val.Kind() + + switch { + case kind >= reflect.Int && kind <= reflect.Int64: + return reflect.Int + case kind >= reflect.Uint && kind <= reflect.Uint64: + return reflect.Uint + case kind >= reflect.Float32 && kind <= reflect.Float64: + return reflect.Float32 + default: + return kind + } +}