decode hooks can take types or kinds
diff --git a/decode_hooks.go b/decode_hooks.go index 087a392..e5a949b 100644 --- a/decode_hooks.go +++ b/decode_hooks.go
@@ -1,11 +1,59 @@ package mapstructure import ( + "errors" "reflect" "strconv" "strings" ) +// typedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns +// it into the proper DecodeHookFunc type, such as DecodeHookFuncType. +func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc { + // Create variables here so we can reference them with the reflect pkg + var f1 DecodeHookFuncType + var f2 DecodeHookFuncKind + + // Fill in the variables into this interface and the rest is done + // automatically using the reflect package. + potential := []interface{}{f1, f2} + + v := reflect.ValueOf(h) + vt := v.Type() + for _, raw := range potential { + p := reflect.ValueOf(raw) + pt := p.Type() + if vt.ConvertibleTo(pt) { + return v.Convert(pt).Interface() + } + } + + return nil +} + +// DecodeHookExec executes the given decode hook. This should be used +// since it'll naturally degrade to the older backwards compatible DecodeHookFunc +// that took reflect.Kind instead of reflect.Type. +func DecodeHookExec( + raw DecodeHookFunc, + from reflect.Type, to reflect.Type, + data interface{}) (interface{}, error) { + // Build our arguments that reflect expects + argVals := make([]reflect.Value, 3) + argVals[0] = reflect.ValueOf(from) + argVals[1] = reflect.ValueOf(to) + argVals[2] = reflect.ValueOf(data) + + switch f := typedDecodeHook(raw).(type) { + case DecodeHookFuncType: + return f(from, to, data) + case DecodeHookFuncKind: + return f(from.Kind(), to.Kind(), data) + default: + return nil, errors.New("invalid decode hook signature") + } +} + // ComposeDecodeHookFunc creates a single DecodeHookFunc that // automatically composes multiple DecodeHookFuncs. // @@ -13,18 +61,18 @@ // previous transformation. func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc { return func( - f reflect.Kind, - t reflect.Kind, + f reflect.Type, + t reflect.Type, data interface{}) (interface{}, error) { var err error for _, f1 := range fs { - data, err = f1(f, t, data) + data, err = DecodeHookExec(f1, f, t, data) if err != nil { return nil, err } // Modify the from kind to be correct with the new data - f = getKind(reflect.ValueOf(data)) + f = reflect.ValueOf(data).Type() } return data, nil
diff --git a/decode_hooks_test.go b/decode_hooks_test.go index b417dee..84e3fc5 100644 --- a/decode_hooks_test.go +++ b/decode_hooks_test.go
@@ -23,7 +23,8 @@ f := ComposeDecodeHookFunc(f1, f2) - result, err := f(reflect.String, reflect.Slice, "") + result, err := DecodeHookExec( + f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), "") if err != nil { t.Fatalf("bad: %s", err) } @@ -43,7 +44,8 @@ f := ComposeDecodeHookFunc(f1, f2) - _, err := f(reflect.String, reflect.Slice, 42) + _, err := DecodeHookExec( + f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), 42) if err.Error() != "foo" { t.Fatalf("bad: %s", err) } @@ -69,7 +71,8 @@ f := ComposeDecodeHookFunc(f1, f2) - _, err := f(reflect.String, reflect.Slice, "") + _, err := DecodeHookExec( + f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), "") if err != nil { t.Fatalf("bad: %s", err) } @@ -81,24 +84,26 @@ func TestStringToSliceHookFunc(t *testing.T) { f := StringToSliceHookFunc(",") + strType := reflect.TypeOf("") + sliceType := reflect.TypeOf([]byte("")) cases := []struct { - f, t reflect.Kind + f, t reflect.Type data interface{} result interface{} err bool }{ - {reflect.Slice, reflect.Slice, 42, 42, false}, - {reflect.String, reflect.String, 42, 42, false}, + {sliceType, sliceType, 42, 42, false}, + {strType, strType, 42, 42, false}, { - reflect.String, - reflect.Slice, + strType, + sliceType, "foo,bar,baz", []string{"foo", "bar", "baz"}, false, }, { - reflect.String, - reflect.Slice, + strType, + sliceType, "", []string{}, false, @@ -106,7 +111,7 @@ } for i, tc := range cases { - actual, err := f(tc.f, tc.t, tc.data) + actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data) if tc.err != (err != nil) { t.Fatalf("case %d: expected err %#v", i, tc.err) } @@ -121,56 +126,59 @@ func TestWeaklyTypedHook(t *testing.T) { var f DecodeHookFunc = WeaklyTypedHook + boolType := reflect.TypeOf(true) + strType := reflect.TypeOf("") + sliceType := reflect.TypeOf([]byte("")) cases := []struct { - f, t reflect.Kind + f, t reflect.Type data interface{} result interface{} err bool }{ // TO STRING { - reflect.Bool, - reflect.String, + boolType, + strType, false, "0", false, }, { - reflect.Bool, - reflect.String, + boolType, + strType, true, "1", false, }, { - reflect.Float32, - reflect.String, + reflect.TypeOf(float32(1)), + strType, float32(7), "7", false, }, { - reflect.Int, - reflect.String, + reflect.TypeOf(int(1)), + strType, int(7), "7", false, }, { - reflect.Slice, - reflect.String, + sliceType, + strType, []uint8("foo"), "foo", false, }, { - reflect.Uint, - reflect.String, + reflect.TypeOf(uint(1)), + strType, uint(7), "7", false, @@ -178,7 +186,7 @@ } for i, tc := range cases { - actual, err := f(tc.f, tc.t, tc.data) + actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data) if tc.err != (err != nil) { t.Fatalf("case %d: expected err %#v", i, tc.err) }
diff --git a/mapstructure.go b/mapstructure.go index 9fceaef..070495d 100644 --- a/mapstructure.go +++ b/mapstructure.go
@@ -19,10 +19,12 @@ // DecodeHookFunc is the callback function that can be used for // data transformations. See "DecodeHook" in the DecoderConfig // struct. -type DecodeHookFunc func( - from reflect.Kind, - to reflect.Kind, - data interface{}) (interface{}, error) +// +// The type should be DecodeHookFuncType or DecodeHookFuncKind. +type DecodeHookFunc interface{} + +type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface{}, error) +type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) // DecoderConfig is the configuration that is used to create a new decoder // and allows customization of various aspects of decoding. @@ -181,7 +183,9 @@ if d.config.DecodeHook != nil { // We have a DecodeHook, so let's pre-process the data. var err error - data, err = d.config.DecodeHook(getKind(dataVal), getKind(val), data) + data, err = DecodeHookExec( + d.config.DecodeHook, + dataVal.Type(), val.Type(), data) if err != nil { return err }