ComposeDecodeHookFunc
diff --git a/decode_hooks.go b/decode_hooks.go
index 1321ba1..f9ae144 100644
--- a/decode_hooks.go
+++ b/decode_hooks.go
@@ -5,6 +5,31 @@
"strings"
)
+// ComposeDecodeHookFunc creates a single DecodeHookFunc that
+// automatically composes multiple DecodeHookFuncs.
+//
+// The composed funcs are called in order, with the result of the
+// previous transformation.
+func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc {
+ return func(
+ f reflect.Kind,
+ t reflect.Kind,
+ data interface{}) (interface{}, error) {
+ var err error
+ for _, f1 := range fs {
+ data, err = f1(f, t, data)
+ if err != nil {
+ return nil, err
+ }
+
+ // Modify the from kind to be correct with the new data
+ f = getKind(reflect.ValueOf(data))
+ }
+
+ return data, nil
+ }
+}
+
// StringToSliceHookFunc returns a DecodeHookFunc that converts
// string to []string by splitting on the given sep.
func StringToSliceHookFunc(sep string) DecodeHookFunc {
diff --git a/decode_hooks_test.go b/decode_hooks_test.go
index 2e22e22..6a226a7 100644
--- a/decode_hooks_test.go
+++ b/decode_hooks_test.go
@@ -1,10 +1,83 @@
package mapstructure
import (
+ "errors"
"reflect"
"testing"
)
+func TestComposeDecodeHookFunc(t *testing.T) {
+ f1 := func(
+ f reflect.Kind,
+ t reflect.Kind,
+ data interface{}) (interface{}, error) {
+ return data.(string) + "foo", nil
+ }
+
+ f2 := func(
+ f reflect.Kind,
+ t reflect.Kind,
+ data interface{}) (interface{}, error) {
+ return data.(string) + "bar", nil
+ }
+
+ f := ComposeDecodeHookFunc(f1, f2)
+
+ result, err := f(reflect.String, reflect.Slice, "")
+ if err != nil {
+ t.Fatalf("bad: %s", err)
+ }
+ if result.(string) != "foobar" {
+ t.Fatalf("bad: %#v", result)
+ }
+}
+
+func TestComposeDecodeHookFunc_err(t *testing.T) {
+ f1 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) {
+ return nil, errors.New("foo")
+ }
+
+ f2 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) {
+ panic("NOPE")
+ }
+
+ f := ComposeDecodeHookFunc(f1, f2)
+
+ _, err := f(reflect.String, reflect.Slice, 42)
+ if err.Error() != "foo" {
+ t.Fatalf("bad: %s", err)
+ }
+}
+
+func TestComposeDecodeHookFunc_kinds(t *testing.T) {
+ var f2From reflect.Kind
+
+ f1 := func(
+ f reflect.Kind,
+ t reflect.Kind,
+ data interface{}) (interface{}, error) {
+ return int(42), nil
+ }
+
+ f2 := func(
+ f reflect.Kind,
+ t reflect.Kind,
+ data interface{}) (interface{}, error) {
+ f2From = f
+ return data, nil
+ }
+
+ f := ComposeDecodeHookFunc(f1, f2)
+
+ _, err := f(reflect.String, reflect.Slice, "")
+ if err != nil {
+ t.Fatalf("bad: %s", err)
+ }
+ if f2From != reflect.Int {
+ t.Fatalf("bad: %#v", f2From)
+ }
+}
+
func TestStringToSliceHookFunc(t *testing.T) {
f := StringToSliceHookFunc(",")
diff --git a/mapstructure.go b/mapstructure.go
index 66c6042..0c23d8b 100644
--- a/mapstructure.go
+++ b/mapstructure.go
@@ -163,14 +163,14 @@
if d.config.DecodeHook != nil {
// We have a DecodeHook, so let's pre-process the data.
var err error
- data, err = d.config.DecodeHook(d.getKind(dataVal), d.getKind(val), data)
+ data, err = d.config.DecodeHook(getKind(dataVal), getKind(val), data)
if err != nil {
return err
}
}
var err error
- dataKind := d.getKind(val)
+ dataKind := getKind(val)
switch dataKind {
case reflect.Bool:
err = d.decodeBool(name, data, val)
@@ -206,21 +206,6 @@
return err
}
-func (d *Decoder) getKind(val reflect.Value) reflect.Kind {
- kind := val.Kind()
-
- switch {
- case kind >= reflect.Int && kind <= reflect.Int64:
- return reflect.Int
- case kind >= reflect.Uint && kind <= reflect.Uint64:
- return reflect.Uint
- case kind >= reflect.Float32 && kind <= reflect.Float64:
- return reflect.Float32
- default:
- return kind
- }
-}
-
// This decodes a basic type (bool, int, string, etc.) and sets the
// value to "data" of that type.
func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value) error {
@@ -238,7 +223,7 @@
func (d *Decoder) decodeString(name string, data interface{}, val reflect.Value) error {
dataVal := reflect.ValueOf(data)
- dataKind := d.getKind(dataVal)
+ dataKind := getKind(dataVal)
converted := true
switch {
@@ -280,7 +265,7 @@
func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) error {
dataVal := reflect.ValueOf(data)
- dataKind := d.getKind(dataVal)
+ dataKind := getKind(dataVal)
switch {
case dataKind == reflect.Int:
@@ -313,7 +298,7 @@
func (d *Decoder) decodeUint(name string, data interface{}, val reflect.Value) error {
dataVal := reflect.ValueOf(data)
- dataKind := d.getKind(dataVal)
+ dataKind := getKind(dataVal)
switch {
case dataKind == reflect.Int:
@@ -346,7 +331,7 @@
func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) error {
dataVal := reflect.ValueOf(data)
- dataKind := d.getKind(dataVal)
+ dataKind := getKind(dataVal)
switch {
case dataKind == reflect.Bool:
@@ -377,7 +362,7 @@
func (d *Decoder) decodeFloat(name string, data interface{}, val reflect.Value) error {
dataVal := reflect.ValueOf(data)
- dataKind := d.getKind(dataVal)
+ dataKind := getKind(dataVal)
switch {
case dataKind == reflect.Int:
@@ -685,3 +670,18 @@
return nil
}
+
+func getKind(val reflect.Value) reflect.Kind {
+ kind := val.Kind()
+
+ switch {
+ case kind >= reflect.Int && kind <= reflect.Int64:
+ return reflect.Int
+ case kind >= reflect.Uint && kind <= reflect.Uint64:
+ return reflect.Uint
+ case kind >= reflect.Float32 && kind <= reflect.Float64:
+ return reflect.Float32
+ default:
+ return kind
+ }
+}