Ensure token is always up to date
diff --git a/apigeeSync_suite_test.go b/apigeeSync_suite_test.go
index 4a9de78..6181ff0 100644
--- a/apigeeSync_suite_test.go
+++ b/apigeeSync_suite_test.go
@@ -1,6 +1,7 @@
package apidApigeeSync
import (
+ "github.com/apigee-labs/transicator/common"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@@ -12,7 +13,6 @@
"github.com/30x/apid-core"
"github.com/30x/apid-core/factory"
- "github.com/apigee-labs/transicator/common"
)
var (
@@ -151,7 +151,6 @@
var _ = BeforeEach(func() {
apid.Events().Close()
- token = ""
lastSequence = ""
_, err := getDB().Exec("DELETE FROM APID_CLUSTER")
diff --git a/apigee_sync.go b/apigee_sync.go
index 5928168..deead43 100644
--- a/apigee_sync.go
+++ b/apigee_sync.go
@@ -1,9 +1,7 @@
package apidApigeeSync
import (
- "bytes"
"encoding/json"
- "io/ioutil"
"net/http"
"net/url"
"path"
@@ -11,6 +9,8 @@
"sync/atomic"
+ "io/ioutil"
+
"github.com/30x/apid-core"
"github.com/apigee-labs/transicator/common"
)
@@ -23,7 +23,6 @@
var (
block string = "45"
- token string
lastSequence string
polling uint32
)
@@ -84,10 +83,6 @@
lastSequence = getLastSequence()
for {
log.Debug("polling...")
- if token == "" {
- // invalid token, loop until we get one
- getBearerToken()
- }
/* Find the scopes associated with the config id */
scopes := findScopesForId(apidInfo.ClusterID)
@@ -117,6 +112,7 @@
client := &http.Client{Timeout: httpTimeout} // must be greater than block value
req, err := http.NewRequest("GET", uri, nil)
addHeaders(req)
+
r, err := client.Do(req)
if err != nil {
log.Errorf("change agent comm error: %s", err)
@@ -127,7 +123,7 @@
log.Errorf("Get changes request failed with status code: %d", r.StatusCode)
switch r.StatusCode {
case http.StatusUnauthorized:
- token = ""
+ tokenManager.invalidateToken()
case http.StatusNotModified:
r.Body.Close()
@@ -135,9 +131,15 @@
case http.StatusBadRequest:
var apiErr apiError
- err = json.NewDecoder(r.Body).Decode(&apiErr)
+ var b []byte
+ b, err = ioutil.ReadAll(r.Body)
if err != nil {
- log.Errorf("JSON Response Data not parsable: %v", err)
+ log.Errorf("Unable to read response body: %v", err)
+ break
+ }
+ err = json.Unmarshal(b, &apiErr)
+ if err != nil {
+ log.Errorf("JSON Response Data not parsable: %s", string(b))
break
}
if apiErr.Code == "SNAPSHOT_TOO_OLD" {
@@ -196,113 +198,8 @@
}
}
-/*
- * This function will (for now) use the Access Key/Secret Key/ApidConfig Id
- * to get the bearer token, and the scopes (as comma separated scope)
- */
-func getBearerToken() {
-
- log.Info("Getting a Bearer token...")
- uriString := config.GetString(configProxyServerBaseURI)
- uri, err := url.Parse(uriString)
- if err != nil {
- log.Panicf("unable to parse uri config '%s' value: '%s': %v", configProxyServerBaseURI, uriString, err)
- }
- uri.Path = path.Join(uri.Path, "/accesstoken")
-
- retryIn := 5 * time.Millisecond
- maxBackOff := maxBackoffTimeout
- backOffFunc := createBackOff(retryIn, maxBackOff)
- first := true
-
- for {
- if first {
- first = false
- } else {
- backOffFunc()
- }
-
- token = ""
- form := url.Values{}
- form.Set("grant_type", "client_credentials")
- form.Add("client_id", config.GetString(configConsumerKey))
- form.Add("client_secret", config.GetString(configConsumerSecret))
- req, err := http.NewRequest("POST", uri.String(), bytes.NewBufferString(form.Encode()))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
- req.Header.Set("display_name", apidInfo.InstanceName)
- req.Header.Set("apid_instance_id", apidInfo.InstanceID)
- req.Header.Set("apid_cluster_Id", apidInfo.ClusterID)
- req.Header.Set("status", "ONLINE")
- req.Header.Set("plugin_details", apidPluginDetails)
-
- if newInstanceID {
- req.Header.Set("created_at_apid", time.Now().Format(time.RFC3339))
- } else {
- req.Header.Set("updated_at_apid", time.Now().Format(time.RFC3339))
- }
-
- client := &http.Client{Timeout: httpTimeout}
- resp, err := client.Do(req)
- if err != nil {
- log.Errorf("Unable to Connect to Edge Proxy Server: %v", err)
- continue
- }
-
- body, err := ioutil.ReadAll(resp.Body)
- resp.Body.Close()
- if err != nil {
- log.Errorf("Unable to read EdgeProxy Sever response: %v", err)
- continue
- }
-
- if resp.StatusCode != 200 {
- log.Errorf("Oauth Request Failed with Resp Code: %d. Body: %s", resp.StatusCode, string(body))
- continue
- }
-
- var oauthResp oauthTokenResp
- log.Debugf("Response: %s ", body)
- err = json.Unmarshal(body, &oauthResp)
- if err != nil {
- log.Error("unable to unmarshal JSON response %s: %v", string(body), err)
- continue
- }
- token = oauthResp.AccessToken
-
- if newInstanceID {
- newInstanceID = false
- updateApidInstanceInfo()
- }
-
- /*
- * This stores the bearer token for any other plugin to
- * consume.
- */
- config.Set(bearerToken, token)
-
- log.Debug("Got a new Bearer token.")
-
- return
- }
-}
-
-type oauthTokenResp struct {
- IssuedAt int64 `json:"issuedAt"`
- AppName string `json:"applicationName"`
- Scope string `json:"scope"`
- Status string `json:"status"`
- ApiProdList []string `json:"apiProductList"`
- ExpiresIn int64 `json:"expiresIn"`
- DeveloperEmail string `json:"developerEmail"`
- TokenType string `json:"tokenType"`
- ClientId string `json:"clientId"`
- AccessToken string `json:"accessToken"`
- TokenExpIn int64 `json:"refreshTokenExpiresIn"`
- RefreshCount int64 `json:"refreshCount"`
-}
-
func Redirect(req *http.Request, via []*http.Request) error {
- req.Header.Add("Authorization", "Bearer "+token)
+ req.Header.Add("Authorization", "Bearer "+tokenManager.getBearerToken())
req.Header.Add("org", apidInfo.ClusterID) // todo: this is strange.. is it needed?
return nil
}
@@ -383,10 +280,6 @@
log.Panicf("bad url value for config %s: %s", snapshotUri, err)
}
- // getBearerToken loops until good
- getBearerToken()
- // todo: this could expire... ensure it's called again as needed
-
/* Frame and send the snapshot request */
snapshotUri.Path = path.Join(snapshotUri.Path, "snapshots")
@@ -457,7 +350,7 @@
}
func addHeaders(req *http.Request) {
- req.Header.Add("Authorization", "Bearer "+token)
+ req.Header.Add("Authorization", "Bearer "+tokenManager.getBearerToken())
req.Header.Set("apid_instance_id", apidInfo.InstanceID)
req.Header.Set("apid_cluster_Id", apidInfo.ClusterID)
req.Header.Set("updated_at_apid", time.Now().Format(time.RFC3339))
diff --git a/data.go b/data.go
index 2531c8d..c532bf1 100644
--- a/data.go
+++ b/data.go
@@ -251,8 +251,8 @@
} else {
// first start - no row, generate a UUID and store it
err = nil
- newInstanceID = true
info.InstanceID = generateUUID()
+ updateApidInstanceInfo()
db.Exec("INSERT INTO APID (instance_id) VALUES (?)", info.InstanceID)
}
diff --git a/init.go b/init.go
index 6078147..8ecc940 100644
--- a/init.go
+++ b/init.go
@@ -25,7 +25,7 @@
// special value - set by ApigeeSync, not taken from configuration
configApidInstanceID = "apigeesync_apid_instance_id"
// This will not be needed once we have plugin handling tokens.
- bearerToken = "apigeesync_bearer_token"
+ configBearerToken = "apigeesync_bearer_token"
)
var (
@@ -36,6 +36,7 @@
apidInfo apidInstanceInfo
apidPluginDetails string
newInstanceID bool
+ tokenManager *tokenMan
)
type apidInstanceInfo struct {
@@ -77,6 +78,8 @@
data = services.Data()
events = services.Events()
+ tokenManager = createTokenManager()
+
/* This callback function will get called, once all the plugins are
* initialized (not just this plugin). This is needed because,
* downloadSnapshots/changes etc have to begin to be processed only
diff --git a/mock_server.go b/mock_server.go
index 6800c9e..4e0c3ca 100644
--- a/mock_server.go
+++ b/mock_server.go
@@ -38,6 +38,8 @@
product => * app_credential
*/
+const oauthExpiresIn = 2 * 60 * 1000 // 2 minutes
+
type MockParms struct {
ReliableAPI bool
ClusterID string
@@ -212,9 +214,9 @@
func (m *MockServer) registerRoutes(router apid.Router) {
- router.HandleFunc("/accesstoken", m.unreliable(m.sendToken)).Methods("POST")
- router.HandleFunc("/snapshots", m.unreliable(m.auth(m.sendSnapshot))).Methods("GET")
- router.HandleFunc("/changes", m.unreliable(m.auth(m.sendChanges))).Methods("GET")
+ router.HandleFunc("/accesstoken", m.unreliable(m.gomega(m.sendToken))).Methods("POST")
+ router.HandleFunc("/snapshots", m.unreliable(m.gomega(m.auth(m.sendSnapshot)))).Methods("GET")
+ router.HandleFunc("/changes", m.unreliable(m.gomega(m.auth(m.sendChanges)))).Methods("GET")
router.HandleFunc("/bundles/{id}", m.sendDeploymentBundle).Methods("GET")
router.HandleFunc("/analytics", m.sendAnalyticsURL).Methods("GET")
router.HandleFunc("/analytics", m.putAnalyticsData).Methods("PUT")
@@ -236,7 +238,6 @@
func (m *MockServer) sendToken(w http.ResponseWriter, req *http.Request) {
defer GinkgoRecover()
- m.registerFailHandler(w)
Expect(req.Header.Get("Content-Type")).To(Equal("application/x-www-form-urlencoded; param=value"))
@@ -263,8 +264,9 @@
Expect(err).NotTo(HaveOccurred())
m.oauthToken = generateUUID()
- res := oauthTokenResp{
+ res := oauthToken{
AccessToken: m.oauthToken,
+ ExpiresIn: oauthExpiresIn,
}
body, err := json.Marshal(res)
Expect(err).NotTo(HaveOccurred())
@@ -273,7 +275,6 @@
func (m *MockServer) sendSnapshot(w http.ResponseWriter, req *http.Request) {
defer GinkgoRecover()
- m.registerFailHandler(w)
q := req.URL.Query()
scopes := q["scope"]
@@ -306,7 +307,6 @@
func (m *MockServer) sendChanges(w http.ResponseWriter, req *http.Request) {
defer GinkgoRecover()
- m.registerFailHandler(w)
val := atomic.SwapInt32(m.newSnap, 0)
if val > 0 {
@@ -435,13 +435,28 @@
}
}
+// enables GoMega handling
+func (m *MockServer) gomega(target http.HandlerFunc) http.HandlerFunc {
+ return func(w http.ResponseWriter, req *http.Request) {
+ errors := InterceptGomegaFailures(func() {
+ target(w, req)
+ })
+ if len(errors) > 0 {
+ w.WriteHeader(http.StatusBadRequest)
+ w.Write([]byte(fmt.Sprintf("assertion errors:\n%v", errors)))
+ }
+ }
+}
+
// enforces handler auth
func (m *MockServer) auth(target http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
auth := req.Header.Get("Authorization")
- if auth != fmt.Sprintf("Bearer %s", m.oauthToken) {
+ expectedAuth := fmt.Sprintf("Bearer %s", m.oauthToken)
+ if auth != expectedAuth {
w.WriteHeader(http.StatusBadRequest)
+ w.Write([]byte(fmt.Sprintf("Bad auth token. Is: %s, should be: %s", auth, expectedAuth)))
} else {
target(w, req)
}
@@ -465,14 +480,6 @@
}
}
-func (m *MockServer) registerFailHandler(w http.ResponseWriter) {
- RegisterFailHandler(func(message string, callerSkip ...int) {
- w.WriteHeader(400)
- w.Write([]byte(message))
- panic(message)
- })
-}
-
func (m *MockServer) newRow(keyAndVals map[string]string) (row common.Row) {
row = common.Row{}
diff --git a/token.go b/token.go
new file mode 100644
index 0000000..cca8963
--- /dev/null
+++ b/token.go
@@ -0,0 +1,210 @@
+package apidApigeeSync
+
+import (
+ "bytes"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "path"
+ "sync"
+ "time"
+)
+
+var (
+ refreshFloatTime = time.Minute
+ getTokenLock sync.Mutex
+)
+
+/*
+Usage:
+ man := createTokenManager()
+ bearer := man.getBearerToken()
+ // will automatically update config(configBearerToken) for other modules
+ // optionally, when done...
+ man.close()
+*/
+
+func createTokenManager() *tokenMan {
+ t := &tokenMan{}
+ t.doRefresh = make(chan bool, 1)
+ t.maintainToken()
+ return t
+}
+
+type tokenMan struct {
+ token *oauthToken
+ doRefresh chan bool
+}
+
+func (t *tokenMan) getBearerToken() string {
+ return t.getToken().AccessToken
+}
+
+func (t *tokenMan) maintainToken() {
+ go func() {
+ for {
+ if t.token.needsRefresh() {
+ getTokenLock.Lock()
+ t.retrieveNewToken()
+ getTokenLock.Unlock()
+ }
+ select {
+ case _, ok := <-t.doRefresh:
+ if !ok {
+ log.Debug("closed tokenMan")
+ return
+ }
+ log.Debug("force token refresh")
+ case <-time.After(t.token.refreshIn()):
+ log.Debug("auto refresh token")
+ }
+ }
+ }()
+}
+
+func (t *tokenMan) invalidateToken() {
+ log.Debug("invalidating token")
+ t.token = nil
+ t.doRefresh <- true
+}
+
+// will block until valid
+func (t *tokenMan) getToken() *oauthToken {
+ getTokenLock.Lock()
+ defer getTokenLock.Unlock()
+
+ if t.token.isValid() {
+ log.Debugf("returning existing token: %v", t.token)
+ return t.token
+ }
+
+ t.retrieveNewToken()
+ return t.token
+}
+
+func (t *tokenMan) close() {
+ log.Debug("close token manager")
+ close(t.doRefresh)
+}
+
+// don't call externally. will block until success.
+func (t *tokenMan) retrieveNewToken() {
+
+ log.Debug("Getting OAuth token...")
+ uriString := config.GetString(configProxyServerBaseURI)
+ uri, err := url.Parse(uriString)
+ if err != nil {
+ log.Panicf("unable to parse uri config '%s' value: '%s': %v", configProxyServerBaseURI, uriString, err)
+ }
+ uri.Path = path.Join(uri.Path, "/accesstoken")
+
+ retryIn := 5 * time.Millisecond
+ maxBackOff := maxBackoffTimeout
+ backOffFunc := createBackOff(retryIn, maxBackOff)
+ first := true
+
+ for {
+ if first {
+ first = false
+ } else {
+ backOffFunc()
+ }
+
+ form := url.Values{}
+ form.Set("grant_type", "client_credentials")
+ form.Add("client_id", config.GetString(configConsumerKey))
+ form.Add("client_secret", config.GetString(configConsumerSecret))
+ req, err := http.NewRequest("POST", uri.String(), bytes.NewBufferString(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+ req.Header.Set("display_name", apidInfo.InstanceName)
+ req.Header.Set("apid_instance_id", apidInfo.InstanceID)
+ req.Header.Set("apid_cluster_Id", apidInfo.ClusterID)
+ req.Header.Set("status", "ONLINE")
+ req.Header.Set("plugin_details", apidPluginDetails)
+
+ if newInstanceID {
+ req.Header.Set("created_at_apid", time.Now().Format(time.RFC3339))
+ } else {
+ req.Header.Set("updated_at_apid", time.Now().Format(time.RFC3339))
+ }
+
+ client := &http.Client{Timeout: httpTimeout}
+ resp, err := client.Do(req)
+ if err != nil {
+ log.Errorf("Unable to Connect to Edge Proxy Server: %v", err)
+ continue
+ }
+
+ body, err := ioutil.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ log.Errorf("Unable to read EdgeProxy Sever response: %v", err)
+ continue
+ }
+
+ if resp.StatusCode != 200 {
+ log.Errorf("Oauth Request Failed with Resp Code: %d. Body: %s", resp.StatusCode, string(body))
+ continue
+ }
+
+ var token oauthToken
+ err = json.Unmarshal(body, &token)
+ if err != nil {
+ log.Error("unable to unmarshal JSON response %s: %v", string(body), err)
+ continue
+ }
+
+ if token.ExpiresIn >= 0 {
+ token.ExpiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Millisecond)
+ } else {
+ // no expiration, arbitrarily expire about a year from now
+ token.ExpiresAt = time.Now().Add(365 * 24 * time.Hour)
+ }
+
+ log.Debugf("Got new token: %#v", token)
+
+ t.token = &token
+ config.Set(configBearerToken, token.AccessToken)
+ return
+ }
+}
+
+type oauthToken struct {
+ IssuedAt int64 `json:"issuedAt"`
+ AppName string `json:"applicationName"`
+ Scope string `json:"scope"`
+ Status string `json:"status"`
+ ApiProdList []string `json:"apiProductList"`
+ ExpiresIn int64 `json:"expiresIn"`
+ DeveloperEmail string `json:"developerEmail"`
+ TokenType string `json:"tokenType"`
+ ClientId string `json:"clientId"`
+ AccessToken string `json:"accessToken"`
+ RefreshExpIn int64 `json:"refreshTokenExpiresIn"`
+ RefreshCount int64 `json:"refreshCount"`
+ ExpiresAt time.Time
+}
+
+var noTime time.Time
+
+func (t *oauthToken) isValid() bool {
+ if t == nil || t.AccessToken == "" {
+ return false
+ }
+ return t.AccessToken != "" && time.Now().Before(t.ExpiresAt)
+}
+
+func (t *oauthToken) refreshIn() time.Duration {
+ if t == nil || t.ExpiresAt == noTime {
+ return time.Duration(0)
+ }
+ return t.ExpiresAt.Sub(time.Now()) - refreshFloatTime
+}
+
+func (t *oauthToken) needsRefresh() bool {
+ if t == nil || t.ExpiresAt == noTime {
+ return true
+ }
+ return time.Now().Add(refreshFloatTime).After(t.ExpiresAt)
+}
diff --git a/token_test.go b/token_test.go
new file mode 100644
index 0000000..24bd13a
--- /dev/null
+++ b/token_test.go
@@ -0,0 +1,129 @@
+package apidApigeeSync
+
+import (
+ "time"
+
+ "net/http"
+ "net/http/httptest"
+
+ "encoding/json"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("token", func() {
+
+ Context("oauthToken", func() {
+
+ It("should calculate valid token", func() {
+ t := &oauthToken{
+ AccessToken: "x",
+ ExpiresIn: 120000,
+ ExpiresAt: time.Now().Add(2 * time.Minute),
+ }
+ Expect(t.refreshIn().Seconds()).To(BeNumerically(">", 0))
+ Expect(t.needsRefresh()).To(BeFalse())
+ Expect(t.isValid()).To(BeTrue())
+ })
+
+ It("should calculate expired token", func() {
+ t := &oauthToken{
+ AccessToken: "x",
+ ExpiresIn: 0,
+ ExpiresAt: time.Now(),
+ }
+ Expect(t.refreshIn().Seconds()).To(BeNumerically("<", 0))
+ Expect(t.needsRefresh()).To(BeTrue())
+ Expect(t.isValid()).To(BeFalse())
+ })
+
+ It("should calculate token needing refresh", func() {
+ t := &oauthToken{
+ AccessToken: "x",
+ ExpiresIn: 59000,
+ ExpiresAt: time.Now().Add(time.Minute - time.Second),
+ }
+ Expect(t.refreshIn().Seconds()).To(BeNumerically("<", 0))
+ Expect(t.needsRefresh()).To(BeTrue())
+ Expect(t.isValid()).To(BeTrue())
+ })
+
+ It("should calculate on empty token", func() {
+ t := &oauthToken{}
+ Expect(t.refreshIn().Seconds()).To(BeNumerically("<=", 0))
+ Expect(t.needsRefresh()).To(BeTrue())
+ Expect(t.isValid()).To(BeFalse())
+ })
+ })
+
+ Context("tokenMan", func() {
+
+ It("should get a valid token", func() {
+ token := tokenManager.getToken()
+
+ Expect(token.AccessToken).ToNot(BeEmpty())
+ Expect(token.ExpiresIn > 0).To(BeTrue())
+ Expect(token.ExpiresAt).To(BeTemporally(">", time.Now()))
+
+ bToken := tokenManager.getBearerToken()
+ Expect(bToken).To(Equal(token.AccessToken))
+ })
+
+ It("should refresh when forced to", func() {
+ token := tokenManager.getToken()
+ Expect(token.AccessToken).ToNot(BeEmpty())
+
+ tokenManager.invalidateToken()
+
+ token2 := tokenManager.getToken()
+ Expect(token).ToNot(Equal(token2))
+ Expect(token.AccessToken).ToNot(Equal(token2.AccessToken))
+ })
+
+ It("should refresh in refresh interval", func(done Done) {
+
+ finished := make(chan bool)
+ var tm *tokenMan
+ start := time.Now()
+ count := 0
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer GinkgoRecover()
+
+ count++
+
+ if count > 1 {
+ if start.Add(500).After(time.Now()) {
+ Fail("didn't refresh within expected interval")
+ }
+ finished <- true
+ }
+
+ res := oauthToken{
+ AccessToken: string(count),
+ ExpiresIn: 1000,
+ }
+ body, err := json.Marshal(res)
+ Expect(err).NotTo(HaveOccurred())
+ w.Write(body)
+ }))
+ defer ts.Close()
+
+ tokenManager.close()
+ oldBase := config.Get(configProxyServerBaseURI)
+ config.Set(configProxyServerBaseURI, ts.URL)
+ oldFloat := refreshFloatTime
+ refreshFloatTime = 950 * time.Millisecond
+ defer func() {
+ tm.close()
+ config.Set(configProxyServerBaseURI, oldBase)
+ tokenManager = createTokenManager()
+ refreshFloatTime = oldFloat
+ }()
+
+ tm = createTokenManager()
+ <-finished
+ close(done)
+ })
+ })
+})