Merge branch 'magiconair-use-time' into v1
diff --git a/jws/claims.go b/jws/claims.go index 068caa8..4cc616c 100644 --- a/jws/claims.go +++ b/jws/claims.go
@@ -2,6 +2,7 @@ import ( "encoding/json" + "time" "github.com/SermoDigital/jose" "github.com/SermoDigital/jose/jwt" @@ -84,19 +85,19 @@ // Expiration retrieves claim "exp" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.4 -func (c Claims) Expiration() (float64, bool) { +func (c Claims) Expiration() (time.Time, bool) { return jwt.Claims(c).Expiration() } // NotBefore retrieves claim "nbf" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.5 -func (c Claims) NotBefore() (float64, bool) { +func (c Claims) NotBefore() (time.Time, bool) { return jwt.Claims(c).NotBefore() } // IssuedAt retrieves claim "iat" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.6 -func (c Claims) IssuedAt() (float64, bool) { +func (c Claims) IssuedAt() (time.Time, bool) { return jwt.Claims(c).IssuedAt() } @@ -161,19 +162,19 @@ // SetExpiration sets claim "exp" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.4 -func (c Claims) SetExpiration(expiration float64) { +func (c Claims) SetExpiration(expiration time.Time) { jwt.Claims(c).SetExpiration(expiration) } // SetNotBefore sets claim "nbf" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.5 -func (c Claims) SetNotBefore(notBefore float64) { +func (c Claims) SetNotBefore(notBefore time.Time) { jwt.Claims(c).SetNotBefore(notBefore) } // SetIssuedAt sets claim "iat" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.6 -func (c Claims) SetIssuedAt(issuedAt float64) { +func (c Claims) SetIssuedAt(issuedAt time.Time) { jwt.Claims(c).SetIssuedAt(issuedAt) }
diff --git a/jws/jwt.go b/jws/jwt.go index 67f18f7..9720613 100644 --- a/jws/jwt.go +++ b/jws/jwt.go
@@ -80,7 +80,7 @@ if err := v1.Validate(j); err != nil { return err } - return jwt.Claims(c).Validate(float64(time.Now().Unix()), v1.EXP, v1.NBF) + return jwt.Claims(c).Validate(time.Now(), v1.EXP, v1.NBF) } } return ErrIsNotJWT @@ -98,7 +98,7 @@ // NewValidator returns a pointer to a jwt.Validator structure containing // the info to be used in the validation of a JWT. -func NewValidator(c Claims, exp, nbf float64, fn func(Claims) error) *jwt.Validator { +func NewValidator(c Claims, exp, nbf time.Duration, fn func(Claims) error) *jwt.Validator { return &jwt.Validator{ Expected: jwt.Claims(c), EXP: exp,
diff --git a/jws/jwt_test.go b/jws/jwt_test.go index 94a6f3e..5ea4a9b 100644 --- a/jws/jwt_test.go +++ b/jws/jwt_test.go
@@ -61,7 +61,7 @@ t.Error(err) } - d := float64(time.Now().Add(1 * time.Hour).Unix()) + d := time.Hour fn := func(c Claims) error { scopes, ok := c.Get("scopes").([]interface{})
diff --git a/jwt/claims.go b/jwt/claims.go index 51b491b..a44f207 100644 --- a/jwt/claims.go +++ b/jwt/claims.go
@@ -2,6 +2,8 @@ import ( "encoding/json" + "reflect" + "time" "github.com/SermoDigital/jose" ) @@ -12,15 +14,15 @@ // Validate validates the Claims per the claims found in // https://tools.ietf.org/html/rfc7519#section-4.1 -func (c Claims) Validate(now, expLeeway, nbfLeeway float64) error { +func (c Claims) Validate(now time.Time, expLeeway, nbfLeeway time.Duration) error { if exp, ok := c.Expiration(); ok { - if now > exp+expLeeway { + if now.After(exp.Add(expLeeway)) { return ErrTokenIsExpired } } if nbf, ok := c.NotBefore(); ok { - if now <= nbf-nbfLeeway { + if !now.After(nbf.Add(-nbfLeeway)) { return ErrTokenNotYetValid } } @@ -140,23 +142,20 @@ // Expiration retrieves claim "exp" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.4 -func (c Claims) Expiration() (float64, bool) { - v, ok := c.Get("exp").(float64) - return v, ok +func (c Claims) Expiration() (time.Time, bool) { + return c.GetTime("exp") } // NotBefore retrieves claim "nbf" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.5 -func (c Claims) NotBefore() (float64, bool) { - v, ok := c.Get("nbf").(float64) - return v, ok +func (c Claims) NotBefore() (time.Time, bool) { + return c.GetTime("nbf") } // IssuedAt retrieves claim "iat" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.6 -func (c Claims) IssuedAt() (float64, bool) { - v, ok := c.Get("iat").(float64) - return v, ok +func (c Claims) IssuedAt() (time.Time, bool) { + return c.GetTime("iat") } // JWTID retrieves claim "jti" per its type in @@ -211,20 +210,20 @@ // SetExpiration sets claim "exp" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.4 -func (c Claims) SetExpiration(expiration float64) { - c.Set("exp", expiration) +func (c Claims) SetExpiration(expiration time.Time) { + c.SetTime("exp", expiration) } // SetNotBefore sets claim "nbf" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.5 -func (c Claims) SetNotBefore(notBefore float64) { - c.Set("nbf", notBefore) +func (c Claims) SetNotBefore(notBefore time.Time) { + c.SetTime("nbf", notBefore) } // SetIssuedAt sets claim "iat" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.6 -func (c Claims) SetIssuedAt(issuedAt float64) { - c.Set("iat", issuedAt) +func (c Claims) SetIssuedAt(issuedAt time.Time) { + c.SetTime("iat", issuedAt) } // SetJWTID sets claim "jti" per its type in @@ -233,6 +232,41 @@ c.Set("jti", uniqueID) } +// zero pre-allocs the zero-time value +var zero = time.Time{} + +// GetTime returns a UNIX time for the given key. +// +// It converts an int, int32, int64, uint, uint32, uint64 or float64 value +// into a UNIX time (epoch seconds). float32 does not have sufficient +// precision to store a UNIX time. +// +// Numeric values parsed from JSON will always be stored as float64 since +// Claims is a map[string]interface{}. However, internally the values may be +// stored directly in the claims map as different types. +func (c Claims) GetTime(key string) (time.Time, bool) { + x := c.Get(key) + if x == nil { + return zero, false + } + v := reflect.ValueOf(x) + switch v.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + return time.Unix(v.Int(), 0), true + case reflect.Uint, reflect.Uint32, reflect.Uint64: + return time.Unix(int64(v.Uint()), 0), true + case reflect.Float64: + return time.Unix(int64(v.Float()), 0), true + default: + return zero, false + } +} + +// SetTime stores a UNIX time for the given key. +func (c Claims) SetTime(key string, t time.Time) { + c.Set(key, t.Unix()) +} + var ( _ json.Marshaler = (Claims)(nil) _ json.Unmarshaler = (*Claims)(nil)
diff --git a/jwt/claims_test.go b/jwt/claims_test.go index b274647..b653785 100644 --- a/jwt/claims_test.go +++ b/jwt/claims_test.go
@@ -2,6 +2,7 @@ import ( "testing" + "time" "github.com/SermoDigital/jose/crypto" "github.com/SermoDigital/jose/jws" @@ -86,21 +87,23 @@ } func TestValidate(t *testing.T) { - const before, now, after, leeway float64 = 10, 20, 30, 5 + now := time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC) + before, after := now.Add(-time.Minute), now.Add(time.Minute) + leeway := 10 * time.Second - exp := func(t float64) jwt.Claims { - return jwt.Claims{"exp": t} + exp := func(t time.Time) jwt.Claims { + return jwt.Claims{"exp": t.Unix()} } - nbf := func(t float64) jwt.Claims { - return jwt.Claims{"nbf": t} + nbf := func(t time.Time) jwt.Claims { + return jwt.Claims{"nbf": t.Unix()} } var tests = []struct { desc string c jwt.Claims - now float64 - expLeeway float64 - nbfLeeway float64 + now time.Time + expLeeway time.Duration + nbfLeeway time.Duration err error }{ // test for nbf < now <= exp @@ -115,13 +118,13 @@ {desc: "nbf > now", c: nbf(after), now: now, err: jwt.ErrTokenNotYetValid}, // test for nbf-x < now <= exp+y - {desc: "now < exp+x", now: now + leeway - 1, expLeeway: leeway, c: exp(now), err: nil}, - {desc: "now = exp+x", now: now + leeway, expLeeway: leeway, c: exp(now), err: nil}, - {desc: "now > exp+x", now: now + leeway + 1, expLeeway: leeway, c: exp(now), err: jwt.ErrTokenIsExpired}, + {desc: "now < exp+x", now: now.Add(leeway - time.Second), expLeeway: leeway, c: exp(now), err: nil}, + {desc: "now = exp+x", now: now.Add(leeway), expLeeway: leeway, c: exp(now), err: nil}, + {desc: "now > exp+x", now: now.Add(leeway + time.Second), expLeeway: leeway, c: exp(now), err: jwt.ErrTokenIsExpired}, - {desc: "nbf-x > now", c: nbf(now), nbfLeeway: leeway, now: now - leeway + 1, err: nil}, - {desc: "nbf-x = now", c: nbf(now), nbfLeeway: leeway, now: now - leeway, err: jwt.ErrTokenNotYetValid}, - {desc: "nbf-x < now", c: nbf(now), nbfLeeway: leeway, now: now - leeway - 1, err: jwt.ErrTokenNotYetValid}, + {desc: "nbf-x > now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway + time.Second), err: nil}, + {desc: "nbf-x = now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway), err: jwt.ErrTokenNotYetValid}, + {desc: "nbf-x < now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway - time.Second), err: jwt.ErrTokenNotYetValid}, } for i, tt := range tests { @@ -130,3 +133,67 @@ } } } + +func TestGetAndSetTime(t *testing.T) { + now := time.Now() + nowUnix := now.Unix() + c := jwt.Claims{ + "int": int(nowUnix), + "int32": int32(nowUnix), + "int64": int64(nowUnix), + "uint": uint(nowUnix), + "uint32": uint32(nowUnix), + "uint64": uint64(nowUnix), + "float64": float64(nowUnix), + } + c.SetTime("setTime", now) + for k := range c { + v, ok := c.GetTime(k) + if got, want := v, time.Unix(nowUnix, 0); !ok || !got.Equal(want) { + t.Errorf("%s: got %v want %v", k, got, want) + } + } +} + +// TestTimeValuesThroughJSON verifies that the time values +// that are set via the Set{IssuedAt,NotBefore,Expiration}() +// methods can actually be parsed back +func TestTimeValuesThroughJSON(t *testing.T) { + now := time.Unix(time.Now().Unix(), 0) + + c := jws.Claims{} + c.SetIssuedAt(now) + c.SetNotBefore(now) + c.SetExpiration(now) + + // serialize to JWT + tok := jws.NewJWT(c, crypto.SigningMethodHS256) + b, err := tok.Serialize([]byte("key")) + if err != nil { + t.Fatal(err) + } + + // parse the JWT again + tok2, err := jws.ParseJWT(b) + if err != nil { + t.Fatal(err) + } + c2 := tok2.Claims() + + iat, ok1 := c2.IssuedAt() + nbf, ok2 := c2.NotBefore() + exp, ok3 := c2.Expiration() + if !ok1 || !ok2 || !ok3 { + t.Fatal("got false want true") + } + + if got, want := iat, now; !got.Equal(want) { + t.Errorf("%s: got %v want %v", "iat", got, want) + } + if got, want := nbf, now; !got.Equal(want) { + t.Errorf("%s: got %v want %v", "nbf", got, want) + } + if got, want := exp, now; !got.Equal(want) { + t.Errorf("%s: got %v want %v", "exp", got, want) + } +}
diff --git a/jwt/jwt.go b/jwt/jwt.go index bd84259..18f5758 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go
@@ -1,6 +1,10 @@ package jwt -import "github.com/SermoDigital/jose/crypto" +import ( + "time" + + "github.com/SermoDigital/jose/crypto" +) // JWT represents a JWT per RFC 7519. // It's described as an interface instead of a physical structure @@ -33,10 +37,10 @@ // Validator represents some of the validation options. type Validator struct { - Expected Claims // If non-nil, these are required to match. - EXP float64 // EXPLeeway - NBF float64 // NBFLeeway - Fn ValidateFunc // See ValidateFunc for more information. + Expected Claims // If non-nil, these are required to match. + EXP time.Duration // EXPLeeway + NBF time.Duration // NBFLeeway + Fn ValidateFunc // See ValidateFunc for more information. _ struct{} } @@ -71,7 +75,7 @@ } if aud, ok := v.Expected.Audience(); ok { - if aud2, _ := j.Claims().Audience(); !eq(aud, aud2){ + if aud2, _ := j.Claims().Audience(); !eq(aud, aud2) { return ErrInvalidAUDClaim } } @@ -111,21 +115,21 @@ // SetExpiration sets the "exp" claim per // https://tools.ietf.org/html/rfc7519#section-4.1.4 -func (v *Validator) SetExpiration(exp float64) { +func (v *Validator) SetExpiration(exp time.Time) { v.expect() v.Expected.Set("exp", exp) } // SetNotBefore sets the "nbf" claim per // https://tools.ietf.org/html/rfc7519#section-4.1.5 -func (v *Validator) SetNotBefore(nbf float64) { +func (v *Validator) SetNotBefore(nbf time.Time) { v.expect() v.Expected.Set("nbf", nbf) } // SetIssuedAt sets the "iat" claim per // https://tools.ietf.org/html/rfc7519#section-4.1.6 -func (v *Validator) SetIssuedAt(iat float64) { +func (v *Validator) SetIssuedAt(iat time.Time) { v.expect() v.Expected.Set("iat", iat) }