decode hooks can take types or kinds
diff --git a/decode_hooks.go b/decode_hooks.go
index 087a392..e5a949b 100644
--- a/decode_hooks.go
+++ b/decode_hooks.go
@@ -1,11 +1,59 @@
package mapstructure
import (
+ "errors"
"reflect"
"strconv"
"strings"
)
+// typedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns
+// it into the proper DecodeHookFunc type, such as DecodeHookFuncType.
+func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc {
+ // Create variables here so we can reference them with the reflect pkg
+ var f1 DecodeHookFuncType
+ var f2 DecodeHookFuncKind
+
+ // Fill in the variables into this interface and the rest is done
+ // automatically using the reflect package.
+ potential := []interface{}{f1, f2}
+
+ v := reflect.ValueOf(h)
+ vt := v.Type()
+ for _, raw := range potential {
+ p := reflect.ValueOf(raw)
+ pt := p.Type()
+ if vt.ConvertibleTo(pt) {
+ return v.Convert(pt).Interface()
+ }
+ }
+
+ return nil
+}
+
+// DecodeHookExec executes the given decode hook. This should be used
+// since it'll naturally degrade to the older backwards compatible DecodeHookFunc
+// that took reflect.Kind instead of reflect.Type.
+func DecodeHookExec(
+ raw DecodeHookFunc,
+ from reflect.Type, to reflect.Type,
+ data interface{}) (interface{}, error) {
+ // Build our arguments that reflect expects
+ argVals := make([]reflect.Value, 3)
+ argVals[0] = reflect.ValueOf(from)
+ argVals[1] = reflect.ValueOf(to)
+ argVals[2] = reflect.ValueOf(data)
+
+ switch f := typedDecodeHook(raw).(type) {
+ case DecodeHookFuncType:
+ return f(from, to, data)
+ case DecodeHookFuncKind:
+ return f(from.Kind(), to.Kind(), data)
+ default:
+ return nil, errors.New("invalid decode hook signature")
+ }
+}
+
// ComposeDecodeHookFunc creates a single DecodeHookFunc that
// automatically composes multiple DecodeHookFuncs.
//
@@ -13,18 +61,18 @@
// previous transformation.
func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc {
return func(
- f reflect.Kind,
- t reflect.Kind,
+ f reflect.Type,
+ t reflect.Type,
data interface{}) (interface{}, error) {
var err error
for _, f1 := range fs {
- data, err = f1(f, t, data)
+ data, err = DecodeHookExec(f1, f, t, data)
if err != nil {
return nil, err
}
// Modify the from kind to be correct with the new data
- f = getKind(reflect.ValueOf(data))
+ f = reflect.ValueOf(data).Type()
}
return data, nil
diff --git a/decode_hooks_test.go b/decode_hooks_test.go
index b417dee..84e3fc5 100644
--- a/decode_hooks_test.go
+++ b/decode_hooks_test.go
@@ -23,7 +23,8 @@
f := ComposeDecodeHookFunc(f1, f2)
- result, err := f(reflect.String, reflect.Slice, "")
+ result, err := DecodeHookExec(
+ f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), "")
if err != nil {
t.Fatalf("bad: %s", err)
}
@@ -43,7 +44,8 @@
f := ComposeDecodeHookFunc(f1, f2)
- _, err := f(reflect.String, reflect.Slice, 42)
+ _, err := DecodeHookExec(
+ f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), 42)
if err.Error() != "foo" {
t.Fatalf("bad: %s", err)
}
@@ -69,7 +71,8 @@
f := ComposeDecodeHookFunc(f1, f2)
- _, err := f(reflect.String, reflect.Slice, "")
+ _, err := DecodeHookExec(
+ f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), "")
if err != nil {
t.Fatalf("bad: %s", err)
}
@@ -81,24 +84,26 @@
func TestStringToSliceHookFunc(t *testing.T) {
f := StringToSliceHookFunc(",")
+ strType := reflect.TypeOf("")
+ sliceType := reflect.TypeOf([]byte(""))
cases := []struct {
- f, t reflect.Kind
+ f, t reflect.Type
data interface{}
result interface{}
err bool
}{
- {reflect.Slice, reflect.Slice, 42, 42, false},
- {reflect.String, reflect.String, 42, 42, false},
+ {sliceType, sliceType, 42, 42, false},
+ {strType, strType, 42, 42, false},
{
- reflect.String,
- reflect.Slice,
+ strType,
+ sliceType,
"foo,bar,baz",
[]string{"foo", "bar", "baz"},
false,
},
{
- reflect.String,
- reflect.Slice,
+ strType,
+ sliceType,
"",
[]string{},
false,
@@ -106,7 +111,7 @@
}
for i, tc := range cases {
- actual, err := f(tc.f, tc.t, tc.data)
+ actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected err %#v", i, tc.err)
}
@@ -121,56 +126,59 @@
func TestWeaklyTypedHook(t *testing.T) {
var f DecodeHookFunc = WeaklyTypedHook
+ boolType := reflect.TypeOf(true)
+ strType := reflect.TypeOf("")
+ sliceType := reflect.TypeOf([]byte(""))
cases := []struct {
- f, t reflect.Kind
+ f, t reflect.Type
data interface{}
result interface{}
err bool
}{
// TO STRING
{
- reflect.Bool,
- reflect.String,
+ boolType,
+ strType,
false,
"0",
false,
},
{
- reflect.Bool,
- reflect.String,
+ boolType,
+ strType,
true,
"1",
false,
},
{
- reflect.Float32,
- reflect.String,
+ reflect.TypeOf(float32(1)),
+ strType,
float32(7),
"7",
false,
},
{
- reflect.Int,
- reflect.String,
+ reflect.TypeOf(int(1)),
+ strType,
int(7),
"7",
false,
},
{
- reflect.Slice,
- reflect.String,
+ sliceType,
+ strType,
[]uint8("foo"),
"foo",
false,
},
{
- reflect.Uint,
- reflect.String,
+ reflect.TypeOf(uint(1)),
+ strType,
uint(7),
"7",
false,
@@ -178,7 +186,7 @@
}
for i, tc := range cases {
- actual, err := f(tc.f, tc.t, tc.data)
+ actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected err %#v", i, tc.err)
}
diff --git a/mapstructure.go b/mapstructure.go
index 9fceaef..070495d 100644
--- a/mapstructure.go
+++ b/mapstructure.go
@@ -19,10 +19,12 @@
// DecodeHookFunc is the callback function that can be used for
// data transformations. See "DecodeHook" in the DecoderConfig
// struct.
-type DecodeHookFunc func(
- from reflect.Kind,
- to reflect.Kind,
- data interface{}) (interface{}, error)
+//
+// The type should be DecodeHookFuncType or DecodeHookFuncKind.
+type DecodeHookFunc interface{}
+
+type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface{}, error)
+type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error)
// DecoderConfig is the configuration that is used to create a new decoder
// and allows customization of various aspects of decoding.
@@ -181,7 +183,9 @@
if d.config.DecodeHook != nil {
// We have a DecodeHook, so let's pre-process the data.
var err error
- data, err = d.config.DecodeHook(getKind(dataVal), getKind(val), data)
+ data, err = DecodeHookExec(
+ d.config.DecodeHook,
+ dataVal.Type(), val.Type(), data)
if err != nil {
return err
}