[ISSUE-69568832] refactor code, fix bugs, add tests
diff --git a/api.go b/api.go index c0e75f3..62e370c 100644 --- a/api.go +++ b/api.go
@@ -24,28 +24,34 @@ const tokenEndpoint = "/accesstoken" +const ( + parBlock = "block" + parTag = "If-None-Match" +) + type ApiManager struct { tokenMan tokenManager + endpoint string } func (a *ApiManager) InitAPI(api apid.APIService) { - api.HandleFunc(tokenEndpoint, a.getAccessToken).Methods("GET") + api.HandleFunc(a.endpoint, a.getAccessToken).Methods("GET") } func (a *ApiManager) getAccessToken(w http.ResponseWriter, r *http.Request) { - b := r.URL.Query().Get("block") + b := r.URL.Query().Get(parBlock) var timeout int if b != "" { var err error timeout, err = strconv.Atoi(b) - if err != nil { + if err != nil || timeout < 0 { writeError(w, http.StatusBadRequest, "bad block value, must be number of seconds") return } } log.Debugf("api timeout: %d", timeout) - ifNoneMatch := r.Header.Get("If-None-Match") - + ifNoneMatch := r.Header.Get(parTag) + log.Debugf("ifNoneMatch: %s", ifNoneMatch) if a.tokenMan.getBearerToken() != ifNoneMatch { w.Write([]byte(a.tokenMan.getBearerToken())) return
diff --git a/api_test.go b/api_test.go new file mode 100644 index 0000000..d46ac4c --- /dev/null +++ b/api_test.go
@@ -0,0 +1,125 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package apidApigeeSync + +import ( + "fmt" + "github.com/apid/apid-core" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "io/ioutil" + "net/http" + "net/url" + "strconv" + "time" +) + +const ( + apiTestUrl = "http://127.0.0.1:9000" +) + +var _ = Describe("API Manager", func() { + testCount := 0 + var testApiMan *ApiManager + var dummyTokenMan *dummyTokenManager + var client *http.Client + BeforeEach(func() { + testCount++ + dummyTokenMan = &dummyTokenManager{ + token: fmt.Sprintf("test_token_%d", testCount), + tokenReadyChan: make(chan bool, 1), + } + testApiMan = &ApiManager{ + endpoint: tokenEndpoint + strconv.Itoa(testCount), + tokenMan: dummyTokenMan, + } + testApiMan.InitAPI(apid.API()) + time.Sleep(100 * time.Millisecond) + client = &http.Client{} + }) + + clientGet := func(path string, pars map[string][]string, header map[string][]string) (int, []byte) { + uri, err := url.Parse(apiTestUrl + path) + Expect(err).Should(Succeed()) + query := url.Values(pars) + uri.RawQuery = query.Encode() + httpReq, err := http.NewRequest("GET", uri.String(), nil) + httpReq.Header = http.Header(header) + Expect(err).Should(Succeed()) + res, err := client.Do(httpReq) + Expect(err).Should(Succeed()) + defer res.Body.Close() + responseBody, err := ioutil.ReadAll(res.Body) + Expect(err).Should(Succeed()) + return res.StatusCode, responseBody + } + + It("should get token without long-polling", func() { + code, res := clientGet(testApiMan.endpoint, nil, nil) + Expect(code).Should(Equal(http.StatusOK)) + Expect(string(res)).Should(Equal(dummyTokenMan.token)) + }) + + It("should get bad request for invalid timeout", func() { + code, _ := clientGet(testApiMan.endpoint, map[string][]string{ + parBlock: {"invalid"}, + }, map[string][]string{ + parTag: {dummyTokenMan.getBearerToken()}, + }) + Expect(code).Should(Equal(http.StatusBadRequest)) + + code, _ = clientGet(testApiMan.endpoint, map[string][]string{ + parBlock: {"-1"}, + }, map[string][]string{ + parTag: {dummyTokenMan.getBearerToken()}, + }) + Expect(code).Should(Equal(http.StatusBadRequest)) + }) + + It("should get token immediately if mismatch", func() { + code, res := clientGet(testApiMan.endpoint, map[string][]string{ + parBlock: {"10"}, + }, map[string][]string{ + parTag: {"mismatch"}, + }) + Expect(code).Should(Equal(http.StatusOK)) + Expect(string(res)).Should(Equal(dummyTokenMan.token)) + }, 3) + + It("should get StatusNotModified if timeout", func() { + code, _ := clientGet(testApiMan.endpoint, map[string][]string{ + parBlock: {"1"}, + }, map[string][]string{ + parTag: {dummyTokenMan.getBearerToken()}, + }) + Expect(code).Should(Equal(http.StatusNotModified)) + }, 3) + + It("should do long-polling", func() { + go func() { + time.Sleep(1) + dummyTokenMan.token = "new_token" + dummyTokenMan.tokenReadyChan <- true + }() + code, res := clientGet(testApiMan.endpoint, map[string][]string{ + parBlock: {"10"}, + }, map[string][]string{ + parTag: {dummyTokenMan.getBearerToken()}, + }) + Expect(code).Should(Equal(http.StatusOK)) + Expect(string(res)).Should(Equal(dummyTokenMan.getBearerToken())) + }, 3) + +})
diff --git a/apigeeSync_suite_test.go b/apigeeSync_suite_test.go index a53a488..47c5498 100644 --- a/apigeeSync_suite_test.go +++ b/apigeeSync_suite_test.go
@@ -24,27 +24,27 @@ "time" "github.com/apid/apid-core" - + "github.com/apid/apid-core/events" "github.com/apid/apid-core/factory" ) -var ( - tmpDir string -) - const dummyConfigValue string = "placeholder" const expectedClusterId = "bootstrap" +var tmpDir string + var _ = BeforeSuite(func() { apid.Initialize(factory.DefaultServicesFactory()) - config = apid.Config() dataService = apid.Data() - eventService = apid.Events() + config = apid.Config() + apiService = apid.API() + go apiService.Listen() + //dataService = apid.Data() log = apid.Log().ForModule("apigeeSync") var err error - tmpDir, err = ioutil.TempDir("", "api_test") + tmpDir, err = ioutil.TempDir("", "apid_test") Expect(err).NotTo(HaveOccurred()) - config.Set("local_storage_path", tmpDir) + config.Set(configLocalStoragePath, tmpDir) config.Set(configProxyServerBaseURI, dummyConfigValue) config.Set(configSnapServerBaseURI, dummyConfigValue) config.Set(configChangeServerBaseURI, dummyConfigValue) @@ -57,6 +57,7 @@ }, 3) var _ = BeforeEach(func() { + eventService = events.CreateService() config.Set(configName, "testhost") config.Set(configApidClusterId, expectedClusterId) apidInfo.ClusterID = expectedClusterId @@ -66,14 +67,22 @@ }) var _ = AfterEach(func() { + cleanCommonDb() + eventService.Close() }) var _ = AfterSuite(func() { - apid.Events().Close() - os.RemoveAll(tmpDir) + Expect(os.RemoveAll(tmpDir)).Should(Succeed()) }) func TestApigeeSync(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "ApigeeSync Suite") } + +func cleanCommonDb() { + db, err := dataService.DB() + Expect(err).Should(Succeed()) + _, err = db.Exec(`DROP TABLE IF EXISTS APID;`) + Expect(err).Should(Succeed()) +}
diff --git a/apigee_sync.go b/apigee_sync.go deleted file mode 100644 index 47b9fd6..0000000 --- a/apigee_sync.go +++ /dev/null
@@ -1,108 +0,0 @@ -// Copyright 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package apidApigeeSync - -import ( - "net/http" - "time" -) - -const ( - httpTimeout = time.Minute - pluginTimeout = time.Minute - maxIdleConnsPerHost = 10 -) - -/* - * Call toExecute repeatedly until it does not return an error, with an exponential backoff policy - * for retrying on errors - */ -func pollWithBackoff(quit chan bool, toExecute func(chan bool) error, handleError func(error)) { - - backoff := NewExponentialBackoff(200*time.Millisecond, config.GetDuration(configPollInterval), 2, true) - - //inintialize the retry channel to start first attempt immediately - retry := time.After(0 * time.Millisecond) - - for { - select { - case <-quit: - log.Info("Quit signal recieved. Returning") - return - case <-retry: - start := time.Now() - - err := toExecute(quit) - if err == nil { - return - } - - if _, ok := err.(quitSignalError); ok { - return - } - - end := time.Now() - //error encountered, since we would have returned above otherwise - handleError(err) - - /* TODO keep this around? Imagine an immediately erroring service, - * causing many sequential requests which could pollute logs - */ - //only backoff if the request took less than one second - if end.After(start.Add(time.Second)) { - backoff.Reset() - retry = time.After(0 * time.Millisecond) - } else { - retry = time.After(backoff.Duration()) - } - } - } -} - -func addHeaders(req *http.Request, token string) { - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("apid_instance_id", apidInfo.InstanceID) - req.Header.Set("apid_cluster_Id", apidInfo.ClusterID) - req.Header.Set("updated_at_apid", time.Now().Format(time.RFC3339)) -} - -type changeServerError struct { - Code string `json:"code"` -} - -type quitSignalError struct { -} - -type expected200Error struct { -} - -type authFailError struct { -} - -func (an expected200Error) Error() string { - return "Did not recieve OK response" -} - -func (a quitSignalError) Error() string { - return "Signal to quit encountered" -} - -func (a changeServerError) Error() string { - return a.Code -} - -func (a authFailError) Error() string { - return "Authorization failed" -}
diff --git a/backoff.go b/backoff.go deleted file mode 100644 index bad8077..0000000 --- a/backoff.go +++ /dev/null
@@ -1,98 +0,0 @@ -// Copyright 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package apidApigeeSync - -import ( - "math" - "math/rand" - "time" -) - -const defaultInitial time.Duration = 200 * time.Millisecond -const defaultMax time.Duration = 10 * time.Second -const defaultFactor float64 = 2 - -type Backoff struct { - attempt int - initial, max time.Duration - jitter bool - backoffStrategy func() time.Duration -} - -type ExponentialBackoff struct { - Backoff - factor float64 -} - -func NewExponentialBackoff(initial, max time.Duration, factor float64, jitter bool) *ExponentialBackoff { - backoff := &ExponentialBackoff{} - - if initial <= 0 { - initial = defaultInitial - } - if max <= 0 { - max = defaultMax - } - - if factor <= 0 { - factor = defaultFactor - } - - backoff.initial = initial - backoff.max = max - backoff.attempt = 0 - backoff.factor = factor - backoff.jitter = jitter - backoff.backoffStrategy = backoff.exponentialBackoffStrategy - - return backoff -} - -func (b *Backoff) Duration() time.Duration { - d := b.backoffStrategy() - b.attempt++ - return d -} - -func (b *ExponentialBackoff) exponentialBackoffStrategy() time.Duration { - - initial := float64(b.Backoff.initial) - attempt := float64(b.Backoff.attempt) - duration := initial * math.Pow(b.factor, attempt) - - if duration > math.MaxInt64 { - return b.max - } - dur := time.Duration(duration) - - if b.jitter { - duration = (rand.Float64()*(duration-initial) + initial) - } - - if dur > b.max { - return b.max - } - - log.Debugf("Backing off for %d ms", int64(dur/time.Millisecond)) - return dur -} - -func (b *Backoff) Reset() { - b.attempt = 0 -} - -func (b *Backoff) Attempt() int { - return b.attempt -}
diff --git a/change_test.go b/change_test.go index a467a18..97d5fb3 100644 --- a/change_test.go +++ b/change_test.go
@@ -19,8 +19,10 @@ "github.com/apid/apid-core/api" "github.com/apigee-labs/transicator/common" . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" "net/http" "net/http/httptest" + "strconv" "time" ) @@ -31,118 +33,179 @@ var _ = Describe("Change Agent", func() { Context("Change Agent Unit Tests", func() { - testCount := 0 - var testChangeMan *pollChangeManager - var dummyDbMan *dummyDbManager - var dummySnapMan *dummySnapshotManager - var dummyTokenMan *dummyTokenManager - var testServer *httptest.Server - var testRouter apid.Router - var testMock *MockServer - BeforeEach(func() { - testCount++ - dummyDbMan = &dummyDbManager{ - knownTables: map[string]bool{ - "_transicator_metadata": true, - "_transicator_tables": true, - "attributes": true, - "edgex_apid_cluster": true, - "edgex_data_scope": true, - "kms_api_product": true, - "kms_app": true, - "kms_app_credential": true, - "kms_app_credential_apiproduct_mapper": true, - "kms_company": true, - "kms_company_developer": true, - "kms_deployment": true, - "kms_developer": true, - "kms_organization": true, - }, - scopes: []string{"43aef41d"}, - } - dummySnapMan = &dummySnapshotManager{ - downloadCalledChan: make(chan bool, 1), - } - dummyTokenMan = &dummyTokenManager{ - invalidateChan: make(chan bool, 1), - } - client := &http.Client{} - testChangeMan = createChangeManager(dummyDbMan, dummySnapMan, dummyTokenMan, client) - testChangeMan.block = 0 - // create a new API service to have a new router for testing - testRouter = api.CreateService().Router() - testServer = httptest.NewServer(testRouter) - // set up mock server - mockParms := MockParms{ - ReliableAPI: true, - ClusterID: config.GetString(configApidClusterId), - TokenKey: config.GetString(configConsumerKey), - TokenSecret: config.GetString(configConsumerSecret), - Scope: "ert452", - Organization: "att", - Environment: "prod", - } - apidInfo.ClusterID = expectedClusterId - apidInfo.InstanceID = expectedInstanceId - testMock = Mock(mockParms, testRouter) - config.Set(configProxyServerBaseURI, testServer.URL) - config.Set(configSnapServerBaseURI, testServer.URL) - config.Set(configChangeServerBaseURI, testServer.URL) - config.Set(configPollInterval, 1*time.Millisecond) + Context("utils", func() { - }) + It("should correctly identify non-proper subsets with respect to maps", func() { + //test b proper subset of a + Expect(changesHaveNewTables(map[string]bool{"a": true, "b": true}, + []common.Change{{Table: "b"}}, + )).To(BeFalse()) - AfterEach(func() { - testServer.Close() - <-testChangeMan.close() - config.Set(configProxyServerBaseURI, dummyConfigValue) - config.Set(configSnapServerBaseURI, dummyConfigValue) - config.Set(configChangeServerBaseURI, dummyConfigValue) - config.Set(configPollInterval, 10*time.Millisecond) - }) + //test a == b + Expect(changesHaveNewTables(map[string]bool{"a": true, "b": true}, + []common.Change{{Table: "a"}, {Table: "b"}}, + )).To(BeFalse()) - It("test change agent with authorization failure", func() { - log.Debug("test change agent with authorization failure") - testMock.forceAuthFail() - testChangeMan.pollChangeWithBackoff() - // auth check fails - <-dummyTokenMan.invalidateChan - log.Debug("closing") - }) + //test b superset of a + Expect(changesHaveNewTables(map[string]bool{"a": true, "b": true}, + []common.Change{{Table: "a"}, {Table: "b"}, {Table: "c"}}, + )).To(BeTrue()) - It("test change agent with too old snapshot", func() { - log.Debug("test change agent with too old snapshot") - testMock.passAuthCheck() - testMock.forceNewSnapshot() - testChangeMan.pollChangeWithBackoff() - <-dummySnapMan.downloadCalledChan - log.Debug("closing") - }) + //test b not subset of a + Expect(changesHaveNewTables(map[string]bool{"a": true, "b": true}, + []common.Change{{Table: "c"}}, + )).To(BeTrue()) - It("change agent should retry with authorization failure", func(done Done) { - log.Debug("change agent should retry with authorization failure") - testMock.forceAuthFail() - testMock.forceNoSnapshot() - apid.Events().ListenFunc(ApigeeSyncEventSelector, func(event apid.Event) { + //test a empty + Expect(changesHaveNewTables(map[string]bool{}, + []common.Change{{Table: "a"}}, + )).To(BeTrue()) - if _, ok := event.(*common.ChangeList); ok { - closeDone := testChangeMan.close() - log.Debug("closing") - go func() { - // when close done, all handlers for the first snapshot have been executed - <-closeDone - close(done) - }() + //test b empty + Expect(changesHaveNewTables(map[string]bool{"a": true, "b": true}, + []common.Change{}, + )).To(BeFalse()) - } + //test b nil + Expect(changesHaveNewTables(map[string]bool{"a": true, "b": true}, nil)).To(BeFalse()) + + //test a nil + Expect(changesHaveNewTables(nil, + []common.Change{{Table: "a"}}, + )).To(BeTrue()) }) - testChangeMan.pollChangeWithBackoff() - // auth check fails - <-dummyTokenMan.invalidateChan - testMock.passAuthCheck() - }, 3) + It("Compare Sequence Number", func() { + Expect(getChangeStatus("1.1.1", "1.1.2")).To(Equal(1)) + Expect(getChangeStatus("1.1.1", "1.2.1")).To(Equal(1)) + Expect(getChangeStatus("1.2.1", "1.2.1")).To(Equal(0)) + Expect(getChangeStatus("1.2.1", "1.2.2")).To(Equal(1)) + Expect(getChangeStatus("2.2.1", "1.2.2")).To(Equal(-1)) + Expect(getChangeStatus("2.2.1", "2.2.0")).To(Equal(-1)) + }) + }) + + Context("changeManager", func() { + testCount := 0 + var testChangeMan *pollChangeManager + var dummyDbMan *dummyDbManager + var dummySnapMan *dummySnapshotManager + var dummyTokenMan *dummyTokenManager + var testServer *httptest.Server + var testRouter apid.Router + var testMock *MockServer + BeforeEach(func() { + testCount++ + dummyDbMan = &dummyDbManager{ + knownTables: map[string]bool{ + "_transicator_metadata": true, + "_transicator_tables": true, + "attributes": true, + "edgex_apid_cluster": true, + "edgex_data_scope": true, + "kms_api_product": true, + "kms_app": true, + "kms_app_credential": true, + "kms_app_credential_apiproduct_mapper": true, + "kms_company": true, + "kms_company_developer": true, + "kms_deployment": true, + "kms_developer": true, + "kms_organization": true, + }, + scopes: []string{"43aef41d"}, + lastSeqUpdated: make(chan string, 1), + } + dummySnapMan = &dummySnapshotManager{ + downloadCalledChan: make(chan bool, 1), + } + dummyTokenMan = &dummyTokenManager{ + invalidateChan: make(chan bool, 1), + } + client := &http.Client{} + testChangeMan = createChangeManager(dummyDbMan, dummySnapMan, dummyTokenMan, client) + testChangeMan.block = 0 + + // create a new API service to have a new router for testing + testRouter = api.CreateService().Router() + testServer = httptest.NewServer(testRouter) + // set up mock server + mockParms := MockParms{ + ReliableAPI: true, + ClusterID: config.GetString(configApidClusterId), + TokenKey: config.GetString(configConsumerKey), + TokenSecret: config.GetString(configConsumerSecret), + Scope: "", + Organization: "att", + Environment: "prod", + } + apidInfo.ClusterID = expectedClusterId + apidInfo.InstanceID = expectedInstanceId + testMock = Mock(mockParms, testRouter) + config.Set(configProxyServerBaseURI, testServer.URL) + config.Set(configSnapServerBaseURI, testServer.URL) + config.Set(configChangeServerBaseURI, testServer.URL) + config.Set(configPollInterval, 1*time.Millisecond) + + initialBackoffInterval = time.Millisecond + testMock.oauthToken = "test_token_" + strconv.Itoa(testCount) + dummyTokenMan.token = testMock.oauthToken + + }) + + AfterEach(func() { + testServer.Close() + <-testChangeMan.close() + config.Set(configProxyServerBaseURI, dummyConfigValue) + config.Set(configSnapServerBaseURI, dummyConfigValue) + config.Set(configChangeServerBaseURI, dummyConfigValue) + config.Set(configPollInterval, 10*time.Millisecond) + }) + + It("test change agent with authorization failure", func() { + log.Debug("test change agent with authorization failure") + testMock.forceAuthFailOnce() + testChangeMan.pollChangeWithBackoff() + // auth check fails + <-dummyTokenMan.invalidateChan + log.Debug("closing") + }) + + It("test change agent with too old snapshot", func() { + log.Debug("test change agent with too old snapshot") + testMock.passAuthCheck() + testMock.forceNewSnapshot() + testChangeMan.pollChangeWithBackoff() + <-dummySnapMan.downloadCalledChan + log.Debug("closing") + }) + + It("change agent should retry with authorization failure", func() { + log.Debug("change agent should retry with authorization failure") + testMock.forceAuthFailOnce() + testMock.forceNoSnapshot() + called := false + eventService.ListenFunc(ApigeeSyncEventSelector, func(event apid.Event) { + if _, ok := event.(*common.ChangeList); ok { + called = true + } + }) + testChangeMan.pollChangeWithBackoff() + <-dummyTokenMan.invalidateChan + Expect(<-dummyDbMan.lastSeqUpdated).Should(Equal(testMock.lastSequenceID())) + Expect(called).Should(BeTrue()) + }, 3) + + }) + + Context("offline change manager", func() { + It("offline change manager should have no effect", func() { + o := &offlineChangeManager{} + o.pollChangeWithBackoff() + <-o.close() + }) + }) }) })
diff --git a/changes.go b/changes.go index 9b32658..6afa827 100644 --- a/changes.go +++ b/changes.go
@@ -16,6 +16,7 @@ import ( "encoding/json" + "fmt" "github.com/apigee-labs/transicator/common" "io/ioutil" "net/http" @@ -36,12 +37,12 @@ block int lastSequence string dbMan DbManager - snapMan snapShotManager + snapMan snapshotManager tokenMan tokenManager client *http.Client } -func createChangeManager(dbMan DbManager, snapMan snapShotManager, tokenMan tokenManager, client *http.Client) *pollChangeManager { +func createChangeManager(dbMan DbManager, snapMan snapshotManager, tokenMan tokenManager, client *http.Client) *pollChangeManager { isClosedInt := int32(0) isLaunchedInt := int32(0) return &pollChangeManager{ @@ -65,7 +66,7 @@ finishChan := make(chan bool, 1) //has been closed if atomic.SwapInt32(c.isClosed, 1) == int32(1) { - log.Error("pollChangeManager: close() called on a closed pollChangeManager!") + log.Warn("pollChangeManager: close() called on a closed pollChangeManager!") go func() { log.Debug("change manager closed") finishChan <- false @@ -74,9 +75,8 @@ } // not launched if atomic.LoadInt32(c.isLaunched) == int32(0) { - log.Warn("pollChangeManager: close() called when pollChangeWithBackoff unlaunched! Will wait until pollChangeWithBackoff is launched and then kill it and tokenManager!") + log.Warn("pollChangeManager: close() called when pollChangeWithBackoff unlaunched!") go func() { - c.quitChan <- true c.tokenMan.close() <-c.snapMan.close() log.Debug("change manager closed") @@ -134,7 +134,7 @@ select { case <-c.quitChan: log.Info("pollChangeAgent; Recevied quit signal to stop polling change server, close token manager") - return quitSignalError{} + return quitSignalError default: scopes, err := c.dbMan.findScopesForId(apidInfo.ClusterID) if err != nil { @@ -163,11 +163,8 @@ log.Errorf("Get changes request failed with status code: %d", r.StatusCode) switch r.StatusCode { case http.StatusUnauthorized: - err = c.tokenMan.invalidateToken() - if err != nil { - return nil, err - } - return nil, authFailError{} + c.tokenMan.invalidateToken() + return nil, authFailError case http.StatusNotModified: return nil, nil @@ -191,7 +188,7 @@ return nil, err default: log.Errorf("Unknown response code from change server: %v", r.Status) - return nil, nil + return nil, fmt.Errorf("unknown response code from change server: %v", r.Status) } } @@ -227,12 +224,24 @@ log.Errorf("Error in processChangeList: %v", err) return err } + /* + * Check to see if there was any change in scope. If found, handle it + * by getting a new snapshot + */ + newScopes, err := c.dbMan.findScopesForId(apidInfo.ClusterID) + if err != nil { + return err + } + cs := scopeChanged(newScopes, scopes) + if cs != nil { + return cs + } select { case <-time.After(httpTimeout): log.Panic("Timeout. Plugins failed to respond to changes.") case <-eventService.Emit(ApigeeSyncEventSelector, cl): } - } else if c.lastSequence == "" { + } else if c.lastSequence == "" { // emit the first changelist anyway select { case <-time.After(httpTimeout): log.Panic("Timeout. Plugins failed to respond to changes.") @@ -247,20 +256,6 @@ log.Panicf("Unable to update Sequence in DB. Err {%v}", err) } c.lastSequence = cl.LastSequence - - /* - * Check to see if there was any change in scope. If found, handle it - * by getting a new snapshot - */ - newScopes, err := c.dbMan.findScopesForId(apidInfo.ClusterID) - if err != nil { - return err - } - cs := scopeChanged(newScopes, scopes) - if cs != nil { - return cs - } - return nil }
diff --git a/data.go b/data.go index 0205bb7..be8dd34 100644 --- a/data.go +++ b/data.go
@@ -112,7 +112,7 @@ prep, err := txn.Prepare(sql) if err != nil { - log.Errorf("INSERT Fail to prepare statement [%s] error=[%v]", sql, err) + log.Errorf("INSERT Fail to prepare statement %s error=%v", sql, err) return err } defer prep.Close() @@ -130,10 +130,10 @@ _, err = prep.Exec(values...) if err != nil { - log.Errorf("INSERT Fail [%s] values=%v error=[%v]", sql, values, err) + log.Errorf("INSERT Fail %s values=%v error=%v", sql, values, err) return err } - log.Debugf("INSERT Success [%s] values=%v", sql, values) + log.Debugf("INSERT Success %s values=%v", sql, values) return nil } @@ -158,13 +158,13 @@ } if len(rows) == 0 { - return fmt.Errorf("No rows found for table.", tableName) + return fmt.Errorf("no rows found for table %s", tableName) } sql := dbMan.buildDeleteSql(tableName, rows[0], pkeys) prep, err := txn.Prepare(sql) if err != nil { - return fmt.Errorf("DELETE Fail to prep statement [%s] error=[%v]", sql, err) + return fmt.Errorf("DELETE Fail to prep statement %s error=%v", sql, err) } defer prep.Close() for _, row := range rows { @@ -172,15 +172,15 @@ // delete prepared statement from existing template statement res, err := txn.Stmt(prep).Exec(values...) if err != nil { - return fmt.Errorf("DELETE Fail [%s] values=%v error=[%v]", sql, values, err) + return fmt.Errorf("DELETE Fail %s values=%v error=%v", sql, values, err) } affected, err := res.RowsAffected() if err == nil && affected != 0 { - log.Debugf("DELETE Success [%s] values=%v", sql, values) + log.Debugf("DELETE Success %s values=%v", sql, values) } else if err == nil && affected == 0 { - return fmt.Errorf("Entry not found [%s] values=%v. Nothing to delete.", sql, values) + return fmt.Errorf("entry not found %s values=%v, nothing to delete", sql, values) } else { - return fmt.Errorf("DELETE Failed [%s] values=%v error=[%v]", sql, values, err) + return fmt.Errorf("DELETE Failed %s values=%v error=%v", sql, values, err) } } @@ -232,7 +232,7 @@ sql := dbMan.buildUpdateSql(tableName, orderedColumns, newRows[0], pkeys) prep, err := txn.Prepare(sql) if err != nil { - return fmt.Errorf("UPDATE Fail to prep statement [%s] error=[%v]", sql, err) + return fmt.Errorf("UPDATE Fail to prep statement %s error=%v", sql, err) } defer prep.Close() @@ -264,15 +264,15 @@ res, err := txn.Stmt(prep).Exec(values...) if err != nil { - return fmt.Errorf("UPDATE Fail [%s] values=%v error=[%v]", sql, values, err) + return fmt.Errorf("UPDATE Fail %s values=%v error=%v", sql, values, err) } numRowsAffected, err := res.RowsAffected() if err != nil { - return fmt.Errorf("UPDATE Fail [%s] values=%v error=[%v]", sql, values, err) + return fmt.Errorf("UPDATE Fail %s values=%v error=%v", sql, values, err) } //delete this once we figure out why tests are failing/not updating log.Debugf("NUM ROWS AFFECTED BY UPDATE: %d", numRowsAffected) - log.Debugf("UPDATE Success [%s] values=%v", sql, values) + log.Debugf("UPDATE Success %s values=%v", sql, values) } @@ -346,7 +346,7 @@ sql := "SELECT columnName FROM _transicator_tables WHERE tableName=$1 AND primaryKey ORDER BY columnName;" rows, err := db.Query(sql, normalizedTableName) if err != nil { - log.Errorf("Failed [%s] values=[s%] Error: %v", sql, normalizedTableName, err) + log.Errorf("Failed %s values=%s Error: %v", sql, normalizedTableName, err) return nil, err } var columnNames []string @@ -484,7 +484,8 @@ info.InstanceID, info.ClusterID, "") } } else if savedClusterId != info.ClusterID { - log.Warn("Detected apid cluster id change in config. Apid will start clean") + log.Warnf("Detected apid cluster id change in config. %v v.s. %v Apid will start clean.", + savedClusterId, info.ClusterID) err = nil info.IsNewInstance = true info.InstanceID = util.GenerateUUID() @@ -500,7 +501,7 @@ } func (dbMan *dbManager) updateApidInstanceInfo(instanceId, clusterId, lastSnap string) error { - + log.Debugf("updateApidInstanceInfo: %v, %v, %v", instanceId, clusterId, lastSnap) // always use default database for this db, err := dataService.DB() if err != nil { @@ -609,22 +610,22 @@ } db, err := dataService.DBVersion(snapshot.SnapshotInfo) if err != nil { - return fmt.Errorf("Unable to access database: %v", err) + return fmt.Errorf("unable to access database: %v", err) } var numApidClusters int tx, err := db.Begin() if err != nil { - return fmt.Errorf("Unable to open DB txn: {%v}", err.Error()) + return fmt.Errorf("unable to open DB txn: {%v}", err.Error()) } defer tx.Rollback() err = tx.QueryRow("SELECT COUNT(*) FROM edgex_apid_cluster").Scan(&numApidClusters) if err != nil { - return fmt.Errorf("Unable to read database: {%s}", err.Error()) + return fmt.Errorf("unable to read database: {%s}", err.Error()) } if numApidClusters != 1 { - return fmt.Errorf("Illegal state for apid_cluster. Must be a single row.") + return fmt.Errorf("illegal state for apid_cluster, must be a single row") } _, err = tx.Exec("ALTER TABLE edgex_apid_cluster ADD COLUMN last_sequence text DEFAULT ''") @@ -633,21 +634,22 @@ } if err = tx.Commit(); err != nil { - return fmt.Errorf("Error when commit in processSqliteSnapshot: %v", err) + return fmt.Errorf("error when commit in processSqliteSnapshot: %v", err) } //update apid instance info apidInfo.LastSnapshot = snapshot.SnapshotInfo err = dbMan.updateApidInstanceInfo(apidInfo.InstanceID, apidInfo.ClusterID, apidInfo.LastSnapshot) if err != nil { - return fmt.Errorf("Unable to update instance info: %v", err) + log.Errorf("Unable to update instance info: %v", err) + return fmt.Errorf("unable to update instance info: %v", err) } dbMan.setDB(db) if isDataSnapshot { dbMan.knownTables, err = dbMan.extractTables() if err != nil { - return fmt.Errorf("Unable to extract tables: %v", err) + return fmt.Errorf("unable to extract tables: %v", err) } } log.Debugf("Snapshot processed: %s", snapshot.SnapshotInfo)
diff --git a/data_test.go b/data_test.go index e952168..abad465 100644 --- a/data_test.go +++ b/data_test.go
@@ -30,6 +30,10 @@ var testDbMan *dbManager var dbVersion string BeforeEach(func() { + var testDir string + testDir, err := ioutil.TempDir(tmpDir, "data_test") + config.Set(configLocalStoragePath, testDir) + Expect(err).NotTo(HaveOccurred()) testDbMan = creatDbManager() testCount++ dbVersion = "data_test_" + strconv.Itoa(testCount) @@ -39,7 +43,7 @@ }) AfterEach(func() { - dataService.ReleaseDB(dbVersion) + config.Set(configLocalStoragePath, tmpDir) }) It("check scope changes", func() { @@ -1370,7 +1374,7 @@ } AfterEach(func() { - dataService.ReleaseCommonDB() + }) It("should fail if more than one apid_cluster rows", func() {
diff --git a/init.go b/init.go index dd310c0..ac0eeb7 100644 --- a/init.go +++ b/init.go
@@ -38,7 +38,8 @@ // special value - set by ApigeeSync, not taken from configuration configApidInstanceID = "apigeesync_apid_instance_id" // This will not be needed once we have plugin handling tokens. - configBearerToken = "apigeesync_bearer_token" + configBearerToken = "apigeesync_bearer_token" + configLocalStoragePath = "local_storage_path" ) const ( @@ -47,13 +48,12 @@ var ( /* All set during plugin initialization */ - log apid.LogService - config apid.ConfigService - dataService apid.DataService - eventService apid.EventsService - apiService apid.APIService - apidInfo apidInstanceInfo - isOfflineMode bool + log apid.LogService + config apid.ConfigService + dataService apid.DataService + eventService apid.EventsService + apiService apid.APIService + apidInfo apidInstanceInfo /* Set during post plugin initialization * set this as a default, so that it's guaranteed to be valid even if postInitPlugins isn't called @@ -89,7 +89,7 @@ log.Debugf("Using %s as display name", config.GetString(configName)) } -func checkForRequiredValues() error { +func checkForRequiredValues(isOfflineMode bool) error { required := []string{configProxyServerBaseURI, configConsumerKey, configConsumerSecret} if !isOfflineMode { required = append(required, configSnapServerBaseURI, configChangeServerBaseURI) @@ -97,12 +97,12 @@ // check for required values for _, key := range required { if !config.IsSet(key) { - return fmt.Errorf("Missing required config value: %s", key) + return fmt.Errorf("missing required config value: %s", key) } } proto := config.GetString(configSnapshotProtocol) if proto != "sqlite" { - return fmt.Errorf("Illegal value for %s. Only currently supported snashot protocol is sqlite", configSnapshotProtocol) + return fmt.Errorf("illegal value for %s. Only currently supported snashot protocol is sqlite", configSnapshotProtocol) } return nil @@ -114,25 +114,11 @@ /* initialization */ func initConfigs(services apid.Services) error { - log.Debug("start init") - - config = services.Config() - initConfigDefaults() - - if config.GetBool(configDiagnosticMode) { - log.Warn("Diagnostic mode: will not download changelist and snapshots!") - isOfflineMode = true - } - - err := checkForRequiredValues() - if err != nil { - return err - } return nil } -func initManagers() error { +func initManagers(isOfflineMode bool) (*listenerManager, *ApiManager, error) { // check for forward proxy var tr *http.Transport tr = util.Transport(config.GetString(util.ConfigfwdProxyPortURL)) @@ -141,17 +127,17 @@ apidDbManager := creatDbManager() db, err := dataService.DB() if err != nil { - return fmt.Errorf("Unable to access DB: %v", err) + return nil, nil, fmt.Errorf("unable to access DB: %v", err) } apidDbManager.setDB(db) err = apidDbManager.initDB() if err != nil { - return fmt.Errorf("Unable to access DB: %v", err) + return nil, nil, fmt.Errorf("unable to access DB: %v", err) } apidInfo, err = apidDbManager.getApidInstanceInfo() if err != nil { - return fmt.Errorf("Unable to get apid instance info: %v", err) + return nil, nil, fmt.Errorf("unable to get apid instance info: %v", err) } if config.IsSet(configApidInstanceID) { @@ -159,12 +145,12 @@ } config.Set(configApidInstanceID, apidInfo.InstanceID) - apidTokenManager := createSimpleTokenManager(apidInfo.IsNewInstance) - var apidSnapshotManager snapShotManager + apidTokenManager := createApidTokenManager(apidInfo.IsNewInstance) + var snapMan snapshotManager var apidChangeManager changeManager if isOfflineMode { - apidSnapshotManager = &offlineSnapshotManager{ + snapMan = &offlineSnapshotManager{ dbMan: apidDbManager, } apidChangeManager = &offlineChangeManager{} @@ -177,23 +163,22 @@ return nil }, } - apidSnapshotManager = createSnapShotManager(apidDbManager, apidTokenManager, httpClient) - apidChangeManager = createChangeManager(apidDbManager, apidSnapshotManager, apidTokenManager, httpClient) + snapMan = createSnapShotManager(apidDbManager, apidTokenManager, httpClient) + apidChangeManager = createChangeManager(apidDbManager, snapMan, apidTokenManager, httpClient) } listenerMan := &listenerManager{ - changeMan: apidChangeManager, - snapMan: apidSnapshotManager, - tokenMan: apidTokenManager, + changeMan: apidChangeManager, + snapMan: snapMan, + tokenMan: apidTokenManager, + isOfflineMode: isOfflineMode, } apiMan := &ApiManager{ + endpoint: tokenEndpoint, tokenMan: apidTokenManager, } - - listenerMan.init() - apiMan.InitAPI(apiService) - return nil + return listenerMan, apiMan, nil } func initPlugin(services apid.Services) (apid.PluginData, error) { @@ -201,14 +186,29 @@ dataService = services.Data() eventService = services.Events() apiService = services.API() - err := initConfigs(services) + log.Debug("start init") + config = services.Config() + initConfigDefaults() + + isOfflineMode := false + if config.GetBool(configDiagnosticMode) { + log.Warn("Diagnostic mode: will not download changelist and snapshots!") + isOfflineMode = true + } + + err := checkForRequiredValues(isOfflineMode) if err != nil { return pluginData, err } - - if err = initManagers(); err != nil { + if err != nil { return pluginData, err } + listenerMan, apiMan, err := initManagers(isOfflineMode) + if err != nil { + return pluginData, err + } + listenerMan.init() + apiMan.InitAPI(apiService) log.Debug("end init") return pluginData, nil
diff --git a/init_test.go b/init_test.go index 7727ee7..f2063e4 100644 --- a/init_test.go +++ b/init_test.go
@@ -30,11 +30,10 @@ Context("Apid Instance display name", func() { AfterEach(func() { - eventService = apid.Events() apiService = apid.API() }) - It("should be hostname by default", func() { + It("init should register listener", func() { me := &mockEvent{ listenerMap: make(map[apid.EventSelector]apid.EventHandlerFunc), } @@ -49,7 +48,7 @@ events: me, } testname := "test_" + strconv.Itoa(testCount) - config.Set(configName, testname) + ms.config.Set(configName, testname) pd, err := initPlugin(ms) Expect(err).Should(Succeed()) Expect(apidInfo.InstanceName).To(Equal(testname)) @@ -57,8 +56,40 @@ Expect(ma.handleMap[tokenEndpoint]).ToNot(BeNil()) Expect(pd).Should(Equal(pluginData)) Expect(apidInfo.IsNewInstance).Should(BeTrue()) - dataService.ReleaseCommonDB() - }, 3) + }) + + It("create managers for normal mode", func() { + listenerMan, apiMan, err := initManagers(false) + Expect(err).Should(Succeed()) + Expect(listenerMan).ToNot(BeNil()) + Expect(listenerMan.tokenMan).ToNot(BeNil()) + snapMan, ok := listenerMan.snapMan.(*apidSnapshotManager) + Expect(ok).Should(BeTrue()) + Expect(snapMan.tokenMan).ToNot(BeNil()) + Expect(snapMan.dbMan).ToNot(BeNil()) + changeMan, ok := listenerMan.changeMan.(*pollChangeManager) + Expect(ok).Should(BeTrue()) + Expect(changeMan.tokenMan).ToNot(BeNil()) + Expect(changeMan.dbMan).ToNot(BeNil()) + Expect(changeMan.snapMan).ToNot(BeNil()) + Expect(apiMan).ToNot(BeNil()) + Expect(apiMan.tokenMan).ToNot(BeNil()) + }) + + It("create managers for diagnostic mode", func() { + config.Set(configDiagnosticMode, true) + listenerMan, apiMan, err := initManagers(true) + Expect(err).Should(Succeed()) + Expect(listenerMan).ToNot(BeNil()) + Expect(listenerMan.tokenMan).ToNot(BeNil()) + snapMan, ok := listenerMan.snapMan.(*offlineSnapshotManager) + Expect(ok).Should(BeTrue()) + Expect(snapMan.dbMan).ToNot(BeNil()) + _, ok = listenerMan.changeMan.(*offlineChangeManager) + Expect(ok).Should(BeTrue()) + Expect(apiMan).ToNot(BeNil()) + Expect(apiMan.tokenMan).ToNot(BeNil()) + }) }) })
diff --git a/listener.go b/listener.go index caafe8f..7f46b34 100644 --- a/listener.go +++ b/listener.go
@@ -25,9 +25,10 @@ ) type listenerManager struct { - changeMan changeManager - snapMan snapShotManager - tokenMan tokenManager + changeMan changeManager + snapMan snapshotManager + tokenMan tokenManager + isOfflineMode bool } func (l *listenerManager) init() { @@ -85,7 +86,7 @@ * Then, poll for changes */ func (l *listenerManager) bootstrap(lastSnap string) { - if isOfflineMode && lastSnap == "" { + if l.isOfflineMode && lastSnap == "" { log.Panic("Diagnostic mode requires existent snapshot info in default DB.") }
diff --git a/managerInterfaces.go b/managerInterfaces.go index facc616..978aede 100644 --- a/managerInterfaces.go +++ b/managerInterfaces.go
@@ -21,13 +21,13 @@ type tokenManager interface { getBearerToken() string - invalidateToken() error + invalidateToken() close() start() getTokenReadyChannel() <-chan bool } -type snapShotManager interface { +type snapshotManager interface { close() <-chan bool downloadBootSnapshot() downloadDataSnapshot() error
diff --git a/mock_server.go b/mock_server_test.go similarity index 96% rename from mock_server.go rename to mock_server_test.go index e201097..e61b83e 100644 --- a/mock_server.go +++ b/mock_server_test.go
@@ -97,7 +97,7 @@ authFail *int32 } -func (m *MockServer) forceAuthFail() { +func (m *MockServer) forceAuthFailOnce() { atomic.StoreInt32(m.authFail, 1) } @@ -268,8 +268,11 @@ scopes := q["scope"] Expect(scopes).To(ContainElement(m.params.ClusterID)) - - w.Header().Set("Transicator-Snapshot-TXID", util.GenerateUUID()) + if m.params.Scope != "" { + Expect(scopes).To(ContainElement(m.params.Scope)) + } + m.snapshotID = util.GenerateUUID() + w.Header().Set(headerSnapshotNumber, m.snapshotID) if len(scopes) == 1 { //send bootstrap db @@ -313,7 +316,9 @@ //Expect(q.Get("snapshot")).To(Equal(m.snapshotID)) Expect(scopes).To(ContainElement(m.params.ClusterID)) - //Expect(scopes).To(ContainElement(m.params.Scope)) + if m.params.Scope != "" { + Expect(scopes).To(ContainElement(m.params.Scope)) + } // todo: the following is just legacy for the existing test in apigeeSync_suite_test developer := m.createDeveloperWithProductAndApp() @@ -345,6 +350,7 @@ // force failing auth check if atomic.LoadInt32(m.authFail) == 1 { + atomic.StoreInt32(m.authFail, 0) w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(fmt.Sprintf("Force fail: bad auth token. "))) return @@ -358,7 +364,7 @@ // check auth header auth := req.Header.Get("Authorization") - expectedAuth := fmt.Sprintf("Bearer %s", m.oauthToken) + expectedAuth := m.getBearerToken() if auth != expectedAuth { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(fmt.Sprintf("Bad auth token. Is: %s, should be: %s", auth, expectedAuth))) @@ -368,6 +374,10 @@ } } +func (m *MockServer) getBearerToken() string { + return fmt.Sprintf("Bearer %s", m.oauthToken) +} + // make a handler unreliable func (m *MockServer) unreliable(target http.HandlerFunc) http.HandlerFunc { if m.params.ReliableAPI {
diff --git a/snapshot.go b/snapshot.go index e3787bf..c18b1c8 100644 --- a/snapshot.go +++ b/snapshot.go
@@ -29,7 +29,13 @@ "time" ) -type simpleSnapShotManager struct { +const bootstrapSnapshotName = "bootstrap" +const ( + headerSnapshotNumber = "Transicator-Snapshot-TXID" +) + +type apidSnapshotManager struct { + *offlineSnapshotManager // to send quit signal to the downloading thread quitChan chan bool // to mark the graceful close of snapshotManager @@ -43,10 +49,13 @@ client *http.Client } -func createSnapShotManager(dbMan DbManager, tokenMan tokenManager, client *http.Client) *simpleSnapShotManager { +func createSnapShotManager(dbMan DbManager, tokenMan tokenManager, client *http.Client) *apidSnapshotManager { isClosedInt := int32(0) isDownloadingInt := int32(0) - return &simpleSnapShotManager{ + return &apidSnapshotManager{ + offlineSnapshotManager: &offlineSnapshotManager{ + dbMan: dbMan, + }, quitChan: make(chan bool, 1), finishChan: make(chan bool, 1), isClosed: &isClosedInt, @@ -58,18 +67,17 @@ } /* - * thread-safe close of snapShotManager + * thread-safe close of snapshotManager * It marks status as closed immediately, and quits backoff downloading * use <- close() for blocking close * should only be called by pollChangeManager, because pollChangeManager is dependent on it */ -func (s *simpleSnapShotManager) close() <-chan bool { +func (s *apidSnapshotManager) close() <-chan bool { //has been closed before if atomic.SwapInt32(s.isClosed, 1) == int32(1) { - log.Error("snapShotManager: close() called on a closed snapShotManager!") + log.Warn("snapshotManager: close() called on a closed snapshotManager!") go func() { s.finishChan <- false - log.Debug("change manager closed") }() return s.finishChan } @@ -83,61 +91,36 @@ } // retrieve boot information: apid_config and apid_config_scope -func (s *simpleSnapShotManager) downloadBootSnapshot() { +func (s *apidSnapshotManager) downloadBootSnapshot() { if atomic.SwapInt32(s.isDownloading, 1) == int32(1) { log.Panic("downloadBootSnapshot: only 1 thread can download snapshot at the same time!") } defer atomic.StoreInt32(s.isDownloading, int32(0)) - // has been closed - if atomic.LoadInt32(s.isClosed) == int32(1) { - log.Warn("snapShotManager: downloadBootSnapshot called on closed snapShotManager") - return - } - log.Debug("download Snapshot for boot data") scopes := []string{apidInfo.ClusterID} snapshot := &common.Snapshot{} - err := s.downloadSnapshot(true, scopes, snapshot) - if err != nil { - // this may happen during shutdown - if _, ok := err.(quitSignalError); ok { - log.Warn("downloadBootSnapshot failed due to shutdown: " + err.Error()) - } - return - } - - // has been closed - if atomic.LoadInt32(s.isClosed) == int32(1) { - log.Error("snapShotManager: processSnapshot called on closed snapShotManager") - return - } + s.downloadSnapshot(true, scopes, snapshot) // note that for boot snapshot case, we don't need to inform plugins as they'll get the data snapshot s.storeBootSnapshot(snapshot) } -func (s *simpleSnapShotManager) storeBootSnapshot(snapshot *common.Snapshot) { +func (s *apidSnapshotManager) storeBootSnapshot(snapshot *common.Snapshot) { if err := s.dbMan.processSnapshot(snapshot, false); err != nil { log.Panic(err) } } // use the scope IDs from the boot snapshot to get all the data associated with the scopes -func (s *simpleSnapShotManager) downloadDataSnapshot() error { +func (s *apidSnapshotManager) downloadDataSnapshot() error { if atomic.SwapInt32(s.isDownloading, 1) == int32(1) { log.Panic("downloadDataSnapshot: only 1 thread can download snapshot at the same time!") } defer atomic.StoreInt32(s.isDownloading, int32(0)) - // has been closed - if atomic.LoadInt32(s.isClosed) == int32(1) { - log.Warn("snapShotManager: downloadDataSnapshot called on closed snapShotManager") - return nil - } - log.Debug("download Snapshot for data scopes") scopes, err := s.dbMan.findScopesForId(apidInfo.ClusterID) @@ -146,45 +129,14 @@ } scopes = append(scopes, apidInfo.ClusterID) snapshot := &common.Snapshot{} - err = s.downloadSnapshot(false, scopes, snapshot) - if err != nil { - // this may happen during shutdown - if _, ok := err.(quitSignalError); ok { - log.Warn("downloadDataSnapshot failed due to shutdown: " + err.Error()) - } - return err - } + s.downloadSnapshot(false, scopes, snapshot) return s.startOnDataSnapshot(snapshot.SnapshotInfo) } -// Skip Downloading snapshot if there is already a snapshot available from previous run -func (s *simpleSnapShotManager) startOnDataSnapshot(snapshotName string) error { - log.Infof("Processing snapshot: %s", snapshotName) - snapshot := &common.Snapshot{ - SnapshotInfo: snapshotName, - } - if err := s.dbMan.processSnapshot(snapshot, true); err != nil { - return err - } - log.Info("Emitting Snapshot to plugins") - select { - case <-time.After(pluginTimeout): - return fmt.Errorf("timeout, plugins failed to respond to snapshot") - case <-eventService.Emit(ApigeeSyncEventSelector, snapshot): - // the new snapshot has been processed - } - return nil -} - // a blocking method // will keep retrying with backoff until success -func (s *simpleSnapShotManager) downloadSnapshot(isBoot bool, scopes []string, snapshot *common.Snapshot) error { - // if closed - if atomic.LoadInt32(s.isClosed) == int32(1) { - log.Warn("Trying to download snapshot with a closed snapShotManager") - return quitSignalError{} - } +func (s *apidSnapshotManager) downloadSnapshot(isBoot bool, scopes []string, snapshot *common.Snapshot) { log.Debug("downloadSnapshot") @@ -207,10 +159,9 @@ //to accommodate functions which need more parameters, wrap them in closures attemptDownload := s.getAttemptDownloadClosure(isBoot, snapshot, uri) pollWithBackoff(s.quitChan, attemptDownload, handleSnapshotServerError) - return nil } -func (s *simpleSnapShotManager) getAttemptDownloadClosure(isBoot bool, snapshot *common.Snapshot, uri string) func(chan bool) error { +func (s *apidSnapshotManager) getAttemptDownloadClosure(isBoot bool, snapshot *common.Snapshot, uri string) func(chan bool) error { return func(_ chan bool) error { var tid string @@ -221,14 +172,11 @@ } addHeaders(req, s.tokenMan.getBearerToken()) - var processSnapshotResponse func(string, io.Reader, *common.Snapshot) error - if config.GetString(configSnapshotProtocol) != "sqlite" { log.Panic("Only currently supported snashot protocol is sqlite") } req.Header.Set("Accept", "application/transicator+sqlite") - processSnapshotResponse = processSnapshotServerFileResponse // Issue the request to the snapshot server r, err := s.client.Do(req) @@ -239,22 +187,28 @@ defer r.Body.Close() - if r.StatusCode != 200 { + switch r.StatusCode { + case http.StatusOK: + break + case http.StatusUnauthorized: + s.tokenMan.invalidateToken() + fallthrough + default: body, _ := ioutil.ReadAll(r.Body) log.Errorf("Snapshot server conn failed with resp code %d, body: %s", r.StatusCode, string(body)) - return expected200Error{} + return expected200Error } // Bootstrap scope is a special case, that can occur only once. The tid is // hardcoded to "bootstrap" to ensure there can be no clash of tid between // bootstrap and subsequent data scopes. if isBoot { - tid = "bootstrap" + tid = bootstrapSnapshotName } else { - tid = r.Header.Get("Transicator-Snapshot-TXID") + tid = r.Header.Get(headerSnapshotNumber) } // Decode the Snapshot server response - err = processSnapshotResponse(tid, r.Body, snapshot) + err = processSnapshotServerFileResponse(tid, r.Body, snapshot) if err != nil { log.Errorf("Snapshot server response Data not parsable: %v", err) return err @@ -303,7 +257,7 @@ } func handleSnapshotServerError(err error) { - log.Debugf("Error connecting to snapshot server: %v", err) + log.Errorf("Error connecting to snapshot server: %v", err) } type offlineSnapshotManager struct {
diff --git a/snapshot_test.go b/snapshot_test.go new file mode 100644 index 0000000..8d99563 --- /dev/null +++ b/snapshot_test.go
@@ -0,0 +1,146 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package apidApigeeSync + +import ( + "github.com/apid/apid-core" + "github.com/apid/apid-core/api" + "github.com/apigee-labs/transicator/common" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "net/http" + "net/http/httptest" + "strconv" + "time" +) + +var _ = Describe("Snapshot Manager", func() { + testCount := 0 + var dummyDbMan *dummyDbManager + BeforeEach(func() { + testCount++ + dummyDbMan = &dummyDbManager{} + }) + + Context("offlineSnapshotManager", func() { + var testSnapMan *offlineSnapshotManager + BeforeEach(func() { + testSnapMan = &offlineSnapshotManager{ + dbMan: dummyDbMan, + } + }) + AfterEach(func() { + <-testSnapMan.close() + }) + + It("should have error if download called", func() { + Expect(testSnapMan.downloadDataSnapshot()).ToNot(Succeed()) + Expect(func() { testSnapMan.downloadBootSnapshot() }).To(Panic()) + }) + + It("startOnDataSnapshot should emit events", func() { + called := false + eventService.ListenFunc(ApigeeSyncEventSelector, func(event apid.Event) { + if _, ok := event.(*common.Snapshot); ok { + called = true + } + }) + snapshotId := "test_snapshot_" + strconv.Itoa(testCount) + Expect(testSnapMan.startOnDataSnapshot(snapshotId)).Should(Succeed()) + Expect(dummyDbMan.snapshot.SnapshotInfo).Should(Equal(snapshotId)) + Expect(called).Should(BeTrue()) + }) + }) + + Context("apidSnapshotManager", func() { + var testSnapMan *apidSnapshotManager + var dummyTokenMan *dummyTokenManager + var testServer *httptest.Server + var testRouter apid.Router + var testMock *MockServer + BeforeEach(func() { + dummyTokenMan = &dummyTokenManager{ + invalidateChan: make(chan bool, 1), + } + client := &http.Client{} + testSnapMan = createSnapShotManager(dummyDbMan, dummyTokenMan, client) + + // create a new API service to have a new router for testing + testRouter = api.CreateService().Router() + testServer = httptest.NewServer(testRouter) + // set up mock server + mockParms := MockParms{ + ReliableAPI: true, + ClusterID: config.GetString(configApidClusterId), + TokenKey: config.GetString(configConsumerKey), + TokenSecret: config.GetString(configConsumerSecret), + Scope: "", + Organization: "att", + Environment: "prod", + } + apidInfo.ClusterID = expectedClusterId + apidInfo.InstanceID = expectedInstanceId + testMock = Mock(mockParms, testRouter) + config.Set(configProxyServerBaseURI, testServer.URL) + config.Set(configSnapServerBaseURI, testServer.URL) + config.Set(configChangeServerBaseURI, testServer.URL) + config.Set(configPollInterval, 1*time.Millisecond) + + initialBackoffInterval = time.Millisecond + testMock.oauthToken = "test_token_" + strconv.Itoa(testCount) + dummyTokenMan.token = testMock.oauthToken + }) + + AfterEach(func() { + <-testSnapMan.close() + }) + + It("downloadBootSnapshot happy path", func() { + testMock.normalAuthCheck() + testSnapMan.downloadBootSnapshot() + Expect(dummyDbMan.isDataSnapshot).Should(BeFalse()) + Expect(dummyDbMan.snapshot.SnapshotInfo).Should(Equal(bootstrapSnapshotName)) + }) + + It("downloadBootSnapshot should retry for auth failure", func() { + testMock.forceAuthFailOnce() + testSnapMan.downloadBootSnapshot() + Expect(dummyDbMan.isDataSnapshot).Should(BeFalse()) + Expect(dummyDbMan.snapshot.SnapshotInfo).Should(Equal(bootstrapSnapshotName)) + Expect(<-dummyTokenMan.invalidateChan).Should(BeTrue()) + }) + + It("downloadDataSnapshot happy path", func() { + testMock.params.Scope = "test_scope_" + strconv.Itoa(testCount) + dummyDbMan.scopes = []string{testMock.params.Scope} + testMock.normalAuthCheck() + testSnapMan.downloadDataSnapshot() + Expect(dummyDbMan.isDataSnapshot).Should(BeTrue()) + Expect(dummyDbMan.snapshot.SnapshotInfo).Should(Equal(testMock.snapshotID)) + }) + + It("downloadDataSnapshot should retry for auth failure", func() { + testMock.params.Scope = "test_scope_" + strconv.Itoa(testCount) + dummyDbMan.scopes = []string{testMock.params.Scope} + testMock.forceAuthFailOnce() + testSnapMan.downloadDataSnapshot() + Expect(dummyDbMan.isDataSnapshot).Should(BeTrue()) + Expect(dummyDbMan.snapshot.SnapshotInfo).Should(Equal(testMock.snapshotID)) + Expect(<-dummyTokenMan.invalidateChan).Should(BeTrue()) + }) + + }) + +})
diff --git a/test_mock_test.go b/test_mock_test.go index dbb4a43..9ba4e19 100644 --- a/test_mock_test.go +++ b/test_mock_test.go
@@ -139,20 +139,21 @@ type dummyTokenManager struct { invalidateChan chan bool + token string + tokenReadyChan chan bool } func (t *dummyTokenManager) getTokenReadyChannel() <-chan bool { - return nil + return t.tokenReadyChan } func (t *dummyTokenManager) getBearerToken() string { - return "" + return t.token } -func (t *dummyTokenManager) invalidateToken() error { +func (t *dummyTokenManager) invalidateToken() { log.Debug("invalidateToken called") t.invalidateChan <- true - return nil } func (t *dummyTokenManager) close() { @@ -189,9 +190,12 @@ } type dummyDbManager struct { - lastSequence string - knownTables map[string]bool - scopes []string + lastSequence string + knownTables map[string]bool + scopes []string + snapshot *common.Snapshot + isDataSnapshot bool + lastSeqUpdated chan string } func (d *dummyDbManager) initDB() error { @@ -207,6 +211,7 @@ return d.scopes, nil } func (d *dummyDbManager) updateLastSequence(lastSequence string) error { + d.lastSeqUpdated <- lastSequence return nil } func (d *dummyDbManager) getApidInstanceInfo() (info apidInstanceInfo, err error) { @@ -221,6 +226,8 @@ return nil } func (d *dummyDbManager) processSnapshot(snapshot *common.Snapshot, isDataSnapshot bool) error { + d.snapshot = snapshot + d.isDataSnapshot = isDataSnapshot return nil } func (d *dummyDbManager) getKnowTables() map[string]bool {
diff --git a/token.go b/token.go index 8d49a8e..bfd8900 100644 --- a/token.go +++ b/token.go
@@ -17,7 +17,6 @@ import ( "bytes" "encoding/json" - "errors" "github.com/apid/apid-core/util" "io/ioutil" "net/http" @@ -40,10 +39,10 @@ man.close() */ -func createSimpleTokenManager(isNewInstance bool) *simpleTokenManager { +func createApidTokenManager(isNewInstance bool) *apidTokenManager { isClosedInt := int32(0) - t := &simpleTokenManager{ + t := &apidTokenManager{ quitPollingForToken: make(chan bool, 1), closed: make(chan bool), getTokenChan: make(chan bool), @@ -57,7 +56,7 @@ return t } -type simpleTokenManager struct { +type apidTokenManager struct { token *OauthToken isClosed *int32 quitPollingForToken chan bool @@ -71,17 +70,17 @@ isNewInstance bool } -func (t *simpleTokenManager) start() { +func (t *apidTokenManager) start() { t.retrieveNewToken() t.refreshTimer = time.After(t.token.refreshIn()) go t.maintainToken() } -func (t *simpleTokenManager) getBearerToken() string { +func (t *apidTokenManager) getBearerToken() string { return t.getToken().AccessToken } -func (t *simpleTokenManager) maintainToken() { +func (t *apidTokenManager) maintainToken() { for { select { case <-t.closed: @@ -102,19 +101,13 @@ } // will block until valid -func (t *simpleTokenManager) invalidateToken() error { - //has been closed - if atomic.LoadInt32(t.isClosed) == int32(1) { - log.Debug("TokenManager: invalidateToken() called on closed tokenManager") - return errors.New("invalidateToken() called on closed tokenManager") - } +func (t *apidTokenManager) invalidateToken() { log.Debug("invalidating token") t.invalidateTokenChan <- true <-t.invalidateDone - return nil } -func (t *simpleTokenManager) getToken() *OauthToken { +func (t *apidTokenManager) getToken() *OauthToken { //has been closed if atomic.LoadInt32(t.isClosed) == int32(1) { log.Debug("TokenManager: getToken() called on closed tokenManager") @@ -128,7 +121,7 @@ * blocking close() of tokenMan */ -func (t *simpleTokenManager) close() { +func (t *apidTokenManager) close() { //has been closed if atomic.SwapInt32(t.isClosed, 1) == int32(1) { log.Panic("TokenManager: close() has been called before!") @@ -143,7 +136,7 @@ } // don't call externally. will block until success. -func (t *simpleTokenManager) retrieveNewToken() { +func (t *apidTokenManager) retrieveNewToken() { log.Debug("Getting OAuth token...") uriString := config.GetString(configProxyServerBaseURI) @@ -153,10 +146,10 @@ } uri.Path = path.Join(uri.Path, "/accesstoken") - pollWithBackoff(t.quitPollingForToken, t.getRetrieveNewTokenClosure(uri), func(err error) { log.Errorf("Error getting new token : ", err) }) + pollWithBackoff(t.quitPollingForToken, t.getRetrieveNewTokenClosure(uri), func(err error) { log.Errorf("Error getting new token : %v", err) }) } -func (t *simpleTokenManager) getRetrieveNewTokenClosure(uri *url.URL) func(chan bool) error { +func (t *apidTokenManager) getRetrieveNewTokenClosure(uri *url.URL) func(chan bool) error { return func(_ chan bool) error { form := url.Values{} form.Set("grant_type", "client_credentials") @@ -196,7 +189,7 @@ if resp.StatusCode != 200 { log.Errorf("Oauth Request Failed with Resp Code: %d. Body: %s", resp.StatusCode, string(body)) - return expected200Error{} + return expected200Error } var token OauthToken @@ -214,22 +207,11 @@ } log.Debugf("Got new token: %#v", token) - - /* - if newInstanceID { - newInstanceID = false - err = updateApidInstanceInfo() - if err != nil { - log.Errorf("unable to unmarshal update apid instance info : %v", string(body), err) - return err - - } - } - */ t.token = &token config.Set(configBearerToken, token.AccessToken) //don't block on the buffered channel. that means there is already a signal to serve new token + //TODO: This assumes apid-gateway is 1-1 mapping. Make use of generic long-polling provided by apid-core select { case t.tokenUpdatedChan <- true: default: @@ -240,7 +222,7 @@ } } -func (t *simpleTokenManager) getTokenReadyChannel() <-chan bool { +func (t *apidTokenManager) getTokenReadyChannel() <-chan bool { return t.tokenUpdatedChan }
diff --git a/token_test.go b/token_test.go index 1883a6f..fa9379c 100644 --- a/token_test.go +++ b/token_test.go
@@ -94,7 +94,7 @@ w.Write(body) })) config.Set(configProxyServerBaseURI, ts.URL) - testedTokenManager := createSimpleTokenManager(false) + testedTokenManager := createApidTokenManager(false) testedTokenManager.start() token := testedTokenManager.getToken() @@ -123,7 +123,7 @@ })) config.Set(configProxyServerBaseURI, ts.URL) - testedTokenManager := createSimpleTokenManager(false) + testedTokenManager := createApidTokenManager(false) testedTokenManager.start() token := testedTokenManager.getToken() Expect(token.AccessToken).ToNot(BeEmpty()) @@ -163,7 +163,7 @@ })) config.Set(configProxyServerBaseURI, ts.URL) - testedTokenManager := createSimpleTokenManager(false) + testedTokenManager := createApidTokenManager(false) testedTokenManager.start() testedTokenManager.getToken() @@ -200,7 +200,7 @@ })) config.Set(configProxyServerBaseURI, ts.URL) - testedTokenManager := createSimpleTokenManager(true) + testedTokenManager := createApidTokenManager(true) testedTokenManager.start() testedTokenManager.getToken() testedTokenManager.invalidateToken()
diff --git a/util.go b/util.go new file mode 100644 index 0000000..7d59214 --- /dev/null +++ b/util.go
@@ -0,0 +1,172 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package apidApigeeSync + +import ( + "fmt" + "math" + "math/rand" + "net/http" + "time" +) + +const ( + httpTimeout = time.Minute + pluginTimeout = time.Minute + maxIdleConnsPerHost = 10 + defaultInitial time.Duration = 200 * time.Millisecond + defaultMax time.Duration = 10 * time.Second + defaultFactor float64 = 2 +) + +var ( + initialBackoffInterval = defaultInitial +) + +var ( + expected200Error = fmt.Errorf("did not recieve OK response") + quitSignalError = fmt.Errorf("signal to quit encountered") + authFailError = fmt.Errorf("authorization failed") +) + +type Backoff struct { + attempt int + initial, max time.Duration + jitter bool + backoffStrategy func() time.Duration +} + +type ExponentialBackoff struct { + Backoff + factor float64 +} + +func NewExponentialBackoff(initial, max time.Duration, factor float64, jitter bool) *ExponentialBackoff { + backoff := &ExponentialBackoff{} + + if initial <= 0 { + initial = defaultInitial + } + if max <= 0 { + max = defaultMax + } + + if factor <= 0 { + factor = defaultFactor + } + + backoff.initial = initial + backoff.max = max + backoff.attempt = 0 + backoff.factor = factor + backoff.jitter = jitter + backoff.backoffStrategy = backoff.exponentialBackoffStrategy + + return backoff +} + +func (b *Backoff) Duration() time.Duration { + d := b.backoffStrategy() + b.attempt++ + return d +} + +func (b *ExponentialBackoff) exponentialBackoffStrategy() time.Duration { + + initial := float64(b.Backoff.initial) + attempt := float64(b.Backoff.attempt) + duration := initial * math.Pow(b.factor, attempt) + + if duration > math.MaxInt64 { + return b.max + } + dur := time.Duration(duration) + + if b.jitter { + duration = rand.Float64()*(duration-initial) + initial + } + + if dur > b.max { + return b.max + } + + log.Debugf("Backing off for %d ms", int64(dur/time.Millisecond)) + return dur +} + +func (b *Backoff) Reset() { + b.attempt = 0 +} + +func (b *Backoff) Attempt() int { + return b.attempt +} + +/* + * Call toExecute repeatedly until it does not return an error, with an exponential backoff policy + * for retrying on errors + */ +func pollWithBackoff(quit chan bool, toExecute func(chan bool) error, handleError func(error)) { + + backoff := NewExponentialBackoff(initialBackoffInterval, config.GetDuration(configPollInterval), 2, true) + + //initialize the retry channel to start first attempt immediately + retry := time.After(0 * time.Millisecond) + + for { + select { + case <-quit: + log.Info("Quit signal recieved. Returning") + return + case <-retry: + start := time.Now() + + err := toExecute(quit) + if err == nil || err == quitSignalError { + return + } + + end := time.Now() + //error encountered, since we would have returned above otherwise + handleError(err) + + /* TODO keep this around? Imagine an immediately erroring service, + * causing many sequential requests which could pollute logs + */ + //only backoff if the request took less than one second + if end.After(start.Add(time.Second)) { + backoff.Reset() + retry = time.After(0 * time.Millisecond) + } else { + retry = time.After(backoff.Duration()) + } + } + } +} + +func addHeaders(req *http.Request, token string) { + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("apid_instance_id", apidInfo.InstanceID) + req.Header.Set("apid_cluster_Id", apidInfo.ClusterID) + req.Header.Set("updated_at_apid", time.Now().Format(time.RFC3339)) +} + +type changeServerError struct { + Code string `json:"code"` +} + +func (a changeServerError) Error() string { + return a.Code +}
diff --git a/backoff_test.go b/util_test.go similarity index 100% rename from backoff_test.go rename to util_test.go