jwt: more testing, works now. jws: added some accessor methods
diff --git a/crypto/signing_method.go b/crypto/signing_method.go index 998ac59..c8b8874 100644 --- a/crypto/signing_method.go +++ b/crypto/signing_method.go
@@ -1,8 +1,8 @@ package crypto -// SigningMethod is an interface that provides a way to sign JWS tokens. import "crypto" +// SigningMethod is an interface that provides a way to sign JWS tokens. type SigningMethod interface { // Alg describes the signing algorithm, and is used to uniquely // describe the specific crypto.SigningMethod.
diff --git a/jws/jws.go b/jws/jws.go index 2a65f25..e777b18 100644 --- a/jws/jws.go +++ b/jws/jws.go
@@ -23,6 +23,9 @@ // Payload returns the JWS' payload. func (j *JWS) Payload() interface{} { return j.payload.v } +// SetPayload sets the JWS' raw, unexported payload. +func (j *JWS) SetPayload(val interface{}) { j.payload.v = val } + // sigHead represents the 'signatures' member of the JWS' "general" // serialization form per // https://tools.ietf.org/html/rfc7515#section-7.2.1 @@ -356,9 +359,6 @@ if len(j.sb) < 1 { return ErrCannotValidate } - if j.isJWT { - return j.validateJWT(key, method) - } return j.sb[0].validate(j.plcache, key, method) } @@ -368,3 +368,71 @@ } return method.Verify(format(s.Protected, pl), s.Signature, key) } + +// SetProtected sets the protected Header with the given value. +// If i is provided, it'll assume the JWS is in the "general" format, +// and set the Header at index i (inside the signatures member) with +// the given value. +func (j *JWS) SetProtected(key string, val interface{}, i ...int) { + k := 0 + if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 { + k = i[0] + } + j.sb[k].protected.Set(key, val) +} + +// RemoveProtected removes the value inside the protected Header that +// corresponds with the given key. +// For information on parameter i, see SetProtected. +func (j *JWS) RemoveProtected(key string, i ...int) { + k := 0 + if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 { + k = i[0] + } + j.sb[k].protected.Del(key) +} + +// GetProtected retrieves the value inside the protected Header that +// corresponds with the given key. +// For information on parameter i, see SetProtected. +func (j *JWS) GetProtected(key string, i ...int) interface{} { + k := 0 + if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 { + k = i[0] + } + return j.sb[k].protected.Get(key) +} + +// SetUnprotected sets the protected Header with the given value. +// If i is provided, it'll assume the JWS is in the "general" format, +// and set the Header at index i (inside the signatures member) with +// the given value. +func (j *JWS) SetUnprotected(key string, val interface{}, i ...int) { + k := 0 + if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 { + k = i[0] + } + j.sb[k].unprotected.Set(key, val) +} + +// RemoveUnprotected removes the value inside the unprotected Header that +// corresponds with the given key. +// For information on parameter i, see SetUnprotected. +func (j *JWS) RemoveUnprotected(key string, i ...int) { + k := 0 + if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 { + k = i[0] + } + j.sb[k].unprotected.Del(key) +} + +// GetUnprotected retrieves the value inside the protected Header that +// corresponds with the given key. +// For information on parameter i, see SetUnprotected. +func (j *JWS) GetUnprotected(key string, i ...int) interface{} { + k := 0 + if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 { + k = i[0] + } + return j.sb[k].unprotected.Get(key) +}
diff --git a/jws/jwt.go b/jws/jwt.go index ce2e985..2f4dfc2 100644 --- a/jws/jwt.go +++ b/jws/jwt.go
@@ -1,6 +1,8 @@ package jws import ( + "time" + "github.com/SermoDigital/jose/crypto" "github.com/SermoDigital/jose/jwt" ) @@ -33,16 +35,75 @@ return nil } -// ParseJWT parses a serialized JWT into a physical JWT. -func ParseJWT(encoded []byte) (JWT, error) { - return ParseCompact(encoded) +// ParseJWT parses a serialized jwt.JWT into a physical jwt.JWT. +// If its payload isn't a set of claims (or able to be coerced into +// a set of claims) it'll return an error stating the +// JWT isn't a JWT. +func ParseJWT(encoded []byte) (jwt.JWT, error) { + t, err := ParseCompact(encoded) + if err != nil { + return nil, err + } + c, ok := t.payload.v.(map[string]interface{}) + if !ok { + return nil, ErrIsNotJWT + } + t.payload.v = Claims(c) + t.isJWT = true + return t, nil } // IsJWT returns true if the JWS is a JWT. func (j *JWS) IsJWT() bool { return j.isJWT } -func (j *JWS) validateJWT(key interface{}, m crypto.SigningMethod) error { - return nil +// Verify helps implement jwt.JWT. +func (j *JWS) Verify(key interface{}, m crypto.SigningMethod, o ...jwt.Opts) error { + if j.isJWT { + if err := j.Validate(key, m); err != nil { + return err + } + c, ok := j.payload.v.(Claims) + if ok { + var p jwt.Opts + if len(o) > 0 { + p = o[0] + } + + if p.Fn != nil { + if err := p.Fn(jwt.Claims(c)); err != nil { + return err + } + } + return jwt.Claims(c).Validate(time.Now().Unix(), p.EXP, p.NBF) + } + } + return ErrIsNotJWT +} + +// Opts represents some of the validation options. +// It mimics jwt.Opts. +type Opts struct { + EXP int64 // EXPLeeway + NBF int64 // NBFLeeway + Fn func(Claims) error + _ struct{} +} + +// C is shorthand for Convert(fn). +func (o Opts) C() jwt.Opts { return o.Convert() } + +// Convert converts Opts into jwt.Opts. +func (o Opts) Convert() jwt.Opts { + p := jwt.Opts{ + EXP: o.EXP, + NBF: o.NBF, + } + if o.Fn != nil { + p.Fn = func(c jwt.Claims) error { + return o.Fn(Claims(c)) + } + } + return p } var _ jwt.JWT = (*JWS)(nil)
diff --git a/jws/jwt_test.go b/jws/jwt_test.go new file mode 100644 index 0000000..aac4051 --- /dev/null +++ b/jws/jwt_test.go
@@ -0,0 +1,42 @@ +package jws + +import ( + "testing" + + "github.com/SermoDigital/jose/crypto" +) + +var claims = Claims{ + "name": "Eric", + "scopes": []string{ + "user.account.info", + "user.account.update", + "user.account.delete", + }, + "admin": true, + "data": struct { + Foo, Bar int + }{ + Foo: 12, + Bar: 50, + }, +} + +func TestBasicJWT(t *testing.T) { + j := NewJWT(claims, crypto.SigningMethodRS512) + b, err := j.Serialize(rsaPriv) + if err != nil { + t.Error(err) + } + + w, err := ParseJWT(b) + if err != nil { + t.Error(err) + } + + if w.Claims().Get("name") != "Eric" && + w.Claims().Get("admin") != true && + w.Claims().Get("scopes").([]string)[0] != "user.account.info" { + Error(t, claims, w.Claims()) + } +}
diff --git a/jwt/claims.go b/jwt/claims.go index 46076d9..1955400 100644 --- a/jwt/claims.go +++ b/jwt/claims.go
@@ -10,6 +10,36 @@ // methods, similar to net/url.Values. type Claims map[string]interface{} +// Validate ... +func (c Claims) Validate(now, expLeeway, nbfLeeway int64) error { + if exp, ok := c.expiration(); ok { + if !within(exp, expLeeway, now) { + return ErrTokenIsExpired + } + } + + if nbf, ok := c.notBefore(); ok { + if !within(nbf, nbfLeeway, now) { + return ErrTokenNotYetValid + } + } + return nil +} + +func (c Claims) expiration() (int64, bool) { + v, ok := c.Get("exp").(int64) + return v, ok +} + +func (c Claims) notBefore() (int64, bool) { + v, ok := c.Get("nbf").(int64) + return v, ok +} + +func within(cur, delta, max int64) bool { + return cur+delta < max || cur-delta < max +} + // Get retrieves the value corresponding with key from the Claims. func (c Claims) Get(key string) interface{} { if c == nil { @@ -74,3 +104,8 @@ *c = Claims(tmp) return nil } + +var ( + _ json.Marshaler = (Claims)(nil) + _ json.Unmarshaler = (*Claims)(nil) +)
diff --git a/jwt/errors.go b/jwt/errors.go new file mode 100644 index 0000000..e7ed564 --- /dev/null +++ b/jwt/errors.go
@@ -0,0 +1,8 @@ +package jwt + +import "errors" + +var ( + ErrTokenIsExpired = errors.New("token is expired") + ErrTokenNotYetValid = errors.New("token is not yet valid") +)
diff --git a/jwt/jwt.go b/jwt/jwt.go index bd09ec0..5a30aae 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go
@@ -2,15 +2,39 @@ import "github.com/SermoDigital/jose/crypto" +// Opts represents some of the validation options. +type Opts struct { + EXP int64 // EXPLeeway + NBF int64 // NBFLeeway + Fn ValidateFunc // See ValidateFunc for more information. + _ struct{} +} + +// JWT represents a JWT as per RFC 7519. +// It's described as an interface instead of a physical structure +// because both JWS and JWEs can be JWTs. So, in order to use either, +// import one of those two packages and use their "NewJWT" (and other) +// functions. type JWT interface { // Claims returns the set of Claims. Claims() Claims - // Validate returns an error describing any issues found while - // validating the JWT. - Validate(key interface{}, method crypto.SigningMethod) error + // Verify returns an error describing any issues found while + // validating the JWT. For info on the fn parameter, see the + // comment on ValidateFunc. + Verify(key interface{}, method crypto.SigningMethod, o ...Opts) error // Serialize serializes the JWT into its on-the-wire // representation. Serialize(key interface{}) ([]byte, error) } + +// ValidateFunc is a function that provides access to the JWT +// and allows for custom validation. Keep in mind that the Verify +// methods in the JWS/JWE sibling packages call ValidateFunc *after* +// validating the JWS/JWE, but *before* any validation per the JWT +// RFC. Therefore, the ValidateFunc can be used to short-circuit +// verification, but cannot be used to circumvent the RFC. +// Custom JWT implementations are free to abuse this, but it is +// not recommended. +type ValidateFunc func(Claims) error