Support floats
diff --git a/mapstructure.go b/mapstructure.go index e7c4155..a84a489 100644 --- a/mapstructure.go +++ b/mapstructure.go
@@ -124,17 +124,10 @@ return nil } - k := val.Kind() - - // Some shortcuts because we treat all ints and uints the same way - if k >= reflect.Int && k <= reflect.Int64 { - k = reflect.Int - } else if k >= reflect.Uint && k <= reflect.Uint64 { - k = reflect.Uint - } - var err error - switch k { + dataKind := d.getKind(val) + + switch dataKind { case reflect.Bool: fallthrough case reflect.Interface: @@ -142,9 +135,11 @@ case reflect.String: err = d.decodeBasic(name, data, val) case reflect.Int: - fallthrough - case reflect.Uint: err = d.decodeInt(name, data, val) + case reflect.Uint: + err = d.decodeUint(name, data, val) + case reflect.Float32: + err = d.decodeFloat(name, data, val) case reflect.Struct: err = d.decodeStruct(name, data, val) case reflect.Map: @@ -153,7 +148,7 @@ err = d.decodeSlice(name, data, val) default: // If we reached this point then we weren't able to decode it - return fmt.Errorf("%s: unsupported type: %s", name, k) + return fmt.Errorf("%s: unsupported type: %s", name, dataKind) } // If we reached here, then we successfully decoded SOMETHING, so @@ -165,6 +160,21 @@ return err } +func (d *Decoder) getKind(val reflect.Value) reflect.Kind { + kind := val.Kind() + + switch { + case kind >= reflect.Int && kind <= reflect.Int64: + return reflect.Int + case kind >= reflect.Uint && kind <= reflect.Uint64: + return reflect.Uint + case kind >= reflect.Float32 && kind <= reflect.Float64: + return reflect.Float32 + default: + return kind + } +} + // This decodes a basic type (bool, int, string, etc.) and sets the // value to "data" of that type. func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value) error { @@ -182,47 +192,59 @@ func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) error { dataVal := reflect.ValueOf(data) - dataKind := dataVal.Kind() - if dataKind >= reflect.Int && dataKind <= reflect.Int64 { - dataKind = reflect.Int - } else if dataKind >= reflect.Uint && dataKind <= reflect.Uint64 { - dataKind = reflect.Uint - } else if dataKind >= reflect.Float32 && dataKind <= reflect.Float64 { - dataKind = reflect.Float32 - } else { + dataKind := d.getKind(dataVal) + + switch dataKind { + case reflect.Int: + val.SetInt(dataVal.Int()) + case reflect.Uint: + val.SetInt(int64(dataVal.Uint())) + case reflect.Float32: + val.SetInt(int64(dataVal.Float())) + default: return fmt.Errorf( "'%s' expected type '%s', got unconvertible type '%s'", name, val.Type(), dataVal.Type()) } - valKind := val.Kind() - if valKind >= reflect.Int && valKind <= reflect.Int64 { - valKind = reflect.Int - } else if valKind >= reflect.Uint && valKind <= reflect.Uint64 { - valKind = reflect.Uint - } + return nil +} + +func (d *Decoder) decodeUint(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.ValueOf(data) + dataKind := d.getKind(dataVal) switch dataKind { case reflect.Int: - if valKind == reflect.Int { - val.SetInt(dataVal.Int()) - } else { - val.SetUint(uint64(dataVal.Int())) - } + val.SetUint(uint64(dataVal.Int())) case reflect.Uint: - if valKind == reflect.Int { - val.SetInt(int64(dataVal.Uint())) - } else { - val.SetUint(dataVal.Uint()) - } + val.SetUint(dataVal.Uint()) case reflect.Float32: - if valKind == reflect.Int { - val.SetInt(int64(dataVal.Float())) - } else { - val.SetUint(uint64(dataVal.Float())) - } + val.SetUint(uint64(dataVal.Float())) default: - panic("should never reach") + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s'", + name, val.Type(), dataVal.Type()) + } + + return nil +} + +func (d *Decoder) decodeFloat(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.ValueOf(data) + dataKind := d.getKind(dataVal) + + switch dataKind { + case reflect.Int: + val.SetFloat(float64(dataVal.Int())) + case reflect.Uint: + val.SetFloat(float64(dataVal.Uint())) + case reflect.Float32: + val.SetFloat(float64(dataVal.Float())) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s'", + name, val.Type(), dataVal.Type()) } return nil
diff --git a/mapstructure_test.go b/mapstructure_test.go index 5f359c2..f459ca8 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go
@@ -10,6 +10,7 @@ Vint int Vuint uint Vbool bool + Vfloat float64 Vextra string vsilent bool Vdata interface{} @@ -61,6 +62,7 @@ "vint": 42, "Vuint": 42, "vbool": true, + "Vfloat": 42.42, "vsilent": true, "vdata": 42, } @@ -88,6 +90,10 @@ t.Errorf("vbool value should be true: %#v", result.Vbool) } + if result.Vfloat != 42.42 { + t.Errorf("vfloat value should be 42.42: %#v", result.Vfloat) + } + if result.Vextra != "" { t.Errorf("vextra value should be empty: %#v", result.Vextra) }