Better unmarshalling for "null". Fixes #35.
diff --git a/decode.go b/decode.go index 1a87da9..5f00702 100644 --- a/decode.go +++ b/decode.go
@@ -288,6 +288,14 @@ return good } +var zeroValue reflect.Value + +func resetMap(out reflect.Value) { + for _, k := range out.MapKeys() { + out.SetMapIndex(k, zeroValue) + } +} + var durationType = reflect.TypeOf(time.Duration(0)) func (d *decoder) scalar(n *node, out reflect.Value) (good bool) { @@ -302,6 +310,15 @@ if set := d.setter(tag, &out, &good); set != nil { defer set() } + if resolved == nil { + if out.Kind() == reflect.Map && !out.CanAddr() { + resetMap(out) + } else { + out.Set(reflect.Zero(out.Type())) + } + good = true + return + } switch out.Kind() { case reflect.String: if resolved != nil { @@ -378,17 +395,11 @@ good = true } case reflect.Ptr: - switch resolved.(type) { - case nil: - out.Set(reflect.Zero(out.Type())) + if out.Type().Elem() == reflect.TypeOf(resolved) { + elem := reflect.New(out.Type().Elem()) + elem.Elem().Set(reflect.ValueOf(resolved)) + out.Set(elem) good = true - default: - if out.Type().Elem() == reflect.TypeOf(resolved) { - elem := reflect.New(out.Type().Elem()) - elem.Elem().Set(reflect.ValueOf(resolved)) - out.Set(elem) - good = true - } } } return good
diff --git a/decode_test.go b/decode_test.go index 46c94d4..7042908 100644 --- a/decode_test.go +++ b/decode_test.go
@@ -316,7 +316,10 @@ map[string]*string{"foo": new(string)}, }, { "foo: null", - map[string]string{}, + map[string]string{"foo": ""}, + }, { + "foo: null", + map[string]interface{}{"foo": nil}, }, // Ignored field @@ -372,7 +375,7 @@ map[string]time.Duration{"a": 3 * time.Second}, }, - // Issue #24. + // Issue #24. { "a: <foo>", map[string]string{"a": "<foo>"}, @@ -630,6 +633,30 @@ } } +var unmarshalNullTests = []func() interface{}{ + func() interface{} { var v interface{}; v = "v"; return &v }, + func() interface{} { var s = "s"; return &s }, + func() interface{} { var s = "s"; sptr := &s; return &sptr }, + func() interface{} { var i = 1; return &i }, + func() interface{} { var i = 1; iptr := &i; return &iptr }, + func() interface{} { m := map[string]int{"s": 1}; return &m }, + func() interface{} { m := map[string]int{"s": 1}; return m }, +} + +func (s *S) TestUnmarshalNull(c *C) { + for _, test := range unmarshalNullTests { + item := test() + zero := reflect.Zero(reflect.TypeOf(item).Elem()).Interface() + err := yaml.Unmarshal([]byte("null"), item) + c.Assert(err, IsNil) + if reflect.TypeOf(item).Kind() == reflect.Map { + c.Assert(reflect.ValueOf(item).Interface(), DeepEquals, reflect.MakeMap(reflect.TypeOf(item)).Interface()) + } else { + c.Assert(reflect.ValueOf(item).Elem().Interface(), DeepEquals, zero) + } + } +} + //var data []byte //func init() { // var err error
diff --git a/yaml.go b/yaml.go index 6370416..903b13b 100644 --- a/yaml.go +++ b/yaml.go
@@ -88,7 +88,11 @@ defer p.destroy() node := p.parse() if node != nil { - d.unmarshal(node, reflect.ValueOf(out)) + v := reflect.ValueOf(out) + if v.Kind() == reflect.Ptr && !v.IsNil() { + v = v.Elem() + } + d.unmarshal(node, v) } return nil }