Merge pull request #34 from mitchellh/f-decode-hook-type

Decode hook accepts reflect.Type, plus time to Duration parsing
diff --git a/decode_hooks.go b/decode_hooks.go
index 087a392..aa91f76 100644
--- a/decode_hooks.go
+++ b/decode_hooks.go
@@ -1,11 +1,59 @@
 package mapstructure
 
 import (
+	"errors"
 	"reflect"
 	"strconv"
 	"strings"
+	"time"
 )
 
+// typedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns
+// it into the proper DecodeHookFunc type, such as DecodeHookFuncType.
+func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc {
+	// Create variables here so we can reference them with the reflect pkg
+	var f1 DecodeHookFuncType
+	var f2 DecodeHookFuncKind
+
+	// Fill in the variables into this interface and the rest is done
+	// automatically using the reflect package.
+	potential := []interface{}{f1, f2}
+
+	v := reflect.ValueOf(h)
+	vt := v.Type()
+	for _, raw := range potential {
+		pt := reflect.ValueOf(raw).Type()
+		if vt.ConvertibleTo(pt) {
+			return v.Convert(pt).Interface()
+		}
+	}
+
+	return nil
+}
+
+// DecodeHookExec executes the given decode hook. This should be used
+// since it'll naturally degrade to the older backwards compatible DecodeHookFunc
+// that took reflect.Kind instead of reflect.Type.
+func DecodeHookExec(
+	raw DecodeHookFunc,
+	from reflect.Type, to reflect.Type,
+	data interface{}) (interface{}, error) {
+	// Build our arguments that reflect expects
+	argVals := make([]reflect.Value, 3)
+	argVals[0] = reflect.ValueOf(from)
+	argVals[1] = reflect.ValueOf(to)
+	argVals[2] = reflect.ValueOf(data)
+
+	switch f := typedDecodeHook(raw).(type) {
+	case DecodeHookFuncType:
+		return f(from, to, data)
+	case DecodeHookFuncKind:
+		return f(from.Kind(), to.Kind(), data)
+	default:
+		return nil, errors.New("invalid decode hook signature")
+	}
+}
+
 // ComposeDecodeHookFunc creates a single DecodeHookFunc that
 // automatically composes multiple DecodeHookFuncs.
 //
@@ -13,18 +61,18 @@
 // previous transformation.
 func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc {
 	return func(
-		f reflect.Kind,
-		t reflect.Kind,
+		f reflect.Type,
+		t reflect.Type,
 		data interface{}) (interface{}, error) {
 		var err error
 		for _, f1 := range fs {
-			data, err = f1(f, t, data)
+			data, err = DecodeHookExec(f1, f, t, data)
 			if err != nil {
 				return nil, err
 			}
 
 			// Modify the from kind to be correct with the new data
-			f = getKind(reflect.ValueOf(data))
+			f = reflect.ValueOf(data).Type()
 		}
 
 		return data, nil
@@ -51,6 +99,25 @@
 	}
 }
 
+// StringToTimeDurationHookFunc returns a DecodeHookFunc that converts
+// strings to time.Duration.
+func StringToTimeDurationHookFunc() DecodeHookFunc {
+	return func(
+		f reflect.Type,
+		t reflect.Type,
+		data interface{}) (interface{}, error) {
+		if f.Kind() != reflect.String {
+			return data, nil
+		}
+		if t != reflect.TypeOf(time.Duration(5)) {
+			return data, nil
+		}
+
+		// Convert it by parsing
+		return time.ParseDuration(data.(string))
+	}
+}
+
 func WeaklyTypedHook(
 	f reflect.Kind,
 	t reflect.Kind,
diff --git a/decode_hooks_test.go b/decode_hooks_test.go
index b417dee..53289af 100644
--- a/decode_hooks_test.go
+++ b/decode_hooks_test.go
@@ -4,6 +4,7 @@
 	"errors"
 	"reflect"
 	"testing"
+	"time"
 )
 
 func TestComposeDecodeHookFunc(t *testing.T) {
@@ -23,7 +24,8 @@
 
 	f := ComposeDecodeHookFunc(f1, f2)
 
-	result, err := f(reflect.String, reflect.Slice, "")
+	result, err := DecodeHookExec(
+		f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), "")
 	if err != nil {
 		t.Fatalf("bad: %s", err)
 	}
@@ -43,7 +45,8 @@
 
 	f := ComposeDecodeHookFunc(f1, f2)
 
-	_, err := f(reflect.String, reflect.Slice, 42)
+	_, err := DecodeHookExec(
+		f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), 42)
 	if err.Error() != "foo" {
 		t.Fatalf("bad: %s", err)
 	}
@@ -69,7 +72,8 @@
 
 	f := ComposeDecodeHookFunc(f1, f2)
 
-	_, err := f(reflect.String, reflect.Slice, "")
+	_, err := DecodeHookExec(
+		f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), "")
 	if err != nil {
 		t.Fatalf("bad: %s", err)
 	}
@@ -81,24 +85,26 @@
 func TestStringToSliceHookFunc(t *testing.T) {
 	f := StringToSliceHookFunc(",")
 
+	strType := reflect.TypeOf("")
+	sliceType := reflect.TypeOf([]byte(""))
 	cases := []struct {
-		f, t   reflect.Kind
+		f, t   reflect.Type
 		data   interface{}
 		result interface{}
 		err    bool
 	}{
-		{reflect.Slice, reflect.Slice, 42, 42, false},
-		{reflect.String, reflect.String, 42, 42, false},
+		{sliceType, sliceType, 42, 42, false},
+		{strType, strType, 42, 42, false},
 		{
-			reflect.String,
-			reflect.Slice,
+			strType,
+			sliceType,
 			"foo,bar,baz",
 			[]string{"foo", "bar", "baz"},
 			false,
 		},
 		{
-			reflect.String,
-			reflect.Slice,
+			strType,
+			sliceType,
 			"",
 			[]string{},
 			false,
@@ -106,7 +112,36 @@
 	}
 
 	for i, tc := range cases {
-		actual, err := f(tc.f, tc.t, tc.data)
+		actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data)
+		if tc.err != (err != nil) {
+			t.Fatalf("case %d: expected err %#v", i, tc.err)
+		}
+		if !reflect.DeepEqual(actual, tc.result) {
+			t.Fatalf(
+				"case %d: expected %#v, got %#v",
+				i, tc.result, actual)
+		}
+	}
+}
+
+func TestStringToTimeDurationHookFunc(t *testing.T) {
+	f := StringToTimeDurationHookFunc()
+
+	strType := reflect.TypeOf("")
+	timeType := reflect.TypeOf(time.Duration(5))
+	cases := []struct {
+		f, t   reflect.Type
+		data   interface{}
+		result interface{}
+		err    bool
+	}{
+		{strType, timeType, "5s", 5 * time.Second, false},
+		{strType, timeType, "5", time.Duration(0), true},
+		{strType, strType, "5", "5", false},
+	}
+
+	for i, tc := range cases {
+		actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data)
 		if tc.err != (err != nil) {
 			t.Fatalf("case %d: expected err %#v", i, tc.err)
 		}
@@ -121,56 +156,59 @@
 func TestWeaklyTypedHook(t *testing.T) {
 	var f DecodeHookFunc = WeaklyTypedHook
 
+	boolType := reflect.TypeOf(true)
+	strType := reflect.TypeOf("")
+	sliceType := reflect.TypeOf([]byte(""))
 	cases := []struct {
-		f, t   reflect.Kind
+		f, t   reflect.Type
 		data   interface{}
 		result interface{}
 		err    bool
 	}{
 		// TO STRING
 		{
-			reflect.Bool,
-			reflect.String,
+			boolType,
+			strType,
 			false,
 			"0",
 			false,
 		},
 
 		{
-			reflect.Bool,
-			reflect.String,
+			boolType,
+			strType,
 			true,
 			"1",
 			false,
 		},
 
 		{
-			reflect.Float32,
-			reflect.String,
+			reflect.TypeOf(float32(1)),
+			strType,
 			float32(7),
 			"7",
 			false,
 		},
 
 		{
-			reflect.Int,
-			reflect.String,
+			reflect.TypeOf(int(1)),
+			strType,
 			int(7),
 			"7",
 			false,
 		},
 
 		{
-			reflect.Slice,
-			reflect.String,
+			sliceType,
+			strType,
 			[]uint8("foo"),
 			"foo",
 			false,
 		},
 
 		{
-			reflect.Uint,
-			reflect.String,
+			reflect.TypeOf(uint(1)),
+			strType,
 			uint(7),
 			"7",
 			false,
@@ -178,7 +216,7 @@
 	}
 
 	for i, tc := range cases {
-		actual, err := f(tc.f, tc.t, tc.data)
+		actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data)
 		if tc.err != (err != nil) {
 			t.Fatalf("case %d: expected err %#v", i, tc.err)
 		}
diff --git a/mapstructure.go b/mapstructure.go
index 9fceaef..8886d34 100644
--- a/mapstructure.go
+++ b/mapstructure.go
@@ -19,10 +19,20 @@
 // DecodeHookFunc is the callback function that can be used for
 // data transformations. See "DecodeHook" in the DecoderConfig
 // struct.
-type DecodeHookFunc func(
-	from reflect.Kind,
-	to reflect.Kind,
-	data interface{}) (interface{}, error)
+//
+// The type should be DecodeHookFuncType or DecodeHookFuncKind.
+// Either is accepted. Types are a superset of Kinds (Types can return
+// Kinds) and are generally a richer thing to use, but Kinds are simpler
+// if you only need those.
+//
+// The reason DecodeHookFunc is multi-typed is for backwards compatibility:
+// we started with Kinds and then realized Types were the better solution,
+// but have a promise to not break backwards compat so we now support
+// both.
+type DecodeHookFunc interface{}
+
+type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface{}, error)
+type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error)
 
 // DecoderConfig is the configuration that is used to create a new decoder
 // and allows customization of various aspects of decoding.
@@ -181,7 +191,9 @@
 	if d.config.DecodeHook != nil {
 		// We have a DecodeHook, so let's pre-process the data.
 		var err error
-		data, err = d.config.DecodeHook(getKind(dataVal), getKind(val), data)
+		data, err = DecodeHookExec(
+			d.config.DecodeHook,
+			dataVal.Type(), val.Type(), data)
 		if err != nil {
 			return err
 		}
diff --git a/mapstructure_test.go b/mapstructure_test.go
index 1444b13..0c9a31f 100644
--- a/mapstructure_test.go
+++ b/mapstructure_test.go
@@ -261,6 +261,43 @@
 	}
 }
 
+func TestDecode_DecodeHookType(t *testing.T) {
+	t.Parallel()
+
+	input := map[string]interface{}{
+		"vint": "WHAT",
+	}
+
+	decodeHook := func(from reflect.Type, to reflect.Type, v interface{}) (interface{}, error) {
+		if from.Kind() == reflect.String &&
+			to.Kind() != reflect.String {
+			return 5, nil
+		}
+
+		return v, nil
+	}
+
+	var result Basic
+	config := &DecoderConfig{
+		DecodeHook: decodeHook,
+		Result:     &result,
+	}
+
+	decoder, err := NewDecoder(config)
+	if err != nil {
+		t.Fatalf("err: %s", err)
+	}
+
+	err = decoder.Decode(input)
+	if err != nil {
+		t.Fatalf("got an err: %s", err)
+	}
+
+	if result.Vint != 5 {
+		t.Errorf("vint should be 5: %#v", result.Vint)
+	}
+}
+
 func TestDecode_Nil(t *testing.T) {
 	t.Parallel()