proto: fix equality to work with V1 generated format
When the new V2 generated extension format was introduced, we
mistakenly dropped support for comparing V1 generated extensions
for equality. Add that back.
diff --git a/proto/all_test.go b/proto/all_test.go
index 18eeb0f..fd4a94e 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -1956,14 +1956,54 @@
}
func TestMapFieldWithNil(t *testing.T) {
- m := &MessageWithMap{
+ m1 := &MessageWithMap{
MsgMapping: map[int64]*FloatingPoint{
1: nil,
},
}
- b, err := Marshal(m)
- if err == nil {
- t.Fatalf("Marshal of bad map should have failed, got these bytes: %v", b)
+ b, err := Marshal(m1)
+ if err != nil {
+ t.Fatalf("Marshal: %v", err)
+ }
+ m2 := new(MessageWithMap)
+ if err := Unmarshal(b, m2); err != nil {
+ t.Fatalf("Unmarshal: %v, got these bytes: %v", err, b)
+ }
+ if v, ok := m2.MsgMapping[1]; !ok {
+ t.Error("msg_mapping[1] not present")
+ } else if v != nil {
+ t.Errorf("msg_mapping[1] not nil: %v", v)
+ }
+}
+
+func TestMapFieldWithNilBytes(t *testing.T) {
+ m1 := &MessageWithMap{
+ ByteMapping: map[bool][]byte{
+ false: []byte{},
+ true: nil,
+ },
+ }
+ n := Size(m1)
+ b, err := Marshal(m1)
+ if err != nil {
+ t.Fatalf("Marshal: %v", err)
+ }
+ if n != len(b) {
+ t.Errorf("Size(m1) = %d; want len(Marshal(m1)) = %d", n, len(b))
+ }
+ m2 := new(MessageWithMap)
+ if err := Unmarshal(b, m2); err != nil {
+ t.Fatalf("Unmarshal: %v, got these bytes: %v", err, b)
+ }
+ if v, ok := m2.ByteMapping[false]; !ok {
+ t.Error("byte_mapping[false] not present")
+ } else if len(v) != 0 {
+ t.Errorf("byte_mapping[false] not empty: %#v", v)
+ }
+ if v, ok := m2.ByteMapping[true]; !ok {
+ t.Error("byte_mapping[true] not present")
+ } else if len(v) != 0 {
+ t.Errorf("byte_mapping[true] not empty: %#v", v)
}
}
diff --git a/proto/encode.go b/proto/encode.go
index 1b5578d..8c1b8fd 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -1149,7 +1149,7 @@
if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil {
return err
}
- if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil {
+ if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil && err != ErrNil {
return err
}
return nil
@@ -1159,11 +1159,6 @@
for _, key := range v.MapKeys() {
val := v.MapIndex(key)
- // The only illegal map entry values are nil message pointers.
- if val.Kind() == reflect.Ptr && val.IsNil() {
- return errors.New("proto: map has nil element")
- }
-
keycopy.Set(key)
valcopy.Set(val)
diff --git a/proto/equal.go b/proto/equal.go
index 8c6fa85..cafb99f 100644
--- a/proto/equal.go
+++ b/proto/equal.go
@@ -128,6 +128,13 @@
}
}
+ if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
+ em2 := v2.FieldByName("XXX_extensions")
+ if !equalExtMap(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
+ return false
+ }
+ }
+
uf := v1.FieldByName("XXX_unrecognized")
if !uf.IsValid() {
return true
@@ -227,6 +234,10 @@
func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool {
em1, _ := x1.extensionsRead()
em2, _ := x2.extensionsRead()
+ return equalExtMap(base, em1, em2)
+}
+
+func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
if len(em1) != len(em2) {
return false
}
diff --git a/proto/properties.go b/proto/properties.go
index 39edea3..dd29683 100644
--- a/proto/properties.go
+++ b/proto/properties.go
@@ -473,17 +473,13 @@
p.dec = (*Buffer).dec_slice_int64
p.packedDec = (*Buffer).dec_slice_packed_int64
case reflect.Uint8:
- p.enc = (*Buffer).enc_slice_byte
p.dec = (*Buffer).dec_slice_byte
- p.size = size_slice_byte
- // This is a []byte, which is either a bytes field,
- // or the value of a map field. In the latter case,
- // we always encode an empty []byte, so we should not
- // use the proto3 enc/size funcs.
- // f == nil iff this is the key/value of a map field.
- if p.proto3 && f != nil {
+ if p.proto3 {
p.enc = (*Buffer).enc_proto3_slice_byte
p.size = size_proto3_slice_byte
+ } else {
+ p.enc = (*Buffer).enc_slice_byte
+ p.size = size_slice_byte
}
case reflect.Float32, reflect.Float64:
switch t2.Bits() {