Fix Claims.Validate The `within()` function is wrong since it is supposed to check `max-delta < val < max+delta` but not `val == max`. It also does not work for the `nbf` check since that needs to verify that `nbf-leeway < now` Also, if I read RFC 7519 correctly then the token should not be used after `exp` and not before `nbf`. So `Validate` should check for `nbf < now <= exp` or more generally `nbf-x < now <= exp+y`
diff --git a/jwt/claims.go b/jwt/claims.go index cc135d8..51b491b 100644 --- a/jwt/claims.go +++ b/jwt/claims.go
@@ -14,23 +14,19 @@ // https://tools.ietf.org/html/rfc7519#section-4.1 func (c Claims) Validate(now, expLeeway, nbfLeeway float64) error { if exp, ok := c.Expiration(); ok { - if !within(exp, expLeeway, now) { + if now > exp+expLeeway { return ErrTokenIsExpired } } if nbf, ok := c.NotBefore(); ok { - if !within(nbf, nbfLeeway, now) { + if now <= nbf-nbfLeeway { return ErrTokenNotYetValid } } return nil } -func within(val, delta, max float64) bool { - return val > max+delta || val > max-delta -} - // Get retrieves the value corresponding with key from the Claims. func (c Claims) Get(key string) interface{} { if c == nil {
diff --git a/jwt/claims_test.go b/jwt/claims_test.go index c5edd70..b274647 100644 --- a/jwt/claims_test.go +++ b/jwt/claims_test.go
@@ -5,6 +5,7 @@ "github.com/SermoDigital/jose/crypto" "github.com/SermoDigital/jose/jws" + "github.com/SermoDigital/jose/jwt" ) func TestMultipleAudienceBug_AfterMarshal(t *testing.T) { @@ -83,3 +84,49 @@ t.Logf("aud Value: %s", aud) t.Logf("aud Type : %T", aud) } + +func TestValidate(t *testing.T) { + const before, now, after, leeway float64 = 10, 20, 30, 5 + + exp := func(t float64) jwt.Claims { + return jwt.Claims{"exp": t} + } + nbf := func(t float64) jwt.Claims { + return jwt.Claims{"nbf": t} + } + + var tests = []struct { + desc string + c jwt.Claims + now float64 + expLeeway float64 + nbfLeeway float64 + err error + }{ + // test for nbf < now <= exp + {desc: "exp == nil && nbf == nil", c: jwt.Claims{}, now: now, err: nil}, + + {desc: "now > exp", now: now, c: exp(before), err: jwt.ErrTokenIsExpired}, + {desc: "now = exp", now: now, c: exp(now), err: nil}, + {desc: "now < exp", now: now, c: exp(after), err: nil}, + + {desc: "nbf < now", c: nbf(before), now: now, err: nil}, + {desc: "nbf = now", c: nbf(now), now: now, err: jwt.ErrTokenNotYetValid}, + {desc: "nbf > now", c: nbf(after), now: now, err: jwt.ErrTokenNotYetValid}, + + // test for nbf-x < now <= exp+y + {desc: "now < exp+x", now: now + leeway - 1, expLeeway: leeway, c: exp(now), err: nil}, + {desc: "now = exp+x", now: now + leeway, expLeeway: leeway, c: exp(now), err: nil}, + {desc: "now > exp+x", now: now + leeway + 1, expLeeway: leeway, c: exp(now), err: jwt.ErrTokenIsExpired}, + + {desc: "nbf-x > now", c: nbf(now), nbfLeeway: leeway, now: now - leeway + 1, err: nil}, + {desc: "nbf-x = now", c: nbf(now), nbfLeeway: leeway, now: now - leeway, err: jwt.ErrTokenNotYetValid}, + {desc: "nbf-x < now", c: nbf(now), nbfLeeway: leeway, now: now - leeway - 1, err: jwt.ErrTokenNotYetValid}, + } + + for i, tt := range tests { + if got, want := tt.c.Validate(tt.now, tt.expLeeway, tt.nbfLeeway), tt.err; got != want { + t.Errorf("%d - %q: got %v want %v", i, tt.desc, got, want) + } + } +}