Merge pull request #64 from skipor/support_func_decode
Support func decoding via hooks
diff --git a/mapstructure.go b/mapstructure.go
index fbf3c24..bd46603 100644
--- a/mapstructure.go
+++ b/mapstructure.go
@@ -229,6 +229,8 @@
err = d.decodePtr(name, data, val)
case reflect.Slice:
err = d.decodeSlice(name, data, val)
+ case reflect.Func:
+ err = d.decodeFunc(name, data, val)
default:
// If we reached this point then we weren't able to decode it
return fmt.Errorf("%s: unsupported type: %s", name, dataKind)
@@ -560,6 +562,19 @@
return nil
}
+func (d *Decoder) decodeFunc(name string, data interface{}, val reflect.Value) error {
+ // Create an element of the concrete (non pointer) type and decode
+ // into that. Then set the value of the pointer to this type.
+ dataVal := reflect.Indirect(reflect.ValueOf(data))
+ if val.Type() != dataVal.Type() {
+ return fmt.Errorf(
+ "'%s' expected type '%s', got unconvertible type '%s'",
+ name, val.Type(), dataVal.Type())
+ }
+ val.Set(dataVal)
+ return nil
+}
+
func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) error {
dataVal := reflect.Indirect(reflect.ValueOf(data))
dataValKind := dataVal.Kind()
diff --git a/mapstructure_test.go b/mapstructure_test.go
index f949572..226252a 100644
--- a/mapstructure_test.go
+++ b/mapstructure_test.go
@@ -85,6 +85,10 @@
Value []Basic
}
+type Func struct {
+ Foo func() string
+}
+
type Tagged struct {
Extra string `mapstructure:"bar,what,what"`
Value string `mapstructure:"foo"`
@@ -482,6 +486,42 @@
}
}
+func TestDecode_FuncHook(t *testing.T) {
+ t.Parallel()
+
+ input := map[string]interface{}{
+ "foo": "baz",
+ }
+
+ decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) {
+ if t.Kind() != reflect.Func {
+ return v, nil
+ }
+ val := v.(string)
+ return func() string { return val }, nil
+ }
+
+ var result Func
+ config := &DecoderConfig{
+ DecodeHook: decodeHook,
+ Result: &result,
+ }
+
+ decoder, err := NewDecoder(config)
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ err = decoder.Decode(input)
+ if err != nil {
+ t.Fatalf("got an err: %s", err)
+ }
+
+ if result.Foo() != "baz" {
+ t.Errorf("Foo call result should be 'baz': %s", result.Foo())
+ }
+}
+
func TestDecode_NonStruct(t *testing.T) {
t.Parallel()