Merge branch 'use-time' of https://github.com/magiconair/jose into magiconair-use-time
diff --git a/jws/claims.go b/jws/claims.go
index 068caa8..4cc616c 100644
--- a/jws/claims.go
+++ b/jws/claims.go
@@ -2,6 +2,7 @@
import (
"encoding/json"
+ "time"
"github.com/SermoDigital/jose"
"github.com/SermoDigital/jose/jwt"
@@ -84,19 +85,19 @@
// Expiration retrieves claim "exp" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.4
-func (c Claims) Expiration() (float64, bool) {
+func (c Claims) Expiration() (time.Time, bool) {
return jwt.Claims(c).Expiration()
}
// NotBefore retrieves claim "nbf" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.5
-func (c Claims) NotBefore() (float64, bool) {
+func (c Claims) NotBefore() (time.Time, bool) {
return jwt.Claims(c).NotBefore()
}
// IssuedAt retrieves claim "iat" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.6
-func (c Claims) IssuedAt() (float64, bool) {
+func (c Claims) IssuedAt() (time.Time, bool) {
return jwt.Claims(c).IssuedAt()
}
@@ -161,19 +162,19 @@
// SetExpiration sets claim "exp" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.4
-func (c Claims) SetExpiration(expiration float64) {
+func (c Claims) SetExpiration(expiration time.Time) {
jwt.Claims(c).SetExpiration(expiration)
}
// SetNotBefore sets claim "nbf" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.5
-func (c Claims) SetNotBefore(notBefore float64) {
+func (c Claims) SetNotBefore(notBefore time.Time) {
jwt.Claims(c).SetNotBefore(notBefore)
}
// SetIssuedAt sets claim "iat" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.6
-func (c Claims) SetIssuedAt(issuedAt float64) {
+func (c Claims) SetIssuedAt(issuedAt time.Time) {
jwt.Claims(c).SetIssuedAt(issuedAt)
}
diff --git a/jws/jwt.go b/jws/jwt.go
index 67f18f7..9720613 100644
--- a/jws/jwt.go
+++ b/jws/jwt.go
@@ -80,7 +80,7 @@
if err := v1.Validate(j); err != nil {
return err
}
- return jwt.Claims(c).Validate(float64(time.Now().Unix()), v1.EXP, v1.NBF)
+ return jwt.Claims(c).Validate(time.Now(), v1.EXP, v1.NBF)
}
}
return ErrIsNotJWT
@@ -98,7 +98,7 @@
// NewValidator returns a pointer to a jwt.Validator structure containing
// the info to be used in the validation of a JWT.
-func NewValidator(c Claims, exp, nbf float64, fn func(Claims) error) *jwt.Validator {
+func NewValidator(c Claims, exp, nbf time.Duration, fn func(Claims) error) *jwt.Validator {
return &jwt.Validator{
Expected: jwt.Claims(c),
EXP: exp,
diff --git a/jws/jwt_test.go b/jws/jwt_test.go
index a8acb10..49e75a3 100644
--- a/jws/jwt_test.go
+++ b/jws/jwt_test.go
@@ -61,7 +61,7 @@
t.Error(err)
}
- d := float64(time.Now().Add(1 * time.Hour).Unix())
+ d := time.Hour
fn := func(c Claims) error {
if c.Get("name") != "Eric" &&
c.Get("admin") != true &&
diff --git a/jwt/claims.go b/jwt/claims.go
index 51b491b..a44f207 100644
--- a/jwt/claims.go
+++ b/jwt/claims.go
@@ -2,6 +2,8 @@
import (
"encoding/json"
+ "reflect"
+ "time"
"github.com/SermoDigital/jose"
)
@@ -12,15 +14,15 @@
// Validate validates the Claims per the claims found in
// https://tools.ietf.org/html/rfc7519#section-4.1
-func (c Claims) Validate(now, expLeeway, nbfLeeway float64) error {
+func (c Claims) Validate(now time.Time, expLeeway, nbfLeeway time.Duration) error {
if exp, ok := c.Expiration(); ok {
- if now > exp+expLeeway {
+ if now.After(exp.Add(expLeeway)) {
return ErrTokenIsExpired
}
}
if nbf, ok := c.NotBefore(); ok {
- if now <= nbf-nbfLeeway {
+ if !now.After(nbf.Add(-nbfLeeway)) {
return ErrTokenNotYetValid
}
}
@@ -140,23 +142,20 @@
// Expiration retrieves claim "exp" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.4
-func (c Claims) Expiration() (float64, bool) {
- v, ok := c.Get("exp").(float64)
- return v, ok
+func (c Claims) Expiration() (time.Time, bool) {
+ return c.GetTime("exp")
}
// NotBefore retrieves claim "nbf" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.5
-func (c Claims) NotBefore() (float64, bool) {
- v, ok := c.Get("nbf").(float64)
- return v, ok
+func (c Claims) NotBefore() (time.Time, bool) {
+ return c.GetTime("nbf")
}
// IssuedAt retrieves claim "iat" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.6
-func (c Claims) IssuedAt() (float64, bool) {
- v, ok := c.Get("iat").(float64)
- return v, ok
+func (c Claims) IssuedAt() (time.Time, bool) {
+ return c.GetTime("iat")
}
// JWTID retrieves claim "jti" per its type in
@@ -211,20 +210,20 @@
// SetExpiration sets claim "exp" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.4
-func (c Claims) SetExpiration(expiration float64) {
- c.Set("exp", expiration)
+func (c Claims) SetExpiration(expiration time.Time) {
+ c.SetTime("exp", expiration)
}
// SetNotBefore sets claim "nbf" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.5
-func (c Claims) SetNotBefore(notBefore float64) {
- c.Set("nbf", notBefore)
+func (c Claims) SetNotBefore(notBefore time.Time) {
+ c.SetTime("nbf", notBefore)
}
// SetIssuedAt sets claim "iat" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.6
-func (c Claims) SetIssuedAt(issuedAt float64) {
- c.Set("iat", issuedAt)
+func (c Claims) SetIssuedAt(issuedAt time.Time) {
+ c.SetTime("iat", issuedAt)
}
// SetJWTID sets claim "jti" per its type in
@@ -233,6 +232,41 @@
c.Set("jti", uniqueID)
}
+// zero pre-allocs the zero-time value
+var zero = time.Time{}
+
+// GetTime returns a UNIX time for the given key.
+//
+// It converts an int, int32, int64, uint, uint32, uint64 or float64 value
+// into a UNIX time (epoch seconds). float32 does not have sufficient
+// precision to store a UNIX time.
+//
+// Numeric values parsed from JSON will always be stored as float64 since
+// Claims is a map[string]interface{}. However, internally the values may be
+// stored directly in the claims map as different types.
+func (c Claims) GetTime(key string) (time.Time, bool) {
+ x := c.Get(key)
+ if x == nil {
+ return zero, false
+ }
+ v := reflect.ValueOf(x)
+ switch v.Kind() {
+ case reflect.Int, reflect.Int32, reflect.Int64:
+ return time.Unix(v.Int(), 0), true
+ case reflect.Uint, reflect.Uint32, reflect.Uint64:
+ return time.Unix(int64(v.Uint()), 0), true
+ case reflect.Float64:
+ return time.Unix(int64(v.Float()), 0), true
+ default:
+ return zero, false
+ }
+}
+
+// SetTime stores a UNIX time for the given key.
+func (c Claims) SetTime(key string, t time.Time) {
+ c.Set(key, t.Unix())
+}
+
var (
_ json.Marshaler = (Claims)(nil)
_ json.Unmarshaler = (*Claims)(nil)
diff --git a/jwt/claims_test.go b/jwt/claims_test.go
index b274647..b653785 100644
--- a/jwt/claims_test.go
+++ b/jwt/claims_test.go
@@ -2,6 +2,7 @@
import (
"testing"
+ "time"
"github.com/SermoDigital/jose/crypto"
"github.com/SermoDigital/jose/jws"
@@ -86,21 +87,23 @@
}
func TestValidate(t *testing.T) {
- const before, now, after, leeway float64 = 10, 20, 30, 5
+ now := time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)
+ before, after := now.Add(-time.Minute), now.Add(time.Minute)
+ leeway := 10 * time.Second
- exp := func(t float64) jwt.Claims {
- return jwt.Claims{"exp": t}
+ exp := func(t time.Time) jwt.Claims {
+ return jwt.Claims{"exp": t.Unix()}
}
- nbf := func(t float64) jwt.Claims {
- return jwt.Claims{"nbf": t}
+ nbf := func(t time.Time) jwt.Claims {
+ return jwt.Claims{"nbf": t.Unix()}
}
var tests = []struct {
desc string
c jwt.Claims
- now float64
- expLeeway float64
- nbfLeeway float64
+ now time.Time
+ expLeeway time.Duration
+ nbfLeeway time.Duration
err error
}{
// test for nbf < now <= exp
@@ -115,13 +118,13 @@
{desc: "nbf > now", c: nbf(after), now: now, err: jwt.ErrTokenNotYetValid},
// test for nbf-x < now <= exp+y
- {desc: "now < exp+x", now: now + leeway - 1, expLeeway: leeway, c: exp(now), err: nil},
- {desc: "now = exp+x", now: now + leeway, expLeeway: leeway, c: exp(now), err: nil},
- {desc: "now > exp+x", now: now + leeway + 1, expLeeway: leeway, c: exp(now), err: jwt.ErrTokenIsExpired},
+ {desc: "now < exp+x", now: now.Add(leeway - time.Second), expLeeway: leeway, c: exp(now), err: nil},
+ {desc: "now = exp+x", now: now.Add(leeway), expLeeway: leeway, c: exp(now), err: nil},
+ {desc: "now > exp+x", now: now.Add(leeway + time.Second), expLeeway: leeway, c: exp(now), err: jwt.ErrTokenIsExpired},
- {desc: "nbf-x > now", c: nbf(now), nbfLeeway: leeway, now: now - leeway + 1, err: nil},
- {desc: "nbf-x = now", c: nbf(now), nbfLeeway: leeway, now: now - leeway, err: jwt.ErrTokenNotYetValid},
- {desc: "nbf-x < now", c: nbf(now), nbfLeeway: leeway, now: now - leeway - 1, err: jwt.ErrTokenNotYetValid},
+ {desc: "nbf-x > now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway + time.Second), err: nil},
+ {desc: "nbf-x = now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway), err: jwt.ErrTokenNotYetValid},
+ {desc: "nbf-x < now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway - time.Second), err: jwt.ErrTokenNotYetValid},
}
for i, tt := range tests {
@@ -130,3 +133,67 @@
}
}
}
+
+func TestGetAndSetTime(t *testing.T) {
+ now := time.Now()
+ nowUnix := now.Unix()
+ c := jwt.Claims{
+ "int": int(nowUnix),
+ "int32": int32(nowUnix),
+ "int64": int64(nowUnix),
+ "uint": uint(nowUnix),
+ "uint32": uint32(nowUnix),
+ "uint64": uint64(nowUnix),
+ "float64": float64(nowUnix),
+ }
+ c.SetTime("setTime", now)
+ for k := range c {
+ v, ok := c.GetTime(k)
+ if got, want := v, time.Unix(nowUnix, 0); !ok || !got.Equal(want) {
+ t.Errorf("%s: got %v want %v", k, got, want)
+ }
+ }
+}
+
+// TestTimeValuesThroughJSON verifies that the time values
+// that are set via the Set{IssuedAt,NotBefore,Expiration}()
+// methods can actually be parsed back
+func TestTimeValuesThroughJSON(t *testing.T) {
+ now := time.Unix(time.Now().Unix(), 0)
+
+ c := jws.Claims{}
+ c.SetIssuedAt(now)
+ c.SetNotBefore(now)
+ c.SetExpiration(now)
+
+ // serialize to JWT
+ tok := jws.NewJWT(c, crypto.SigningMethodHS256)
+ b, err := tok.Serialize([]byte("key"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // parse the JWT again
+ tok2, err := jws.ParseJWT(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c2 := tok2.Claims()
+
+ iat, ok1 := c2.IssuedAt()
+ nbf, ok2 := c2.NotBefore()
+ exp, ok3 := c2.Expiration()
+ if !ok1 || !ok2 || !ok3 {
+ t.Fatal("got false want true")
+ }
+
+ if got, want := iat, now; !got.Equal(want) {
+ t.Errorf("%s: got %v want %v", "iat", got, want)
+ }
+ if got, want := nbf, now; !got.Equal(want) {
+ t.Errorf("%s: got %v want %v", "nbf", got, want)
+ }
+ if got, want := exp, now; !got.Equal(want) {
+ t.Errorf("%s: got %v want %v", "exp", got, want)
+ }
+}
diff --git a/jwt/jwt.go b/jwt/jwt.go
index bd84259..18f5758 100644
--- a/jwt/jwt.go
+++ b/jwt/jwt.go
@@ -1,6 +1,10 @@
package jwt
-import "github.com/SermoDigital/jose/crypto"
+import (
+ "time"
+
+ "github.com/SermoDigital/jose/crypto"
+)
// JWT represents a JWT per RFC 7519.
// It's described as an interface instead of a physical structure
@@ -33,10 +37,10 @@
// Validator represents some of the validation options.
type Validator struct {
- Expected Claims // If non-nil, these are required to match.
- EXP float64 // EXPLeeway
- NBF float64 // NBFLeeway
- Fn ValidateFunc // See ValidateFunc for more information.
+ Expected Claims // If non-nil, these are required to match.
+ EXP time.Duration // EXPLeeway
+ NBF time.Duration // NBFLeeway
+ Fn ValidateFunc // See ValidateFunc for more information.
_ struct{}
}
@@ -71,7 +75,7 @@
}
if aud, ok := v.Expected.Audience(); ok {
- if aud2, _ := j.Claims().Audience(); !eq(aud, aud2){
+ if aud2, _ := j.Claims().Audience(); !eq(aud, aud2) {
return ErrInvalidAUDClaim
}
}
@@ -111,21 +115,21 @@
// SetExpiration sets the "exp" claim per
// https://tools.ietf.org/html/rfc7519#section-4.1.4
-func (v *Validator) SetExpiration(exp float64) {
+func (v *Validator) SetExpiration(exp time.Time) {
v.expect()
v.Expected.Set("exp", exp)
}
// SetNotBefore sets the "nbf" claim per
// https://tools.ietf.org/html/rfc7519#section-4.1.5
-func (v *Validator) SetNotBefore(nbf float64) {
+func (v *Validator) SetNotBefore(nbf time.Time) {
v.expect()
v.Expected.Set("nbf", nbf)
}
// SetIssuedAt sets the "iat" claim per
// https://tools.ietf.org/html/rfc7519#section-4.1.6
-func (v *Validator) SetIssuedAt(iat float64) {
+func (v *Validator) SetIssuedAt(iat time.Time) {
v.expect()
v.Expected.Set("iat", iat)
}