net/proto2/go: add GetAllExtensionDescs returns a []*ExtensionDesc // GetAllExtensionDescs returns a slice of all the descriptors // extensions present in pb. // // If an extension is not registered, a descriptor with only the // 'field' value set will // // be returned instead of a full descriptor. // // The returned slice is not guaranteed to be in any given order. // func GetAllExtensionDescs(pb Message) (extensions []*ExtensionDesc, // err error)
diff --git a/proto/extensions.go b/proto/extensions.go index 9f484f5..482f3e9 100644 --- a/proto/extensions.go +++ b/proto/extensions.go
@@ -489,6 +489,34 @@ return } +// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order. +// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing +// just the Field field, which defines the extension's field number. +func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { + epb, ok := extendable(pb) + if !ok { + return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb) + } + registeredExtensions := RegisteredExtensions(pb) + + emap, mu := epb.extensionsRead() + mu.Lock() + defer mu.Unlock() + extensions := make([]*ExtensionDesc, 0, len(emap)) + for extid, e := range emap { + desc := e.desc + if desc == nil { + desc = registeredExtensions[extid] + if desc == nil { + desc = &ExtensionDesc{Field: extid} + } + } + + extensions = append(extensions, desc) + } + return extensions, nil +} + // SetExtension sets the specified extension of pb to the specified value. func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { epb, ok := extendable(pb)
diff --git a/proto/extensions_test.go b/proto/extensions_test.go index ed6a27d..4278a87 100644 --- a/proto/extensions_test.go +++ b/proto/extensions_test.go
@@ -35,6 +35,7 @@ "bytes" "fmt" "reflect" + "sort" "testing" "github.com/golang/protobuf/proto" @@ -45,7 +46,7 @@ msg := &pb.MyMessage{} ext1 := &pb.Ext{} if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { - t.Fatalf("Could not set ext1: %s", ext1) + t.Fatalf("Could not set ext1: %s", err) } exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{ pb.E_Ext_More, @@ -62,6 +63,54 @@ } } +func TestExtensionDescsWithMissingExtensions(t *testing.T) { + msg := &pb.MyMessage{Count: proto.Int32(0)} + extdesc1 := pb.E_Ext_More + ext1 := &pb.Ext{} + if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { + t.Fatalf("Could not set ext1: %s", err) + } + extdesc2 := &proto.ExtensionDesc{ + ExtendedType: (*pb.MyMessage)(nil), + ExtensionType: (*bool)(nil), + Field: 123456789, + Name: "a.b", + Tag: "varint,123456789,opt", + } + ext2 := proto.Bool(false) + if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { + t.Fatalf("Could not set ext2: %s", err) + } + + b, err := proto.Marshal(msg) + if err != nil { + t.Fatalf("Could not marshal msg: %v", err) + } + if err := proto.Unmarshal(b, msg); err != nil { + t.Fatalf("Could not unmarshal into msg: %v", err) + } + + descs, err := proto.ExtensionDescs(msg) + if err != nil { + t.Fatalf("proto.ExtensionDescs: got error %v", err) + } + sortExtDescs(descs) + wantDescs := []*proto.ExtensionDesc{extdesc1, &proto.ExtensionDesc{Field: extdesc2.Field}} + if !reflect.DeepEqual(descs, wantDescs) { + t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs) + } +} + +type ExtensionDescSlice []*proto.ExtensionDesc + +func (s ExtensionDescSlice) Len() int { return len(s) } +func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field } +func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func sortExtDescs(s []*proto.ExtensionDesc) { + sort.Sort(ExtensionDescSlice(s)) +} + func TestGetExtensionStability(t *testing.T) { check := func(m *pb.MyMessage) bool { ext1, err := proto.GetExtension(m, pb.E_Ext_More)