Add VerifyProtoRepresenting and RespondWithProto
These handlers should make it a little easer to validate protobuf
messages that are exchanged over http.
diff --git a/ghttp/handlers.go b/ghttp/handlers.go
index 1e481f3..fdd4034 100644
--- a/ghttp/handlers.go
+++ b/ghttp/handlers.go
@@ -7,7 +7,9 @@
"io/ioutil"
"net/http"
"net/url"
+ "reflect"
+ "github.com/golang/protobuf/proto"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/types"
)
@@ -137,6 +139,32 @@
return VerifyForm(url.Values{key: values})
}
+//VerifyProtoRepresenting returns a handler that verifies that the body of the request is a valid protobuf
+//representation of the passed message.
+//
+//VerifyProtoRepresenting also verifies that the request's content type is application/x-protobuf
+func VerifyProtoRepresenting(expected proto.Message) http.HandlerFunc {
+ return CombineHandlers(
+ VerifyContentType("application/x-protobuf"),
+ func(w http.ResponseWriter, req *http.Request) {
+ body, err := ioutil.ReadAll(req.Body)
+ Ω(err).ShouldNot(HaveOccurred())
+ req.Body.Close()
+
+ expectedType := reflect.TypeOf(expected)
+ actualValuePtr := reflect.New(expectedType.Elem())
+
+ actual, ok := actualValuePtr.Interface().(proto.Message)
+ Ω(ok).Should(BeTrue(), "Message value is not a proto.Message")
+
+ err = proto.Unmarshal(body, actual)
+ Ω(err).ShouldNot(HaveOccurred(), "Failed to unmarshal protobuf")
+
+ Ω(actual).Should(Equal(expected), "ProtoBuf Mismatch")
+ },
+ )
+}
+
func copyHeader(src http.Header, dst http.Header) {
for key, value := range src {
dst[key] = value
@@ -245,3 +273,28 @@
w.Write(data)
}
}
+
+//RespondWithProto returns a handler that responds to a request with the specified status code and a body
+//containing the protobuf serialization of the provided message.
+//
+//Also, RespondWithProto can be given an optional http.Header. The headers defined therein will be added to the response headers.
+func RespondWithProto(statusCode int, message proto.Message, optionalHeader ...http.Header) http.HandlerFunc {
+ return func(w http.ResponseWriter, req *http.Request) {
+ data, err := proto.Marshal(message)
+ Ω(err).ShouldNot(HaveOccurred())
+
+ var headers http.Header
+ if len(optionalHeader) == 1 {
+ headers = optionalHeader[0]
+ } else {
+ headers = make(http.Header)
+ }
+ if _, found := headers["Content-Type"]; !found {
+ headers["Content-Type"] = []string{"application/x-protobuf"}
+ }
+ copyHeader(headers, w.Header())
+
+ w.WriteHeader(statusCode)
+ w.Write(data)
+ }
+}
diff --git a/ghttp/protobuf/protobuf.go b/ghttp/protobuf/protobuf.go
new file mode 100644
index 0000000..b2972bc
--- /dev/null
+++ b/ghttp/protobuf/protobuf.go
@@ -0,0 +1,3 @@
+package protobuf
+
+//go:generate protoc --go_out=. simple_message.proto
diff --git a/ghttp/protobuf/simple_message.pb.go b/ghttp/protobuf/simple_message.pb.go
new file mode 100644
index 0000000..c55a484
--- /dev/null
+++ b/ghttp/protobuf/simple_message.pb.go
@@ -0,0 +1,55 @@
+// Code generated by protoc-gen-go.
+// source: simple_message.proto
+// DO NOT EDIT!
+
+/*
+Package protobuf is a generated protocol buffer package.
+
+It is generated from these files:
+ simple_message.proto
+
+It has these top-level messages:
+ SimpleMessage
+*/
+package protobuf
+
+import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
+import math "math"
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+type SimpleMessage struct {
+ Description *string `protobuf:"bytes,1,req,name=description" json:"description,omitempty"`
+ Id *int32 `protobuf:"varint,2,req,name=id" json:"id,omitempty"`
+ Metadata *string `protobuf:"bytes,3,opt,name=metadata" json:"metadata,omitempty"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *SimpleMessage) Reset() { *m = SimpleMessage{} }
+func (m *SimpleMessage) String() string { return proto.CompactTextString(m) }
+func (*SimpleMessage) ProtoMessage() {}
+
+func (m *SimpleMessage) GetDescription() string {
+ if m != nil && m.Description != nil {
+ return *m.Description
+ }
+ return ""
+}
+
+func (m *SimpleMessage) GetId() int32 {
+ if m != nil && m.Id != nil {
+ return *m.Id
+ }
+ return 0
+}
+
+func (m *SimpleMessage) GetMetadata() string {
+ if m != nil && m.Metadata != nil {
+ return *m.Metadata
+ }
+ return ""
+}
diff --git a/ghttp/protobuf/simple_message.proto b/ghttp/protobuf/simple_message.proto
new file mode 100644
index 0000000..35b7145
--- /dev/null
+++ b/ghttp/protobuf/simple_message.proto
@@ -0,0 +1,9 @@
+syntax = "proto2";
+
+package protobuf;
+
+message SimpleMessage {
+ required string description = 1;
+ required int32 id = 2;
+ optional string metadata = 3;
+}
diff --git a/ghttp/test_server_test.go b/ghttp/test_server_test.go
index 3540ebf..497e46c 100644
--- a/ghttp/test_server_test.go
+++ b/ghttp/test_server_test.go
@@ -7,7 +7,9 @@
"net/url"
"regexp"
+ "github.com/golang/protobuf/proto"
"github.com/onsi/gomega/gbytes"
+ "github.com/onsi/gomega/ghttp/protobuf"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@@ -636,6 +638,53 @@
})
})
+ Describe("VerifyProtoRepresenting", func() {
+ var message *protobuf.SimpleMessage
+
+ BeforeEach(func() {
+ message = new(protobuf.SimpleMessage)
+ message.Description = proto.String("A description")
+ message.Id = proto.Int32(0)
+
+ s.AppendHandlers(CombineHandlers(
+ VerifyRequest("POST", "/proto"),
+ VerifyProtoRepresenting(message),
+ ))
+ })
+
+ It("verifies the proto body and the content type", func() {
+ serialized, err := proto.Marshal(message)
+ Ω(err).ShouldNot(HaveOccurred())
+
+ resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", bytes.NewReader(serialized))
+ Ω(err).ShouldNot(HaveOccurred())
+ })
+
+ It("should verify the proto body and the content type", func() {
+ serialized, err := proto.Marshal(&protobuf.SimpleMessage{
+ Description: proto.String("A description"),
+ Id: proto.Int32(0),
+ Metadata: proto.String("some metadata"),
+ })
+ Ω(err).ShouldNot(HaveOccurred())
+
+ failures := InterceptGomegaFailures(func() {
+ http.Post(s.URL()+"/proto", "application/x-protobuf", bytes.NewReader(serialized))
+ })
+ Ω(failures).Should(HaveLen(1))
+ })
+
+ It("should verify the proto body and the content type", func() {
+ serialized, err := proto.Marshal(message)
+ Ω(err).ShouldNot(HaveOccurred())
+
+ failures := InterceptGomegaFailures(func() {
+ http.Post(s.URL()+"/proto", "application/not-x-protobuf", bytes.NewReader(serialized))
+ })
+ Ω(failures).Should(HaveLen(1))
+ })
+ })
+
Describe("RespondWith", func() {
Context("without headers", func() {
BeforeEach(func() {
@@ -908,5 +957,84 @@
})
})
})
+
+ Describe("RespondWithProto", func() {
+ var message *protobuf.SimpleMessage
+
+ BeforeEach(func() {
+ message = new(protobuf.SimpleMessage)
+ message.Description = proto.String("A description")
+ message.Id = proto.Int32(99)
+ })
+
+ Context("when no optional headers are set", func() {
+ BeforeEach(func() {
+ s.AppendHandlers(CombineHandlers(
+ VerifyRequest("POST", "/proto"),
+ RespondWithProto(http.StatusCreated, message),
+ ))
+ })
+
+ It("should return the response", func() {
+ resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", nil)
+ Ω(err).ShouldNot(HaveOccurred())
+
+ Ω(resp.StatusCode).Should(Equal(http.StatusCreated))
+
+ var received protobuf.SimpleMessage
+ body, err := ioutil.ReadAll(resp.Body)
+ err = proto.Unmarshal(body, &received)
+ Ω(err).ShouldNot(HaveOccurred())
+ })
+
+ It("should set the Content-Type header to application/x-protobuf", func() {
+ resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", nil)
+ Ω(err).ShouldNot(HaveOccurred())
+
+ Ω(resp.Header["Content-Type"]).Should(Equal([]string{"application/x-protobuf"}))
+ })
+ })
+
+ Context("when optional headers are set", func() {
+ var headers http.Header
+ BeforeEach(func() {
+ headers = http.Header{"Stuff": []string{"things"}}
+ })
+
+ JustBeforeEach(func() {
+ s.AppendHandlers(CombineHandlers(
+ VerifyRequest("POST", "/proto"),
+ RespondWithProto(http.StatusCreated, message, headers),
+ ))
+ })
+
+ It("should preserve those headers", func() {
+ resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", nil)
+ Ω(err).ShouldNot(HaveOccurred())
+
+ Ω(resp.Header["Stuff"]).Should(Equal([]string{"things"}))
+ })
+
+ It("should set the Content-Type header to application/x-protobuf", func() {
+ resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", nil)
+ Ω(err).ShouldNot(HaveOccurred())
+
+ Ω(resp.Header["Content-Type"]).Should(Equal([]string{"application/x-protobuf"}))
+ })
+
+ Context("when setting the Content-Type explicitly", func() {
+ BeforeEach(func() {
+ headers["Content-Type"] = []string{"not-x-protobuf"}
+ })
+
+ It("should use the Content-Type header that was explicitly set", func() {
+ resp, err = http.Post(s.URL()+"/proto", "application/x-protobuf", nil)
+ Ω(err).ShouldNot(HaveOccurred())
+
+ Ω(resp.Header["Content-Type"]).Should(Equal([]string{"not-x-protobuf"}))
+ })
+ })
+ })
+ })
})
})