net/proto2/go: add GetAllExtensionDescs returns a []*ExtensionDesc
// GetAllExtensionDescs returns a slice of all the descriptors
// extensions present in pb.
// // If an extension is not registered, a descriptor with only the
// 'field' value set will
// // be returned instead of a full descriptor.
// // The returned slice is not guaranteed to be in any given order.
// func GetAllExtensionDescs(pb Message) (extensions []*ExtensionDesc,
// err error)
diff --git a/proto/extensions.go b/proto/extensions.go
index 9f484f5..482f3e9 100644
--- a/proto/extensions.go
+++ b/proto/extensions.go
@@ -489,6 +489,34 @@
return
}
+// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order.
+// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
+// just the Field field, which defines the extension's field number.
+func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
+ epb, ok := extendable(pb)
+ if !ok {
+ return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb)
+ }
+ registeredExtensions := RegisteredExtensions(pb)
+
+ emap, mu := epb.extensionsRead()
+ mu.Lock()
+ defer mu.Unlock()
+ extensions := make([]*ExtensionDesc, 0, len(emap))
+ for extid, e := range emap {
+ desc := e.desc
+ if desc == nil {
+ desc = registeredExtensions[extid]
+ if desc == nil {
+ desc = &ExtensionDesc{Field: extid}
+ }
+ }
+
+ extensions = append(extensions, desc)
+ }
+ return extensions, nil
+}
+
// SetExtension sets the specified extension of pb to the specified value.
func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
epb, ok := extendable(pb)
diff --git a/proto/extensions_test.go b/proto/extensions_test.go
index ed6a27d..4278a87 100644
--- a/proto/extensions_test.go
+++ b/proto/extensions_test.go
@@ -35,6 +35,7 @@
"bytes"
"fmt"
"reflect"
+ "sort"
"testing"
"github.com/golang/protobuf/proto"
@@ -45,7 +46,7 @@
msg := &pb.MyMessage{}
ext1 := &pb.Ext{}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
- t.Fatalf("Could not set ext1: %s", ext1)
+ t.Fatalf("Could not set ext1: %s", err)
}
exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
pb.E_Ext_More,
@@ -62,6 +63,54 @@
}
}
+func TestExtensionDescsWithMissingExtensions(t *testing.T) {
+ msg := &pb.MyMessage{Count: proto.Int32(0)}
+ extdesc1 := pb.E_Ext_More
+ ext1 := &pb.Ext{}
+ if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
+ t.Fatalf("Could not set ext1: %s", err)
+ }
+ extdesc2 := &proto.ExtensionDesc{
+ ExtendedType: (*pb.MyMessage)(nil),
+ ExtensionType: (*bool)(nil),
+ Field: 123456789,
+ Name: "a.b",
+ Tag: "varint,123456789,opt",
+ }
+ ext2 := proto.Bool(false)
+ if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
+ t.Fatalf("Could not set ext2: %s", err)
+ }
+
+ b, err := proto.Marshal(msg)
+ if err != nil {
+ t.Fatalf("Could not marshal msg: %v", err)
+ }
+ if err := proto.Unmarshal(b, msg); err != nil {
+ t.Fatalf("Could not unmarshal into msg: %v", err)
+ }
+
+ descs, err := proto.ExtensionDescs(msg)
+ if err != nil {
+ t.Fatalf("proto.ExtensionDescs: got error %v", err)
+ }
+ sortExtDescs(descs)
+ wantDescs := []*proto.ExtensionDesc{extdesc1, &proto.ExtensionDesc{Field: extdesc2.Field}}
+ if !reflect.DeepEqual(descs, wantDescs) {
+ t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
+ }
+}
+
+type ExtensionDescSlice []*proto.ExtensionDesc
+
+func (s ExtensionDescSlice) Len() int { return len(s) }
+func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
+func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+func sortExtDescs(s []*proto.ExtensionDesc) {
+ sort.Sort(ExtensionDescSlice(s))
+}
+
func TestGetExtensionStability(t *testing.T) {
check := func(m *pb.MyMessage) bool {
ext1, err := proto.GetExtension(m, pb.E_Ext_More)