jwt: more testing, works now.
jws: added some accessor methods
diff --git a/crypto/signing_method.go b/crypto/signing_method.go
index 998ac59..c8b8874 100644
--- a/crypto/signing_method.go
+++ b/crypto/signing_method.go
@@ -1,8 +1,8 @@
package crypto
-// SigningMethod is an interface that provides a way to sign JWS tokens.
import "crypto"
+// SigningMethod is an interface that provides a way to sign JWS tokens.
type SigningMethod interface {
// Alg describes the signing algorithm, and is used to uniquely
// describe the specific crypto.SigningMethod.
diff --git a/jws/jws.go b/jws/jws.go
index 2a65f25..e777b18 100644
--- a/jws/jws.go
+++ b/jws/jws.go
@@ -23,6 +23,9 @@
// Payload returns the JWS' payload.
func (j *JWS) Payload() interface{} { return j.payload.v }
+// SetPayload sets the JWS' raw, unexported payload.
+func (j *JWS) SetPayload(val interface{}) { j.payload.v = val }
+
// sigHead represents the 'signatures' member of the JWS' "general"
// serialization form per
// https://tools.ietf.org/html/rfc7515#section-7.2.1
@@ -356,9 +359,6 @@
if len(j.sb) < 1 {
return ErrCannotValidate
}
- if j.isJWT {
- return j.validateJWT(key, method)
- }
return j.sb[0].validate(j.plcache, key, method)
}
@@ -368,3 +368,71 @@
}
return method.Verify(format(s.Protected, pl), s.Signature, key)
}
+
+// SetProtected sets the protected Header with the given value.
+// If i is provided, it'll assume the JWS is in the "general" format,
+// and set the Header at index i (inside the signatures member) with
+// the given value.
+func (j *JWS) SetProtected(key string, val interface{}, i ...int) {
+ k := 0
+ if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
+ k = i[0]
+ }
+ j.sb[k].protected.Set(key, val)
+}
+
+// RemoveProtected removes the value inside the protected Header that
+// corresponds with the given key.
+// For information on parameter i, see SetProtected.
+func (j *JWS) RemoveProtected(key string, i ...int) {
+ k := 0
+ if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
+ k = i[0]
+ }
+ j.sb[k].protected.Del(key)
+}
+
+// GetProtected retrieves the value inside the protected Header that
+// corresponds with the given key.
+// For information on parameter i, see SetProtected.
+func (j *JWS) GetProtected(key string, i ...int) interface{} {
+ k := 0
+ if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
+ k = i[0]
+ }
+ return j.sb[k].protected.Get(key)
+}
+
+// SetUnprotected sets the protected Header with the given value.
+// If i is provided, it'll assume the JWS is in the "general" format,
+// and set the Header at index i (inside the signatures member) with
+// the given value.
+func (j *JWS) SetUnprotected(key string, val interface{}, i ...int) {
+ k := 0
+ if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
+ k = i[0]
+ }
+ j.sb[k].unprotected.Set(key, val)
+}
+
+// RemoveUnprotected removes the value inside the unprotected Header that
+// corresponds with the given key.
+// For information on parameter i, see SetUnprotected.
+func (j *JWS) RemoveUnprotected(key string, i ...int) {
+ k := 0
+ if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
+ k = i[0]
+ }
+ j.sb[k].unprotected.Del(key)
+}
+
+// GetUnprotected retrieves the value inside the protected Header that
+// corresponds with the given key.
+// For information on parameter i, see SetUnprotected.
+func (j *JWS) GetUnprotected(key string, i ...int) interface{} {
+ k := 0
+ if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
+ k = i[0]
+ }
+ return j.sb[k].unprotected.Get(key)
+}
diff --git a/jws/jwt.go b/jws/jwt.go
index ce2e985..2f4dfc2 100644
--- a/jws/jwt.go
+++ b/jws/jwt.go
@@ -1,6 +1,8 @@
package jws
import (
+ "time"
+
"github.com/SermoDigital/jose/crypto"
"github.com/SermoDigital/jose/jwt"
)
@@ -33,16 +35,75 @@
return nil
}
-// ParseJWT parses a serialized JWT into a physical JWT.
-func ParseJWT(encoded []byte) (JWT, error) {
- return ParseCompact(encoded)
+// ParseJWT parses a serialized jwt.JWT into a physical jwt.JWT.
+// If its payload isn't a set of claims (or able to be coerced into
+// a set of claims) it'll return an error stating the
+// JWT isn't a JWT.
+func ParseJWT(encoded []byte) (jwt.JWT, error) {
+ t, err := ParseCompact(encoded)
+ if err != nil {
+ return nil, err
+ }
+ c, ok := t.payload.v.(map[string]interface{})
+ if !ok {
+ return nil, ErrIsNotJWT
+ }
+ t.payload.v = Claims(c)
+ t.isJWT = true
+ return t, nil
}
// IsJWT returns true if the JWS is a JWT.
func (j *JWS) IsJWT() bool { return j.isJWT }
-func (j *JWS) validateJWT(key interface{}, m crypto.SigningMethod) error {
- return nil
+// Verify helps implement jwt.JWT.
+func (j *JWS) Verify(key interface{}, m crypto.SigningMethod, o ...jwt.Opts) error {
+ if j.isJWT {
+ if err := j.Validate(key, m); err != nil {
+ return err
+ }
+ c, ok := j.payload.v.(Claims)
+ if ok {
+ var p jwt.Opts
+ if len(o) > 0 {
+ p = o[0]
+ }
+
+ if p.Fn != nil {
+ if err := p.Fn(jwt.Claims(c)); err != nil {
+ return err
+ }
+ }
+ return jwt.Claims(c).Validate(time.Now().Unix(), p.EXP, p.NBF)
+ }
+ }
+ return ErrIsNotJWT
+}
+
+// Opts represents some of the validation options.
+// It mimics jwt.Opts.
+type Opts struct {
+ EXP int64 // EXPLeeway
+ NBF int64 // NBFLeeway
+ Fn func(Claims) error
+ _ struct{}
+}
+
+// C is shorthand for Convert(fn).
+func (o Opts) C() jwt.Opts { return o.Convert() }
+
+// Convert converts Opts into jwt.Opts.
+func (o Opts) Convert() jwt.Opts {
+ p := jwt.Opts{
+ EXP: o.EXP,
+ NBF: o.NBF,
+ }
+ if o.Fn != nil {
+ p.Fn = func(c jwt.Claims) error {
+ return o.Fn(Claims(c))
+ }
+ }
+ return p
}
var _ jwt.JWT = (*JWS)(nil)
diff --git a/jws/jwt_test.go b/jws/jwt_test.go
new file mode 100644
index 0000000..aac4051
--- /dev/null
+++ b/jws/jwt_test.go
@@ -0,0 +1,42 @@
+package jws
+
+import (
+ "testing"
+
+ "github.com/SermoDigital/jose/crypto"
+)
+
+var claims = Claims{
+ "name": "Eric",
+ "scopes": []string{
+ "user.account.info",
+ "user.account.update",
+ "user.account.delete",
+ },
+ "admin": true,
+ "data": struct {
+ Foo, Bar int
+ }{
+ Foo: 12,
+ Bar: 50,
+ },
+}
+
+func TestBasicJWT(t *testing.T) {
+ j := NewJWT(claims, crypto.SigningMethodRS512)
+ b, err := j.Serialize(rsaPriv)
+ if err != nil {
+ t.Error(err)
+ }
+
+ w, err := ParseJWT(b)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if w.Claims().Get("name") != "Eric" &&
+ w.Claims().Get("admin") != true &&
+ w.Claims().Get("scopes").([]string)[0] != "user.account.info" {
+ Error(t, claims, w.Claims())
+ }
+}
diff --git a/jwt/claims.go b/jwt/claims.go
index 46076d9..1955400 100644
--- a/jwt/claims.go
+++ b/jwt/claims.go
@@ -10,6 +10,36 @@
// methods, similar to net/url.Values.
type Claims map[string]interface{}
+// Validate ...
+func (c Claims) Validate(now, expLeeway, nbfLeeway int64) error {
+ if exp, ok := c.expiration(); ok {
+ if !within(exp, expLeeway, now) {
+ return ErrTokenIsExpired
+ }
+ }
+
+ if nbf, ok := c.notBefore(); ok {
+ if !within(nbf, nbfLeeway, now) {
+ return ErrTokenNotYetValid
+ }
+ }
+ return nil
+}
+
+func (c Claims) expiration() (int64, bool) {
+ v, ok := c.Get("exp").(int64)
+ return v, ok
+}
+
+func (c Claims) notBefore() (int64, bool) {
+ v, ok := c.Get("nbf").(int64)
+ return v, ok
+}
+
+func within(cur, delta, max int64) bool {
+ return cur+delta < max || cur-delta < max
+}
+
// Get retrieves the value corresponding with key from the Claims.
func (c Claims) Get(key string) interface{} {
if c == nil {
@@ -74,3 +104,8 @@
*c = Claims(tmp)
return nil
}
+
+var (
+ _ json.Marshaler = (Claims)(nil)
+ _ json.Unmarshaler = (*Claims)(nil)
+)
diff --git a/jwt/errors.go b/jwt/errors.go
new file mode 100644
index 0000000..e7ed564
--- /dev/null
+++ b/jwt/errors.go
@@ -0,0 +1,8 @@
+package jwt
+
+import "errors"
+
+var (
+ ErrTokenIsExpired = errors.New("token is expired")
+ ErrTokenNotYetValid = errors.New("token is not yet valid")
+)
diff --git a/jwt/jwt.go b/jwt/jwt.go
index bd09ec0..5a30aae 100644
--- a/jwt/jwt.go
+++ b/jwt/jwt.go
@@ -2,15 +2,39 @@
import "github.com/SermoDigital/jose/crypto"
+// Opts represents some of the validation options.
+type Opts struct {
+ EXP int64 // EXPLeeway
+ NBF int64 // NBFLeeway
+ Fn ValidateFunc // See ValidateFunc for more information.
+ _ struct{}
+}
+
+// JWT represents a JWT as per RFC 7519.
+// It's described as an interface instead of a physical structure
+// because both JWS and JWEs can be JWTs. So, in order to use either,
+// import one of those two packages and use their "NewJWT" (and other)
+// functions.
type JWT interface {
// Claims returns the set of Claims.
Claims() Claims
- // Validate returns an error describing any issues found while
- // validating the JWT.
- Validate(key interface{}, method crypto.SigningMethod) error
+ // Verify returns an error describing any issues found while
+ // validating the JWT. For info on the fn parameter, see the
+ // comment on ValidateFunc.
+ Verify(key interface{}, method crypto.SigningMethod, o ...Opts) error
// Serialize serializes the JWT into its on-the-wire
// representation.
Serialize(key interface{}) ([]byte, error)
}
+
+// ValidateFunc is a function that provides access to the JWT
+// and allows for custom validation. Keep in mind that the Verify
+// methods in the JWS/JWE sibling packages call ValidateFunc *after*
+// validating the JWS/JWE, but *before* any validation per the JWT
+// RFC. Therefore, the ValidateFunc can be used to short-circuit
+// verification, but cannot be used to circumvent the RFC.
+// Custom JWT implementations are free to abuse this, but it is
+// not recommended.
+type ValidateFunc func(Claims) error