Merge pull request #34 from 30x/refactor-haoming
no race in token now
diff --git a/apigeeSync_suite_test.go b/apigeeSync_suite_test.go
index 9703cdf..3ded532 100644
--- a/apigeeSync_suite_test.go
+++ b/apigeeSync_suite_test.go
@@ -25,7 +25,7 @@
const dummyConfigValue string = "placeholder"
-var _ = BeforeSuite(func(){
+var _ = BeforeSuite(func() {
wipeDBAferTest = true
})
@@ -62,7 +62,7 @@
lastSequence = ""
- if (wipeDBAferTest) {
+ if wipeDBAferTest {
_, err := getDB().Exec("DELETE FROM APID_CLUSTER")
Expect(err).NotTo(HaveOccurred())
_, err = getDB().Exec("DELETE FROM DATA_SCOPE")
diff --git a/apigee_sync.go b/apigee_sync.go
index 703681a..99275f3 100644
--- a/apigee_sync.go
+++ b/apigee_sync.go
@@ -1,18 +1,17 @@
package apidApigeeSync
import (
- "time"
- "net/http"
"github.com/30x/apid-core"
+ "net/http"
+ "time"
)
const (
- httpTimeout = time.Minute
- pluginTimeout = time.Minute
+ httpTimeout = time.Minute
+ pluginTimeout = time.Minute
)
-var knownTables = make(map[string]bool)
-
+var knownTables = make(map[string]bool)
/*
* Start from existing snapshot if possible
@@ -21,7 +20,7 @@
*
* Then, poll for changes
*/
-func bootstrap(quitPollingSnapshotServer, quitPollingChangeServer chan bool) {
+func bootstrap() {
if apidInfo.LastSnapshot != "" {
snapshot := startOnLocalSnapshot(apidInfo.LastSnapshot)
@@ -120,4 +119,4 @@
func (a changeServerError) Error() string {
return a.Code
-}
\ No newline at end of file
+}
diff --git a/apigee_sync_test.go b/apigee_sync_test.go
index 7f25087..7aa1a81 100644
--- a/apigee_sync_test.go
+++ b/apigee_sync_test.go
@@ -5,8 +5,8 @@
"github.com/apigee-labs/transicator/common"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
- "time"
"net/http/httptest"
+ "time"
)
var _ = Describe("Sync", func() {
@@ -55,8 +55,8 @@
expectedSnapshotTables := common.ChangeList{
Changes: []common.Change{common.Change{Table: "kms.company"},
- common.Change{Table: "edgex.apid_cluster"},
- common.Change{Table: "edgex.data_scope"}},
+ common.Change{Table: "edgex.apid_cluster"},
+ common.Change{Table: "edgex.data_scope"}},
}
apid.Events().ListenFunc(ApigeeSyncEventSelector, func(event apid.Event) {
@@ -105,7 +105,7 @@
}
} else if cl, ok := event.(*common.ChangeList); ok {
- go func(){quitPollingChangeServer <- true}()
+ go func() { quitPollingChangeServer <- true }()
// ensure that snapshot switched DB versions
Expect(apidInfo.LastSnapshot).To(Equal(lastSnapshot.SnapshotInfo))
expectedDB, err := dataService.DBVersion(lastSnapshot.SnapshotInfo)
@@ -161,16 +161,16 @@
initializeContext()
expectedTables := common.ChangeList{
Changes: []common.Change{common.Change{Table: "kms.company"},
- common.Change{Table: "edgex.apid_cluster"},
- common.Change{Table: "edgex.data_scope"}},
+ common.Change{Table: "edgex.apid_cluster"},
+ common.Change{Table: "edgex.data_scope"}},
}
-
+ thisQuitPollingChangeServer := quitPollingChangeServer
Expect(apidInfo.LastSnapshot).NotTo(BeEmpty())
apid.Events().ListenFunc(ApigeeSyncEventSelector, func(event apid.Event) {
if s, ok := event.(*common.Snapshot); ok {
- go func(){quitPollingChangeServer <- true}()
+ go func() { thisQuitPollingChangeServer <- true }()
//verify that the knownTables array has been properly populated from existing DB
Expect(changesRequireDDLSync(expectedTables)).To(BeFalse())
diff --git a/backoff.go b/backoff.go
index 57e2176..291a037 100644
--- a/backoff.go
+++ b/backoff.go
@@ -6,9 +6,9 @@
"time"
)
-const defaultInitial time.Duration = 200 * time.Millisecond
-const defaultMax time.Duration = 10 * time.Second
-const defaultFactor float64 = 2
+const defaultInitial time.Duration = 200 * time.Millisecond
+const defaultMax time.Duration = 10 * time.Second
+const defaultFactor float64 = 2
type Backoff struct {
attempt int
@@ -19,7 +19,6 @@
type ExponentialBackoff struct {
Backoff
factor float64
-
}
func NewExponentialBackoff(initial, max time.Duration, factor float64) *Backoff {
@@ -45,7 +44,6 @@
return &backoff.Backoff
}
-
func (b *Backoff) Duration() time.Duration {
d := b.backoffStrategy()
b.attempt++
@@ -80,4 +78,4 @@
func (b *Backoff) Attempt() int {
return b.attempt
-}
\ No newline at end of file
+}
diff --git a/changes.go b/changes.go
index 6a0afad..59caa34 100644
--- a/changes.go
+++ b/changes.go
@@ -1,19 +1,18 @@
package apidApigeeSync
import (
- "time"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
"net/url"
"path"
- "net/http"
- "io/ioutil"
- "encoding/json"
+ "time"
"github.com/apigee-labs/transicator/common"
-
)
var lastSequence string
-var block string = "45"
+var block string = "45"
/*
* Long polls the change agent with a 45 second block. Parses the response from
@@ -64,10 +63,10 @@
v.Add("block", block)
/*
- * Include all the scopes associated with the config Id
- * The Config Id is included as well, as it acts as the
- * Bootstrap scope
- */
+ * Include all the scopes associated with the config Id
+ * The Config Id is included as well, as it acts as the
+ * Bootstrap scope
+ */
for _, scope := range scopes {
v.Add("scope", scope)
}
@@ -154,7 +153,6 @@
return nil
}
-
func changesRequireDDLSync(changes common.ChangeList) bool {
return changesHaveNewTables(knownTables, changes.Changes)
}
@@ -175,7 +173,7 @@
func changesHaveNewTables(a map[string]bool, changes []common.Change) bool {
//nil maps should not be passed in. Making the distinction between nil map and empty map
- if a == nil || changes == nil{
+ if a == nil || changes == nil {
return true
}
diff --git a/init.go b/init.go
index 7967e53..c9ad44f 100644
--- a/init.go
+++ b/init.go
@@ -30,21 +30,20 @@
var (
/* All set during plugin initialization */
- log apid.LogService
- config apid.ConfigService
- dataService apid.DataService
- events apid.EventsService
- apidInfo apidInstanceInfo
- newInstanceID bool
- tokenManager *tokenMan
+ log apid.LogService
+ config apid.ConfigService
+ dataService apid.DataService
+ events apid.EventsService
+ apidInfo apidInstanceInfo
+ newInstanceID bool
+ tokenManager *tokenMan
quitPollingSnapshotServer chan bool
- quitPollingChangeServer chan bool
+ quitPollingChangeServer chan bool
/* Set during post plugin initialization
* set this as a default, so that it's guaranteed to be valid even if postInitPlugins isn't called
*/
apidPluginDetails string = `[{"name":"apidApigeeSync","schemaVer":"1.0"}]`
-
)
type apidInstanceInfo struct {
@@ -72,7 +71,7 @@
log.Debugf("Using %s as display name", config.GetString(configName))
}
-func initVariables(services apid.Services) (error) {
+func initVariables(services apid.Services) error {
dataService = services.Data()
events = services.Events()
//TODO listen for arbitrary commands, these channels can be used to kill polling goroutines
@@ -104,10 +103,10 @@
return nil
}
-func checkForRequiredValues() (error) {
+func checkForRequiredValues() error {
// check for required values
for _, key := range []string{configProxyServerBaseURI, configConsumerKey, configConsumerSecret,
- configSnapServerBaseURI, configChangeServerBaseURI} {
+ configSnapServerBaseURI, configChangeServerBaseURI} {
if !config.IsSet(key) {
return fmt.Errorf("Missing required config value: %s", key)
}
@@ -125,7 +124,7 @@
}
/* Idempotent state initialization */
-func _initPlugin(services apid.Services) (error) {
+func _initPlugin(services apid.Services) error {
SetLogger(services.Log().ForModule("apigeeSync"))
log.Debug("start init")
@@ -197,7 +196,7 @@
tokenManager = createTokenManager()
- go bootstrap(quitPollingSnapshotServer, quitPollingChangeServer)
+ go bootstrap()
events.Listen(ApigeeSyncEventSelector, &handler{})
log.Debug("Done post plugin init")
diff --git a/listener.go b/listener.go
index 605e309..ba6c02c 100644
--- a/listener.go
+++ b/listener.go
@@ -38,7 +38,7 @@
if config.GetString(configSnapshotProtocol) == "json" {
processJsonSnapshot(snapshot, db)
- } else if config.GetString(configSnapshotProtocol) == "sqlite"{
+ } else if config.GetString(configSnapshotProtocol) == "sqlite" {
processSqliteSnapshot(snapshot, db)
}
diff --git a/snapshot.go b/snapshot.go
index 51baeee..1731cfb 100644
--- a/snapshot.go
+++ b/snapshot.go
@@ -1,18 +1,18 @@
package apidApigeeSync
import (
- "github.com/30x/apid-core"
- "net/http"
"encoding/json"
- "os"
+ "github.com/30x/apid-core"
"github.com/30x/apid-core/data"
"github.com/apigee-labs/transicator/common"
+ "net/http"
+ "os"
"io"
- "time"
+ "io/ioutil"
"net/url"
"path"
- "io/ioutil"
+ "time"
)
// retrieve boot information: apid_config and apid_config_scope
@@ -94,7 +94,7 @@
}
// Skip Downloading snapshot if there is already a snapshot available from previous run
-func startOnLocalSnapshot(snapshot string) *common.Snapshot{
+func startOnLocalSnapshot(snapshot string) *common.Snapshot {
log.Infof("Starting on local snapshot: %s", snapshot)
// ensure DB version will be accessible on behalf of dependant plugins
@@ -146,7 +146,7 @@
}
-func getAttemptDownloadClosure(client *http.Client, snapshot *common.Snapshot, uri string) func(chan bool) error{
+func getAttemptDownloadClosure(client *http.Client, snapshot *common.Snapshot, uri string) func(chan bool) error {
return func(_ chan bool) error {
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
@@ -155,7 +155,7 @@
}
addHeaders(req)
- var processSnapshotResponse func(*http.Response, *common.Snapshot) (error)
+ var processSnapshotResponse func(*http.Response, *common.Snapshot) error
// Set the transport protocol type based on conf file input
if config.GetString(configSnapshotProtocol) == "json" {
@@ -165,7 +165,7 @@
req.Header.Set("Accept", "application/transicator+sqlite")
processSnapshotResponse = processSnapshotServerFileResponse
}
-
+
// Issue the request to the snapshot server
r, err := client.Do(req)
if err != nil {
@@ -249,4 +249,4 @@
func handleSnapshotServerError(err error) {
log.Debugf("Error connecting to snapshot server: %v", err)
-}
\ No newline at end of file
+}
diff --git a/token.go b/token.go
index 3c4bcbc..bc6f1df 100644
--- a/token.go
+++ b/token.go
@@ -7,13 +7,11 @@
"net/http"
"net/url"
"path"
- "sync"
"time"
)
var (
refreshFloatTime = time.Minute
- getTokenLock sync.Mutex
)
/*
@@ -26,22 +24,31 @@
*/
func createTokenManager() *tokenMan {
- t := &tokenMan{}
- t.doRefresh = make(chan bool, 1)
- t.quitPollingForToken = make(chan bool, 1)
- t.tokenRefreshed = make(chan bool)
- t.closed = make(chan bool)
+
+ t := &tokenMan{
+ quitPollingForToken: make(chan bool, 1),
+ closed: make(chan bool),
+ getTokenChan: make(chan bool),
+ invalidateTokenChan: make(chan bool),
+ returnTokenChan: make(chan *oauthToken),
+ invalidateDone: make(chan bool),
+ }
+
t.retrieveNewToken()
- t.maintainToken()
+ t.refreshTimer = time.After(t.token.refreshIn())
+ go t.maintainToken()
return t
}
type tokenMan struct {
- token *oauthToken
- doRefresh chan bool
+ token *oauthToken
quitPollingForToken chan bool
- tokenRefreshed chan bool
- closed chan bool
+ closed chan bool
+ getTokenChan chan bool
+ invalidateTokenChan chan bool
+ refreshTimer <-chan time.Time
+ returnTokenChan chan *oauthToken
+ invalidateDone chan bool
}
func (t *tokenMan) getBearerToken() string {
@@ -49,56 +56,45 @@
}
func (t *tokenMan) maintainToken() {
- go func() {
- for {
- select {
-
- case _, ok := <-t.doRefresh:
- if !ok {
- log.Debug("closed tokenMan")
- t.closed <- true
- return
- }
- log.Debug("force token refresh")
- t.retrieveNewToken()
- t.tokenRefreshed <- true
- continue
- case <-time.After(t.token.refreshIn()):
- log.Debug("auto refresh token")
- getTokenLock.Lock()
- t.retrieveNewToken()
- getTokenLock.Unlock()
- continue
-
- }
+ for {
+ select {
+ case <-t.closed:
+ return
+ case <-t.refreshTimer:
+ log.Debug("auto refresh token")
+ t.retrieveNewToken()
+ t.refreshTimer = time.After(t.token.refreshIn())
+ case <-t.getTokenChan:
+ token := t.token
+ t.returnTokenChan <- token
+ case <-t.invalidateTokenChan:
+ t.retrieveNewToken()
+ t.refreshTimer = time.After(t.token.refreshIn())
+ t.invalidateDone <- true
}
- }()
-}
-
-func (t *tokenMan) invalidateToken() {
- log.Debug("invalidating token")
- getTokenLock.Lock()
- t.token = nil
- t.doRefresh <- true
- //ensure refresh signal has been received
- <-t.tokenRefreshed
- getTokenLock.Unlock()
+ }
}
// will block until valid
-//assumption is that if we can get the lock, then it's valid
+func (t *tokenMan) invalidateToken() {
+ log.Debug("invalidating token")
+ t.invalidateTokenChan <- true
+ <-t.invalidateDone
+}
+
+
func (t *tokenMan) getToken() *oauthToken {
- getTokenLock.Lock()
- defer getTokenLock.Unlock()
- return t.token
+ t.getTokenChan <- true
+ return <-t.returnTokenChan
}
func (t *tokenMan) close() {
log.Debug("close token manager")
t.quitPollingForToken <- true
- close(t.doRefresh)
- //block until close signal has been received by maintenance routine
- <-t.closed
+ // sending instead of closing, to make sure it enters the t.doRefresh branch
+ log.Debug("token manager closed")
+ t.closed <- true
+ close(t.closed)
}
// don't call externally. will block until success.
@@ -112,7 +108,7 @@
}
uri.Path = path.Join(uri.Path, "/accesstoken")
- pollWithBackoff(t.quitPollingForToken, t.getRetrieveNewTokenClosure(uri), func(err error) {log.Errorf("Error getting new token : ", err)})
+ pollWithBackoff(t.quitPollingForToken, t.getRetrieveNewTokenClosure(uri), func(err error) { log.Errorf("Error getting new token : ", err) })
}
func (t *tokenMan) getRetrieveNewTokenClosure(uri *url.URL) func(chan bool) error {
@@ -179,7 +175,6 @@
}
}
-
t.token = &token
config.Set(configBearerToken, token.AccessToken)
diff --git a/token_test.go b/token_test.go
index e1bb460..2cff649 100644
--- a/token_test.go
+++ b/token_test.go
@@ -177,7 +177,7 @@
}
res := oauthToken{
AccessToken: string(count),
- ExpiresIn: 200000,
+ ExpiresIn: 200000,
}
body, err := json.Marshal(res)
Expect(err).NotTo(HaveOccurred())