add tests, fix bugs
diff --git a/api.go b/api.go index 553673c..55e5fcb 100644 --- a/api.go +++ b/api.go
@@ -130,7 +130,7 @@ services.API().HandleFunc(a.blobEndpoint, a.apiReturnBlobData).Methods("GET") services.API().HandleFunc(a.configStatusEndpoint, a.apiPutConfigStatus).Methods("PUT") services.API().HandleFunc(a.heartbeatEndpoint, a.apiPutHeartbeat).Methods("PUT") - services.API().HandleFunc(a.registerEndpoint, a.apiPutRegister).Methods("POST") + services.API().HandleFunc(a.registerEndpoint, a.apiPutRegister).Methods("PUT") a.apiInitialized = true log.Debug("API endpoints initialized") } @@ -445,7 +445,7 @@ vars := mux.Vars(r) uuid := vars["uuid"] if !isValidUuid(uuid) { - a.writeError(w, http.StatusBadRequest, API_ERR_INVALID_PARAMETERS, "Bad/Missing gateway uuid") + a.writeError(w, http.StatusBadRequest, API_ERR_INVALID_PARAMETERS, "Bad/Missing gateway UUID") return } reported := r.Header.Get("reportedTime") @@ -462,7 +462,7 @@ case http.StatusOK: a.writePutHeartbeatResp(w, trackerResp) default: - log.Infof("apiPutHeartbeat code: %v Reason: %v", trackerResp.code, trackerResp.body) + log.Infof("apiPutHeartbeat code: %v Reason: %v", trackerResp.code, string(trackerResp.body)) a.writeError(w, trackerResp.code, API_ERR_FROM_TRACKER, string(trackerResp.body)) } } @@ -540,10 +540,7 @@ } func isValidUuid(uuid string) bool { - r, err := regexp.Compile("^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-4[a-fA-F0-9]{3}-[8|9|aA|bB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$") - if err != nil { - return false - } + r := regexp.MustCompile("^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-4[a-fA-F0-9]{3}-[8|9|aA|bB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$") return r.MatchString(uuid) } @@ -575,7 +572,7 @@ case !isValidUuid(body.Uuid): return false, "Bad/Missing gateway UUID" case body.ReportedTime == "" || !isIso8601(body.ReportedTime): - return false, "Bad/Missing gateway ReportedTimeService" + return false, "Bad/Missing gateway reportedTime" } return true, "" } @@ -598,7 +595,7 @@ case !isValidUuid(body.ServiceId): return false, "Bad/Missing gateway ServiceId" case body.ReportedTime == "": - return false, "Bad/Missing gateway ReportedTimeService" + return false, "Bad/Missing gateway reportedTime" } for _, s := range body.StatusDetails {
diff --git a/api_test.go b/api_test.go index ccd3b0c..8178020 100644 --- a/api_test.go +++ b/api_test.go
@@ -19,13 +19,16 @@ "net/http" "net/url" + "bytes" "crypto/rand" "fmt" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" mathrand "math/rand" "os" + "reflect" "strconv" + "strings" "time" ) @@ -43,13 +46,16 @@ testCount += 1 dummyDbMan = &dummyDbManager{} testApiMan = &apiManager{ - dbMan: dummyDbMan, - deploymentsEndpoint: deploymentsEndpoint + strconv.Itoa(testCount), - blobEndpoint: blobEndpointPath + strconv.Itoa(testCount) + "/{blobId}", - eTag: int64(testCount * 10), - deploymentsChanged: make(chan interface{}, 5), - addSubscriber: make(chan chan deploymentsResult), - removeSubscriber: make(chan chan deploymentsResult), + dbMan: dummyDbMan, + deploymentsEndpoint: deploymentsEndpoint + strconv.Itoa(testCount), + blobEndpoint: blobEndpointPath + strconv.Itoa(testCount) + "/{blobId}", + configStatusEndpoint: "/test" + strconv.Itoa(testCount) + configStatusEndpoint, + heartbeatEndpoint: "/test" + strconv.Itoa(testCount) + heartbeatEndpoint, + registerEndpoint: "/test" + strconv.Itoa(testCount) + registerEndpoint, + eTag: int64(testCount * 10), + deploymentsChanged: make(chan interface{}, 5), + addSubscriber: make(chan chan deploymentsResult), + removeSubscriber: make(chan chan deploymentsResult), } testApiMan.InitAPI() time.Sleep(100 * time.Millisecond) @@ -274,6 +280,212 @@ }) }) + FContext("Tracking endpoints", func() { + var dummyClient *dummyTrackerClient + var testClient *http.Client + + var _ = BeforeEach(func() { + dummyClient = &dummyTrackerClient{} + testApiMan.trackerCl = dummyClient + testClient = &http.Client{} + }) + + var _ = AfterEach(func() { + + }) + + Context("PUT /heartbeat/{uuid}", func() { + It("/heartbeat should validate request", func() { + // setup test data + dummyClient.code = http.StatusOK + testData := [][]string{ + {GenerateUUID(), time.Now().Format(time.RFC3339)}, + {GenerateUUID(), time.Now().Format(iso8601)}, + {GenerateUUID(), "invalid-time"}, + {GenerateUUID(), time.Now().Format(time.RubyDate)}, + {"invalid-uuid", time.Now().Format(iso8601)}, + {"invalid-uuid", "invalid-time"}, + {"", time.Now().Format(time.RFC3339)}, + {GenerateUUID(), ""}, + } + + expectedCode := []int{ + http.StatusOK, + http.StatusOK, + http.StatusBadRequest, + http.StatusBadRequest, + http.StatusBadRequest, + http.StatusBadRequest, + http.StatusNotFound, + http.StatusBadRequest, + } + + expectedBody := []string{ + strings.Join(testData[0], " "), + strings.Join(testData[1], " "), + "reportedTime", + "reportedTime", + "UUID", + "UUID", + "", + "reportedTime", + } + + // setup http client + uri, err := url.Parse(apiTestUrl) + Expect(err).Should(Succeed()) + for i, data := range testData { + uri.Path = strings.Replace(testApiMan.heartbeatEndpoint, "{uuid}", data[0], 1) + log.Debug(uri.String()) + req, err := http.NewRequest("PUT", uri.String(), nil) + Expect(err).Should(Succeed()) + req.Header.Set("reportedTime", data[1]) + // http put + res, err := testClient.Do(req) + Expect(err).Should(Succeed()) + // parse response + defer res.Body.Close() + Expect(res.StatusCode).Should(Equal(expectedCode[i])) + body, err := ioutil.ReadAll(res.Body) + Expect(err).Should(Succeed()) + Expect(strings.Contains(strings.ToLower(string(body)), strings.ToLower(expectedBody[i]))).To(BeTrue()) + } + }) + + It("/heartbeat should populate errors from tracker", func() { + // setup test data + testData := [][]string{ + {GenerateUUID(), time.Now().Format(iso8601)}, + {GenerateUUID(), time.Now().Format(iso8601)}, + {GenerateUUID(), time.Now().Format(iso8601)}, + } + + expectedCode := []int{ + http.StatusBadRequest, + http.StatusInternalServerError, + http.StatusBadGateway, + } + + expectedBody := []string{ + strings.Join(testData[0], " "), + strings.Join(testData[1], " "), + strings.Join(testData[2], " "), + } + + // setup http client + uri, err := url.Parse(apiTestUrl) + Expect(err).Should(Succeed()) + for i, data := range testData { + dummyClient.code = expectedCode[i] + uri.Path = strings.Replace(testApiMan.heartbeatEndpoint, "{uuid}", data[0], 1) + req, err := http.NewRequest("PUT", uri.String(), nil) + Expect(err).Should(Succeed()) + req.Header.Set("reportedTime", data[1]) + // http put + res, err := testClient.Do(req) + Expect(err).Should(Succeed()) + // parse response + defer res.Body.Close() + Expect(res.StatusCode).Should(Equal(expectedCode[i])) + body, err := ioutil.ReadAll(res.Body) + Expect(err).Should(Succeed()) + Expect(strings.Contains(strings.ToLower(string(body)), strings.ToLower(expectedBody[i]))).To(BeTrue()) + } + }) + }) + + Context("PUT /register/{uuid}", func() { + It("/register should validate request", func() { + // setup test data + dummyClient.code = http.StatusOK + + testData := [][]string{ + {GenerateUUID(), "pod", "podType", time.Now().Format(iso8601), "name", "type"}, + {"", "pod", "podType", time.Now().Format(iso8601), "name", "type"}, + {GenerateUUID(), "", "podType", time.Now().Format(iso8601), "name", "type"}, + {GenerateUUID(), "pod", "", time.Now().Format(iso8601), "name", "type"}, + {GenerateUUID(), "pod", "podType", "", "name", "type"}, + {GenerateUUID(), "pod", "podType", time.Now().Format(iso8601), "", "type"}, + {GenerateUUID(), "pod", "podType", time.Now().Format(iso8601), "name", ""}, + {"invalid-uuid", "pod", "podType", time.Now().Format(iso8601), "name", "type"}, + {GenerateUUID(), "pod", "podType", "invalid-time", "name", "type"}, + {GenerateUUID(), "pod", "podType", time.Now().Format(iso8601), "name", "type"}, + } + + pathUuid := []string{ + "", + GenerateUUID(), + "", + "", + "", + "", + "", + "", + "", + GenerateUUID(), + } + + expectedCode := []int{ + http.StatusOK, + http.StatusBadRequest, + http.StatusOK, + http.StatusOK, + http.StatusBadRequest, + http.StatusOK, + http.StatusOK, + http.StatusBadRequest, + http.StatusBadRequest, + http.StatusBadRequest, + } + + expectedBody := []string{ + strings.Join(testData[0], " "), + "mismatch UUID", + strings.Join(testData[2], " "), + strings.Join(testData[3], " "), + "reportedTime", + strings.Join(testData[5], " "), + strings.Join(testData[6], " "), + "UUID", + "reportedTime", + "mismatch UUID", + } + + // setup http client + uri, err := url.Parse(apiTestUrl) + Expect(err).Should(Succeed()) + for i, data := range testData { + uuid := pathUuid[i] + if uuid == "" { + uuid = data[0] + } + uri.Path = strings.Replace(testApiMan.registerEndpoint, "{uuid}", uuid, 1) + reqBody, err := json.Marshal(registerBody{ + Uuid: data[0], + Pod: data[1], + PodType: data[2], + ReportedTime: data[3], + Name: data[4], + Type: data[5], + }) + Expect(err).Should(Succeed()) + log.Debug(uri.String()) + req, err := http.NewRequest("PUT", uri.String(), bytes.NewReader(reqBody)) + Expect(err).Should(Succeed()) + // http put + res, err := testClient.Do(req) + Expect(err).Should(Succeed()) + // parse response + defer res.Body.Close() + Expect(res.StatusCode).Should(Equal(expectedCode[i])) + body, err := ioutil.ReadAll(res.Body) + Expect(err).Should(Succeed()) + Expect(strings.Contains(strings.ToLower(string(body)), strings.ToLower(expectedBody[i]))).To(BeTrue()) + } + }) + }) + }) + }) func setTestDeployments(dummyDbMan *dummyDbManager, self string) []ApiDeploymentDetails { @@ -386,3 +598,44 @@ buff[8] = (buff[8] | 0x80) & 0xBF return fmt.Sprintf("%x-%x-%x-%x-%x", buff[0:4], buff[4:6], buff[6:8], buff[8:10], buff[10:]) } + +type dummyTrackerClient struct { + code int + args []string +} + +func (d *dummyTrackerClient) putConfigStatus(reqBody *configStatusBody) *trackerResponse { + + return &trackerResponse{ + code: d.code, + contentType: "application/octet-stream", + body: []byte(concatenateFields(reqBody)), + } +} + +func (d *dummyTrackerClient) putRegister(uuid string, reqBody *registerBody) *trackerResponse { + + return &trackerResponse{ + code: d.code, + contentType: "application/octet-stream", + body: []byte(concatenateFields(reqBody)), + } +} + +func (d *dummyTrackerClient) putHeartbeat(uuid, reported string) *trackerResponse { + return &trackerResponse{ + code: d.code, + contentType: "application/octet-stream", + body: []byte(uuid + " " + reported), + } +} + +func concatenateFields(s interface{}) string { + v := reflect.ValueOf(s).Elem() + fields := []string{} + for i := 0; i < v.NumField(); i++ { + fields = append(fields, v.Field(i).String()) + } + log.Warn(strings.Join(fields, " ")) + return strings.Join(fields, " ") +}