proto: clean up proto API for extensions Methods that manipulate protos with extensions will now take proto.Message instead of the internal extendableProto interface. A ClearExtensions method is added to clear all extensions on protos.
diff --git a/proto/extensions.go b/proto/extensions.go index 054f4f1..0de0b42 100644 --- a/proto/extensions.go +++ b/proto/extensions.go
@@ -92,8 +92,12 @@ } // SetRawExtension is for testing only. -func SetRawExtension(base extendableProto, id int32, b []byte) { - base.ExtensionMap()[id] = Extension{enc: b} +func SetRawExtension(base Message, id int32, b []byte) { + epb, ok := base.(extendableProto) + if !ok { + return + } + epb.ExtensionMap()[id] = Extension{enc: b} } // isExtensionField returns true iff the given field number is in an extension range. @@ -209,26 +213,39 @@ } // HasExtension returns whether the given extension is present in pb. -func HasExtension(pb extendableProto, extension *ExtensionDesc) bool { +func HasExtension(pb Message, extension *ExtensionDesc) bool { // TODO: Check types, field numbers, etc.? - _, ok := pb.ExtensionMap()[extension.Field] + epb, ok := pb.(extendableProto) + if !ok { + return false + } + _, ok = epb.ExtensionMap()[extension.Field] return ok } // ClearExtension removes the given extension from pb. -func ClearExtension(pb extendableProto, extension *ExtensionDesc) { +func ClearExtension(pb Message, extension *ExtensionDesc) { + epb, ok := pb.(extendableProto) + if !ok { + return + } // TODO: Check types, field numbers, etc.? - delete(pb.ExtensionMap(), extension.Field) + delete(epb.ExtensionMap(), extension.Field) } // GetExtension parses and returns the given extension of pb. // If the extension is not present and has no default value it returns ErrMissingExtension. -func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) { - if err := checkExtensionTypes(pb, extension); err != nil { +func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { + epb, ok := pb.(extendableProto) + if !ok { + return nil, errors.New("proto: not an extendable proto") + } + + if err := checkExtensionTypes(epb, extension); err != nil { return nil, err } - emap := pb.ExtensionMap() + emap := epb.ExtensionMap() e, ok := emap[extension.Field] if !ok { // defaultExtensionValue returns the default value or @@ -334,8 +351,7 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { epb, ok := pb.(extendableProto) if !ok { - err = errors.New("proto: not an extendable proto") - return + return nil, errors.New("proto: not an extendable proto") } extensions = make([]interface{}, len(es)) for i, e := range es { @@ -351,8 +367,12 @@ } // SetExtension sets the specified extension of pb to the specified value. -func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error { - if err := checkExtensionTypes(pb, extension); err != nil { +func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { + epb, ok := pb.(extendableProto) + if !ok { + return errors.New("proto: not an extendable proto") + } + if err := checkExtensionTypes(epb, extension); err != nil { return err } typ := reflect.TypeOf(extension.ExtensionType) @@ -368,10 +388,22 @@ return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) } - pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value} + epb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value} return nil } +// ClearAllExtensions clears all extensions from pb. +func ClearAllExtensions(pb Message) { + epb, ok := pb.(extendableProto) + if !ok { + return + } + m := epb.ExtensionMap() + for k := range m { + delete(m, k) + } +} + // A global registry of extensions. // The generated code will register the generated descriptors by calling RegisterExtension.
diff --git a/proto/extensions_test.go b/proto/extensions_test.go index 8012210..ed6a27d 100644 --- a/proto/extensions_test.go +++ b/proto/extensions_test.go
@@ -428,3 +428,28 @@ } } } + +func TestClearAllExtensions(t *testing.T) { + // unregistered extension + desc := &proto.ExtensionDesc{ + ExtendedType: (*pb.MyMessage)(nil), + ExtensionType: (*bool)(nil), + Field: 101010100, + Name: "emptyextension", + Tag: "varint,0,opt", + } + m := &pb.MyMessage{} + if proto.HasExtension(m, desc) { + t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) + } + if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil { + t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err) + } + if !proto.HasExtension(m, desc) { + t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m)) + } + proto.ClearAllExtensions(m) + if proto.HasExtension(m, desc) { + t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) + } +}