Add VerifyProtoRepresenting and RespondWithProto These handlers should make it a little easer to validate protobuf messages that are exchanged over http.
diff --git a/ghttp/handlers.go b/ghttp/handlers.go index 1e481f3..fdd4034 100644 --- a/ghttp/handlers.go +++ b/ghttp/handlers.go
@@ -7,7 +7,9 @@ "io/ioutil" "net/http" "net/url" + "reflect" + "github.com/golang/protobuf/proto" . "github.com/onsi/gomega" "github.com/onsi/gomega/types" ) @@ -137,6 +139,32 @@ return VerifyForm(url.Values{key: values}) } +//VerifyProtoRepresenting returns a handler that verifies that the body of the request is a valid protobuf +//representation of the passed message. +// +//VerifyProtoRepresenting also verifies that the request's content type is application/x-protobuf +func VerifyProtoRepresenting(expected proto.Message) http.HandlerFunc { + return CombineHandlers( + VerifyContentType("application/x-protobuf"), + func(w http.ResponseWriter, req *http.Request) { + body, err := ioutil.ReadAll(req.Body) + Ω(err).ShouldNot(HaveOccurred()) + req.Body.Close() + + expectedType := reflect.TypeOf(expected) + actualValuePtr := reflect.New(expectedType.Elem()) + + actual, ok := actualValuePtr.Interface().(proto.Message) + Ω(ok).Should(BeTrue(), "Message value is not a proto.Message") + + err = proto.Unmarshal(body, actual) + Ω(err).ShouldNot(HaveOccurred(), "Failed to unmarshal protobuf") + + Ω(actual).Should(Equal(expected), "ProtoBuf Mismatch") + }, + ) +} + func copyHeader(src http.Header, dst http.Header) { for key, value := range src { dst[key] = value @@ -245,3 +273,28 @@ w.Write(data) } } + +//RespondWithProto returns a handler that responds to a request with the specified status code and a body +//containing the protobuf serialization of the provided message. +// +//Also, RespondWithProto can be given an optional http.Header. The headers defined therein will be added to the response headers. +func RespondWithProto(statusCode int, message proto.Message, optionalHeader ...http.Header) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + data, err := proto.Marshal(message) + Ω(err).ShouldNot(HaveOccurred()) + + var headers http.Header + if len(optionalHeader) == 1 { + headers = optionalHeader[0] + } else { + headers = make(http.Header) + } + if _, found := headers["Content-Type"]; !found { + headers["Content-Type"] = []string{"application/x-protobuf"} + } + copyHeader(headers, w.Header()) + + w.WriteHeader(statusCode) + w.Write(data) + } +}
diff --git a/ghttp/protobuf/protobuf.go b/ghttp/protobuf/protobuf.go new file mode 100644 index 0000000..b2972bc --- /dev/null +++ b/ghttp/protobuf/protobuf.go
@@ -0,0 +1,3 @@ +package protobuf + +//go:generate protoc --go_out=. simple_message.proto
diff --git a/ghttp/protobuf/simple_message.pb.go b/ghttp/protobuf/simple_message.pb.go new file mode 100644 index 0000000..c55a484 --- /dev/null +++ b/ghttp/protobuf/simple_message.pb.go
@@ -0,0 +1,55 @@ +// Code generated by protoc-gen-go. +// source: simple_message.proto +// DO NOT EDIT! + +/* +Package protobuf is a generated protocol buffer package. + +It is generated from these files: + simple_message.proto + +It has these top-level messages: + SimpleMessage +*/ +package protobuf + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +type SimpleMessage struct { + Description *string `protobuf:"bytes,1,req,name=description" json:"description,omitempty"` + Id *int32 `protobuf:"varint,2,req,name=id" json:"id,omitempty"` + Metadata *string `protobuf:"bytes,3,opt,name=metadata" json:"metadata,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SimpleMessage) Reset() { *m = SimpleMessage{} } +func (m *SimpleMessage) String() string { return proto.CompactTextString(m) } +func (*SimpleMessage) ProtoMessage() {} + +func (m *SimpleMessage) GetDescription() string { + if m != nil && m.Description != nil { + return *m.Description + } + return "" +} + +func (m *SimpleMessage) GetId() int32 { + if m != nil && m.Id != nil { + return *m.Id + } + return 0 +} + +func (m *SimpleMessage) GetMetadata() string { + if m != nil && m.Metadata != nil { + return *m.Metadata + } + return "" +}
diff --git a/ghttp/protobuf/simple_message.proto b/ghttp/protobuf/simple_message.proto new file mode 100644 index 0000000..35b7145 --- /dev/null +++ b/ghttp/protobuf/simple_message.proto
@@ -0,0 +1,9 @@ +syntax = "proto2"; + +package protobuf; + +message SimpleMessage { + required string description = 1; + required int32 id = 2; + optional string metadata = 3; +}
diff --git a/ghttp/test_server_test.go b/ghttp/test_server_test.go index 3540ebf..497e46c 100644 --- a/ghttp/test_server_test.go +++ b/ghttp/test_server_test.go
@@ -7,7 +7,9 @@ "net/url" "regexp" + "github.com/golang/protobuf/proto" "github.com/onsi/gomega/gbytes" + "github.com/onsi/gomega/ghttp/protobuf" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -636,6 +638,53 @@ }) }) + Describe("VerifyProtoRepresenting", func() { + var message *protobuf.SimpleMessage + + BeforeEach(func() { + message = new(protobuf.SimpleMessage) + message.Description = proto.String("A description") + message.Id = proto.Int32(0) + + s.AppendHandlers(CombineHandlers( + VerifyRequest("POST", "/proto"), + VerifyProtoRepresenting(message), + )) + }) + + It("verifies the proto body and the content type", func() { + serialized, err := proto.Marshal(message) + Ω(err).ShouldNot(HaveOccurred()) + + resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", bytes.NewReader(serialized)) + Ω(err).ShouldNot(HaveOccurred()) + }) + + It("should verify the proto body and the content type", func() { + serialized, err := proto.Marshal(&protobuf.SimpleMessage{ + Description: proto.String("A description"), + Id: proto.Int32(0), + Metadata: proto.String("some metadata"), + }) + Ω(err).ShouldNot(HaveOccurred()) + + failures := InterceptGomegaFailures(func() { + http.Post(s.URL()+"/proto", "application/x-protobuf", bytes.NewReader(serialized)) + }) + Ω(failures).Should(HaveLen(1)) + }) + + It("should verify the proto body and the content type", func() { + serialized, err := proto.Marshal(message) + Ω(err).ShouldNot(HaveOccurred()) + + failures := InterceptGomegaFailures(func() { + http.Post(s.URL()+"/proto", "application/not-x-protobuf", bytes.NewReader(serialized)) + }) + Ω(failures).Should(HaveLen(1)) + }) + }) + Describe("RespondWith", func() { Context("without headers", func() { BeforeEach(func() { @@ -908,5 +957,84 @@ }) }) }) + + Describe("RespondWithProto", func() { + var message *protobuf.SimpleMessage + + BeforeEach(func() { + message = new(protobuf.SimpleMessage) + message.Description = proto.String("A description") + message.Id = proto.Int32(99) + }) + + Context("when no optional headers are set", func() { + BeforeEach(func() { + s.AppendHandlers(CombineHandlers( + VerifyRequest("POST", "/proto"), + RespondWithProto(http.StatusCreated, message), + )) + }) + + It("should return the response", func() { + resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", nil) + Ω(err).ShouldNot(HaveOccurred()) + + Ω(resp.StatusCode).Should(Equal(http.StatusCreated)) + + var received protobuf.SimpleMessage + body, err := ioutil.ReadAll(resp.Body) + err = proto.Unmarshal(body, &received) + Ω(err).ShouldNot(HaveOccurred()) + }) + + It("should set the Content-Type header to application/x-protobuf", func() { + resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", nil) + Ω(err).ShouldNot(HaveOccurred()) + + Ω(resp.Header["Content-Type"]).Should(Equal([]string{"application/x-protobuf"})) + }) + }) + + Context("when optional headers are set", func() { + var headers http.Header + BeforeEach(func() { + headers = http.Header{"Stuff": []string{"things"}} + }) + + JustBeforeEach(func() { + s.AppendHandlers(CombineHandlers( + VerifyRequest("POST", "/proto"), + RespondWithProto(http.StatusCreated, message, headers), + )) + }) + + It("should preserve those headers", func() { + resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", nil) + Ω(err).ShouldNot(HaveOccurred()) + + Ω(resp.Header["Stuff"]).Should(Equal([]string{"things"})) + }) + + It("should set the Content-Type header to application/x-protobuf", func() { + resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", nil) + Ω(err).ShouldNot(HaveOccurred()) + + Ω(resp.Header["Content-Type"]).Should(Equal([]string{"application/x-protobuf"})) + }) + + Context("when setting the Content-Type explicitly", func() { + BeforeEach(func() { + headers["Content-Type"] = []string{"not-x-protobuf"} + }) + + It("should use the Content-Type header that was explicitly set", func() { + resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", nil) + Ω(err).ShouldNot(HaveOccurred()) + + Ω(resp.Header["Content-Type"]).Should(Equal([]string{"not-x-protobuf"})) + }) + }) + }) + }) }) })