Merge pull request #31 from hectorj/handle-struct
Use struct directly if type matches
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 0000000..7f3fe9a
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,7 @@
+language: go
+
+go:
+ - 1.4
+
+script:
+ - go test
diff --git a/decode_hooks.go b/decode_hooks.go
index 087a392..aa91f76 100644
--- a/decode_hooks.go
+++ b/decode_hooks.go
@@ -1,11 +1,59 @@
package mapstructure
import (
+ "errors"
"reflect"
"strconv"
"strings"
+ "time"
)
+// typedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns
+// it into the proper DecodeHookFunc type, such as DecodeHookFuncType.
+func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc {
+ // Create variables here so we can reference them with the reflect pkg
+ var f1 DecodeHookFuncType
+ var f2 DecodeHookFuncKind
+
+ // Fill in the variables into this interface and the rest is done
+ // automatically using the reflect package.
+ potential := []interface{}{f1, f2}
+
+ v := reflect.ValueOf(h)
+ vt := v.Type()
+ for _, raw := range potential {
+ pt := reflect.ValueOf(raw).Type()
+ if vt.ConvertibleTo(pt) {
+ return v.Convert(pt).Interface()
+ }
+ }
+
+ return nil
+}
+
+// DecodeHookExec executes the given decode hook. This should be used
+// since it'll naturally degrade to the older backwards compatible DecodeHookFunc
+// that took reflect.Kind instead of reflect.Type.
+func DecodeHookExec(
+ raw DecodeHookFunc,
+ from reflect.Type, to reflect.Type,
+ data interface{}) (interface{}, error) {
+ // Build our arguments that reflect expects
+ argVals := make([]reflect.Value, 3)
+ argVals[0] = reflect.ValueOf(from)
+ argVals[1] = reflect.ValueOf(to)
+ argVals[2] = reflect.ValueOf(data)
+
+ switch f := typedDecodeHook(raw).(type) {
+ case DecodeHookFuncType:
+ return f(from, to, data)
+ case DecodeHookFuncKind:
+ return f(from.Kind(), to.Kind(), data)
+ default:
+ return nil, errors.New("invalid decode hook signature")
+ }
+}
+
// ComposeDecodeHookFunc creates a single DecodeHookFunc that
// automatically composes multiple DecodeHookFuncs.
//
@@ -13,18 +61,18 @@
// previous transformation.
func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc {
return func(
- f reflect.Kind,
- t reflect.Kind,
+ f reflect.Type,
+ t reflect.Type,
data interface{}) (interface{}, error) {
var err error
for _, f1 := range fs {
- data, err = f1(f, t, data)
+ data, err = DecodeHookExec(f1, f, t, data)
if err != nil {
return nil, err
}
// Modify the from kind to be correct with the new data
- f = getKind(reflect.ValueOf(data))
+ f = reflect.ValueOf(data).Type()
}
return data, nil
@@ -51,6 +99,25 @@
}
}
+// StringToTimeDurationHookFunc returns a DecodeHookFunc that converts
+// strings to time.Duration.
+func StringToTimeDurationHookFunc() DecodeHookFunc {
+ return func(
+ f reflect.Type,
+ t reflect.Type,
+ data interface{}) (interface{}, error) {
+ if f.Kind() != reflect.String {
+ return data, nil
+ }
+ if t != reflect.TypeOf(time.Duration(5)) {
+ return data, nil
+ }
+
+ // Convert it by parsing
+ return time.ParseDuration(data.(string))
+ }
+}
+
func WeaklyTypedHook(
f reflect.Kind,
t reflect.Kind,
diff --git a/decode_hooks_test.go b/decode_hooks_test.go
index b417dee..53289af 100644
--- a/decode_hooks_test.go
+++ b/decode_hooks_test.go
@@ -4,6 +4,7 @@
"errors"
"reflect"
"testing"
+ "time"
)
func TestComposeDecodeHookFunc(t *testing.T) {
@@ -23,7 +24,8 @@
f := ComposeDecodeHookFunc(f1, f2)
- result, err := f(reflect.String, reflect.Slice, "")
+ result, err := DecodeHookExec(
+ f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), "")
if err != nil {
t.Fatalf("bad: %s", err)
}
@@ -43,7 +45,8 @@
f := ComposeDecodeHookFunc(f1, f2)
- _, err := f(reflect.String, reflect.Slice, 42)
+ _, err := DecodeHookExec(
+ f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), 42)
if err.Error() != "foo" {
t.Fatalf("bad: %s", err)
}
@@ -69,7 +72,8 @@
f := ComposeDecodeHookFunc(f1, f2)
- _, err := f(reflect.String, reflect.Slice, "")
+ _, err := DecodeHookExec(
+ f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), "")
if err != nil {
t.Fatalf("bad: %s", err)
}
@@ -81,24 +85,26 @@
func TestStringToSliceHookFunc(t *testing.T) {
f := StringToSliceHookFunc(",")
+ strType := reflect.TypeOf("")
+ sliceType := reflect.TypeOf([]byte(""))
cases := []struct {
- f, t reflect.Kind
+ f, t reflect.Type
data interface{}
result interface{}
err bool
}{
- {reflect.Slice, reflect.Slice, 42, 42, false},
- {reflect.String, reflect.String, 42, 42, false},
+ {sliceType, sliceType, 42, 42, false},
+ {strType, strType, 42, 42, false},
{
- reflect.String,
- reflect.Slice,
+ strType,
+ sliceType,
"foo,bar,baz",
[]string{"foo", "bar", "baz"},
false,
},
{
- reflect.String,
- reflect.Slice,
+ strType,
+ sliceType,
"",
[]string{},
false,
@@ -106,7 +112,36 @@
}
for i, tc := range cases {
- actual, err := f(tc.f, tc.t, tc.data)
+ actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data)
+ if tc.err != (err != nil) {
+ t.Fatalf("case %d: expected err %#v", i, tc.err)
+ }
+ if !reflect.DeepEqual(actual, tc.result) {
+ t.Fatalf(
+ "case %d: expected %#v, got %#v",
+ i, tc.result, actual)
+ }
+ }
+}
+
+func TestStringToTimeDurationHookFunc(t *testing.T) {
+ f := StringToTimeDurationHookFunc()
+
+ strType := reflect.TypeOf("")
+ timeType := reflect.TypeOf(time.Duration(5))
+ cases := []struct {
+ f, t reflect.Type
+ data interface{}
+ result interface{}
+ err bool
+ }{
+ {strType, timeType, "5s", 5 * time.Second, false},
+ {strType, timeType, "5", time.Duration(0), true},
+ {strType, strType, "5", "5", false},
+ }
+
+ for i, tc := range cases {
+ actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected err %#v", i, tc.err)
}
@@ -121,56 +156,59 @@
func TestWeaklyTypedHook(t *testing.T) {
var f DecodeHookFunc = WeaklyTypedHook
+ boolType := reflect.TypeOf(true)
+ strType := reflect.TypeOf("")
+ sliceType := reflect.TypeOf([]byte(""))
cases := []struct {
- f, t reflect.Kind
+ f, t reflect.Type
data interface{}
result interface{}
err bool
}{
// TO STRING
{
- reflect.Bool,
- reflect.String,
+ boolType,
+ strType,
false,
"0",
false,
},
{
- reflect.Bool,
- reflect.String,
+ boolType,
+ strType,
true,
"1",
false,
},
{
- reflect.Float32,
- reflect.String,
+ reflect.TypeOf(float32(1)),
+ strType,
float32(7),
"7",
false,
},
{
- reflect.Int,
- reflect.String,
+ reflect.TypeOf(int(1)),
+ strType,
int(7),
"7",
false,
},
{
- reflect.Slice,
- reflect.String,
+ sliceType,
+ strType,
[]uint8("foo"),
"foo",
false,
},
{
- reflect.Uint,
- reflect.String,
+ reflect.TypeOf(uint(1)),
+ strType,
uint(7),
"7",
false,
@@ -178,7 +216,7 @@
}
for i, tc := range cases {
- actual, err := f(tc.f, tc.t, tc.data)
+ actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected err %#v", i, tc.err)
}
diff --git a/mapstructure.go b/mapstructure.go
index 6abe075..d48ac95 100644
--- a/mapstructure.go
+++ b/mapstructure.go
@@ -19,10 +19,20 @@
// DecodeHookFunc is the callback function that can be used for
// data transformations. See "DecodeHook" in the DecoderConfig
// struct.
-type DecodeHookFunc func(
- from reflect.Kind,
- to reflect.Kind,
- data interface{}) (interface{}, error)
+//
+// The type should be DecodeHookFuncType or DecodeHookFuncKind.
+// Either is accepted. Types are a superset of Kinds (Types can return
+// Kinds) and are generally a richer thing to use, but Kinds are simpler
+// if you only need those.
+//
+// The reason DecodeHookFunc is multi-typed is for backwards compatibility:
+// we started with Kinds and then realized Types were the better solution,
+// but have a promise to not break backwards compat so we now support
+// both.
+type DecodeHookFunc interface{}
+
+type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface{}, error)
+type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error)
// DecoderConfig is the configuration that is used to create a new decoder
// and allows customization of various aspects of decoding.
@@ -51,6 +61,7 @@
// - string to bool (accepts: 1, t, T, TRUE, true, True, 0, f, F,
// FALSE, false, False. Anything else is an error)
// - empty array = empty map and vice versa
+ // - negative numbers to overflowed uint values (base 10)
//
WeaklyTypedInput bool
@@ -180,7 +191,9 @@
if d.config.DecodeHook != nil {
// We have a DecodeHook, so let's pre-process the data.
var err error
- data, err = d.config.DecodeHook(getKind(dataVal), getKind(val), data)
+ data, err = DecodeHookExec(
+ d.config.DecodeHook,
+ dataVal.Type(), val.Type(), data)
if err != nil {
return err
}
@@ -319,11 +332,21 @@
switch {
case dataKind == reflect.Int:
- val.SetUint(uint64(dataVal.Int()))
+ i := dataVal.Int()
+ if i < 0 && !d.config.WeaklyTypedInput {
+ return fmt.Errorf("cannot parse '%s', %d overflows uint",
+ name, i)
+ }
+ val.SetUint(uint64(i))
case dataKind == reflect.Uint:
val.SetUint(dataVal.Uint())
case dataKind == reflect.Float32:
- val.SetUint(uint64(dataVal.Float()))
+ f := dataVal.Float()
+ if f < 0 && !d.config.WeaklyTypedInput {
+ return fmt.Errorf("cannot parse '%s', %f overflows uint",
+ name, f)
+ }
+ val.SetUint(uint64(f))
case dataKind == reflect.Bool && d.config.WeaklyTypedInput:
if dataVal.Bool() {
val.SetUint(1)
diff --git a/mapstructure_test.go b/mapstructure_test.go
index 036e6b5..0c9a31f 100644
--- a/mapstructure_test.go
+++ b/mapstructure_test.go
@@ -261,6 +261,43 @@
}
}
+func TestDecode_DecodeHookType(t *testing.T) {
+ t.Parallel()
+
+ input := map[string]interface{}{
+ "vint": "WHAT",
+ }
+
+ decodeHook := func(from reflect.Type, to reflect.Type, v interface{}) (interface{}, error) {
+ if from.Kind() == reflect.String &&
+ to.Kind() != reflect.String {
+ return 5, nil
+ }
+
+ return v, nil
+ }
+
+ var result Basic
+ config := &DecoderConfig{
+ DecodeHook: decodeHook,
+ Result: &result,
+ }
+
+ decoder, err := NewDecoder(config)
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ err = decoder.Decode(input)
+ if err != nil {
+ t.Fatalf("got an err: %s", err)
+ }
+
+ if result.Vint != 5 {
+ t.Errorf("vint should be 5: %#v", result.Vint)
+ }
+}
+
func TestDecode_Nil(t *testing.T) {
t.Parallel()
@@ -658,6 +695,42 @@
if derr.Errors[0] != "'Vstring' expected type 'string', got unconvertible type 'int'" {
t.Errorf("got unexpected error: %s", err)
}
+
+ inputNegIntUint := map[string]interface{}{
+ "vuint": -42,
+ }
+
+ err = Decode(inputNegIntUint, &result)
+ if err == nil {
+ t.Fatal("error should exist")
+ }
+
+ derr, ok = err.(*Error)
+ if !ok {
+ t.Fatalf("error should be kind of Error, instead: %#v", err)
+ }
+
+ if derr.Errors[0] != "cannot parse 'Vuint', -42 overflows uint" {
+ t.Errorf("got unexpected error: %s", err)
+ }
+
+ inputNegFloatUint := map[string]interface{}{
+ "vuint": -42.0,
+ }
+
+ err = Decode(inputNegFloatUint, &result)
+ if err == nil {
+ t.Fatal("error should exist")
+ }
+
+ derr, ok = err.(*Error)
+ if !ok {
+ t.Fatalf("error should be kind of Error, instead: %#v", err)
+ }
+
+ if derr.Errors[0] != "cannot parse 'Vuint', -42.000000 overflows uint" {
+ t.Errorf("got unexpected error: %s", err)
+ }
}
func TestMetadata(t *testing.T) {