Merge pull request #46 from 30x/XAPID-958
reformat code for unit tests, add new test cases, coverage increased…
diff --git a/apigeeSync_suite_test.go b/apigeeSync_suite_test.go
index c00de6e..dd6cba4 100644
--- a/apigeeSync_suite_test.go
+++ b/apigeeSync_suite_test.go
@@ -55,6 +55,7 @@
log = apid.Log()
_initPlugin(apid.AllServices())
+ createManagers()
close(done)
}, 3)
diff --git a/apigee_sync.go b/apigee_sync.go
index 391355b..6fc1389 100644
--- a/apigee_sync.go
+++ b/apigee_sync.go
@@ -27,17 +27,17 @@
snapshot := startOnLocalSnapshot(apidInfo.LastSnapshot)
events.EmitWithCallback(ApigeeSyncEventSelector, snapshot, func(event apid.Event) {
- changeManager.pollChangeWithBackoff()
+ apidChangeManager.pollChangeWithBackoff()
})
log.Infof("Started on local snapshot: %s", snapshot.SnapshotInfo)
return
}
- snapManager.downloadBootSnapshot()
- snapManager.downloadDataSnapshot()
+ apidSnapshotManager.downloadBootSnapshot()
+ apidSnapshotManager.downloadDataSnapshot()
- changeManager.pollChangeWithBackoff()
+ apidChangeManager.pollChangeWithBackoff()
}
@@ -88,7 +88,7 @@
}
func addHeaders(req *http.Request) {
- req.Header.Set("Authorization", "Bearer "+tokenManager.getBearerToken())
+ req.Header.Set("Authorization", "Bearer "+apidTokenManager.getBearerToken())
req.Header.Set("apid_instance_id", apidInfo.InstanceID)
req.Header.Set("apid_cluster_Id", apidInfo.ClusterID)
req.Header.Set("updated_at_apid", time.Now().Format(time.RFC3339))
@@ -104,6 +104,9 @@
type expected200Error struct {
}
+type authFailError struct {
+}
+
func (an expected200Error) Error() string {
return "Did not recieve OK response"
}
@@ -115,3 +118,7 @@
func (a changeServerError) Error() string {
return a.Code
}
+
+func (a authFailError) Error() string {
+ return "Authorization failed"
+}
diff --git a/apigee_sync_test.go b/apigee_sync_test.go
index 463b9e4..22823a4 100644
--- a/apigee_sync_test.go
+++ b/apigee_sync_test.go
@@ -111,7 +111,7 @@
}
} else if cl, ok := event.(*common.ChangeList); ok {
- closeDone = changeManager.close()
+ closeDone = apidChangeManager.close()
// ensure that snapshot switched DB versions
Expect(apidInfo.LastSnapshot).To(Equal(lastSnapshot.SnapshotInfo))
expectedDB, err := dataService.DBVersion(lastSnapshot.SnapshotInfo)
@@ -180,7 +180,7 @@
if s, ok := event.(*common.Snapshot); ok {
// In this test, the changeManager.pollChangeWithBackoff() has not been launched when changeManager closed
// This is because the changeManager.pollChangeWithBackoff() in bootstrap() happened after this handler
- closeDone = changeManager.close()
+ closeDone = apidChangeManager.close()
go func() {
// when close done, all handlers for the first snapshot have been executed
<-closeDone
@@ -289,18 +289,19 @@
*/
It("Should be able to handle duplicate snapshot during bootstrap", func() {
initializeContext()
- tokenManager = createTokenManager()
- snapManager = createSnapShotManager()
+ apidTokenManager = createSimpleTokenManager()
+ apidTokenManager.start()
+ apidSnapshotManager = createSnapShotManager()
events.Listen(ApigeeSyncEventSelector, &handler{})
scopes := []string{apidInfo.ClusterID}
snapshot := &common.Snapshot{}
- snapManager.downloadSnapshot(scopes, snapshot)
- snapManager.storeBootSnapshot(snapshot)
- snapManager.storeDataSnapshot(snapshot)
+ apidSnapshotManager.downloadSnapshot(scopes, snapshot)
+ apidSnapshotManager.storeBootSnapshot(snapshot)
+ apidSnapshotManager.storeDataSnapshot(snapshot)
restoreContext()
- <-snapManager.close()
- tokenManager.close()
+ <-apidSnapshotManager.close()
+ apidTokenManager.close()
}, 3)
It("Reuse http.Client connection for multiple concurrent requests", func() {
diff --git a/change_test.go b/change_test.go
new file mode 100644
index 0000000..7a69995
--- /dev/null
+++ b/change_test.go
@@ -0,0 +1,202 @@
+package apidApigeeSync
+
+import (
+ "github.com/30x/apid-core"
+ "github.com/apigee-labs/transicator/common"
+ . "github.com/onsi/ginkgo"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "time"
+)
+
+var _ = Describe("Change Agent", func() {
+
+ Context("Change Agent Unit Tests", func() {
+ handler := handler{}
+
+ var createTestDb = func(sqlfile string, dbId string) common.Snapshot {
+ initDb(sqlfile, "./mockdb_change.sqlite3")
+ file, err := os.Open("./mockdb_change.sqlite3")
+ if err != nil {
+ Fail("Failed to open mock db for test")
+ }
+
+ s := common.Snapshot{}
+ err = processSnapshotServerFileResponse(dbId, file, &s)
+ if err != nil {
+ Fail("Error processing test snapshots")
+ }
+ return s
+ }
+
+ BeforeEach(func() {
+ event := createTestDb("./sql/init_mock_db.sql", "test_change")
+ handler.Handle(&event)
+ knownTables = extractTablesFromDB(getDB())
+ })
+
+ var initializeContext = func() {
+ testRouter = apid.API().Router()
+ testServer = httptest.NewServer(testRouter)
+
+ // set up mock server
+ mockParms := MockParms{
+ ReliableAPI: true,
+ ClusterID: config.GetString(configApidClusterId),
+ TokenKey: config.GetString(configConsumerKey),
+ TokenSecret: config.GetString(configConsumerSecret),
+ Scope: "ert452",
+ Organization: "att",
+ Environment: "prod",
+ }
+ testMock = Mock(mockParms, testRouter)
+
+ config.Set(configProxyServerBaseURI, testServer.URL)
+ config.Set(configSnapServerBaseURI, testServer.URL)
+ config.Set(configChangeServerBaseURI, testServer.URL)
+ config.Set(configPollInterval, 1*time.Millisecond)
+ }
+
+ var restoreContext = func() {
+
+ testServer.Close()
+ config.Set(configProxyServerBaseURI, dummyConfigValue)
+ config.Set(configSnapServerBaseURI, dummyConfigValue)
+ config.Set(configChangeServerBaseURI, dummyConfigValue)
+ config.Set(configPollInterval, 10*time.Millisecond)
+ }
+
+ It("test change agent with authorization failure", func() {
+ log.Debug("test change agent with authorization failure")
+ testTokenManager := &dummyTokenManager{make(chan bool)}
+ apidTokenManager = testTokenManager
+ apidTokenManager.start()
+ apidSnapshotManager = &dummySnapshotManager{}
+ initializeContext()
+ testMock.forceAuthFail()
+ wipeDBAferTest = true
+ apidChangeManager.pollChangeWithBackoff()
+ // auth check fails
+ <-testTokenManager.invalidateChan
+ log.Debug("closing")
+ <-apidChangeManager.close()
+ restoreContext()
+ })
+
+ It("test change agent with too old snapshot", func() {
+ log.Debug("test change agent with too old snapshot")
+ testTokenManager := &dummyTokenManager{make(chan bool)}
+ apidTokenManager = testTokenManager
+ apidTokenManager.start()
+ testSnapshotManager := &dummySnapshotManager{make(chan bool)}
+ apidSnapshotManager = testSnapshotManager
+ initializeContext()
+
+ testMock.passAuthCheck()
+ testMock.forceNewSnapshot()
+ wipeDBAferTest = true
+ apidChangeManager.pollChangeWithBackoff()
+ <-testSnapshotManager.downloadCalledChan
+ log.Debug("closing")
+ <-apidChangeManager.close()
+ restoreContext()
+ })
+
+ It("change agent should retry with authorization failure", func(done Done) {
+ log.Debug("change agent should retry with authorization failure")
+ testTokenManager := &dummyTokenManager{make(chan bool)}
+ apidTokenManager = testTokenManager
+ apidTokenManager.start()
+ apidSnapshotManager = &dummySnapshotManager{}
+ initializeContext()
+ testMock.forceAuthFail()
+ testMock.forceNoSnapshot()
+ wipeDBAferTest = true
+
+ apid.Events().ListenFunc(ApigeeSyncEventSelector, func(event apid.Event) {
+
+ if _, ok := event.(*common.ChangeList); ok {
+ closeDone := apidChangeManager.close()
+ log.Debug("closing")
+ go func() {
+ // when close done, all handlers for the first snapshot have been executed
+ <-closeDone
+ restoreContext()
+ close(done)
+ }()
+
+ }
+ })
+
+ apidChangeManager.pollChangeWithBackoff()
+ // auth check fails
+ <-testTokenManager.invalidateChan
+ })
+
+ })
+})
+
+type dummyTokenManager struct {
+ invalidateChan chan bool
+}
+
+func (t *dummyTokenManager) getBearerToken() string {
+ return ""
+}
+
+func (t *dummyTokenManager) invalidateToken() error {
+ log.Debug("invalidateToken called")
+ testMock.passAuthCheck()
+ t.invalidateChan <- true
+ return nil
+}
+
+func (t *dummyTokenManager) getToken() *oauthToken {
+ return nil
+}
+
+func (t *dummyTokenManager) close() {
+ return
+}
+
+func (t *dummyTokenManager) getRetrieveNewTokenClosure(*url.URL) func(chan bool) error {
+ return func(chan bool) error {
+ return nil
+ }
+}
+
+func (t *dummyTokenManager) start() {
+
+}
+
+type dummySnapshotManager struct {
+ downloadCalledChan chan bool
+}
+
+func (s *dummySnapshotManager) close() <-chan bool {
+ closeChan := make(chan bool)
+ close(closeChan)
+ return closeChan
+}
+
+func (s *dummySnapshotManager) downloadBootSnapshot() {
+
+}
+
+func (s *dummySnapshotManager) storeBootSnapshot(snapshot *common.Snapshot) {
+
+}
+
+func (s *dummySnapshotManager) downloadDataSnapshot() {
+ log.Debug("dummySnapshotManager.downloadDataSnapshot() called")
+ s.downloadCalledChan <- true
+}
+
+func (s *dummySnapshotManager) storeDataSnapshot(snapshot *common.Snapshot) {
+
+}
+
+func (s *dummySnapshotManager) downloadSnapshot(scopes []string, snapshot *common.Snapshot) error {
+ return nil
+}
diff --git a/changes.go b/changes.go
index a3a39b5..9e8d170 100644
--- a/changes.go
+++ b/changes.go
@@ -54,8 +54,8 @@
log.Warn("pollChangeManager: close() called when pollChangeWithBackoff unlaunched! Will wait until pollChangeWithBackoff is launched and then kill it and tokenManager!")
go func() {
c.quitChan <- true
- tokenManager.close()
- <-snapManager.close()
+ apidTokenManager.close()
+ <-apidSnapshotManager.close()
log.Debug("change manager closed")
finishChan <- false
}()
@@ -65,8 +65,8 @@
log.Debug("pollChangeManager: close pollChangeWithBackoff and token manager")
go func() {
c.quitChan <- true
- tokenManager.close()
- <-snapManager.close()
+ apidTokenManager.close()
+ <-apidSnapshotManager.close()
log.Debug("change manager closed")
finishChan <- true
}()
@@ -183,8 +183,11 @@
log.Errorf("Get changes request failed with status code: %d", r.StatusCode)
switch r.StatusCode {
case http.StatusUnauthorized:
- tokenManager.invalidateToken()
- return nil
+ err = apidTokenManager.invalidateToken()
+ if err != nil {
+ return err
+ }
+ return authFailError{}
case http.StatusNotModified:
return nil
@@ -206,7 +209,7 @@
log.Debug("Received SNAPSHOT_TOO_OLD message from change server.")
err = apiErr
}
- return nil
+ return err
}
return nil
}
@@ -271,7 +274,7 @@
}
if c, ok := err.(changeServerError); ok {
log.Debugf("%s. Fetch a new snapshot to sync...", c.Code)
- snapManager.downloadDataSnapshot()
+ apidSnapshotManager.downloadDataSnapshot()
} else {
log.Debugf("Error connecting to changeserver: %v", err)
}
diff --git a/cmd/mockServer/main.go b/cmd/mockServer/main.go
index 6773606..bf4de27 100644
--- a/cmd/mockServer/main.go
+++ b/cmd/mockServer/main.go
@@ -5,7 +5,6 @@
"os"
- "time"
"github.com/30x/apid-core"
"github.com/30x/apid-core/factory"
@@ -22,11 +21,8 @@
reliable := f.Bool("reliable", true, "if false, server will often send 500 errors")
numDevs := f.Int("numDevs", 2, "number of developers in snapshot")
- addDevEach := f.Duration("addDevEach", 0*time.Second, "add a developer each duration (default 0s)")
- upDevEach := f.Duration("upDevEach", 0*time.Second, "update a developer each duration (default 0s)")
numDeps := f.Int("numDeps", 2, "number of deployments in snapshot")
- upDepEach := f.Duration("upDepEach", 0*time.Second, "update (replace) a deployment each duration (default 0s)")
f.Parse(os.Args[1:])
@@ -51,10 +47,7 @@
Organization: "org",
Environment: "test",
NumDevelopers: *numDevs,
- AddDeveloperEvery: *addDevEach,
- UpdateDeveloperEvery: *upDevEach,
NumDeployments: *numDeps,
- ReplaceDeploymentEvery: *upDepEach,
BundleURI: *bundleURI,
}
diff --git a/data.go b/data.go
index bf0bd2e..a9b1496 100644
--- a/data.go
+++ b/data.go
@@ -101,9 +101,8 @@
if err != nil {
log.Errorf("INSERT Fail [%s] values=%v error=[%v]", sql, values, err)
return false
- } else {
- log.Debugf("INSERT Success [%s] values=%v", sql, values)
}
+ log.Debugf("INSERT Success [%s] values=%v", sql, values)
return true
}
@@ -126,39 +125,42 @@
if len(pkeys) == 0 || err != nil {
log.Errorf("DELETE No primary keys found for table. %s", tableName)
return false
- } else if len(rows) == 0 {
+ }
+
+ if len(rows) == 0 {
log.Errorf("No rows found for table.", tableName)
return false
- } else {
- sql := buildDeleteSql(tableName, rows[0], pkeys)
- prep, err := txn.Prepare(sql)
+ }
+
+ sql := buildDeleteSql(tableName, rows[0], pkeys)
+ prep, err := txn.Prepare(sql)
+ if err != nil {
+ log.Errorf("DELETE Fail to prep statement [%s] error=[%v]", sql, err)
+ return false
+ }
+ defer prep.Close()
+ for _, row := range rows {
+ values := getValueListFromKeys(row, pkeys)
+ // delete prepared statement from existing template statement
+ res, err := txn.Stmt(prep).Exec(values...)
if err != nil {
- log.Errorf("DELETE Fail to prep statement [%s] error=[%v]", sql, err)
+ log.Errorf("DELETE Fail [%s] values=%v error=[%v]", sql, values, err)
return false
}
- defer prep.Close()
- for _, row := range rows {
- values := getValueListFromKeys(row, pkeys)
- // delete prepared statement from existing template statement
- res, err := txn.Stmt(prep).Exec(values...)
- if err != nil {
- log.Errorf("DELETE Fail [%s] values=%v error=[%v]", sql, values, err)
- return false
- } else {
- affected, err := res.RowsAffected()
- if err == nil && affected != 0 {
- log.Debugf("DELETE Success [%s] values=%v", sql, values)
- } else if err == nil && affected == 0 {
- log.Errorf("Entry not found [%s] values=%v. Nothing to delete.", sql, values)
- return false
- } else {
- log.Errorf("DELETE Failed [%s] values=%v error=[%v]", sql, values, err)
- return false
- }
- }
+ affected, err := res.RowsAffected()
+ if err == nil && affected != 0 {
+ log.Debugf("DELETE Success [%s] values=%v", sql, values)
+ } else if err == nil && affected == 0 {
+ log.Errorf("Entry not found [%s] values=%v. Nothing to delete.", sql, values)
+ return false
+ } else {
+ log.Errorf("DELETE Failed [%s] values=%v error=[%v]", sql, values, err)
+ return false
}
- return true
+
}
+ return true
+
}
// Syntax "DELETE FROM Obj WHERE key1=$1 AND key2=$2 ... ;"
diff --git a/init.go b/init.go
index 437ce5c..d9b767e 100644
--- a/init.go
+++ b/init.go
@@ -30,16 +30,16 @@
var (
/* All set during plugin initialization */
- log apid.LogService
- config apid.ConfigService
- dataService apid.DataService
- events apid.EventsService
- apidInfo apidInstanceInfo
- newInstanceID bool
- tokenManager *tokenMan
- changeManager *pollChangeManager
- snapManager *snapShotManager
- httpclient *http.Client
+ log apid.LogService
+ config apid.ConfigService
+ dataService apid.DataService
+ events apid.EventsService
+ apidInfo apidInstanceInfo
+ newInstanceID bool
+ apidTokenManager tokenManager
+ apidChangeManager changeManager
+ apidSnapshotManager snapShotManager
+ httpclient *http.Client
/* Set during post plugin initialization
* set this as a default, so that it's guaranteed to be valid even if postInitPlugins isn't called
@@ -83,16 +83,11 @@
Transport: tr,
Timeout: httpTimeout,
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
- req.Header.Set("Authorization", "Bearer "+tokenManager.getBearerToken())
+ req.Header.Set("Authorization", "Bearer "+apidTokenManager.getBearerToken())
return nil
},
}
- //TODO listen for arbitrary commands, these channels can be used to kill polling goroutines
- //also useful for testing
- snapManager = createSnapShotManager()
- changeManager = createChangeManager()
-
// set up default database
db, err := dataService.DB()
if err != nil {
@@ -117,6 +112,12 @@
return nil
}
+func createManagers() {
+ apidSnapshotManager = createSnapShotManager()
+ apidChangeManager = createChangeManager()
+ apidTokenManager = createSimpleTokenManager()
+}
+
func checkForRequiredValues() error {
// check for required values
for _, key := range []string{configProxyServerBaseURI, configConsumerKey, configConsumerSecret,
@@ -137,7 +138,7 @@
log = logger
}
-/* Idempotent state initialization */
+/* initialization */
func _initPlugin(services apid.Services) error {
SetLogger(services.Log().ForModule("apigeeSync"))
log.Debug("start init")
@@ -165,6 +166,8 @@
return pluginData, err
}
+ createManagers()
+
/* This callback function will get called once all the plugins are
* initialized (not just this plugin). This is needed because,
* downloadSnapshots/changes etc have to begin to be processed only
@@ -208,8 +211,7 @@
log.Debug("start post plugin init")
- tokenManager = createTokenManager()
-
+ apidTokenManager.start()
go bootstrap()
events.Listen(ApigeeSyncEventSelector, &handler{})
diff --git a/listener.go b/listener.go
index a1360d4..2a8b492 100644
--- a/listener.go
+++ b/listener.go
@@ -58,7 +58,10 @@
log.Panicf("Unable to read database: %s", err.Error())
}
apidClusters.Next()
- apidClusters.Scan(&numApidClusters)
+ err = apidClusters.Scan(&numApidClusters)
+ if err != nil {
+ log.Panicf("Unable to read database: %s", err.Error())
+ }
if numApidClusters != 1 {
log.Panic("Illegal state for apid_cluster. Must be a single row.")
@@ -117,20 +120,3 @@
return ok
}
-
-func makeDataScopeFromRow(row common.Row) dataDataScope {
-
- ds := dataDataScope{}
-
- row.Get("id", &ds.ID)
- row.Get("apid_cluster_id", &ds.ClusterID)
- row.Get("scope", &ds.Scope)
- row.Get("org", &ds.Org)
- row.Get("env", &ds.Env)
- row.Get("created", &ds.Created)
- row.Get("created_by", &ds.CreatedBy)
- row.Get("updated", &ds.Updated)
- row.Get("updated_by", &ds.UpdatedBy)
-
- return ds
-}
diff --git a/listener_test.go b/listener_test.go
index 2c9c1c6..d424a0b 100644
--- a/listener_test.go
+++ b/listener_test.go
@@ -15,15 +15,11 @@
var createTestDb = func(sqlfile string, dbId string) common.Snapshot {
initDb(sqlfile, "./mockdb.sqlite3")
file, err := os.Open("./mockdb.sqlite3")
- if err != nil {
- Fail("Failed to open mock db for test")
- }
+ Expect(err).ShouldNot(HaveOccurred())
s := common.Snapshot{}
err = processSnapshotServerFileResponse(dbId, file, &s)
- if err != nil {
- Fail("Error processing test snapshots")
- }
+ Expect(err).ShouldNot(HaveOccurred())
return s
}
diff --git a/managerInterfaces.go b/managerInterfaces.go
new file mode 100644
index 0000000..20bbf6f
--- /dev/null
+++ b/managerInterfaces.go
@@ -0,0 +1,29 @@
+package apidApigeeSync
+
+import (
+ "github.com/apigee-labs/transicator/common"
+ "net/url"
+)
+
+type tokenManager interface {
+ getBearerToken() string
+ invalidateToken() error
+ getToken() *oauthToken
+ close()
+ getRetrieveNewTokenClosure(*url.URL) func(chan bool) error
+ start()
+}
+
+type snapShotManager interface {
+ close() <-chan bool
+ downloadBootSnapshot()
+ storeBootSnapshot(snapshot *common.Snapshot)
+ downloadDataSnapshot()
+ storeDataSnapshot(snapshot *common.Snapshot)
+ downloadSnapshot(scopes []string, snapshot *common.Snapshot) error
+}
+
+type changeManager interface {
+ close() <-chan bool
+ pollChangeWithBackoff()
+}
diff --git a/mock_server.go b/mock_server.go
index dacdd58..8349131 100644
--- a/mock_server.go
+++ b/mock_server.go
@@ -79,10 +79,27 @@
minDeploymentID *int64
maxDeploymentID *int64
newSnap *int32
+ authFail *int32
+}
+
+func (m *MockServer) forceAuthFail() {
+ atomic.StoreInt32(m.authFail, 1)
+}
+
+func (m *MockServer) normalAuthCheck() {
+ atomic.StoreInt32(m.authFail, 0)
+}
+
+func (m *MockServer) passAuthCheck() {
+ atomic.StoreInt32(m.authFail, 2)
}
func (m *MockServer) forceNewSnapshot() {
- atomic.SwapInt32(m.newSnap, 1)
+ atomic.StoreInt32(m.newSnap, 1)
+}
+
+func (m *MockServer) forceNoSnapshot() {
+ atomic.StoreInt32(m.newSnap, 0)
}
func (m *MockServer) lastSequenceID() string {
@@ -146,6 +163,8 @@
*m.minDeploymentID = 1
m.maxDeploymentID = new(int64)
m.newSnap = new(int32)
+ m.authFail = new(int32)
+ *m.authFail = 0
initDb("./sql/init_mock_db.sql", "./mockdb.sqlite3")
initDb("./sql/init_mock_boot_db.sql", "./mockdb_boot.sqlite3")
@@ -253,6 +272,7 @@
val := atomic.SwapInt32(m.newSnap, 0)
if val > 0 {
+ log.Debug("MockServer: force new snapshot")
w.WriteHeader(http.StatusBadRequest)
apiErr := changeServerError{
Code: "SNAPSHOT_TOO_OLD",
@@ -263,6 +283,8 @@
return
}
+ log.Debug("mock server sending change list")
+
q := req.URL.Query()
scopes := q["scope"]
@@ -303,15 +325,29 @@
// enforces handler auth
func (m *MockServer) auth(target http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
- auth := req.Header.Get("Authorization")
+ // force failing auth check
+ if atomic.LoadInt32(m.authFail) == 1 {
+ w.WriteHeader(http.StatusUnauthorized)
+ w.Write([]byte(fmt.Sprintf("Force fail: bad auth token. ")))
+ return
+ }
+
+ // force passing auth check
+ if atomic.LoadInt32(m.authFail) == 2 {
+ target(w, req)
+ return
+ }
+
+ // check auth header
+ auth := req.Header.Get("Authorization")
expectedAuth := fmt.Sprintf("Bearer %s", m.oauthToken)
if auth != expectedAuth {
- w.WriteHeader(http.StatusBadRequest)
+ w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(fmt.Sprintf("Bad auth token. Is: %s, should be: %s", auth, expectedAuth)))
- } else {
- target(w, req)
+ return
}
+ target(w, req)
}
}
diff --git a/snapshot.go b/snapshot.go
index 2389d5f..f4cb8bd 100644
--- a/snapshot.go
+++ b/snapshot.go
@@ -15,7 +15,7 @@
"time"
)
-type snapShotManager struct {
+type simpleSnapShotManager struct {
// to send quit signal to the downloading thread
quitChan chan bool
// to mark the graceful close of snapshotManager
@@ -26,10 +26,10 @@
isDownloading *int32
}
-func createSnapShotManager() *snapShotManager {
+func createSnapShotManager() *simpleSnapShotManager {
isClosedInt := int32(0)
isDownloadingInt := int32(0)
- return &snapShotManager{
+ return &simpleSnapShotManager{
quitChan: make(chan bool, 1),
finishChan: make(chan bool, 1),
isClosed: &isClosedInt,
@@ -43,7 +43,7 @@
* use <- close() for blocking close
* should only be called by pollChangeManager, because pollChangeManager is dependent on it
*/
-func (s *snapShotManager) close() <-chan bool {
+func (s *simpleSnapShotManager) close() <-chan bool {
//has been closed before
if atomic.SwapInt32(s.isClosed, 1) == int32(1) {
log.Error("snapShotManager: close() called on a closed snapShotManager!")
@@ -63,7 +63,7 @@
}
// retrieve boot information: apid_config and apid_config_scope
-func (s *snapShotManager) downloadBootSnapshot() {
+func (s *simpleSnapShotManager) downloadBootSnapshot() {
if atomic.SwapInt32(s.isDownloading, 1) == int32(1) {
log.Panic("downloadBootSnapshot: only 1 thread can download snapshot at the same time!")
}
@@ -99,12 +99,12 @@
s.storeBootSnapshot(snapshot)
}
-func (s *snapShotManager) storeBootSnapshot(snapshot *common.Snapshot) {
+func (s *simpleSnapShotManager) storeBootSnapshot(snapshot *common.Snapshot) {
processSnapshot(snapshot)
}
// use the scope IDs from the boot snapshot to get all the data associated with the scopes
-func (s *snapShotManager) downloadDataSnapshot() {
+func (s *simpleSnapShotManager) downloadDataSnapshot() {
if atomic.SwapInt32(s.isDownloading, 1) == int32(1) {
log.Panic("downloadDataSnapshot: only 1 thread can download snapshot at the same time!")
}
@@ -132,7 +132,7 @@
s.storeDataSnapshot(snapshot)
}
-func (s *snapShotManager) storeDataSnapshot(snapshot *common.Snapshot) {
+func (s *simpleSnapShotManager) storeDataSnapshot(snapshot *common.Snapshot) {
knownTables = extractTablesFromSnapshot(snapshot)
_, err := dataService.DBVersion(snapshot.SnapshotInfo)
@@ -232,7 +232,7 @@
// a blocking method
// will keep retrying with backoff until success
-func (s *snapShotManager) downloadSnapshot(scopes []string, snapshot *common.Snapshot) error {
+func (s *simpleSnapShotManager) downloadSnapshot(scopes []string, snapshot *common.Snapshot) error {
// if closed
if atomic.LoadInt32(s.isClosed) == int32(1) {
log.Warn("Trying to download snapshot with a closed snapShotManager")
diff --git a/token.go b/token.go
index 424c8d4..56d6676 100644
--- a/token.go
+++ b/token.go
@@ -3,6 +3,7 @@
import (
"bytes"
"encoding/json"
+ "errors"
"io/ioutil"
"net/http"
"net/url"
@@ -24,10 +25,10 @@
man.close()
*/
-func createTokenManager() *tokenMan {
+func createSimpleTokenManager() *simpleTokenManager {
isClosedInt := int32(0)
- t := &tokenMan{
+ t := &simpleTokenManager{
quitPollingForToken: make(chan bool, 1),
closed: make(chan bool),
getTokenChan: make(chan bool),
@@ -36,14 +37,10 @@
invalidateDone: make(chan bool),
isClosed: &isClosedInt,
}
-
- t.retrieveNewToken()
- t.refreshTimer = time.After(t.token.refreshIn())
- go t.maintainToken()
return t
}
-type tokenMan struct {
+type simpleTokenManager struct {
token *oauthToken
isClosed *int32
quitPollingForToken chan bool
@@ -55,11 +52,17 @@
invalidateDone chan bool
}
-func (t *tokenMan) getBearerToken() string {
+func (t *simpleTokenManager) start() {
+ t.retrieveNewToken()
+ t.refreshTimer = time.After(t.token.refreshIn())
+ go t.maintainToken()
+}
+
+func (t *simpleTokenManager) getBearerToken() string {
return t.getToken().AccessToken
}
-func (t *tokenMan) maintainToken() {
+func (t *simpleTokenManager) maintainToken() {
for {
select {
case <-t.closed:
@@ -80,18 +83,19 @@
}
// will block until valid
-func (t *tokenMan) invalidateToken() {
+func (t *simpleTokenManager) invalidateToken() error {
//has been closed
if atomic.LoadInt32(t.isClosed) == int32(1) {
log.Debug("TokenManager: invalidateToken() called on closed tokenManager")
- return
+ return errors.New("invalidateToken() called on closed tokenManager")
}
log.Debug("invalidating token")
t.invalidateTokenChan <- true
<-t.invalidateDone
+ return nil
}
-func (t *tokenMan) getToken() *oauthToken {
+func (t *simpleTokenManager) getToken() *oauthToken {
//has been closed
if atomic.LoadInt32(t.isClosed) == int32(1) {
log.Debug("TokenManager: getToken() called on closed tokenManager")
@@ -105,7 +109,7 @@
* blocking close() of tokenMan
*/
-func (t *tokenMan) close() {
+func (t *simpleTokenManager) close() {
//has been closed
if atomic.SwapInt32(t.isClosed, 1) == int32(1) {
log.Panic("TokenManager: close() has been called before!")
@@ -120,7 +124,7 @@
}
// don't call externally. will block until success.
-func (t *tokenMan) retrieveNewToken() {
+func (t *simpleTokenManager) retrieveNewToken() {
log.Debug("Getting OAuth token...")
uriString := config.GetString(configProxyServerBaseURI)
@@ -133,7 +137,7 @@
pollWithBackoff(t.quitPollingForToken, t.getRetrieveNewTokenClosure(uri), func(err error) { log.Errorf("Error getting new token : ", err) })
}
-func (t *tokenMan) getRetrieveNewTokenClosure(uri *url.URL) func(chan bool) error {
+func (t *simpleTokenManager) getRetrieveNewTokenClosure(uri *url.URL) func(chan bool) error {
return func(_ chan bool) error {
form := url.Values{}
form.Set("grant_type", "client_credentials")
diff --git a/token_test.go b/token_test.go
index c8ec7ba..22ff425 100644
--- a/token_test.go
+++ b/token_test.go
@@ -80,7 +80,8 @@
w.Write(body)
}))
config.Set(configProxyServerBaseURI, ts.URL)
- testedTokenManager := createTokenManager()
+ testedTokenManager := createSimpleTokenManager()
+ testedTokenManager.start()
token := testedTokenManager.getToken()
Expect(token.AccessToken).ToNot(BeEmpty())
@@ -108,7 +109,8 @@
}))
config.Set(configProxyServerBaseURI, ts.URL)
- testedTokenManager := createTokenManager()
+ testedTokenManager := createSimpleTokenManager()
+ testedTokenManager.start()
token := testedTokenManager.getToken()
Expect(token.AccessToken).ToNot(BeEmpty())
@@ -147,8 +149,8 @@
}))
config.Set(configProxyServerBaseURI, ts.URL)
- testedTokenManager := createTokenManager()
-
+ testedTokenManager := createSimpleTokenManager()
+ testedTokenManager.start()
testedTokenManager.getToken()
<-finished
@@ -188,8 +190,8 @@
}))
config.Set(configProxyServerBaseURI, ts.URL)
- testedTokenManager := createTokenManager()
-
+ testedTokenManager := createSimpleTokenManager()
+ testedTokenManager.start()
testedTokenManager.getToken()
testedTokenManager.invalidateToken()
testedTokenManager.getToken()