Support floats
diff --git a/mapstructure.go b/mapstructure.go
index e7c4155..a84a489 100644
--- a/mapstructure.go
+++ b/mapstructure.go
@@ -124,17 +124,10 @@
return nil
}
- k := val.Kind()
-
- // Some shortcuts because we treat all ints and uints the same way
- if k >= reflect.Int && k <= reflect.Int64 {
- k = reflect.Int
- } else if k >= reflect.Uint && k <= reflect.Uint64 {
- k = reflect.Uint
- }
-
var err error
- switch k {
+ dataKind := d.getKind(val)
+
+ switch dataKind {
case reflect.Bool:
fallthrough
case reflect.Interface:
@@ -142,9 +135,11 @@
case reflect.String:
err = d.decodeBasic(name, data, val)
case reflect.Int:
- fallthrough
- case reflect.Uint:
err = d.decodeInt(name, data, val)
+ case reflect.Uint:
+ err = d.decodeUint(name, data, val)
+ case reflect.Float32:
+ err = d.decodeFloat(name, data, val)
case reflect.Struct:
err = d.decodeStruct(name, data, val)
case reflect.Map:
@@ -153,7 +148,7 @@
err = d.decodeSlice(name, data, val)
default:
// If we reached this point then we weren't able to decode it
- return fmt.Errorf("%s: unsupported type: %s", name, k)
+ return fmt.Errorf("%s: unsupported type: %s", name, dataKind)
}
// If we reached here, then we successfully decoded SOMETHING, so
@@ -165,6 +160,21 @@
return err
}
+func (d *Decoder) getKind(val reflect.Value) reflect.Kind {
+ kind := val.Kind()
+
+ switch {
+ case kind >= reflect.Int && kind <= reflect.Int64:
+ return reflect.Int
+ case kind >= reflect.Uint && kind <= reflect.Uint64:
+ return reflect.Uint
+ case kind >= reflect.Float32 && kind <= reflect.Float64:
+ return reflect.Float32
+ default:
+ return kind
+ }
+}
+
// This decodes a basic type (bool, int, string, etc.) and sets the
// value to "data" of that type.
func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value) error {
@@ -182,47 +192,59 @@
func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) error {
dataVal := reflect.ValueOf(data)
- dataKind := dataVal.Kind()
- if dataKind >= reflect.Int && dataKind <= reflect.Int64 {
- dataKind = reflect.Int
- } else if dataKind >= reflect.Uint && dataKind <= reflect.Uint64 {
- dataKind = reflect.Uint
- } else if dataKind >= reflect.Float32 && dataKind <= reflect.Float64 {
- dataKind = reflect.Float32
- } else {
+ dataKind := d.getKind(dataVal)
+
+ switch dataKind {
+ case reflect.Int:
+ val.SetInt(dataVal.Int())
+ case reflect.Uint:
+ val.SetInt(int64(dataVal.Uint()))
+ case reflect.Float32:
+ val.SetInt(int64(dataVal.Float()))
+ default:
return fmt.Errorf(
"'%s' expected type '%s', got unconvertible type '%s'",
name, val.Type(), dataVal.Type())
}
- valKind := val.Kind()
- if valKind >= reflect.Int && valKind <= reflect.Int64 {
- valKind = reflect.Int
- } else if valKind >= reflect.Uint && valKind <= reflect.Uint64 {
- valKind = reflect.Uint
- }
+ return nil
+}
+
+func (d *Decoder) decodeUint(name string, data interface{}, val reflect.Value) error {
+ dataVal := reflect.ValueOf(data)
+ dataKind := d.getKind(dataVal)
switch dataKind {
case reflect.Int:
- if valKind == reflect.Int {
- val.SetInt(dataVal.Int())
- } else {
- val.SetUint(uint64(dataVal.Int()))
- }
+ val.SetUint(uint64(dataVal.Int()))
case reflect.Uint:
- if valKind == reflect.Int {
- val.SetInt(int64(dataVal.Uint()))
- } else {
- val.SetUint(dataVal.Uint())
- }
+ val.SetUint(dataVal.Uint())
case reflect.Float32:
- if valKind == reflect.Int {
- val.SetInt(int64(dataVal.Float()))
- } else {
- val.SetUint(uint64(dataVal.Float()))
- }
+ val.SetUint(uint64(dataVal.Float()))
default:
- panic("should never reach")
+ return fmt.Errorf(
+ "'%s' expected type '%s', got unconvertible type '%s'",
+ name, val.Type(), dataVal.Type())
+ }
+
+ return nil
+}
+
+func (d *Decoder) decodeFloat(name string, data interface{}, val reflect.Value) error {
+ dataVal := reflect.ValueOf(data)
+ dataKind := d.getKind(dataVal)
+
+ switch dataKind {
+ case reflect.Int:
+ val.SetFloat(float64(dataVal.Int()))
+ case reflect.Uint:
+ val.SetFloat(float64(dataVal.Uint()))
+ case reflect.Float32:
+ val.SetFloat(float64(dataVal.Float()))
+ default:
+ return fmt.Errorf(
+ "'%s' expected type '%s', got unconvertible type '%s'",
+ name, val.Type(), dataVal.Type())
}
return nil
diff --git a/mapstructure_test.go b/mapstructure_test.go
index 5f359c2..f459ca8 100644
--- a/mapstructure_test.go
+++ b/mapstructure_test.go
@@ -10,6 +10,7 @@
Vint int
Vuint uint
Vbool bool
+ Vfloat float64
Vextra string
vsilent bool
Vdata interface{}
@@ -61,6 +62,7 @@
"vint": 42,
"Vuint": 42,
"vbool": true,
+ "Vfloat": 42.42,
"vsilent": true,
"vdata": 42,
}
@@ -88,6 +90,10 @@
t.Errorf("vbool value should be true: %#v", result.Vbool)
}
+ if result.Vfloat != 42.42 {
+ t.Errorf("vfloat value should be 42.42: %#v", result.Vfloat)
+ }
+
if result.Vextra != "" {
t.Errorf("vextra value should be empty: %#v", result.Vextra)
}