Merge pull request #5 from gyim/master
Add support for more advanced type conversions (weak types)
diff --git a/mapstructure.go b/mapstructure.go
index e7c4155..d95886f 100644
--- a/mapstructure.go
+++ b/mapstructure.go
@@ -12,6 +12,7 @@
"fmt"
"reflect"
"sort"
+ "strconv"
"strings"
)
@@ -23,6 +24,15 @@
// (extra keys).
ErrorUnused bool
+ // If WeaklyTypedInput is true, the decoder will convert values between
+ // the following types:
+ //
+ // - numbers and bools
+ // - strings and numbers
+ // - strings and bools
+ // - empty arrays/slices and empty maps
+ WeaklyTypedInput bool
+
// Metadata is the struct that will contain extra metadata about
// the decoding. If this is nil, then no metadata will be tracked.
Metadata *Metadata
@@ -124,27 +134,22 @@
return nil
}
- k := val.Kind()
-
- // Some shortcuts because we treat all ints and uints the same way
- if k >= reflect.Int && k <= reflect.Int64 {
- k = reflect.Int
- } else if k >= reflect.Uint && k <= reflect.Uint64 {
- k = reflect.Uint
- }
-
var err error
- switch k {
+ dataKind := d.getKind(val)
+
+ switch dataKind {
case reflect.Bool:
- fallthrough
+ err = d.decodeBool(name, data, val)
case reflect.Interface:
- fallthrough
- case reflect.String:
err = d.decodeBasic(name, data, val)
+ case reflect.String:
+ err = d.decodeString(name, data, val)
case reflect.Int:
- fallthrough
- case reflect.Uint:
err = d.decodeInt(name, data, val)
+ case reflect.Uint:
+ err = d.decodeUint(name, data, val)
+ case reflect.Float32:
+ err = d.decodeFloat(name, data, val)
case reflect.Struct:
err = d.decodeStruct(name, data, val)
case reflect.Map:
@@ -153,7 +158,7 @@
err = d.decodeSlice(name, data, val)
default:
// If we reached this point then we weren't able to decode it
- return fmt.Errorf("%s: unsupported type: %s", name, k)
+ return fmt.Errorf("%s: unsupported type: %s", name, dataKind)
}
// If we reached here, then we successfully decoded SOMETHING, so
@@ -165,6 +170,21 @@
return err
}
+func (d *Decoder) getKind(val reflect.Value) reflect.Kind {
+ kind := val.Kind()
+
+ switch {
+ case kind >= reflect.Int && kind <= reflect.Int64:
+ return reflect.Int
+ case kind >= reflect.Uint && kind <= reflect.Uint64:
+ return reflect.Uint
+ case kind >= reflect.Float32 && kind <= reflect.Float64:
+ return reflect.Float32
+ default:
+ return kind
+ }
+}
+
// This decodes a basic type (bool, int, string, etc.) and sets the
// value to "data" of that type.
func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value) error {
@@ -180,60 +200,165 @@
return nil
}
-func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) error {
+func (d *Decoder) decodeString(name string, data interface{}, val reflect.Value) error {
dataVal := reflect.ValueOf(data)
- dataKind := dataVal.Kind()
- if dataKind >= reflect.Int && dataKind <= reflect.Int64 {
- dataKind = reflect.Int
- } else if dataKind >= reflect.Uint && dataKind <= reflect.Uint64 {
- dataKind = reflect.Uint
- } else if dataKind >= reflect.Float32 && dataKind <= reflect.Float64 {
- dataKind = reflect.Float32
- } else {
+ dataKind := d.getKind(dataVal)
+
+ switch {
+ case dataKind == reflect.String:
+ val.SetString(dataVal.String())
+ case dataKind == reflect.Bool && d.config.WeaklyTypedInput:
+ if dataVal.Bool() {
+ val.SetString("1")
+ } else {
+ val.SetString("0")
+ }
+ case dataKind == reflect.Int && d.config.WeaklyTypedInput:
+ val.SetString(strconv.FormatInt(dataVal.Int(), 10))
+ case dataKind == reflect.Uint && d.config.WeaklyTypedInput:
+ val.SetString(strconv.FormatUint(dataVal.Uint(), 10))
+ case dataKind == reflect.Float32 && d.config.WeaklyTypedInput:
+ val.SetString(strconv.FormatFloat(dataVal.Float(), 'f', -1, 64))
+ default:
return fmt.Errorf(
"'%s' expected type '%s', got unconvertible type '%s'",
name, val.Type(), dataVal.Type())
}
- valKind := val.Kind()
- if valKind >= reflect.Int && valKind <= reflect.Int64 {
- valKind = reflect.Int
- } else if valKind >= reflect.Uint && valKind <= reflect.Uint64 {
- valKind = reflect.Uint
- }
+ return nil
+}
- switch dataKind {
- case reflect.Int:
- if valKind == reflect.Int {
- val.SetInt(dataVal.Int())
+func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) error {
+ dataVal := reflect.ValueOf(data)
+ dataKind := d.getKind(dataVal)
+
+ switch {
+ case dataKind == reflect.Int:
+ val.SetInt(dataVal.Int())
+ case dataKind == reflect.Uint:
+ val.SetInt(int64(dataVal.Uint()))
+ case dataKind == reflect.Bool && d.config.WeaklyTypedInput:
+ if dataVal.Bool() {
+ val.SetInt(1)
} else {
- val.SetUint(uint64(dataVal.Int()))
+ val.SetInt(0)
}
- case reflect.Uint:
- if valKind == reflect.Int {
- val.SetInt(int64(dataVal.Uint()))
+ case dataKind == reflect.Float32:
+ val.SetInt(int64(dataVal.Float()))
+ case dataKind == reflect.String && d.config.WeaklyTypedInput:
+ i, err := strconv.ParseInt(dataVal.String(), 0, val.Type().Bits())
+ if err == nil {
+ val.SetInt(i)
} else {
- val.SetUint(dataVal.Uint())
- }
- case reflect.Float32:
- if valKind == reflect.Int {
- val.SetInt(int64(dataVal.Float()))
- } else {
- val.SetUint(uint64(dataVal.Float()))
+ return fmt.Errorf("cannot parse '%s' as int: %s", name, err)
}
default:
- panic("should never reach")
+ return fmt.Errorf(
+ "'%s' expected type '%s', got unconvertible type '%s'",
+ name, val.Type(), dataVal.Type())
+ }
+
+ return nil
+}
+
+func (d *Decoder) decodeUint(name string, data interface{}, val reflect.Value) error {
+ dataVal := reflect.ValueOf(data)
+ dataKind := d.getKind(dataVal)
+
+ switch {
+ case dataKind == reflect.Int:
+ val.SetUint(uint64(dataVal.Int()))
+ case dataKind == reflect.Uint:
+ val.SetUint(dataVal.Uint())
+ case dataKind == reflect.Bool && d.config.WeaklyTypedInput:
+ if dataVal.Bool() {
+ val.SetUint(1)
+ } else {
+ val.SetUint(0)
+ }
+ case dataKind == reflect.Float32:
+ val.SetUint(uint64(dataVal.Float()))
+ case dataKind == reflect.String && d.config.WeaklyTypedInput:
+ i, err := strconv.ParseUint(dataVal.String(), 0, val.Type().Bits())
+ if err == nil {
+ val.SetUint(i)
+ } else {
+ return fmt.Errorf("cannot parse '%s' as uint: %s", name, err)
+ }
+ default:
+ return fmt.Errorf(
+ "'%s' expected type '%s', got unconvertible type '%s'",
+ name, val.Type(), dataVal.Type())
+ }
+
+ return nil
+}
+
+func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) error {
+ dataVal := reflect.ValueOf(data)
+ dataKind := d.getKind(dataVal)
+
+ switch {
+ case dataKind == reflect.Bool:
+ val.SetBool(dataVal.Bool())
+ case dataKind == reflect.Int && d.config.WeaklyTypedInput:
+ val.SetBool(dataVal.Int() != 0)
+ case dataKind == reflect.Uint && d.config.WeaklyTypedInput:
+ val.SetBool(dataVal.Uint() != 0)
+ case dataKind == reflect.Float32 && d.config.WeaklyTypedInput:
+ val.SetBool(dataVal.Float() != 0)
+ case dataKind == reflect.String && d.config.WeaklyTypedInput:
+ b, err := strconv.ParseBool(dataVal.String())
+ if err == nil {
+ val.SetBool(b)
+ } else if dataVal.String() == "" {
+ val.SetBool(false)
+ } else {
+ return fmt.Errorf("cannot parse '%s' as bool: %s", name, err)
+ }
+ default:
+ return fmt.Errorf(
+ "'%s' expected type '%s', got unconvertible type '%s'",
+ name, val.Type(), dataVal.Type())
+ }
+
+ return nil
+}
+
+func (d *Decoder) decodeFloat(name string, data interface{}, val reflect.Value) error {
+ dataVal := reflect.ValueOf(data)
+ dataKind := d.getKind(dataVal)
+
+ switch {
+ case dataKind == reflect.Int:
+ val.SetFloat(float64(dataVal.Int()))
+ case dataKind == reflect.Uint:
+ val.SetFloat(float64(dataVal.Uint()))
+ case dataKind == reflect.Bool && d.config.WeaklyTypedInput:
+ if dataVal.Bool() {
+ val.SetFloat(1)
+ } else {
+ val.SetFloat(0)
+ }
+ case dataKind == reflect.Float32:
+ val.SetFloat(float64(dataVal.Float()))
+ case dataKind == reflect.String && d.config.WeaklyTypedInput:
+ f, err := strconv.ParseFloat(dataVal.String(), val.Type().Bits())
+ if err == nil {
+ val.SetFloat(f)
+ } else {
+ return fmt.Errorf("cannot parse '%s' as float: %s", name, err)
+ }
+ default:
+ return fmt.Errorf(
+ "'%s' expected type '%s', got unconvertible type '%s'",
+ name, val.Type(), dataVal.Type())
}
return nil
}
func (d *Decoder) decodeMap(name string, data interface{}, val reflect.Value) error {
- dataVal := reflect.Indirect(reflect.ValueOf(data))
- if dataVal.Kind() != reflect.Map {
- return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
- }
-
valType := val.Type()
valKeyType := valType.Key()
valElemType := valType.Elem()
@@ -242,6 +367,20 @@
mapType := reflect.MapOf(valKeyType, valElemType)
valMap := reflect.MakeMap(mapType)
+ // Check input type
+ dataVal := reflect.Indirect(reflect.ValueOf(data))
+ if dataVal.Kind() != reflect.Map {
+ // Accept empty array/slice instead of an empty map in weakly typed mode
+ if d.config.WeaklyTypedInput &&
+ (dataVal.Kind() == reflect.Slice || dataVal.Kind() == reflect.Array) &&
+ dataVal.Len() == 0 {
+ val.Set(valMap)
+ return nil
+ } else {
+ return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
+ }
+ }
+
// Accumulate errors
errors := make([]string, 0)
@@ -280,11 +419,6 @@
func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) error {
dataVal := reflect.Indirect(reflect.ValueOf(data))
dataValKind := dataVal.Kind()
- if dataValKind != reflect.Array && dataValKind != reflect.Slice {
- return fmt.Errorf(
- "'%s': source data must be an array or slice, got %s", name, dataValKind)
- }
-
valType := val.Type()
valElemType := valType.Elem()
@@ -292,6 +426,18 @@
sliceType := reflect.SliceOf(valElemType)
valSlice := reflect.MakeSlice(sliceType, dataVal.Len(), dataVal.Len())
+ // Check input type
+ if dataValKind != reflect.Array && dataValKind != reflect.Slice {
+ // Accept empty map instead of array/slice in weakly typed mode
+ if d.config.WeaklyTypedInput && dataVal.Kind() == reflect.Map && dataVal.Len() == 0 {
+ val.Set(valSlice)
+ return nil
+ } else {
+ return fmt.Errorf(
+ "'%s': source data must be an array or slice, got %s", name, dataValKind)
+ }
+ }
+
// Accumulate any errors
errors := make([]string, 0)
diff --git a/mapstructure_examples_test.go b/mapstructure_examples_test.go
index d7f3d00..27de1db 100644
--- a/mapstructure_examples_test.go
+++ b/mapstructure_examples_test.go
@@ -62,11 +62,11 @@
// Output:
// 5 error(s) decoding:
//
- // * 'Name' expected type 'string', got 'int'
+ // * 'Name' expected type 'string', got unconvertible type 'int'
// * 'Age' expected type 'int', got unconvertible type 'string'
- // * 'Emails[0]' expected type 'string', got 'int'
- // * 'Emails[1]' expected type 'string', got 'int'
- // * 'Emails[2]' expected type 'string', got 'int'
+ // * 'Emails[0]' expected type 'string', got unconvertible type 'int'
+ // * 'Emails[1]' expected type 'string', got unconvertible type 'int'
+ // * 'Emails[2]' expected type 'string', got unconvertible type 'int'
}
func ExampleDecode_metadata() {
@@ -107,3 +107,39 @@
// Output:
// Unused keys: []string{"email"}
}
+
+func ExampleDecode_WeaklyTypedInput() {
+ type Person struct {
+ Name string
+ Age int
+ Emails []string
+ }
+
+ // This input can come from anywhere, but typically comes from
+ // something like decoding JSON, generated by a weakly typed language
+ // such as PHP.
+ input := map[string]interface{}{
+ "name": 123, // number => string
+ "age": "42", // string => number
+ "emails": map[string]interface{}{}, // empty map => empty array
+ }
+
+ var result Person
+ config := &DecoderConfig{
+ WeaklyTypedInput: true,
+ Result: &result,
+ }
+
+ decoder, err := NewDecoder(config)
+ if err != nil {
+ panic(err)
+ }
+
+ err = decoder.Decode(input)
+ if err != nil {
+ panic(err)
+ }
+
+ fmt.Printf("%#v", result)
+ // Output: mapstructure.Person{Name:"123", Age:42, Emails:[]string{}}
+}
diff --git a/mapstructure_test.go b/mapstructure_test.go
index 5f359c2..47c98f7 100644
--- a/mapstructure_test.go
+++ b/mapstructure_test.go
@@ -10,6 +10,7 @@
Vint int
Vuint uint
Vbool bool
+ Vfloat float64
Vextra string
vsilent bool
Vdata interface{}
@@ -53,6 +54,31 @@
Value string `mapstructure:"foo"`
}
+type TypeConversionResult struct {
+ IntToFloat float32
+ IntToUint uint
+ IntToBool bool
+ IntToString string
+ UintToInt int
+ UintToFloat float32
+ UintToBool bool
+ UintToString string
+ BoolToInt int
+ BoolToUint uint
+ BoolToFloat float32
+ BoolToString string
+ FloatToInt int
+ FloatToUint uint
+ FloatToBool bool
+ FloatToString string
+ StringToInt int
+ StringToUint uint
+ StringToBool bool
+ StringToFloat float32
+ SliceToMap map[string]interface{}
+ MapToSlice []interface{}
+}
+
func TestBasicTypes(t *testing.T) {
t.Parallel()
@@ -61,6 +87,7 @@
"vint": 42,
"Vuint": 42,
"vbool": true,
+ "Vfloat": 42.42,
"vsilent": true,
"vdata": 42,
}
@@ -88,6 +115,10 @@
t.Errorf("vbool value should be true: %#v", result.Vbool)
}
+ if result.Vfloat != 42.42 {
+ t.Errorf("vfloat value should be 42.42: %#v", result.Vfloat)
+ }
+
if result.Vextra != "" {
t.Errorf("vextra value should be empty: %#v", result.Vextra)
}
@@ -115,6 +146,103 @@
}
}
+func TestTypeConversion(t *testing.T) {
+ input := map[string]interface{}{
+ "IntToFloat": 42,
+ "IntToUint": 42,
+ "IntToBool": 1,
+ "IntToString": 42,
+ "UintToInt": 42,
+ "UintToFloat": 42,
+ "UintToBool": 42,
+ "UintToString": 42,
+ "BoolToInt": true,
+ "BoolToUint": true,
+ "BoolToFloat": true,
+ "BoolToString": true,
+ "FloatToInt": 42.42,
+ "FloatToUint": 42.42,
+ "FloatToBool": 42.42,
+ "FloatToString": 42.42,
+ "StringToInt": "42",
+ "StringToUint": "42",
+ "StringToBool": "1",
+ "StringToFloat": "42.42",
+ "SliceToMap": []interface{} {},
+ "MapToSlice": map[string]interface{} {},
+ }
+
+ expectedResultStrict := TypeConversionResult{
+ IntToFloat: 42.0,
+ IntToUint: 42,
+ UintToInt: 42,
+ UintToFloat: 42,
+ BoolToInt: 0,
+ BoolToUint: 0,
+ BoolToFloat: 0,
+ FloatToInt: 42,
+ FloatToUint: 42,
+ }
+
+ expectedResultWeak := TypeConversionResult{
+ IntToFloat: 42.0,
+ IntToUint: 42,
+ IntToBool: true,
+ IntToString: "42",
+ UintToInt: 42,
+ UintToFloat: 42,
+ UintToBool: true,
+ UintToString: "42",
+ BoolToInt: 1,
+ BoolToUint: 1,
+ BoolToFloat: 1,
+ BoolToString: "1",
+ FloatToInt: 42,
+ FloatToUint: 42,
+ FloatToBool: true,
+ FloatToString: "42.42",
+ StringToInt: 42,
+ StringToUint: 42,
+ StringToBool: true,
+ StringToFloat: 42.42,
+ SliceToMap: map[string]interface{} {},
+ MapToSlice: []interface{} {},
+ }
+
+ // Test strict type conversion
+ var resultStrict TypeConversionResult
+ err := Decode(input, &resultStrict)
+ if err == nil {
+ t.Errorf("should return an error")
+ }
+ if !reflect.DeepEqual(resultStrict, expectedResultStrict) {
+ t.Errorf("expected %v, got: %v", expectedResultStrict, resultStrict)
+ }
+
+ // Test weak type conversion
+ var decoder *Decoder
+ var resultWeak TypeConversionResult
+
+ config := &DecoderConfig{
+ WeaklyTypedInput: true,
+ Result: &resultWeak,
+ }
+
+ decoder, err = NewDecoder(config)
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ err = decoder.Decode(input)
+ if err != nil {
+ t.Fatalf("got an err: %s", err)
+ }
+
+ if !reflect.DeepEqual(resultWeak, expectedResultWeak) {
+ t.Errorf("expected \n%#v, got: \n%#v", expectedResultWeak, resultWeak)
+ }
+}
+
func TestDecode_Embedded(t *testing.T) {
t.Parallel()
@@ -426,7 +554,7 @@
t.Fatalf("error should be kind of Error, instead: %#v", err)
}
- if derr.Errors[0] != "'Vstring' expected type 'string', got 'int'" {
+ if derr.Errors[0] != "'Vstring' expected type 'string', got unconvertible type 'int'" {
t.Errorf("got unexpected error: %s", err)
}
}