jws + jwt: converted JWS -> interfce{} and cleaned up JWS and JWT interfaces.
diff --git a/_test.sh b/_test.sh
new file mode 100755
index 0000000..a36a470
--- /dev/null
+++ b/_test.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+
+set -euo pipefail
+
+go build ./...
+go test ./...
+golint ./...
+go vet ./...
\ No newline at end of file
diff --git a/jws/jws.go b/jws/jws.go
index e777b18..e448caa 100644
--- a/jws/jws.go
+++ b/jws/jws.go
@@ -3,14 +3,65 @@
import (
"bytes"
"encoding/json"
- "sort"
"github.com/SermoDigital/jose"
"github.com/SermoDigital/jose/crypto"
)
-// JWS represents a specific JWS.
-type JWS struct {
+// JWS implements a JWS per RFC 7515.
+type JWS interface {
+ // Payload Returns the payload.
+ Payload() interface{}
+
+ // SetPayload sets the payload with the given value.
+ SetPayload(interface{})
+
+ // Protected returns the JWS' Protected Header.
+ // i represents the index of the Protected Header.
+ // Left empty, it defaults to 0.
+ Protected(...int) jose.Protected
+
+ // Header returns the JWS' unprotected Header.
+ // i represents the index of the Protected Header.
+ // Left empty, it defaults to 0.
+ Header(...int) jose.Header
+
+ // Verify validates the current JWS' signature as-is. Refer to
+ // ValidateMulti for more information.
+ Verify(key interface{}, method crypto.SigningMethod) error
+
+ // ValidateMulti validates the current JWS' signature as-is. Since it's
+ // meant to be called after parsing a stream of bytes into a JWS, it
+ // shouldn't do any internal parsing like the Sign, Flat, Compact, or
+ // General methods do.
+ VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o *SigningOpts) error
+
+ // VerifyCallback validates the current JWS' signature as-is. It
+ // accepts a callback function that can be used to access header
+ // parameters to lookup needed information. For example, looking
+ // up the "kid" parameter.
+ // The return slice must be a slice of keys used in the verification
+ // of the JWS.
+ VerifyCallback(fn VerifyCallback, methods []crypto.SigningMethod, o *SigningOpts) error
+
+ // General serializes the JWS into its "general" form per
+ // https://tools.ietf.org/html/rfc7515#section-7.2.1
+ General(keys ...interface{}) ([]byte, error)
+
+ // Flat serializes the JWS to its "flattened" form per
+ // https://tools.ietf.org/html/rfc7515#section-7.2.2
+ Flat(key interface{}) ([]byte, error)
+
+ // Compact serializes the JWS into its "compact" form per
+ // https://tools.ietf.org/html/rfc7515#section-7.1
+ Compact(key interface{}) ([]byte, error)
+
+ // IsJWT returns true if the JWS is a JWT.
+ IsJWT() bool
+}
+
+// jws represents a specific jws.
+type jws struct {
payload *payload
plcache rawBase64
clean bool
@@ -20,18 +71,38 @@
isJWT bool
}
-// Payload returns the JWS' payload.
-func (j *JWS) Payload() interface{} { return j.payload.v }
+// Payload returns the jws' payload.
+func (j *jws) Payload() interface{} { return j.payload.v }
-// SetPayload sets the JWS' raw, unexported payload.
-func (j *JWS) SetPayload(val interface{}) { j.payload.v = val }
+// SetPayload sets the jws' raw, unexported payload.
+func (j *jws) SetPayload(val interface{}) { j.payload.v = val }
-// sigHead represents the 'signatures' member of the JWS' "general"
+// Protected returns the JWS' Protected Header.
+// i represents the index of the Protected Header.
+// Left empty, it defaults to 0.
+func (j *jws) Protected(i ...int) jose.Protected {
+ if len(i) == 0 {
+ return j.sb[0].protected
+ }
+ return j.sb[i[0]].protected
+}
+
+// Header returns the JWS' unprotected Header.
+// i represents the index of the Protected Header.
+// Left empty, it defaults to 0.
+func (j *jws) Header(i ...int) jose.Header {
+ if len(i) == 0 {
+ return j.sb[0].unprotected
+ }
+ return j.sb[i[0]].unprotected
+}
+
+// sigHead represents the 'signatures' member of the jws' "general"
// serialization form per
// https://tools.ietf.org/html/rfc7515#section-7.2.1
//
// It's embedded inside the "flat" structure in order to properly
-// create the "flat" JWS.
+// create the "flat" jws.
type sigHead struct {
Protected rawBase64 `json:"protected,omitempty"`
Unprotected rawBase64 `json:"header,omitempty"`
@@ -54,8 +125,8 @@
return nil
}
-// New creates a new JWS with the provided crypto.SigningMethods.
-func New(content interface{}, methods ...crypto.SigningMethod) *JWS {
+// New creates a JWS with the provided crypto.SigningMethods.
+func New(content interface{}, methods ...crypto.SigningMethod) JWS {
sb := make([]sigHead, len(methods))
for i := range methods {
sb[i] = sigHead{
@@ -66,7 +137,7 @@
method: methods[i],
}
}
- return &JWS{
+ return &jws{
payload: &payload{v: content},
sb: sb,
}
@@ -93,8 +164,8 @@
Signatures []sigHead `json:"signatures,omitempty"`
}
-// Parse parses any of the three serialized JWS forms into a physical
-// JWS per https://tools.ietf.org/html/rfc7515#section-5.2
+// Parse parses any of the three serialized jws forms into a physical
+// jws per https://tools.ietf.org/html/rfc7515#section-5.2
//
// It accepts a json.Unmarshaler in order to properly parse
// the payload. In order to keep the caller from having to do extra
@@ -108,7 +179,7 @@
// ParseGeneral, ParseFlat, or ParseCompact.
// It should only be called if, for whatever reason, you do not
// know which form the serialized JWT is in.
-func Parse(encoded []byte, u ...json.Unmarshaler) (*JWS, error) {
+func Parse(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
// Try and unmarshal into a generic struct that'll
// hopefully hold either of the two JSON serialization
// formats.s
@@ -125,13 +196,13 @@
return g.parseGeneral(u...)
}
-// ParseGeneral parses a JWS serialized into its "general" form per
+// ParseGeneral parses a jws serialized into its "general" form per
// https://tools.ietf.org/html/rfc7515#section-7.2.1
-// into a physical JWS per
+// into a physical jws per
// https://tools.ietf.org/html/rfc7515#section-5.2
//
// For information on the json.Unmarshaler parameter, see Parse.
-func ParseGeneral(encoded []byte, u ...json.Unmarshaler) (*JWS, error) {
+func ParseGeneral(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
var g generic
if err := json.Unmarshal(encoded, &g); err != nil {
return nil, err
@@ -139,7 +210,7 @@
return g.parseGeneral(u...)
}
-func (g *generic) parseGeneral(u ...json.Unmarshaler) (*JWS, error) {
+func (g *generic) parseGeneral(u ...json.Unmarshaler) (JWS, error) {
var p payload
if len(u) > 0 {
@@ -165,7 +236,7 @@
g.clean = true
}
- return &JWS{
+ return &jws{
payload: &p,
plcache: g.Payload,
clean: true,
@@ -173,13 +244,13 @@
}, nil
}
-// ParseFlat parses a JWS serialized into its "flat" form per
+// ParseFlat parses a jws serialized into its "flat" form per
// https://tools.ietf.org/html/rfc7515#section-7.2.2
-// into a physical JWS per
+// into a physical jws per
// https://tools.ietf.org/html/rfc7515#section-5.2
//
// For information on the json.Unmarshaler parameter, see Parse.
-func ParseFlat(encoded []byte, u ...json.Unmarshaler) (*JWS, error) {
+func ParseFlat(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
var g generic
if err := json.Unmarshal(encoded, &g); err != nil {
return nil, err
@@ -187,7 +258,7 @@
return g.parseFlat(u...)
}
-func (g *generic) parseFlat(u ...json.Unmarshaler) (*JWS, error) {
+func (g *generic) parseFlat(u ...json.Unmarshaler) (JWS, error) {
var p payload
if len(u) > 0 {
@@ -211,7 +282,7 @@
return nil, err
}
- return &JWS{
+ return &jws{
payload: &p,
plcache: g.Payload,
clean: true,
@@ -219,17 +290,21 @@
}, nil
}
-// ParseCompact parses a JWS serialized into its "compact" form per
+// ParseCompact parses a jws serialized into its "compact" form per
// https://tools.ietf.org/html/rfc7515#section-7.1
-// into a physical JWS per
+// into a physical jws per
// https://tools.ietf.org/html/rfc7515#section-5.2
//
// For information on the json.Unmarshaler parameter, see Parse.
-func ParseCompact(encoded []byte, u ...json.Unmarshaler) (*JWS, error) {
+func ParseCompact(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
+ return parseCompact(encoded, false)
+}
+
+func parseCompact(encoded []byte, jwt bool) (*jws, error) {
// This section loosely follows
// https://tools.ietf.org/html/rfc7519#section-7.2
- // because it's used to parse _both_ JWS and JWTs.
+ // because it's used to parse _both_ jws and JWTs.
parts := bytes.Split(encoded, []byte{'.'})
if len(parts) != 3 {
@@ -250,9 +325,10 @@
return nil, err
}
- j := JWS{
+ j := jws{
payload: &payload{},
sb: []sigHead{s},
+ isJWT: jwt,
}
if err := j.payload.UnmarshalJSON(parts[1]); err != nil {
@@ -277,10 +353,11 @@
// should ignore duplicate Header keys instead of reporting an error when
// duplicate Header keys are found.
//
-// Note: Duplicate Header keys are defined in
-// https://tools.ietf.org/html/rfc7515#section-5.2
-// meaning keys that both the protected and unprotected
-// Headers possess.
+// Note:
+// Duplicate Header keys are defined in
+// https://tools.ietf.org/html/rfc7515#section-5.2
+// meaning keys that both the protected and unprotected
+// Headers possess.
var IgnoreDupes bool
// checkHeaders returns an error per the constraints described in
@@ -296,143 +373,3 @@
}
return nil
}
-
-// Any means any of the JWS signatures need to validate.
-// Refer to ValidateMulti for more information.
-const Any int = -1
-
-// ValidateMulti validates the current JWS as-is. Since it's meant to be
-// called after parsing a stream of bytes into a JWS, it doesn't do any
-// internal parsing like the Sign, Flat, Compact, or General methods do.
-// idx represents which signatures need to validate
-// in order for the JWS to be considered valid.
-// Use the constant `Any` (-1) if *any* should validate the JWS. Otherwise,
-// use the indexes of the signatures that need to validate in order
-// for the JWS to be considered valid.
-//
-// Notes:
-// 1.) If idx is omitted it defaults to requiring *all*
-// signatures validate
-// 2.) The JWS spec requires *at least* one
-// signature to validate in order for the JWS to be considered valid.
-func (j *JWS) ValidateMulti(keys []interface{}, methods []crypto.SigningMethod, idx ...int) error {
-
- if len(j.sb) != len(methods) {
- return ErrNotEnoughMethods
- }
-
- if len(keys) < 1 ||
- len(keys) > 1 && len(keys) != len(j.sb) {
- return ErrNotEnoughKeys
- }
-
- if len(keys) == 1 {
- k := keys[0]
- keys = make([]interface{}, len(methods))
- for i := range keys {
- keys[i] = k
- }
- }
-
- any := len(idx) == 1 && idx[0] == Any
- if !any {
- sort.Ints(idx)
- }
-
- rp := 0
- for i := range j.sb {
- if j.sb[i].validate(j.plcache, keys[i], methods[i]) == nil &&
- any || (rp < len(idx) && idx[rp] == i) {
- rp++
- }
- }
-
- if rp < len(idx) {
- return ErrDidNotValidate
- }
- return nil
-}
-
-// Validate validates the current JWS as-is. Refer to ValidateMulti
-// for more information.
-func (j *JWS) Validate(key interface{}, method crypto.SigningMethod) error {
- if len(j.sb) < 1 {
- return ErrCannotValidate
- }
- return j.sb[0].validate(j.plcache, key, method)
-}
-
-func (s *sigHead) validate(pl []byte, key interface{}, method crypto.SigningMethod) error {
- if s.method != method {
- return ErrMismatchedAlgorithms
- }
- return method.Verify(format(s.Protected, pl), s.Signature, key)
-}
-
-// SetProtected sets the protected Header with the given value.
-// If i is provided, it'll assume the JWS is in the "general" format,
-// and set the Header at index i (inside the signatures member) with
-// the given value.
-func (j *JWS) SetProtected(key string, val interface{}, i ...int) {
- k := 0
- if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
- k = i[0]
- }
- j.sb[k].protected.Set(key, val)
-}
-
-// RemoveProtected removes the value inside the protected Header that
-// corresponds with the given key.
-// For information on parameter i, see SetProtected.
-func (j *JWS) RemoveProtected(key string, i ...int) {
- k := 0
- if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
- k = i[0]
- }
- j.sb[k].protected.Del(key)
-}
-
-// GetProtected retrieves the value inside the protected Header that
-// corresponds with the given key.
-// For information on parameter i, see SetProtected.
-func (j *JWS) GetProtected(key string, i ...int) interface{} {
- k := 0
- if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
- k = i[0]
- }
- return j.sb[k].protected.Get(key)
-}
-
-// SetUnprotected sets the protected Header with the given value.
-// If i is provided, it'll assume the JWS is in the "general" format,
-// and set the Header at index i (inside the signatures member) with
-// the given value.
-func (j *JWS) SetUnprotected(key string, val interface{}, i ...int) {
- k := 0
- if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
- k = i[0]
- }
- j.sb[k].unprotected.Set(key, val)
-}
-
-// RemoveUnprotected removes the value inside the unprotected Header that
-// corresponds with the given key.
-// For information on parameter i, see SetUnprotected.
-func (j *JWS) RemoveUnprotected(key string, i ...int) {
- k := 0
- if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
- k = i[0]
- }
- j.sb[k].unprotected.Del(key)
-}
-
-// GetUnprotected retrieves the value inside the protected Header that
-// corresponds with the given key.
-// For information on parameter i, see SetUnprotected.
-func (j *JWS) GetUnprotected(key string, i ...int) interface{} {
- k := 0
- if len(i) > 0 && len(i) < len(j.sb) && i[0] > -1 {
- k = i[0]
- }
- return j.sb[k].unprotected.Get(key)
-}
diff --git a/jws/jws_serialize.go b/jws/jws_serialize.go
index 33369dd..9cb53af 100644
--- a/jws/jws_serialize.go
+++ b/jws/jws_serialize.go
@@ -7,7 +7,7 @@
// Flat serializes the JWS to its "flattened" form per
// https://tools.ietf.org/html/rfc7515#section-7.2.2
-func (j *JWS) Flat(key interface{}) ([]byte, error) {
+func (j *jws) Flat(key interface{}) ([]byte, error) {
if len(j.sb) < 1 {
return nil, ErrNotEnoughMethods
}
@@ -29,7 +29,7 @@
// If only one key is passed it's used for all the provided
// crypto.SigningMethods. Otherwise, len(keys) must equal the number
// of crypto.SigningMethods added.
-func (j *JWS) General(keys ...interface{}) ([]byte, error) {
+func (j *jws) General(keys ...interface{}) ([]byte, error) {
if err := j.sign(keys...); err != nil {
return nil, err
}
@@ -44,7 +44,7 @@
// Compact serializes the JWS into its "compact" form per
// https://tools.ietf.org/html/rfc7515#section-7.1
-func (j *JWS) Compact(key interface{}) ([]byte, error) {
+func (j *jws) Compact(key interface{}) ([]byte, error) {
if len(j.sb) < 1 {
return nil, ErrNotEnoughMethods
}
@@ -65,7 +65,7 @@
}
// sign signs each index of j's sb member.
-func (j *JWS) sign(keys ...interface{}) error {
+func (j *jws) sign(keys ...interface{}) error {
if err := j.cache(); err != nil {
return err
}
@@ -100,7 +100,7 @@
}
// cache marshals the payload, but only if it's changed since the last cache.
-func (j *JWS) cache() error {
+func (j *jws) cache() error {
if !j.clean {
var err error
j.plcache, err = j.payload.Base64()
diff --git a/jws/jws_test.go b/jws/jws_test.go
index 6c8355f..483d604 100644
--- a/jws/jws_test.go
+++ b/jws/jws_test.go
@@ -4,6 +4,7 @@
"bytes"
"encoding/base64"
"encoding/json"
+ "math/rand"
"testing"
"github.com/SermoDigital/jose/crypto"
@@ -43,8 +44,8 @@
t.Error(err)
}
- if !bytes.Equal(easyData, *j2.payload.v.(*easy)) {
- Error(t, easyData, *j2.payload.v.(*easy))
+ if !bytes.Equal(easyData, *j2.Payload().(*easy)) {
+ Error(t, easyData, *j2.Payload().(*easy))
}
}
@@ -61,7 +62,7 @@
}
var k easy
- if err := k.UnmarshalJSON([]byte(j2.payload.v.(string))); err != nil {
+ if err := k.UnmarshalJSON([]byte(j2.Payload().(string))); err != nil {
t.Error(err)
}
@@ -71,7 +72,12 @@
}
func TestParseGeneral(t *testing.T) {
- sm := []crypto.SigningMethod{crypto.SigningMethodRS512, crypto.SigningMethodPS384, crypto.SigningMethodPS256}
+ sm := []crypto.SigningMethod{
+ crypto.SigningMethodRS256,
+ crypto.SigningMethodPS384,
+ crypto.SigningMethodPS512,
+ }
+
j := New(easyData, sm...)
b, err := j.General(rsaPriv)
if err != nil {
@@ -83,7 +89,7 @@
t.Error(err)
}
- for i, v := range j2.sb {
+ for i, v := range j2.(*jws).sb {
k := v.protected.Get("alg").(string)
if k != sm[i].Alg() {
Error(t, sm[i].Alg(), k)
@@ -91,8 +97,13 @@
}
}
-func TestValidateMulti(t *testing.T) {
- sm := []crypto.SigningMethod{crypto.SigningMethodRS512, crypto.SigningMethodPS384, crypto.SigningMethodPS256}
+func TestVerifyMulti(t *testing.T) {
+ sm := []crypto.SigningMethod{
+ crypto.SigningMethodRS256,
+ crypto.SigningMethodPS384,
+ crypto.SigningMethodPS512,
+ }
+
j := New(easyData, sm...)
b, err := j.General(rsaPriv)
if err != nil {
@@ -105,13 +116,18 @@
}
keys := []interface{}{rsaPub, rsaPub, rsaPub}
- if err := j2.ValidateMulti(keys, sm, Any); err != nil {
+ if err := j2.VerifyMulti(keys, sm, nil); err != nil {
t.Error(err)
}
}
-func TestValidateMultiMismatchedAlgs(t *testing.T) {
- sm := []crypto.SigningMethod{crypto.SigningMethodRS256, crypto.SigningMethodPS384, crypto.SigningMethodPS512}
+func TestVerifyMultiMismatchedAlgs(t *testing.T) {
+ sm := []crypto.SigningMethod{
+ crypto.SigningMethodRS256,
+ crypto.SigningMethodPS384,
+ crypto.SigningMethodPS512,
+ }
+
j := New(easyData, sm...)
b, err := j.General(rsaPriv)
if err != nil {
@@ -123,17 +139,29 @@
t.Error(err)
}
- // Shuffle it.
- sm = []crypto.SigningMethod{crypto.SigningMethodRS512, crypto.SigningMethodPS256, crypto.SigningMethodPS384}
+ shuffle := func(a []crypto.SigningMethod) {
+ N := len(a)
+ for i := 0; i < N; i++ {
+ r := i + rand.Intn(N-i)
+ a[r], a[i] = a[i], a[r]
+ }
+ }
+
+ shuffle(sm)
keys := []interface{}{rsaPub, rsaPub, rsaPub}
- if err := j2.ValidateMulti(keys, sm, Any); err == nil {
+ if err := j2.VerifyMulti(keys, sm, nil); err == nil {
t.Error("Should NOT be nil")
}
}
-func TestValidateMultiNotEnoughMethods(t *testing.T) {
- sm := []crypto.SigningMethod{crypto.SigningMethodRS256, crypto.SigningMethodPS384, crypto.SigningMethodPS512}
+func TestVerifyMultiNotEnoughMethods(t *testing.T) {
+ sm := []crypto.SigningMethod{
+ crypto.SigningMethodRS256,
+ crypto.SigningMethodPS384,
+ crypto.SigningMethodPS512,
+ }
+
j := New(easyData, sm...)
b, err := j.General(rsaPriv)
if err != nil {
@@ -148,13 +176,18 @@
sm = sm[0 : len(sm)-1]
keys := []interface{}{rsaPub, rsaPub, rsaPub}
- if err := j2.ValidateMulti(keys, sm, Any); err == nil {
+ if err := j2.VerifyMulti(keys, sm, nil); err == nil {
t.Error("Should NOT be nil")
}
}
-func TestValidateMultiNotEnoughKeys(t *testing.T) {
- sm := []crypto.SigningMethod{crypto.SigningMethodRS256, crypto.SigningMethodPS384, crypto.SigningMethodPS512}
+func TestVerifyMultiNotEnoughKeys(t *testing.T) {
+ sm := []crypto.SigningMethod{
+ crypto.SigningMethodRS256,
+ crypto.SigningMethodPS384,
+ crypto.SigningMethodPS512,
+ }
+
j := New(easyData, sm...)
b, err := j.General(rsaPriv)
if err != nil {
@@ -167,12 +200,12 @@
}
keys := []interface{}{rsaPub, rsaPub}
- if err := j2.ValidateMulti(keys, sm, Any); err == nil {
+ if err := j2.VerifyMulti(keys, sm, nil); err == nil {
t.Error("Should NOT be nil")
}
}
-func TestValidate(t *testing.T) {
+func TestVerify(t *testing.T) {
j := New(easyData, crypto.SigningMethodPS512)
b, err := j.Flat(rsaPriv)
if err != nil {
@@ -184,7 +217,7 @@
t.Error(err)
}
- if err := j2.Validate(rsaPub, crypto.SigningMethodPS512); err != nil {
+ if err := j2.Verify(rsaPub, crypto.SigningMethodPS512); err != nil {
t.Error(err)
}
}
diff --git a/jws/jws_validate.go b/jws/jws_validate.go
new file mode 100644
index 0000000..453eb8f
--- /dev/null
+++ b/jws/jws_validate.go
@@ -0,0 +1,216 @@
+package jws
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/SermoDigital/jose/crypto"
+)
+
+// VerifyCallback is a callback function that can be used to access header
+// parameters to lookup needed information. For example, looking
+// up the "kid" parameter.
+// The return slice must be a slice of keys used in the verification
+// of the JWS.
+type VerifyCallback func(JWS) ([]interface{}, error)
+
+// VerifyCallback validates the current JWS' signature as-is. It
+// accepts a callback function that can be used to access header
+// parameters to lookup needed information. For example, looking
+// up the "kid" parameter.
+// The return slice must be a slice of keys used in the verification
+// of the JWS.
+func (j *jws) VerifyCallback(fn VerifyCallback, methods []crypto.SigningMethod, o *SigningOpts) error {
+ keys, err := fn(j)
+ if err != nil {
+ return err
+ }
+ return j.VerifyMulti(keys, methods, o)
+}
+
+// IsMultiError returns true if the given error is type MultiError.
+func IsMultiError(err error) bool {
+ _, ok := err.(MultiError)
+ return ok
+}
+
+// MultiError is a slice of errors.
+type MultiError []error
+
+func (m MultiError) sanityCheck() error {
+ if m == nil {
+ return nil
+ }
+ return m
+}
+
+// Errors implements the error interface.
+func (m MultiError) Error() string {
+ s, n := "", 0
+ for _, e := range m {
+ if e != nil {
+ if n == 0 {
+ s = e.Error()
+ }
+ n++
+ }
+ }
+ switch n {
+ case 0:
+ return "(0 errors)"
+ case 1:
+ return s
+ case 2:
+ return s + " (and 1 other error)"
+ }
+ return fmt.Sprintf("%s (and %d other errors)", s, n-1)
+}
+
+// Any means any of the JWS signatures need to verify.
+// Refer to verifyMulti for more information.
+const Any int = 0
+
+// VerifyMulti verifies the current JWS as-is. Since it's meant to be
+// called after parsing a stream of bytes into a JWS, it doesn't do any
+// internal parsing like the Sign, Flat, Compact, or General methods do.
+func (j *jws) VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o *SigningOpts) error {
+
+ // Catch a simple mistake. Parameter o is irrelevant in this scenario.
+ if len(keys) == 1 &&
+ len(methods) == 1 &&
+ len(j.sb) == 1 {
+ return j.Verify(keys[0], methods[0])
+ }
+
+ if len(j.sb) != len(methods) {
+ return ErrNotEnoughMethods
+ }
+
+ if len(keys) < 1 ||
+ len(keys) > 1 && len(keys) != len(j.sb) {
+ return ErrNotEnoughKeys
+ }
+
+ // TODO do this better.
+ if len(keys) == 1 {
+ k := keys[0]
+ keys = make([]interface{}, len(methods))
+ for i := range keys {
+ keys[i] = k
+ }
+ }
+
+ var o2 SigningOpts
+ if o == nil {
+ o = &SigningOpts{}
+ }
+
+ var m MultiError
+ for i := range j.sb {
+ err := j.sb[i].verify(j.plcache, keys[i], methods[i])
+ if err != nil {
+ m = append(m, err)
+ } else {
+ o2.Inc()
+ if o.Needs(i) {
+ o2.Append(i)
+ }
+ }
+ }
+
+ if err := o.Validate(&o2); err != nil {
+ return err
+ }
+ return m.sanityCheck()
+}
+
+// SigningOpts is a struct which holds options for validating
+// JWS signatures.
+// Number represents the cumulative which signatures need to verify
+// in order for the JWS to be considered valid.
+// Leave 'Number' empty or set it to the constant 'Any' if any number of
+// valid signatures (greater than one) should verify the JWS.
+//
+// Use the indices of the signatures that need to verify in order
+// for the JWS to be considered valid if specific signatures need
+// to verify in order for the JWS to be considered valid.
+//
+// Note:
+// The JWS spec requires *at least* one
+// signature to verify in order for the JWS to be considered valid.
+type SigningOpts struct {
+ // Minimum of signatures which need to verify.
+ Number int
+
+ // Indices of specific signatures which need to verify.
+ Indices []int
+ ptr int
+
+ _ struct{}
+}
+
+// Append appends x to s's Indices member.
+func (s *SigningOpts) Append(x int) {
+ s.Indices = append(s.Indices, x)
+}
+
+// Needs returns true if x resides inside s's Indices member
+// for the given index. If true, it increments s's internal
+// index. It's used to match two SigningOpts Indices members.
+func (s *SigningOpts) Needs(x int) bool {
+ if s.ptr < len(s.Indices) &&
+ s.Indices[s.ptr] == x {
+ s.ptr++
+ return true
+ }
+ return false
+}
+
+// Inc increments s's Number member by one.
+func (s *SigningOpts) Inc() { s.Number++ }
+
+// Validate returns any errors found while validating the
+// provided SigningOpts. The receiver validates the parameter `have`.
+// It'll return an error if the passed SigningOpts' Number member is less
+// than s's or if the passed SigningOpts' Indices slice isn't equal to s's.
+func (s *SigningOpts) Validate(have *SigningOpts) error {
+ if have.Number < s.Number {
+ return errors.New("TODO 2")
+ }
+ if s.Indices != nil &&
+ !eq(s.Indices, have.Indices) {
+ return errors.New("TODO 3")
+ }
+ return nil
+}
+
+func eq(a, b []int) bool {
+ if a == nil && b == nil {
+ return true
+ }
+ if a == nil || b == nil || len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i] != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// Verify verifies the current JWS as-is. Refer to verifyMulti
+// for more information.
+func (j *jws) Verify(key interface{}, method crypto.SigningMethod) error {
+ if len(j.sb) < 1 {
+ return ErrCannotValidate
+ }
+ return j.sb[0].verify(j.plcache, key, method)
+}
+
+func (s *sigHead) verify(pl []byte, key interface{}, method crypto.SigningMethod) error {
+ if s.method != method {
+ return ErrMismatchedAlgorithms
+ }
+ return method.Verify(format(s.Protected, pl), s.Signature, key)
+}
diff --git a/jws/jwt.go b/jws/jwt.go
index 2f4dfc2..c6051dc 100644
--- a/jws/jwt.go
+++ b/jws/jwt.go
@@ -12,13 +12,13 @@
// NewJWT creates a new JWT with the given claims.
func NewJWT(claims Claims, method crypto.SigningMethod) jwt.JWT {
- j := New(claims, method)
+ j := New(claims, method).(*jws)
j.isJWT = true
return j
}
// Serialize helps implements jwt.JWT.
-func (j *JWS) Serialize(key interface{}) ([]byte, error) {
+func (j *jws) Serialize(key interface{}) ([]byte, error) {
if j.isJWT {
return j.Compact(key)
}
@@ -26,7 +26,7 @@
}
// Claims helps implements jwt.JWT.
-func (j *JWS) Claims() jwt.Claims {
+func (j *jws) Claims() jwt.Claims {
if j.isJWT {
if c, ok := j.payload.v.(Claims); ok {
return jwt.Claims(c)
@@ -40,70 +40,60 @@
// a set of claims) it'll return an error stating the
// JWT isn't a JWT.
func ParseJWT(encoded []byte) (jwt.JWT, error) {
- t, err := ParseCompact(encoded)
+ t, err := parseCompact(encoded, true)
if err != nil {
return nil, err
}
- c, ok := t.payload.v.(map[string]interface{})
+ c, ok := t.Payload().(map[string]interface{})
if !ok {
return nil, ErrIsNotJWT
}
- t.payload.v = Claims(c)
- t.isJWT = true
+ t.SetPayload(Claims(c))
return t, nil
}
// IsJWT returns true if the JWS is a JWT.
-func (j *JWS) IsJWT() bool { return j.isJWT }
+func (j *jws) IsJWT() bool { return j.isJWT }
-// Verify helps implement jwt.JWT.
-func (j *JWS) Verify(key interface{}, m crypto.SigningMethod, o ...jwt.Opts) error {
+func (j *jws) Validate(key interface{}, m crypto.SigningMethod, v ...*jwt.Validator) error {
if j.isJWT {
- if err := j.Validate(key, m); err != nil {
+ if err := j.Verify(key, m); err != nil {
return err
}
+ var v1 jwt.Validator
+ if len(v) > 0 {
+ v1 = *v[0]
+ }
+
c, ok := j.payload.v.(Claims)
if ok {
- var p jwt.Opts
- if len(o) > 0 {
- p = o[0]
+ if err := v1.Validate(j); err != nil {
+ return err
}
-
- if p.Fn != nil {
- if err := p.Fn(jwt.Claims(c)); err != nil {
- return err
- }
- }
- return jwt.Claims(c).Validate(time.Now().Unix(), p.EXP, p.NBF)
+ return jwt.Claims(c).Validate(time.Now().Unix(), v1.EXP, v1.NBF)
}
}
return ErrIsNotJWT
}
-// Opts represents some of the validation options.
-// It mimics jwt.Opts.
-type Opts struct {
- EXP int64 // EXPLeeway
- NBF int64 // NBFLeeway
- Fn func(Claims) error
- _ struct{}
+// Conv converts a func(Claims) error to type jwt.ValidateFunc.
+func Conv(fn func(Claims) error) jwt.ValidateFunc {
+ if fn == nil {
+ return nil
+ }
+ return func(c jwt.Claims) error {
+ return fn(Claims(c))
+ }
}
-// C is shorthand for Convert(fn).
-func (o Opts) C() jwt.Opts { return o.Convert() }
-
-// Convert converts Opts into jwt.Opts.
-func (o Opts) Convert() jwt.Opts {
- p := jwt.Opts{
- EXP: o.EXP,
- NBF: o.NBF,
+// NewOpts returns a pointer to a jwt.Validator structure containing
+// the info to be used in the validation of a JWT.
+func NewOpts(c Claims, exp, nbf int64) *jwt.Validator {
+ return &jwt.Validator{
+ Expected: jwt.Claims(c),
+ EXP: exp,
+ NBF: nbf,
}
- if o.Fn != nil {
- p.Fn = func(c jwt.Claims) error {
- return o.Fn(Claims(c))
- }
- }
- return p
}
-var _ jwt.JWT = (*JWS)(nil)
+var _ jwt.JWT = (*jws)(nil)
diff --git a/jwt/claims.go b/jwt/claims.go
index 8f53d65..8700d48 100644
--- a/jwt/claims.go
+++ b/jwt/claims.go
@@ -13,13 +13,13 @@
// Validate validates the Claims per the claims found in
// https://tools.ietf.org/html/rfc7519#section-4.1
func (c Claims) Validate(now, expLeeway, nbfLeeway int64) error {
- if exp, ok := c.expiration(); ok {
+ if exp, ok := c.Expiration(); ok {
if !within(exp, expLeeway, now) {
return ErrTokenIsExpired
}
}
- if nbf, ok := c.notBefore(); ok {
+ if nbf, ok := c.NotBefore(); ok {
if !within(nbf, nbfLeeway, now) {
return ErrTokenNotYetValid
}
@@ -27,16 +27,6 @@
return nil
}
-func (c Claims) expiration() (int64, bool) {
- v, ok := c.Get("exp").(int64)
- return v, ok
-}
-
-func (c Claims) notBefore() (int64, bool) {
- v, ok := c.Get("nbf").(int64)
- return v, ok
-}
-
func within(cur, delta, max int64) bool {
return cur+delta < max || cur-delta < max
}
@@ -106,6 +96,126 @@
return nil
}
+// Issuer retrieves claim "iss" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.1
+func (c Claims) Issuer() (string, bool) {
+ v, ok := c.Get("iss").(string)
+ return v, ok
+}
+
+// Subject retrieves claim "sub" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.2
+func (c Claims) Subject() (string, bool) {
+ v, ok := c.Get("sub").(string)
+ return v, ok
+}
+
+// Audience retrieves claim "aud" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.3
+func (c Claims) Audience() (interface{}, bool) {
+ switch t := c.Get("aud").(type) {
+ case string, []string:
+ return t, true
+ default:
+ return nil, false
+ }
+}
+
+// Expiration retrieves claim "exp" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.4
+func (c Claims) Expiration() (int64, bool) {
+ v, ok := c.Get("exp").(int64)
+ return v, ok
+}
+
+// NotBefore retrieves claim "nbf" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.5
+func (c Claims) NotBefore() (int64, bool) {
+ v, ok := c.Get("nbf").(int64)
+ return v, ok
+}
+
+// IssuedAt retrieves claim "iat" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.6
+func (c Claims) IssuedAt() (int64, bool) {
+ v, ok := c.Get("iat").(int64)
+ return v, ok
+}
+
+// JWTID retrieves claim "jti" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.7
+func (c Claims) JWTID() (string, bool) {
+ v, ok := c.Get("jti").(string)
+ return v, ok
+}
+
+// RemoveIssuer deletes claim "iss" from c.
+func (c Claims) RemoveIssuer() { c.Del("iss") }
+
+// RemoveSubject deletes claim "sub" from c.
+func (c Claims) RemoveSubject() { c.Del("sub") }
+
+// RemoveAudience deletes claim "aud" from c.
+func (c Claims) RemoveAudience() { c.Del("aud") }
+
+// RemoveExpiration deletes claim "exp" from c.
+func (c Claims) RemoveExpiration() { c.Del("exp") }
+
+// RemoveNotBefore deletes claim "nbf" from c.
+func (c Claims) RemoveNotBefore() { c.Del("nbf") }
+
+// RemoveIssuedAt deletes claim "iat" from c.
+func (c Claims) RemoveIssuedAt() { c.Del("iat") }
+
+// RemoveJWTID deletes claim "jti" from c.
+func (c Claims) RemoveJWTID() { c.Del("jti") }
+
+// SetIssuer sets claim "iss" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.1
+func (c Claims) SetIssuer(issuer string) {
+ c.Set("iss", issuer)
+}
+
+// SetSubject sets claim "iss" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.2
+func (c Claims) SetSubject(subject string) {
+ c.Set("sub", subject)
+}
+
+// SetAudience sets claim "aud" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.3
+func (c Claims) SetAudience(audience ...string) {
+ if len(audience) == 1 {
+ c.Set("aud", audience[0])
+ } else {
+ c.Set("aud", audience)
+ }
+}
+
+// SetExpiration sets claim "exp" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.4
+func (c Claims) SetExpiration(expiration int64) {
+ c.Set("exp", expiration)
+}
+
+// SetNotBefore sets claim "nbf" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.5
+func (c Claims) SetNotBefore(notBefore int64) {
+ c.Set("nbf", notBefore)
+}
+
+// SetIssuedAt sets claim "iat" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.6
+func (c Claims) SetIssuedAt(issuedAt int64) {
+ c.Set("iat", issuedAt)
+}
+
+// SetJWTID sets claim "jti" per its type in
+// https://tools.ietf.org/html/rfc7519#section-4.1.7
+func (c Claims) SetJWTID(uniqueID string) {
+ c.Set("jti", uniqueID)
+}
+
var (
_ json.Marshaler = (Claims)(nil)
_ json.Unmarshaler = (*Claims)(nil)
diff --git a/jwt/eq.go b/jwt/eq.go
new file mode 100644
index 0000000..a7a37a9
--- /dev/null
+++ b/jwt/eq.go
@@ -0,0 +1,52 @@
+package jwt
+
+import "reflect"
+
+// eq returns true if the two types are either strings
+// or comparable slices.
+func eq(a, b interface{}) bool {
+ t1 := reflect.TypeOf(a)
+ t2 := reflect.TypeOf(b)
+
+ if t1.Kind() == t2.Kind() {
+ switch t1.Kind() {
+ case reflect.Slice:
+ return eqSlice(a, b)
+ case reflect.String:
+ return reflect.ValueOf(a).String() ==
+ reflect.ValueOf(b).String()
+ }
+ }
+ return false
+}
+
+// eqSlice returns true if the two interfaces are both slices
+// and are equal. For example: https://play.golang.org/p/5VLMwNE3i-
+func eqSlice(a, b interface{}) bool {
+ if a == nil || b == nil {
+ return false
+ }
+
+ v1 := reflect.ValueOf(a)
+ v2 := reflect.ValueOf(b)
+
+ if v1.Kind() != reflect.Slice ||
+ v2.Kind() != reflect.Slice {
+ return false
+ }
+
+ if v1.Len() == v2.Len() && v1.Len() > 0 {
+ for i := 0; i < v1.Len() && i < v2.Len(); i++ {
+ k1 := v1.Index(i)
+ k2 := v2.Index(i)
+ if k1.Type().Comparable() &&
+ k2.Type().Comparable() &&
+ k1.CanInterface() && k2.CanInterface() &&
+ k1.Interface() != k2.Interface() {
+ return false
+ }
+ }
+ return true
+ }
+ return false
+}
diff --git a/jwt/jwt.go b/jwt/jwt.go
index 5a30aae..1b83c39 100644
--- a/jwt/jwt.go
+++ b/jwt/jwt.go
@@ -1,16 +1,12 @@
package jwt
-import "github.com/SermoDigital/jose/crypto"
+import (
+ "errors"
-// Opts represents some of the validation options.
-type Opts struct {
- EXP int64 // EXPLeeway
- NBF int64 // NBFLeeway
- Fn ValidateFunc // See ValidateFunc for more information.
- _ struct{}
-}
+ "github.com/SermoDigital/jose/crypto"
+)
-// JWT represents a JWT as per RFC 7519.
+// JWT represents a JWT per RFC 7519.
// It's described as an interface instead of a physical structure
// because both JWS and JWEs can be JWTs. So, in order to use either,
// import one of those two packages and use their "NewJWT" (and other)
@@ -19,10 +15,10 @@
// Claims returns the set of Claims.
Claims() Claims
- // Verify returns an error describing any issues found while
+ // Validate returns an error describing any issues found while
// validating the JWT. For info on the fn parameter, see the
// comment on ValidateFunc.
- Verify(key interface{}, method crypto.SigningMethod, o ...Opts) error
+ Validate(key interface{}, method crypto.SigningMethod, v ...*Validator) error
// Serialize serializes the JWT into its on-the-wire
// representation.
@@ -38,3 +34,113 @@
// Custom JWT implementations are free to abuse this, but it is
// not recommended.
type ValidateFunc func(Claims) error
+
+// Validator represents some of the validation options.
+type Validator struct {
+ Expected Claims // If non-nil, these are required to match.
+ EXP int64 // EXPLeeway
+ NBF int64 // NBFLeeway
+ Fn ValidateFunc // See ValidateFunc for more information.
+
+ _ struct{}
+}
+
+var defaultClaims = []string{
+ "iss", "sub", "aud",
+ "exp", "nbf", "iat",
+ "jti",
+}
+
+// Validate validates the JWT based on the expected claims in v.
+// Note: it only validates the registered claims per
+// https://tools.ietf.org/html/rfc7519#section-4.1
+//
+// Custom claims should be validated using v's Fn member.
+func (v *Validator) Validate(j JWT) error {
+ if iss, ok := v.Expected.Issuer(); ok &&
+ j.Claims().Get("iss") != iss {
+ return errors.New("TODO 12")
+ }
+ if sub, ok := v.Expected.Subject(); ok &&
+ j.Claims().Get("sub") != sub {
+ return errors.New("TODO 12")
+ }
+ if iat, ok := v.Expected.IssuedAt(); ok &&
+ j.Claims().Get("iat") != iat {
+ return errors.New("TODO 12")
+ }
+ if jti, ok := v.Expected.JWTID(); ok &&
+ j.Claims().Get("jti") != jti {
+ return errors.New("TODO 12")
+ }
+ if aud, ok := v.Expected.Audience(); ok &&
+ !eq(j.Claims().Get("aud"), aud) {
+ return errors.New("TODO 12")
+ }
+
+ if v.Fn != nil {
+ return v.Fn(j.Claims())
+ }
+ return nil
+}
+
+// SetClaim sets the claim with the given val.
+func (v *Validator) SetClaim(claim string, val interface{}) {
+ v.expect()
+ v.Expected.Set(claim, val)
+}
+
+// SetIssuer sets the "iss" claim per
+// https://tools.ietf.org/html/rfc7519#section-4.1.1
+func (v *Validator) SetIssuer(iss string) {
+ v.expect()
+ v.Expected.Set("iss", iss)
+}
+
+// SetSubject sets the "sub" claim per
+// https://tools.ietf.org/html/rfc7519#section-4.1.2
+func (v *Validator) SetSubject(sub string) {
+ v.expect()
+ v.Expected.Set("sub", sub)
+}
+
+// SetAudience sets the "aud" claim per
+// https://tools.ietf.org/html/rfc7519#section-4.1.3
+func (v *Validator) SetAudience(aud string) {
+ v.expect()
+ v.Expected.Set("aud", aud)
+}
+
+// SetExpiration sets the "exp" claim per
+// https://tools.ietf.org/html/rfc7519#section-4.1.4
+func (v *Validator) SetExpiration(exp int64) {
+ v.expect()
+ v.Expected.Set("exp", exp)
+}
+
+// SetNotBefore sets the "nbf" claim per
+// https://tools.ietf.org/html/rfc7519#section-4.1.5
+func (v *Validator) SetNotBefore(nbf int64) {
+ v.expect()
+ v.Expected.Set("nbf", nbf)
+}
+
+// SetIssuedAt sets the "iat" claim per
+// https://tools.ietf.org/html/rfc7519#section-4.1.6
+func (v *Validator) SetIssuedAt(iat int64) {
+ v.expect()
+ v.Expected.Set("iat", iat)
+}
+
+// SetJWTID sets the "jti" claim per
+// https://tools.ietf.org/html/rfc7519#section-4.1.7
+func (v *Validator) SetJWTID(jti string) {
+ v.expect()
+ v.Expected.Set("jti", jti)
+}
+
+func (v *Validator) expect() {
+ if v.Expected == nil {
+ v.Expected = make(Claims)
+ }
+}