Merge branch 'magiconair-master'
diff --git a/jwt/claims.go b/jwt/claims.go
index cc135d8..51b491b 100644
--- a/jwt/claims.go
+++ b/jwt/claims.go
@@ -14,23 +14,19 @@
// https://tools.ietf.org/html/rfc7519#section-4.1
func (c Claims) Validate(now, expLeeway, nbfLeeway float64) error {
if exp, ok := c.Expiration(); ok {
- if !within(exp, expLeeway, now) {
+ if now > exp+expLeeway {
return ErrTokenIsExpired
}
}
if nbf, ok := c.NotBefore(); ok {
- if !within(nbf, nbfLeeway, now) {
+ if now <= nbf-nbfLeeway {
return ErrTokenNotYetValid
}
}
return nil
}
-func within(val, delta, max float64) bool {
- return val > max+delta || val > max-delta
-}
-
// Get retrieves the value corresponding with key from the Claims.
func (c Claims) Get(key string) interface{} {
if c == nil {
diff --git a/jwt/claims_test.go b/jwt/claims_test.go
index c5edd70..b274647 100644
--- a/jwt/claims_test.go
+++ b/jwt/claims_test.go
@@ -5,6 +5,7 @@
"github.com/SermoDigital/jose/crypto"
"github.com/SermoDigital/jose/jws"
+ "github.com/SermoDigital/jose/jwt"
)
func TestMultipleAudienceBug_AfterMarshal(t *testing.T) {
@@ -83,3 +84,49 @@
t.Logf("aud Value: %s", aud)
t.Logf("aud Type : %T", aud)
}
+
+func TestValidate(t *testing.T) {
+ const before, now, after, leeway float64 = 10, 20, 30, 5
+
+ exp := func(t float64) jwt.Claims {
+ return jwt.Claims{"exp": t}
+ }
+ nbf := func(t float64) jwt.Claims {
+ return jwt.Claims{"nbf": t}
+ }
+
+ var tests = []struct {
+ desc string
+ c jwt.Claims
+ now float64
+ expLeeway float64
+ nbfLeeway float64
+ err error
+ }{
+ // test for nbf < now <= exp
+ {desc: "exp == nil && nbf == nil", c: jwt.Claims{}, now: now, err: nil},
+
+ {desc: "now > exp", now: now, c: exp(before), err: jwt.ErrTokenIsExpired},
+ {desc: "now = exp", now: now, c: exp(now), err: nil},
+ {desc: "now < exp", now: now, c: exp(after), err: nil},
+
+ {desc: "nbf < now", c: nbf(before), now: now, err: nil},
+ {desc: "nbf = now", c: nbf(now), now: now, err: jwt.ErrTokenNotYetValid},
+ {desc: "nbf > now", c: nbf(after), now: now, err: jwt.ErrTokenNotYetValid},
+
+ // test for nbf-x < now <= exp+y
+ {desc: "now < exp+x", now: now + leeway - 1, expLeeway: leeway, c: exp(now), err: nil},
+ {desc: "now = exp+x", now: now + leeway, expLeeway: leeway, c: exp(now), err: nil},
+ {desc: "now > exp+x", now: now + leeway + 1, expLeeway: leeway, c: exp(now), err: jwt.ErrTokenIsExpired},
+
+ {desc: "nbf-x > now", c: nbf(now), nbfLeeway: leeway, now: now - leeway + 1, err: nil},
+ {desc: "nbf-x = now", c: nbf(now), nbfLeeway: leeway, now: now - leeway, err: jwt.ErrTokenNotYetValid},
+ {desc: "nbf-x < now", c: nbf(now), nbfLeeway: leeway, now: now - leeway - 1, err: jwt.ErrTokenNotYetValid},
+ }
+
+ for i, tt := range tests {
+ if got, want := tt.c.Validate(tt.now, tt.expLeeway, tt.nbfLeeway), tt.err; got != want {
+ t.Errorf("%d - %q: got %v want %v", i, tt.desc, got, want)
+ }
+ }
+}