Add WeaklyTypedMode for automatic string/number/bool and slice/map conversion
diff --git a/mapstructure.go b/mapstructure.go index a84a489..d95886f 100644 --- a/mapstructure.go +++ b/mapstructure.go
@@ -12,6 +12,7 @@ "fmt" "reflect" "sort" + "strconv" "strings" ) @@ -23,6 +24,15 @@ // (extra keys). ErrorUnused bool + // If WeaklyTypedInput is true, the decoder will convert values between + // the following types: + // + // - numbers and bools + // - strings and numbers + // - strings and bools + // - empty arrays/slices and empty maps + WeaklyTypedInput bool + // Metadata is the struct that will contain extra metadata about // the decoding. If this is nil, then no metadata will be tracked. Metadata *Metadata @@ -129,11 +139,11 @@ switch dataKind { case reflect.Bool: - fallthrough + err = d.decodeBool(name, data, val) case reflect.Interface: - fallthrough - case reflect.String: err = d.decodeBasic(name, data, val) + case reflect.String: + err = d.decodeString(name, data, val) case reflect.Int: err = d.decodeInt(name, data, val) case reflect.Uint: @@ -190,17 +200,58 @@ return nil } +func (d *Decoder) decodeString(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.ValueOf(data) + dataKind := d.getKind(dataVal) + + switch { + case dataKind == reflect.String: + val.SetString(dataVal.String()) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetString("1") + } else { + val.SetString("0") + } + case dataKind == reflect.Int && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatInt(dataVal.Int(), 10)) + case dataKind == reflect.Uint && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatUint(dataVal.Uint(), 10)) + case dataKind == reflect.Float32 && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatFloat(dataVal.Float(), 'f', -1, 64)) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s'", + name, val.Type(), dataVal.Type()) + } + + return nil +} + func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) error { dataVal := reflect.ValueOf(data) dataKind := d.getKind(dataVal) - switch dataKind { - case reflect.Int: + switch { + case dataKind == reflect.Int: val.SetInt(dataVal.Int()) - case reflect.Uint: + case dataKind == reflect.Uint: val.SetInt(int64(dataVal.Uint())) - case reflect.Float32: + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetInt(1) + } else { + val.SetInt(0) + } + case dataKind == reflect.Float32: val.SetInt(int64(dataVal.Float())) + case dataKind == reflect.String && d.config.WeaklyTypedInput: + i, err := strconv.ParseInt(dataVal.String(), 0, val.Type().Bits()) + if err == nil { + val.SetInt(i) + } else { + return fmt.Errorf("cannot parse '%s' as int: %s", name, err) + } default: return fmt.Errorf( "'%s' expected type '%s', got unconvertible type '%s'", @@ -214,13 +265,57 @@ dataVal := reflect.ValueOf(data) dataKind := d.getKind(dataVal) - switch dataKind { - case reflect.Int: + switch { + case dataKind == reflect.Int: val.SetUint(uint64(dataVal.Int())) - case reflect.Uint: + case dataKind == reflect.Uint: val.SetUint(dataVal.Uint()) - case reflect.Float32: + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetUint(1) + } else { + val.SetUint(0) + } + case dataKind == reflect.Float32: val.SetUint(uint64(dataVal.Float())) + case dataKind == reflect.String && d.config.WeaklyTypedInput: + i, err := strconv.ParseUint(dataVal.String(), 0, val.Type().Bits()) + if err == nil { + val.SetUint(i) + } else { + return fmt.Errorf("cannot parse '%s' as uint: %s", name, err) + } + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s'", + name, val.Type(), dataVal.Type()) + } + + return nil +} + +func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.ValueOf(data) + dataKind := d.getKind(dataVal) + + switch { + case dataKind == reflect.Bool: + val.SetBool(dataVal.Bool()) + case dataKind == reflect.Int && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Int() != 0) + case dataKind == reflect.Uint && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Uint() != 0) + case dataKind == reflect.Float32 && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Float() != 0) + case dataKind == reflect.String && d.config.WeaklyTypedInput: + b, err := strconv.ParseBool(dataVal.String()) + if err == nil { + val.SetBool(b) + } else if dataVal.String() == "" { + val.SetBool(false) + } else { + return fmt.Errorf("cannot parse '%s' as bool: %s", name, err) + } default: return fmt.Errorf( "'%s' expected type '%s', got unconvertible type '%s'", @@ -234,13 +329,26 @@ dataVal := reflect.ValueOf(data) dataKind := d.getKind(dataVal) - switch dataKind { - case reflect.Int: + switch { + case dataKind == reflect.Int: val.SetFloat(float64(dataVal.Int())) - case reflect.Uint: + case dataKind == reflect.Uint: val.SetFloat(float64(dataVal.Uint())) - case reflect.Float32: + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetFloat(1) + } else { + val.SetFloat(0) + } + case dataKind == reflect.Float32: val.SetFloat(float64(dataVal.Float())) + case dataKind == reflect.String && d.config.WeaklyTypedInput: + f, err := strconv.ParseFloat(dataVal.String(), val.Type().Bits()) + if err == nil { + val.SetFloat(f) + } else { + return fmt.Errorf("cannot parse '%s' as float: %s", name, err) + } default: return fmt.Errorf( "'%s' expected type '%s', got unconvertible type '%s'", @@ -251,11 +359,6 @@ } func (d *Decoder) decodeMap(name string, data interface{}, val reflect.Value) error { - dataVal := reflect.Indirect(reflect.ValueOf(data)) - if dataVal.Kind() != reflect.Map { - return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) - } - valType := val.Type() valKeyType := valType.Key() valElemType := valType.Elem() @@ -264,6 +367,20 @@ mapType := reflect.MapOf(valKeyType, valElemType) valMap := reflect.MakeMap(mapType) + // Check input type + dataVal := reflect.Indirect(reflect.ValueOf(data)) + if dataVal.Kind() != reflect.Map { + // Accept empty array/slice instead of an empty map in weakly typed mode + if d.config.WeaklyTypedInput && + (dataVal.Kind() == reflect.Slice || dataVal.Kind() == reflect.Array) && + dataVal.Len() == 0 { + val.Set(valMap) + return nil + } else { + return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) + } + } + // Accumulate errors errors := make([]string, 0) @@ -302,11 +419,6 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) error { dataVal := reflect.Indirect(reflect.ValueOf(data)) dataValKind := dataVal.Kind() - if dataValKind != reflect.Array && dataValKind != reflect.Slice { - return fmt.Errorf( - "'%s': source data must be an array or slice, got %s", name, dataValKind) - } - valType := val.Type() valElemType := valType.Elem() @@ -314,6 +426,18 @@ sliceType := reflect.SliceOf(valElemType) valSlice := reflect.MakeSlice(sliceType, dataVal.Len(), dataVal.Len()) + // Check input type + if dataValKind != reflect.Array && dataValKind != reflect.Slice { + // Accept empty map instead of array/slice in weakly typed mode + if d.config.WeaklyTypedInput && dataVal.Kind() == reflect.Map && dataVal.Len() == 0 { + val.Set(valSlice) + return nil + } else { + return fmt.Errorf( + "'%s': source data must be an array or slice, got %s", name, dataValKind) + } + } + // Accumulate any errors errors := make([]string, 0)
diff --git a/mapstructure_examples_test.go b/mapstructure_examples_test.go index d7f3d00..27de1db 100644 --- a/mapstructure_examples_test.go +++ b/mapstructure_examples_test.go
@@ -62,11 +62,11 @@ // Output: // 5 error(s) decoding: // - // * 'Name' expected type 'string', got 'int' + // * 'Name' expected type 'string', got unconvertible type 'int' // * 'Age' expected type 'int', got unconvertible type 'string' - // * 'Emails[0]' expected type 'string', got 'int' - // * 'Emails[1]' expected type 'string', got 'int' - // * 'Emails[2]' expected type 'string', got 'int' + // * 'Emails[0]' expected type 'string', got unconvertible type 'int' + // * 'Emails[1]' expected type 'string', got unconvertible type 'int' + // * 'Emails[2]' expected type 'string', got unconvertible type 'int' } func ExampleDecode_metadata() { @@ -107,3 +107,39 @@ // Output: // Unused keys: []string{"email"} } + +func ExampleDecode_WeaklyTypedInput() { + type Person struct { + Name string + Age int + Emails []string + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON, generated by a weakly typed language + // such as PHP. + input := map[string]interface{}{ + "name": 123, // number => string + "age": "42", // string => number + "emails": map[string]interface{}{}, // empty map => empty array + } + + var result Person + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + err = decoder.Decode(input) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: mapstructure.Person{Name:"123", Age:42, Emails:[]string{}} +}
diff --git a/mapstructure_test.go b/mapstructure_test.go index f459ca8..47c98f7 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go
@@ -54,6 +54,31 @@ Value string `mapstructure:"foo"` } +type TypeConversionResult struct { + IntToFloat float32 + IntToUint uint + IntToBool bool + IntToString string + UintToInt int + UintToFloat float32 + UintToBool bool + UintToString string + BoolToInt int + BoolToUint uint + BoolToFloat float32 + BoolToString string + FloatToInt int + FloatToUint uint + FloatToBool bool + FloatToString string + StringToInt int + StringToUint uint + StringToBool bool + StringToFloat float32 + SliceToMap map[string]interface{} + MapToSlice []interface{} +} + func TestBasicTypes(t *testing.T) { t.Parallel() @@ -121,6 +146,103 @@ } } +func TestTypeConversion(t *testing.T) { + input := map[string]interface{}{ + "IntToFloat": 42, + "IntToUint": 42, + "IntToBool": 1, + "IntToString": 42, + "UintToInt": 42, + "UintToFloat": 42, + "UintToBool": 42, + "UintToString": 42, + "BoolToInt": true, + "BoolToUint": true, + "BoolToFloat": true, + "BoolToString": true, + "FloatToInt": 42.42, + "FloatToUint": 42.42, + "FloatToBool": 42.42, + "FloatToString": 42.42, + "StringToInt": "42", + "StringToUint": "42", + "StringToBool": "1", + "StringToFloat": "42.42", + "SliceToMap": []interface{} {}, + "MapToSlice": map[string]interface{} {}, + } + + expectedResultStrict := TypeConversionResult{ + IntToFloat: 42.0, + IntToUint: 42, + UintToInt: 42, + UintToFloat: 42, + BoolToInt: 0, + BoolToUint: 0, + BoolToFloat: 0, + FloatToInt: 42, + FloatToUint: 42, + } + + expectedResultWeak := TypeConversionResult{ + IntToFloat: 42.0, + IntToUint: 42, + IntToBool: true, + IntToString: "42", + UintToInt: 42, + UintToFloat: 42, + UintToBool: true, + UintToString: "42", + BoolToInt: 1, + BoolToUint: 1, + BoolToFloat: 1, + BoolToString: "1", + FloatToInt: 42, + FloatToUint: 42, + FloatToBool: true, + FloatToString: "42.42", + StringToInt: 42, + StringToUint: 42, + StringToBool: true, + StringToFloat: 42.42, + SliceToMap: map[string]interface{} {}, + MapToSlice: []interface{} {}, + } + + // Test strict type conversion + var resultStrict TypeConversionResult + err := Decode(input, &resultStrict) + if err == nil { + t.Errorf("should return an error") + } + if !reflect.DeepEqual(resultStrict, expectedResultStrict) { + t.Errorf("expected %v, got: %v", expectedResultStrict, resultStrict) + } + + // Test weak type conversion + var decoder *Decoder + var resultWeak TypeConversionResult + + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &resultWeak, + } + + decoder, err = NewDecoder(config) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + if !reflect.DeepEqual(resultWeak, expectedResultWeak) { + t.Errorf("expected \n%#v, got: \n%#v", expectedResultWeak, resultWeak) + } +} + func TestDecode_Embedded(t *testing.T) { t.Parallel() @@ -432,7 +554,7 @@ t.Fatalf("error should be kind of Error, instead: %#v", err) } - if derr.Errors[0] != "'Vstring' expected type 'string', got 'int'" { + if derr.Errors[0] != "'Vstring' expected type 'string', got unconvertible type 'int'" { t.Errorf("got unexpected error: %s", err) } }