rewrite token manager
diff --git a/apigeeSync_suite_test.go b/apigeeSync_suite_test.go index 9703cdf..3ded532 100644 --- a/apigeeSync_suite_test.go +++ b/apigeeSync_suite_test.go
@@ -25,7 +25,7 @@ const dummyConfigValue string = "placeholder" -var _ = BeforeSuite(func(){ +var _ = BeforeSuite(func() { wipeDBAferTest = true }) @@ -62,7 +62,7 @@ lastSequence = "" - if (wipeDBAferTest) { + if wipeDBAferTest { _, err := getDB().Exec("DELETE FROM APID_CLUSTER") Expect(err).NotTo(HaveOccurred()) _, err = getDB().Exec("DELETE FROM DATA_SCOPE")
diff --git a/apigee_sync.go b/apigee_sync.go index cf571d8..99275f3 100644 --- a/apigee_sync.go +++ b/apigee_sync.go
@@ -1,18 +1,17 @@ package apidApigeeSync import ( - "time" - "net/http" "github.com/30x/apid-core" + "net/http" + "time" ) const ( - httpTimeout = time.Minute - pluginTimeout = time.Minute + httpTimeout = time.Minute + pluginTimeout = time.Minute ) -var knownTables = make(map[string]bool) - +var knownTables = make(map[string]bool) /* * Start from existing snapshot if possible @@ -120,4 +119,4 @@ func (a changeServerError) Error() string { return a.Code -} \ No newline at end of file +}
diff --git a/apigee_sync_test.go b/apigee_sync_test.go index 12382e0..7aa1a81 100644 --- a/apigee_sync_test.go +++ b/apigee_sync_test.go
@@ -5,8 +5,8 @@ "github.com/apigee-labs/transicator/common" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - "time" "net/http/httptest" + "time" ) var _ = Describe("Sync", func() { @@ -55,8 +55,8 @@ expectedSnapshotTables := common.ChangeList{ Changes: []common.Change{common.Change{Table: "kms.company"}, - common.Change{Table: "edgex.apid_cluster"}, - common.Change{Table: "edgex.data_scope"}}, + common.Change{Table: "edgex.apid_cluster"}, + common.Change{Table: "edgex.data_scope"}}, } apid.Events().ListenFunc(ApigeeSyncEventSelector, func(event apid.Event) { @@ -105,7 +105,7 @@ } } else if cl, ok := event.(*common.ChangeList); ok { - go func(){quitPollingChangeServer <- true}() + go func() { quitPollingChangeServer <- true }() // ensure that snapshot switched DB versions Expect(apidInfo.LastSnapshot).To(Equal(lastSnapshot.SnapshotInfo)) expectedDB, err := dataService.DBVersion(lastSnapshot.SnapshotInfo) @@ -161,8 +161,8 @@ initializeContext() expectedTables := common.ChangeList{ Changes: []common.Change{common.Change{Table: "kms.company"}, - common.Change{Table: "edgex.apid_cluster"}, - common.Change{Table: "edgex.data_scope"}}, + common.Change{Table: "edgex.apid_cluster"}, + common.Change{Table: "edgex.data_scope"}}, } thisQuitPollingChangeServer := quitPollingChangeServer Expect(apidInfo.LastSnapshot).NotTo(BeEmpty()) @@ -170,7 +170,7 @@ apid.Events().ListenFunc(ApigeeSyncEventSelector, func(event apid.Event) { if s, ok := event.(*common.Snapshot); ok { - go func(){thisQuitPollingChangeServer <- true}() + go func() { thisQuitPollingChangeServer <- true }() //verify that the knownTables array has been properly populated from existing DB Expect(changesRequireDDLSync(expectedTables)).To(BeFalse())
diff --git a/backoff.go b/backoff.go index 57e2176..291a037 100644 --- a/backoff.go +++ b/backoff.go
@@ -6,9 +6,9 @@ "time" ) -const defaultInitial time.Duration = 200 * time.Millisecond -const defaultMax time.Duration = 10 * time.Second -const defaultFactor float64 = 2 +const defaultInitial time.Duration = 200 * time.Millisecond +const defaultMax time.Duration = 10 * time.Second +const defaultFactor float64 = 2 type Backoff struct { attempt int @@ -19,7 +19,6 @@ type ExponentialBackoff struct { Backoff factor float64 - } func NewExponentialBackoff(initial, max time.Duration, factor float64) *Backoff { @@ -45,7 +44,6 @@ return &backoff.Backoff } - func (b *Backoff) Duration() time.Duration { d := b.backoffStrategy() b.attempt++ @@ -80,4 +78,4 @@ func (b *Backoff) Attempt() int { return b.attempt -} \ No newline at end of file +}
diff --git a/changes.go b/changes.go index 6a0afad..59caa34 100644 --- a/changes.go +++ b/changes.go
@@ -1,19 +1,18 @@ package apidApigeeSync import ( - "time" + "encoding/json" + "io/ioutil" + "net/http" "net/url" "path" - "net/http" - "io/ioutil" - "encoding/json" + "time" "github.com/apigee-labs/transicator/common" - ) var lastSequence string -var block string = "45" +var block string = "45" /* * Long polls the change agent with a 45 second block. Parses the response from @@ -64,10 +63,10 @@ v.Add("block", block) /* - * Include all the scopes associated with the config Id - * The Config Id is included as well, as it acts as the - * Bootstrap scope - */ + * Include all the scopes associated with the config Id + * The Config Id is included as well, as it acts as the + * Bootstrap scope + */ for _, scope := range scopes { v.Add("scope", scope) } @@ -154,7 +153,6 @@ return nil } - func changesRequireDDLSync(changes common.ChangeList) bool { return changesHaveNewTables(knownTables, changes.Changes) } @@ -175,7 +173,7 @@ func changesHaveNewTables(a map[string]bool, changes []common.Change) bool { //nil maps should not be passed in. Making the distinction between nil map and empty map - if a == nil || changes == nil{ + if a == nil || changes == nil { return true }
diff --git a/init.go b/init.go index dd21116..c9ad44f 100644 --- a/init.go +++ b/init.go
@@ -30,21 +30,20 @@ var ( /* All set during plugin initialization */ - log apid.LogService - config apid.ConfigService - dataService apid.DataService - events apid.EventsService - apidInfo apidInstanceInfo - newInstanceID bool - tokenManager *tokenMan + log apid.LogService + config apid.ConfigService + dataService apid.DataService + events apid.EventsService + apidInfo apidInstanceInfo + newInstanceID bool + tokenManager *tokenMan quitPollingSnapshotServer chan bool - quitPollingChangeServer chan bool + quitPollingChangeServer chan bool /* Set during post plugin initialization * set this as a default, so that it's guaranteed to be valid even if postInitPlugins isn't called */ apidPluginDetails string = `[{"name":"apidApigeeSync","schemaVer":"1.0"}]` - ) type apidInstanceInfo struct { @@ -72,7 +71,7 @@ log.Debugf("Using %s as display name", config.GetString(configName)) } -func initVariables(services apid.Services) (error) { +func initVariables(services apid.Services) error { dataService = services.Data() events = services.Events() //TODO listen for arbitrary commands, these channels can be used to kill polling goroutines @@ -104,10 +103,10 @@ return nil } -func checkForRequiredValues() (error) { +func checkForRequiredValues() error { // check for required values for _, key := range []string{configProxyServerBaseURI, configConsumerKey, configConsumerSecret, - configSnapServerBaseURI, configChangeServerBaseURI} { + configSnapServerBaseURI, configChangeServerBaseURI} { if !config.IsSet(key) { return fmt.Errorf("Missing required config value: %s", key) } @@ -125,7 +124,7 @@ } /* Idempotent state initialization */ -func _initPlugin(services apid.Services) (error) { +func _initPlugin(services apid.Services) error { SetLogger(services.Log().ForModule("apigeeSync")) log.Debug("start init")
diff --git a/listener.go b/listener.go index 605e309..ba6c02c 100644 --- a/listener.go +++ b/listener.go
@@ -38,7 +38,7 @@ if config.GetString(configSnapshotProtocol) == "json" { processJsonSnapshot(snapshot, db) - } else if config.GetString(configSnapshotProtocol) == "sqlite"{ + } else if config.GetString(configSnapshotProtocol) == "sqlite" { processSqliteSnapshot(snapshot, db) }
diff --git a/snapshot.go b/snapshot.go index 51baeee..1731cfb 100644 --- a/snapshot.go +++ b/snapshot.go
@@ -1,18 +1,18 @@ package apidApigeeSync import ( - "github.com/30x/apid-core" - "net/http" "encoding/json" - "os" + "github.com/30x/apid-core" "github.com/30x/apid-core/data" "github.com/apigee-labs/transicator/common" + "net/http" + "os" "io" - "time" + "io/ioutil" "net/url" "path" - "io/ioutil" + "time" ) // retrieve boot information: apid_config and apid_config_scope @@ -94,7 +94,7 @@ } // Skip Downloading snapshot if there is already a snapshot available from previous run -func startOnLocalSnapshot(snapshot string) *common.Snapshot{ +func startOnLocalSnapshot(snapshot string) *common.Snapshot { log.Infof("Starting on local snapshot: %s", snapshot) // ensure DB version will be accessible on behalf of dependant plugins @@ -146,7 +146,7 @@ } -func getAttemptDownloadClosure(client *http.Client, snapshot *common.Snapshot, uri string) func(chan bool) error{ +func getAttemptDownloadClosure(client *http.Client, snapshot *common.Snapshot, uri string) func(chan bool) error { return func(_ chan bool) error { req, err := http.NewRequest("GET", uri, nil) if err != nil { @@ -155,7 +155,7 @@ } addHeaders(req) - var processSnapshotResponse func(*http.Response, *common.Snapshot) (error) + var processSnapshotResponse func(*http.Response, *common.Snapshot) error // Set the transport protocol type based on conf file input if config.GetString(configSnapshotProtocol) == "json" { @@ -165,7 +165,7 @@ req.Header.Set("Accept", "application/transicator+sqlite") processSnapshotResponse = processSnapshotServerFileResponse } - + // Issue the request to the snapshot server r, err := client.Do(req) if err != nil { @@ -249,4 +249,4 @@ func handleSnapshotServerError(err error) { log.Debugf("Error connecting to snapshot server: %v", err) -} \ No newline at end of file +}
diff --git a/token.go b/token.go index c33c10b..bc6f1df 100644 --- a/token.go +++ b/token.go
@@ -7,7 +7,6 @@ "net/http" "net/url" "path" - "sync" "time" ) @@ -26,23 +25,30 @@ func createTokenManager() *tokenMan { - t := &tokenMan{} - t.doRefresh = make(chan bool, 1) - t.quitPollingForToken = make(chan bool, 1) - t.tokenRefreshed = make(chan bool) - t.closed = make(chan bool) + t := &tokenMan{ + quitPollingForToken: make(chan bool, 1), + closed: make(chan bool), + getTokenChan: make(chan bool), + invalidateTokenChan: make(chan bool), + returnTokenChan: make(chan *oauthToken), + invalidateDone: make(chan bool), + } + t.retrieveNewToken() - t.maintainToken() + t.refreshTimer = time.After(t.token.refreshIn()) + go t.maintainToken() return t } type tokenMan struct { - sync.Mutex - token *oauthToken - doRefresh chan bool + token *oauthToken quitPollingForToken chan bool - tokenRefreshed chan bool - closed chan bool + closed chan bool + getTokenChan chan bool + invalidateTokenChan chan bool + refreshTimer <-chan time.Time + returnTokenChan chan *oauthToken + invalidateDone chan bool } func (t *tokenMan) getBearerToken() string { @@ -50,55 +56,45 @@ } func (t *tokenMan) maintainToken() { - go func() { - for { - t.Lock() + for { + select { + case <-t.closed: + return + case <-t.refreshTimer: + log.Debug("auto refresh token") + t.retrieveNewToken() + t.refreshTimer = time.After(t.token.refreshIn()) + case <-t.getTokenChan: token := t.token - t.Unlock() - select { - - case _, ok := <-t.doRefresh: - if !ok { - log.Debug("closed tokenMan") - t.closed <- true - return - } - log.Debug("force token refresh") - t.retrieveNewToken() - t.tokenRefreshed <- true - case <-time.After(token.refreshIn()): - log.Debug("auto refresh token") - t.retrieveNewToken() - - } + t.returnTokenChan <- token + case <-t.invalidateTokenChan: + t.retrieveNewToken() + t.refreshTimer = time.After(t.token.refreshIn()) + t.invalidateDone <- true } - }() -} - -func (t *tokenMan) invalidateToken() { - log.Debug("invalidating token") - t.Lock() - t.token = nil - t.Unlock() - t.doRefresh <- true - //ensure refresh signal has been received - <-t.tokenRefreshed + } } // will block until valid -//assumption is that if we can get the lock, then it's valid +func (t *tokenMan) invalidateToken() { + log.Debug("invalidating token") + t.invalidateTokenChan <- true + <-t.invalidateDone +} + + func (t *tokenMan) getToken() *oauthToken { - t.Lock() - defer t.Unlock() - return t.token + t.getTokenChan <- true + return <-t.returnTokenChan } func (t *tokenMan) close() { log.Debug("close token manager") t.quitPollingForToken <- true - close(t.doRefresh) - //block until close signal has been received by maintenance routine - <-t.closed + // sending instead of closing, to make sure it enters the t.doRefresh branch + log.Debug("token manager closed") + t.closed <- true + close(t.closed) } // don't call externally. will block until success. @@ -112,7 +108,7 @@ } uri.Path = path.Join(uri.Path, "/accesstoken") - pollWithBackoff(t.quitPollingForToken, t.getRetrieveNewTokenClosure(uri), func(err error) {log.Errorf("Error getting new token : ", err)}) + pollWithBackoff(t.quitPollingForToken, t.getRetrieveNewTokenClosure(uri), func(err error) { log.Errorf("Error getting new token : ", err) }) } func (t *tokenMan) getRetrieveNewTokenClosure(uri *url.URL) func(chan bool) error { @@ -179,9 +175,7 @@ } } - t.Lock() t.token = &token - t.Unlock() config.Set(configBearerToken, token.AccessToken) return nil
diff --git a/token_test.go b/token_test.go index e1bb460..2cff649 100644 --- a/token_test.go +++ b/token_test.go
@@ -177,7 +177,7 @@ } res := oauthToken{ AccessToken: string(count), - ExpiresIn: 200000, + ExpiresIn: 200000, } body, err := json.Marshal(res) Expect(err).NotTo(HaveOccurred())