Ensure token is always up to date
diff --git a/apigeeSync_suite_test.go b/apigeeSync_suite_test.go index 4a9de78..6181ff0 100644 --- a/apigeeSync_suite_test.go +++ b/apigeeSync_suite_test.go
@@ -1,6 +1,7 @@ package apidApigeeSync import ( + "github.com/apigee-labs/transicator/common" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -12,7 +13,6 @@ "github.com/30x/apid-core" "github.com/30x/apid-core/factory" - "github.com/apigee-labs/transicator/common" ) var ( @@ -151,7 +151,6 @@ var _ = BeforeEach(func() { apid.Events().Close() - token = "" lastSequence = "" _, err := getDB().Exec("DELETE FROM APID_CLUSTER")
diff --git a/apigee_sync.go b/apigee_sync.go index 5928168..deead43 100644 --- a/apigee_sync.go +++ b/apigee_sync.go
@@ -1,9 +1,7 @@ package apidApigeeSync import ( - "bytes" "encoding/json" - "io/ioutil" "net/http" "net/url" "path" @@ -11,6 +9,8 @@ "sync/atomic" + "io/ioutil" + "github.com/30x/apid-core" "github.com/apigee-labs/transicator/common" ) @@ -23,7 +23,6 @@ var ( block string = "45" - token string lastSequence string polling uint32 ) @@ -84,10 +83,6 @@ lastSequence = getLastSequence() for { log.Debug("polling...") - if token == "" { - // invalid token, loop until we get one - getBearerToken() - } /* Find the scopes associated with the config id */ scopes := findScopesForId(apidInfo.ClusterID) @@ -117,6 +112,7 @@ client := &http.Client{Timeout: httpTimeout} // must be greater than block value req, err := http.NewRequest("GET", uri, nil) addHeaders(req) + r, err := client.Do(req) if err != nil { log.Errorf("change agent comm error: %s", err) @@ -127,7 +123,7 @@ log.Errorf("Get changes request failed with status code: %d", r.StatusCode) switch r.StatusCode { case http.StatusUnauthorized: - token = "" + tokenManager.invalidateToken() case http.StatusNotModified: r.Body.Close() @@ -135,9 +131,15 @@ case http.StatusBadRequest: var apiErr apiError - err = json.NewDecoder(r.Body).Decode(&apiErr) + var b []byte + b, err = ioutil.ReadAll(r.Body) if err != nil { - log.Errorf("JSON Response Data not parsable: %v", err) + log.Errorf("Unable to read response body: %v", err) + break + } + err = json.Unmarshal(b, &apiErr) + if err != nil { + log.Errorf("JSON Response Data not parsable: %s", string(b)) break } if apiErr.Code == "SNAPSHOT_TOO_OLD" { @@ -196,113 +198,8 @@ } } -/* - * This function will (for now) use the Access Key/Secret Key/ApidConfig Id - * to get the bearer token, and the scopes (as comma separated scope) - */ -func getBearerToken() { - - log.Info("Getting a Bearer token...") - uriString := config.GetString(configProxyServerBaseURI) - uri, err := url.Parse(uriString) - if err != nil { - log.Panicf("unable to parse uri config '%s' value: '%s': %v", configProxyServerBaseURI, uriString, err) - } - uri.Path = path.Join(uri.Path, "/accesstoken") - - retryIn := 5 * time.Millisecond - maxBackOff := maxBackoffTimeout - backOffFunc := createBackOff(retryIn, maxBackOff) - first := true - - for { - if first { - first = false - } else { - backOffFunc() - } - - token = "" - form := url.Values{} - form.Set("grant_type", "client_credentials") - form.Add("client_id", config.GetString(configConsumerKey)) - form.Add("client_secret", config.GetString(configConsumerSecret)) - req, err := http.NewRequest("POST", uri.String(), bytes.NewBufferString(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") - req.Header.Set("display_name", apidInfo.InstanceName) - req.Header.Set("apid_instance_id", apidInfo.InstanceID) - req.Header.Set("apid_cluster_Id", apidInfo.ClusterID) - req.Header.Set("status", "ONLINE") - req.Header.Set("plugin_details", apidPluginDetails) - - if newInstanceID { - req.Header.Set("created_at_apid", time.Now().Format(time.RFC3339)) - } else { - req.Header.Set("updated_at_apid", time.Now().Format(time.RFC3339)) - } - - client := &http.Client{Timeout: httpTimeout} - resp, err := client.Do(req) - if err != nil { - log.Errorf("Unable to Connect to Edge Proxy Server: %v", err) - continue - } - - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - log.Errorf("Unable to read EdgeProxy Sever response: %v", err) - continue - } - - if resp.StatusCode != 200 { - log.Errorf("Oauth Request Failed with Resp Code: %d. Body: %s", resp.StatusCode, string(body)) - continue - } - - var oauthResp oauthTokenResp - log.Debugf("Response: %s ", body) - err = json.Unmarshal(body, &oauthResp) - if err != nil { - log.Error("unable to unmarshal JSON response %s: %v", string(body), err) - continue - } - token = oauthResp.AccessToken - - if newInstanceID { - newInstanceID = false - updateApidInstanceInfo() - } - - /* - * This stores the bearer token for any other plugin to - * consume. - */ - config.Set(bearerToken, token) - - log.Debug("Got a new Bearer token.") - - return - } -} - -type oauthTokenResp struct { - IssuedAt int64 `json:"issuedAt"` - AppName string `json:"applicationName"` - Scope string `json:"scope"` - Status string `json:"status"` - ApiProdList []string `json:"apiProductList"` - ExpiresIn int64 `json:"expiresIn"` - DeveloperEmail string `json:"developerEmail"` - TokenType string `json:"tokenType"` - ClientId string `json:"clientId"` - AccessToken string `json:"accessToken"` - TokenExpIn int64 `json:"refreshTokenExpiresIn"` - RefreshCount int64 `json:"refreshCount"` -} - func Redirect(req *http.Request, via []*http.Request) error { - req.Header.Add("Authorization", "Bearer "+token) + req.Header.Add("Authorization", "Bearer "+tokenManager.getBearerToken()) req.Header.Add("org", apidInfo.ClusterID) // todo: this is strange.. is it needed? return nil } @@ -383,10 +280,6 @@ log.Panicf("bad url value for config %s: %s", snapshotUri, err) } - // getBearerToken loops until good - getBearerToken() - // todo: this could expire... ensure it's called again as needed - /* Frame and send the snapshot request */ snapshotUri.Path = path.Join(snapshotUri.Path, "snapshots") @@ -457,7 +350,7 @@ } func addHeaders(req *http.Request) { - req.Header.Add("Authorization", "Bearer "+token) + req.Header.Add("Authorization", "Bearer "+tokenManager.getBearerToken()) req.Header.Set("apid_instance_id", apidInfo.InstanceID) req.Header.Set("apid_cluster_Id", apidInfo.ClusterID) req.Header.Set("updated_at_apid", time.Now().Format(time.RFC3339))
diff --git a/data.go b/data.go index 2531c8d..c532bf1 100644 --- a/data.go +++ b/data.go
@@ -251,8 +251,8 @@ } else { // first start - no row, generate a UUID and store it err = nil - newInstanceID = true info.InstanceID = generateUUID() + updateApidInstanceInfo() db.Exec("INSERT INTO APID (instance_id) VALUES (?)", info.InstanceID) }
diff --git a/init.go b/init.go index 6078147..8ecc940 100644 --- a/init.go +++ b/init.go
@@ -25,7 +25,7 @@ // special value - set by ApigeeSync, not taken from configuration configApidInstanceID = "apigeesync_apid_instance_id" // This will not be needed once we have plugin handling tokens. - bearerToken = "apigeesync_bearer_token" + configBearerToken = "apigeesync_bearer_token" ) var ( @@ -36,6 +36,7 @@ apidInfo apidInstanceInfo apidPluginDetails string newInstanceID bool + tokenManager *tokenMan ) type apidInstanceInfo struct { @@ -77,6 +78,8 @@ data = services.Data() events = services.Events() + tokenManager = createTokenManager() + /* This callback function will get called, once all the plugins are * initialized (not just this plugin). This is needed because, * downloadSnapshots/changes etc have to begin to be processed only
diff --git a/mock_server.go b/mock_server.go index 6800c9e..4e0c3ca 100644 --- a/mock_server.go +++ b/mock_server.go
@@ -38,6 +38,8 @@ product => * app_credential */ +const oauthExpiresIn = 2 * 60 * 1000 // 2 minutes + type MockParms struct { ReliableAPI bool ClusterID string @@ -212,9 +214,9 @@ func (m *MockServer) registerRoutes(router apid.Router) { - router.HandleFunc("/accesstoken", m.unreliable(m.sendToken)).Methods("POST") - router.HandleFunc("/snapshots", m.unreliable(m.auth(m.sendSnapshot))).Methods("GET") - router.HandleFunc("/changes", m.unreliable(m.auth(m.sendChanges))).Methods("GET") + router.HandleFunc("/accesstoken", m.unreliable(m.gomega(m.sendToken))).Methods("POST") + router.HandleFunc("/snapshots", m.unreliable(m.gomega(m.auth(m.sendSnapshot)))).Methods("GET") + router.HandleFunc("/changes", m.unreliable(m.gomega(m.auth(m.sendChanges)))).Methods("GET") router.HandleFunc("/bundles/{id}", m.sendDeploymentBundle).Methods("GET") router.HandleFunc("/analytics", m.sendAnalyticsURL).Methods("GET") router.HandleFunc("/analytics", m.putAnalyticsData).Methods("PUT") @@ -236,7 +238,6 @@ func (m *MockServer) sendToken(w http.ResponseWriter, req *http.Request) { defer GinkgoRecover() - m.registerFailHandler(w) Expect(req.Header.Get("Content-Type")).To(Equal("application/x-www-form-urlencoded; param=value")) @@ -263,8 +264,9 @@ Expect(err).NotTo(HaveOccurred()) m.oauthToken = generateUUID() - res := oauthTokenResp{ + res := oauthToken{ AccessToken: m.oauthToken, + ExpiresIn: oauthExpiresIn, } body, err := json.Marshal(res) Expect(err).NotTo(HaveOccurred()) @@ -273,7 +275,6 @@ func (m *MockServer) sendSnapshot(w http.ResponseWriter, req *http.Request) { defer GinkgoRecover() - m.registerFailHandler(w) q := req.URL.Query() scopes := q["scope"] @@ -306,7 +307,6 @@ func (m *MockServer) sendChanges(w http.ResponseWriter, req *http.Request) { defer GinkgoRecover() - m.registerFailHandler(w) val := atomic.SwapInt32(m.newSnap, 0) if val > 0 { @@ -435,13 +435,28 @@ } } +// enables GoMega handling +func (m *MockServer) gomega(target http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + errors := InterceptGomegaFailures(func() { + target(w, req) + }) + if len(errors) > 0 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("assertion errors:\n%v", errors))) + } + } +} + // enforces handler auth func (m *MockServer) auth(target http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { auth := req.Header.Get("Authorization") - if auth != fmt.Sprintf("Bearer %s", m.oauthToken) { + expectedAuth := fmt.Sprintf("Bearer %s", m.oauthToken) + if auth != expectedAuth { w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("Bad auth token. Is: %s, should be: %s", auth, expectedAuth))) } else { target(w, req) } @@ -465,14 +480,6 @@ } } -func (m *MockServer) registerFailHandler(w http.ResponseWriter) { - RegisterFailHandler(func(message string, callerSkip ...int) { - w.WriteHeader(400) - w.Write([]byte(message)) - panic(message) - }) -} - func (m *MockServer) newRow(keyAndVals map[string]string) (row common.Row) { row = common.Row{}
diff --git a/token.go b/token.go new file mode 100644 index 0000000..cca8963 --- /dev/null +++ b/token.go
@@ -0,0 +1,210 @@ +package apidApigeeSync + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "net/url" + "path" + "sync" + "time" +) + +var ( + refreshFloatTime = time.Minute + getTokenLock sync.Mutex +) + +/* +Usage: + man := createTokenManager() + bearer := man.getBearerToken() + // will automatically update config(configBearerToken) for other modules + // optionally, when done... + man.close() +*/ + +func createTokenManager() *tokenMan { + t := &tokenMan{} + t.doRefresh = make(chan bool, 1) + t.maintainToken() + return t +} + +type tokenMan struct { + token *oauthToken + doRefresh chan bool +} + +func (t *tokenMan) getBearerToken() string { + return t.getToken().AccessToken +} + +func (t *tokenMan) maintainToken() { + go func() { + for { + if t.token.needsRefresh() { + getTokenLock.Lock() + t.retrieveNewToken() + getTokenLock.Unlock() + } + select { + case _, ok := <-t.doRefresh: + if !ok { + log.Debug("closed tokenMan") + return + } + log.Debug("force token refresh") + case <-time.After(t.token.refreshIn()): + log.Debug("auto refresh token") + } + } + }() +} + +func (t *tokenMan) invalidateToken() { + log.Debug("invalidating token") + t.token = nil + t.doRefresh <- true +} + +// will block until valid +func (t *tokenMan) getToken() *oauthToken { + getTokenLock.Lock() + defer getTokenLock.Unlock() + + if t.token.isValid() { + log.Debugf("returning existing token: %v", t.token) + return t.token + } + + t.retrieveNewToken() + return t.token +} + +func (t *tokenMan) close() { + log.Debug("close token manager") + close(t.doRefresh) +} + +// don't call externally. will block until success. +func (t *tokenMan) retrieveNewToken() { + + log.Debug("Getting OAuth token...") + uriString := config.GetString(configProxyServerBaseURI) + uri, err := url.Parse(uriString) + if err != nil { + log.Panicf("unable to parse uri config '%s' value: '%s': %v", configProxyServerBaseURI, uriString, err) + } + uri.Path = path.Join(uri.Path, "/accesstoken") + + retryIn := 5 * time.Millisecond + maxBackOff := maxBackoffTimeout + backOffFunc := createBackOff(retryIn, maxBackOff) + first := true + + for { + if first { + first = false + } else { + backOffFunc() + } + + form := url.Values{} + form.Set("grant_type", "client_credentials") + form.Add("client_id", config.GetString(configConsumerKey)) + form.Add("client_secret", config.GetString(configConsumerSecret)) + req, err := http.NewRequest("POST", uri.String(), bytes.NewBufferString(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + req.Header.Set("display_name", apidInfo.InstanceName) + req.Header.Set("apid_instance_id", apidInfo.InstanceID) + req.Header.Set("apid_cluster_Id", apidInfo.ClusterID) + req.Header.Set("status", "ONLINE") + req.Header.Set("plugin_details", apidPluginDetails) + + if newInstanceID { + req.Header.Set("created_at_apid", time.Now().Format(time.RFC3339)) + } else { + req.Header.Set("updated_at_apid", time.Now().Format(time.RFC3339)) + } + + client := &http.Client{Timeout: httpTimeout} + resp, err := client.Do(req) + if err != nil { + log.Errorf("Unable to Connect to Edge Proxy Server: %v", err) + continue + } + + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + log.Errorf("Unable to read EdgeProxy Sever response: %v", err) + continue + } + + if resp.StatusCode != 200 { + log.Errorf("Oauth Request Failed with Resp Code: %d. Body: %s", resp.StatusCode, string(body)) + continue + } + + var token oauthToken + err = json.Unmarshal(body, &token) + if err != nil { + log.Error("unable to unmarshal JSON response %s: %v", string(body), err) + continue + } + + if token.ExpiresIn >= 0 { + token.ExpiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Millisecond) + } else { + // no expiration, arbitrarily expire about a year from now + token.ExpiresAt = time.Now().Add(365 * 24 * time.Hour) + } + + log.Debugf("Got new token: %#v", token) + + t.token = &token + config.Set(configBearerToken, token.AccessToken) + return + } +} + +type oauthToken struct { + IssuedAt int64 `json:"issuedAt"` + AppName string `json:"applicationName"` + Scope string `json:"scope"` + Status string `json:"status"` + ApiProdList []string `json:"apiProductList"` + ExpiresIn int64 `json:"expiresIn"` + DeveloperEmail string `json:"developerEmail"` + TokenType string `json:"tokenType"` + ClientId string `json:"clientId"` + AccessToken string `json:"accessToken"` + RefreshExpIn int64 `json:"refreshTokenExpiresIn"` + RefreshCount int64 `json:"refreshCount"` + ExpiresAt time.Time +} + +var noTime time.Time + +func (t *oauthToken) isValid() bool { + if t == nil || t.AccessToken == "" { + return false + } + return t.AccessToken != "" && time.Now().Before(t.ExpiresAt) +} + +func (t *oauthToken) refreshIn() time.Duration { + if t == nil || t.ExpiresAt == noTime { + return time.Duration(0) + } + return t.ExpiresAt.Sub(time.Now()) - refreshFloatTime +} + +func (t *oauthToken) needsRefresh() bool { + if t == nil || t.ExpiresAt == noTime { + return true + } + return time.Now().Add(refreshFloatTime).After(t.ExpiresAt) +}
diff --git a/token_test.go b/token_test.go new file mode 100644 index 0000000..24bd13a --- /dev/null +++ b/token_test.go
@@ -0,0 +1,129 @@ +package apidApigeeSync + +import ( + "time" + + "net/http" + "net/http/httptest" + + "encoding/json" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("token", func() { + + Context("oauthToken", func() { + + It("should calculate valid token", func() { + t := &oauthToken{ + AccessToken: "x", + ExpiresIn: 120000, + ExpiresAt: time.Now().Add(2 * time.Minute), + } + Expect(t.refreshIn().Seconds()).To(BeNumerically(">", 0)) + Expect(t.needsRefresh()).To(BeFalse()) + Expect(t.isValid()).To(BeTrue()) + }) + + It("should calculate expired token", func() { + t := &oauthToken{ + AccessToken: "x", + ExpiresIn: 0, + ExpiresAt: time.Now(), + } + Expect(t.refreshIn().Seconds()).To(BeNumerically("<", 0)) + Expect(t.needsRefresh()).To(BeTrue()) + Expect(t.isValid()).To(BeFalse()) + }) + + It("should calculate token needing refresh", func() { + t := &oauthToken{ + AccessToken: "x", + ExpiresIn: 59000, + ExpiresAt: time.Now().Add(time.Minute - time.Second), + } + Expect(t.refreshIn().Seconds()).To(BeNumerically("<", 0)) + Expect(t.needsRefresh()).To(BeTrue()) + Expect(t.isValid()).To(BeTrue()) + }) + + It("should calculate on empty token", func() { + t := &oauthToken{} + Expect(t.refreshIn().Seconds()).To(BeNumerically("<=", 0)) + Expect(t.needsRefresh()).To(BeTrue()) + Expect(t.isValid()).To(BeFalse()) + }) + }) + + Context("tokenMan", func() { + + It("should get a valid token", func() { + token := tokenManager.getToken() + + Expect(token.AccessToken).ToNot(BeEmpty()) + Expect(token.ExpiresIn > 0).To(BeTrue()) + Expect(token.ExpiresAt).To(BeTemporally(">", time.Now())) + + bToken := tokenManager.getBearerToken() + Expect(bToken).To(Equal(token.AccessToken)) + }) + + It("should refresh when forced to", func() { + token := tokenManager.getToken() + Expect(token.AccessToken).ToNot(BeEmpty()) + + tokenManager.invalidateToken() + + token2 := tokenManager.getToken() + Expect(token).ToNot(Equal(token2)) + Expect(token.AccessToken).ToNot(Equal(token2.AccessToken)) + }) + + It("should refresh in refresh interval", func(done Done) { + + finished := make(chan bool) + var tm *tokenMan + start := time.Now() + count := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + + count++ + + if count > 1 { + if start.Add(500).After(time.Now()) { + Fail("didn't refresh within expected interval") + } + finished <- true + } + + res := oauthToken{ + AccessToken: string(count), + ExpiresIn: 1000, + } + body, err := json.Marshal(res) + Expect(err).NotTo(HaveOccurred()) + w.Write(body) + })) + defer ts.Close() + + tokenManager.close() + oldBase := config.Get(configProxyServerBaseURI) + config.Set(configProxyServerBaseURI, ts.URL) + oldFloat := refreshFloatTime + refreshFloatTime = 950 * time.Millisecond + defer func() { + tm.close() + config.Set(configProxyServerBaseURI, oldBase) + tokenManager = createTokenManager() + refreshFloatTime = oldFloat + }() + + tm = createTokenManager() + <-finished + close(done) + }) + }) +})