proto: clean up proto API for extensions
Methods that manipulate protos with extensions will now take
proto.Message instead of the internal extendableProto interface.
A ClearExtensions method is added to clear all extensions on protos.
diff --git a/proto/extensions.go b/proto/extensions.go
index 054f4f1..0de0b42 100644
--- a/proto/extensions.go
+++ b/proto/extensions.go
@@ -92,8 +92,12 @@
}
// SetRawExtension is for testing only.
-func SetRawExtension(base extendableProto, id int32, b []byte) {
- base.ExtensionMap()[id] = Extension{enc: b}
+func SetRawExtension(base Message, id int32, b []byte) {
+ epb, ok := base.(extendableProto)
+ if !ok {
+ return
+ }
+ epb.ExtensionMap()[id] = Extension{enc: b}
}
// isExtensionField returns true iff the given field number is in an extension range.
@@ -209,26 +213,39 @@
}
// HasExtension returns whether the given extension is present in pb.
-func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
+func HasExtension(pb Message, extension *ExtensionDesc) bool {
// TODO: Check types, field numbers, etc.?
- _, ok := pb.ExtensionMap()[extension.Field]
+ epb, ok := pb.(extendableProto)
+ if !ok {
+ return false
+ }
+ _, ok = epb.ExtensionMap()[extension.Field]
return ok
}
// ClearExtension removes the given extension from pb.
-func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
+func ClearExtension(pb Message, extension *ExtensionDesc) {
+ epb, ok := pb.(extendableProto)
+ if !ok {
+ return
+ }
// TODO: Check types, field numbers, etc.?
- delete(pb.ExtensionMap(), extension.Field)
+ delete(epb.ExtensionMap(), extension.Field)
}
// GetExtension parses and returns the given extension of pb.
// If the extension is not present and has no default value it returns ErrMissingExtension.
-func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) {
- if err := checkExtensionTypes(pb, extension); err != nil {
+func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
+ epb, ok := pb.(extendableProto)
+ if !ok {
+ return nil, errors.New("proto: not an extendable proto")
+ }
+
+ if err := checkExtensionTypes(epb, extension); err != nil {
return nil, err
}
- emap := pb.ExtensionMap()
+ emap := epb.ExtensionMap()
e, ok := emap[extension.Field]
if !ok {
// defaultExtensionValue returns the default value or
@@ -334,8 +351,7 @@
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
epb, ok := pb.(extendableProto)
if !ok {
- err = errors.New("proto: not an extendable proto")
- return
+ return nil, errors.New("proto: not an extendable proto")
}
extensions = make([]interface{}, len(es))
for i, e := range es {
@@ -351,8 +367,12 @@
}
// SetExtension sets the specified extension of pb to the specified value.
-func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error {
- if err := checkExtensionTypes(pb, extension); err != nil {
+func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
+ epb, ok := pb.(extendableProto)
+ if !ok {
+ return errors.New("proto: not an extendable proto")
+ }
+ if err := checkExtensionTypes(epb, extension); err != nil {
return err
}
typ := reflect.TypeOf(extension.ExtensionType)
@@ -368,10 +388,22 @@
return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
}
- pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
+ epb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
return nil
}
+// ClearAllExtensions clears all extensions from pb.
+func ClearAllExtensions(pb Message) {
+ epb, ok := pb.(extendableProto)
+ if !ok {
+ return
+ }
+ m := epb.ExtensionMap()
+ for k := range m {
+ delete(m, k)
+ }
+}
+
// A global registry of extensions.
// The generated code will register the generated descriptors by calling RegisterExtension.
diff --git a/proto/extensions_test.go b/proto/extensions_test.go
index 8012210..ed6a27d 100644
--- a/proto/extensions_test.go
+++ b/proto/extensions_test.go
@@ -428,3 +428,28 @@
}
}
}
+
+func TestClearAllExtensions(t *testing.T) {
+ // unregistered extension
+ desc := &proto.ExtensionDesc{
+ ExtendedType: (*pb.MyMessage)(nil),
+ ExtensionType: (*bool)(nil),
+ Field: 101010100,
+ Name: "emptyextension",
+ Tag: "varint,0,opt",
+ }
+ m := &pb.MyMessage{}
+ if proto.HasExtension(m, desc) {
+ t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
+ }
+ if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
+ t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
+ }
+ if !proto.HasExtension(m, desc) {
+ t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
+ }
+ proto.ClearAllExtensions(m)
+ if proto.HasExtension(m, desc) {
+ t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
+ }
+}