jws: added JWS validation and Parsing + tests
diff --git a/.tags b/.tags
new file mode 100644
index 0000000..d01d39b
--- /dev/null
+++ b/.tags
@@ -0,0 +1,6 @@
+!_TAG_FILE_FORMAT 2 /extended format; --format=1 will not append ;" to lines/
+!_TAG_FILE_SORTED 1 /0=unsorted, 1=sorted, 2=foldcase/
+!_TAG_PROGRAM_AUTHOR Darren Hiebert /dhiebert@users.sourceforge.net/
+!_TAG_PROGRAM_NAME Exuberant Ctags //
+!_TAG_PROGRAM_URL http://ctags.sourceforge.net /official site/
+!_TAG_PROGRAM_VERSION 5.8 //
diff --git a/base64.go b/base64.go
index 1956791..6be62b8 100644
--- a/base64.go
+++ b/base64.go
@@ -5,7 +5,7 @@
// Encoder is satisfied if the type can marshal itself into a valid
// structure for a JWS.
type Encoder interface {
- // Base64 implies -> JSON -> Base64
+ // Base64 implies T -> JSON -> Base64
Base64() ([]byte, error)
}
diff --git a/header.go b/header.go
index ef58feb..b5fc9e8 100644
--- a/header.go
+++ b/header.go
@@ -47,6 +47,30 @@
return h.MarshalJSON()
}
+// UnmarshalJSON implements json.Unmarshaler for Header.
+func (h *Header) UnmarshalJSON(b []byte) error {
+ if b == nil {
+ return nil
+ }
+
+ b, err := DecodeEscaped(b)
+ if err != nil {
+ return err
+ }
+
+ // Since json.Unmarshal calls UnmarshalJSON,
+ // calling json.Unmarshal on *p would be infinitely recursive
+ // A temp variable is needed because &map[string]interface{}(*p) is
+ // invalid Go.
+
+ tmp := map[string]interface{}(*h)
+ if err = json.Unmarshal(b, &tmp); err != nil {
+ return err
+ }
+ *h = Header(tmp)
+ return nil
+}
+
// Protected Headers are base64-encoded after they're marshaled into
// JSON.
type Protected Header
@@ -96,20 +120,9 @@
// UnmarshalJSON implements json.Unmarshaler for Protected.
func (p *Protected) UnmarshalJSON(b []byte) error {
- b, err := DecodeEscaped(b)
- if err != nil {
- return err
- }
-
- // Since json.Unmarshal calls UnmarshalJSON,
- // calling json.Unmarshal on *p would be infinitely recursive
- // A temp variable is needed because &Header(*p) is invalid Go.
-
- tmp := map[string]interface{}(*p)
- if err = json.Unmarshal(b, &tmp); err != nil {
- return err
- }
- *p = Protected(tmp)
+ var h Header
+ h.UnmarshalJSON(b)
+ *p = Protected(h)
return nil
}
diff --git a/jws/ecdsa.go b/jws/ecdsa.go
index 21a6328..ce46b47 100644
--- a/jws/ecdsa.go
+++ b/jws/ecdsa.go
@@ -104,8 +104,10 @@
func (m *SigningMethodECDSA) Hasher() crypto.Hash { return m.Hash }
// MarshalJSON is in case somebody decides to place SigningMethodECDSA
-// inside the Claims in the "alg" portion. In order to keep things sane,
-// marshalling this will do the same thing as jws.SetProtected("alg", m.Alg())
+// inside the Header, presumably because they (wrongly) decided it was a good
+// idea to use the SigningMethod itself instead of the SigningMethod's Alg
+// method. In order to keep things sane, marshalling this will simply
+// return the JSON-compatible representation of m.Alg().
func (m *SigningMethodECDSA) MarshalJSON() ([]byte, error) {
return []byte(`"` + m.Alg() + `"`), nil
}
diff --git a/jws/errors.go b/jws/errors.go
index 6c792a2..2d48368 100644
--- a/jws/errors.go
+++ b/jws/errors.go
@@ -14,4 +14,39 @@
// ErrCouldNotUnmarshal is returned when Parse's json.Unmarshaler
// parameter returns an error.
ErrCouldNotUnmarshal = errors.New("custom unmarshal failed")
+
+ // ErrNotCompact signals that the provided potential JWS is not
+ // in its compact representation.
+ ErrNotCompact = errors.New("not a compact JWS")
+
+ // ErrDuplicateHeaderParameter signals that there are duplicate parameters
+ // in the provided Headers.
+ ErrDuplicateHeaderParameter = errors.New("duplicate parameters in the JOSE Header")
+
+ // ErrTwoEmptyHeaders is returned if both Headers are empty.
+ ErrTwoEmptyHeaders = errors.New("both headers cannot be empty")
+
+ // ErrNotEnoughKeys is returned when not enough keys are provided for
+ // the given SigningMethods.
+ ErrNotEnoughKeys = errors.New("not enough keys (for given methods)")
+
+ // ErrDidNotValidate means the given JWT did not properly validate
+ ErrDidNotValidate = errors.New("did not validate")
+
+ // ErrNoAlgorithm means no algorithm ("alg") was found in the Protected
+ // Header.
+ ErrNoAlgorithm = errors.New("no algorithm found")
+
+ // ErrAlgorithmDoesntExist means the algorithm asked for cannot be
+ // found inside the signingMethod cache.
+ ErrAlgorithmDoesntExist = errors.New("algorithm doesn't exist")
+
+ // ErrMismatchedAlgorithms means the algorithm inside the JWT was
+ // different than the algorithm the caller wanted to use.
+ ErrMismatchedAlgorithms = errors.New("mismatched algorithms")
+
+ // ErrCannotValidate means the JWS cannot be validated for various
+ // reasons. For example, if there aren't any signatures/payloads/headers
+ // to actually validate.
+ ErrCannotValidate = errors.New("cannot validate")
)
diff --git a/jws/jws.go b/jws/jws.go
index de0f355..4147d37 100644
--- a/jws/jws.go
+++ b/jws/jws.go
@@ -1,8 +1,9 @@
package jws
import (
+ "bytes"
"encoding/json"
- "fmt"
+ "sort"
"github.com/SermoDigital/jose"
)
@@ -13,8 +14,7 @@
plcache rawBase64
clean bool
- sb []sigHead
- methods []SigningMethod
+ sb []sigHead
}
// sigHead represents the 'signatures' member of the JWS' "general"
@@ -31,6 +31,18 @@
protected jose.Protected `json:"-"`
unprotected jose.Header `json:"-"`
clean bool `json:"-"`
+
+ method SigningMethod
+}
+
+func (s *sigHead) unmarshal() error {
+ if err := s.protected.UnmarshalJSON(s.Protected); err != nil {
+ return err
+ }
+ if err := s.unprotected.UnmarshalJSON(s.Unprotected); err != nil {
+ return err
+ }
+ return nil
}
// New creates a new JWS with the provided SigningMethods.
@@ -41,16 +53,31 @@
protected: jose.Protected{
"alg": methods[i].Alg(),
},
- unprotected: make(jose.Header),
+ unprotected: jose.Header{},
+ method: methods[i],
}
}
return &JWS{
payload: &payload{v: content},
sb: sb,
- methods: methods,
}
}
+func (s *sigHead) assignMethod(p jose.Protected) error {
+ alg, ok := p.Get("alg").(string)
+ if !ok {
+ return ErrNoAlgorithm
+ }
+
+ sm := GetSigningMethod(alg)
+ if sm == nil {
+ return ErrNoAlgorithm
+ }
+
+ s.method = sm
+ return nil
+}
+
type generic struct {
Payload rawBase64 `json:"payload"`
sigHead
@@ -61,13 +88,9 @@
// JWS per https://tools.ietf.org/html/rfc7515#section-5.2
//
// It accepts a json.Unmarshaler in order to properly parse
-// the payload. The reason for this is sometimes the payload
-// might implement the json.Marshaler interface, and since
-// the JWS' payload member is an interface{}, a simple
-// json.Unmarshal call cannot magically identify the original
-// type. So, in order to keep the caller from having to do extra
-// parsing of the payload, the a json.Unmarshaler can be passed
-// which will be called to unmarshal the payload however the caller
+// the payload. In order to keep the caller from having to do extra
+// parsing of the payload, a json.Unmarshaler can be passed
+// which will be then to unmarshal the payload however the caller
// wishes. Do note that if json.Unmarshal returns an error the
// original payload will be used as if no json.Unmarshaler was
// passed.
@@ -109,29 +132,36 @@
func (g *generic) parseGeneral(u ...json.Unmarshaler) (*JWS, error) {
- var (
- p payload
- err error
- )
-
+ var p payload
if len(u) > 0 {
- if k := u[0]; k.UnmarshalJSON(g.Payload) != nil {
- p.v = u
- err = ErrCouldNotUnmarshal
- }
+ p.u = u[0]
}
- if err != nil {
- fmt.Println(string(g.Payload))
- if err := json.Unmarshal(g.Payload, &p); err != nil {
+ if err := p.UnmarshalJSON(g.Payload); err != nil {
+ return nil, err
+ }
+
+ for i := range g.Signatures {
+ if err := g.Signatures[i].unmarshal(); err != nil {
return nil, err
}
+ if err := checkHeaders(jose.Header(g.Signatures[i].protected), g.Signatures[i].unprotected); err != nil {
+ return nil, err
+ }
+
+ if err := g.Signatures[i].assignMethod(g.Signatures[i].protected); err != nil {
+ return nil, err
+ }
+
+ g.clean = true
}
return &JWS{
payload: &p,
+ plcache: g.Payload,
+ clean: true,
sb: g.Signatures,
- }, err
+ }, nil
}
// ParseFlat parses a JWS serialized into its "flat" form per
@@ -159,8 +189,23 @@
return nil, err
}
+ if err := g.sigHead.unmarshal(); err != nil {
+ return nil, err
+ }
+ g.sigHead.clean = true
+
+ if err := checkHeaders(jose.Header(g.sigHead.protected), g.sigHead.unprotected); err != nil {
+ return nil, err
+ }
+
+ if err := g.sigHead.assignMethod(g.sigHead.protected); err != nil {
+ return nil, err
+ }
+
return &JWS{
payload: &p,
+ plcache: g.Payload,
+ clean: true,
sb: []sigHead{g.sigHead},
}, nil
}
@@ -168,8 +213,136 @@
// ParseCompact parses a JWS serialized into its "compact" form per
// https://tools.ietf.org/html/rfc7515#section-7.1
// into a physical JWS per
-// https://tools.ietf.org/html/rfc7515#section-5.2//
+// https://tools.ietf.org/html/rfc7515#section-5.2
+//
// For information on the json.Unmarshaler parameter, see Parse.
func ParseCompact(encoded []byte, u ...json.Unmarshaler) (*JWS, error) {
- return nil, nil
+
+ parts := bytes.Split(encoded, []byte{'.'})
+ if len(parts) != 3 {
+ return nil, ErrNotCompact
+ }
+
+ var p jose.Protected
+ if err := p.UnmarshalJSON(parts[0]); err != nil {
+ return nil, err
+ }
+
+ s := sigHead{
+ protected: p,
+ clean: true,
+ }
+
+ if err := s.assignMethod(p); err != nil {
+ return nil, err
+ }
+
+ j := JWS{
+ payload: &payload{},
+ sb: []sigHead{s},
+ }
+
+ if err := j.payload.UnmarshalJSON(parts[1]); err != nil {
+ return nil, err
+ }
+
+ j.clean = true
+
+ if err := j.sb[0].Signature.UnmarshalJSON(parts[2]); err != nil {
+ return nil, err
+ }
+
+ return &j, nil
+}
+
+// IgnoreDupes should be set to true if the internal duplicate header key check
+// should ignore duplicate Header keys instead of reporting an error when
+// duplicate Header keys are found.
+//
+// Note: Duplicate Header keys are defined in
+// https://tools.ietf.org/html/rfc7515#section-5.2
+// meaning keys that both the protected and unprotected
+// Headers possess.
+var IgnoreDupes bool
+
+// checkHeaders returns an error per the constraints described in
+// IgnoreDupes' comment.
+func checkHeaders(a, b jose.Header) error {
+ if len(a)+len(b) == 0 {
+ return ErrTwoEmptyHeaders
+ }
+ for key := range a {
+ if b.Has(key) && !IgnoreDupes {
+ return ErrDuplicateHeaderParameter
+ }
+ }
+ return nil
+}
+
+const Any int = -1
+
+// ValidateMulti validates the current JWS as-is. Since it's meant to be
+// called after parsing a stream of bytes into a JWS, it doesn't do any
+// internal parsing like the Sign, Flat, Compact, or General methods do.
+// idx represents which signatures need to validate
+// in order for the JWS to be considered valid.
+// Use the constant `Any` (-1) if _any_ should validate the JWS. Otherwise,
+// use the indexes of the signatures that need to validate in order
+// for the JWS to be considered valid.
+// Note: if idx is omitted it defaults to requiring _all_
+// signatures validate, and the JWS spec required _at least_ one
+// signature to validate in order for the JWS to be considered
+// valid.
+func (j *JWS) ValidateMulti(keys []interface{}, methods []SigningMethod, idx ...int) error {
+
+ if len(j.sb) != len(methods) {
+ return ErrNotEnoughMethods
+ }
+
+ if len(keys) < 1 ||
+ len(keys) > 1 && len(keys) != len(j.sb) {
+ return ErrNotEnoughKeys
+ }
+
+ if len(keys) == 1 {
+ k := keys[0]
+ keys = make([]interface{}, len(methods))
+ for i := range keys {
+ keys[i] = k
+ }
+ }
+
+ any := len(idx) == 1 && idx[0] == Any
+ if !any {
+ sort.Ints(idx)
+ }
+
+ rp := 0
+ for i := range j.sb {
+ if j.sb[i].validate(j.plcache, keys[i], methods[i]) == nil &&
+ any || (rp < len(idx) && idx[rp] == i) {
+ rp++
+ }
+ }
+
+ if rp < len(idx) {
+ return ErrDidNotValidate
+ }
+ return nil
+}
+
+// Validate validates the current JWS as-is. Refer to ValidateMulti
+// for more information.
+func (j *JWS) Validate(key interface{}, method SigningMethod) error {
+ if len(j.sb) < 1 {
+ return ErrCannotValidate
+ }
+ return j.sb[0].validate(j.plcache, key, method)
+}
+
+func (s *sigHead) validate(pl []byte, key interface{}, method SigningMethod) error {
+ if s.method != method {
+ return ErrMismatchedAlgorithms
+ }
+ return method.Verify(format(s.Protected, pl), s.Signature, key)
}
diff --git a/jws/jws_serialize.go b/jws/jws_serialize.go
index a79c59c..46ecd38 100644
--- a/jws/jws_serialize.go
+++ b/jws/jws_serialize.go
@@ -11,7 +11,6 @@
if len(j.sb) < 1 {
return nil, ErrNotEnoughMethods
}
-
if err := j.sign(key); err != nil {
return nil, err
}
@@ -26,6 +25,10 @@
// General serializes the JWS into its "general" form per
// https://tools.ietf.org/html/rfc7515#section-7.2.1
+//
+// If only one key is passed it's used for all the provided
+// SigningMethods. Otherwise, len(keys) must equal the number
+// of SigningMethods added.
func (j *JWS) General(keys ...interface{}) ([]byte, error) {
if err := j.sign(keys...); err != nil {
return nil, err
@@ -67,13 +70,26 @@
return err
}
+ if len(keys) < 1 ||
+ len(keys) > 1 && len(keys) != len(j.sb) {
+ return ErrNotEnoughKeys
+ }
+
+ if len(keys) == 1 {
+ k := keys[0]
+ keys = make([]interface{}, len(j.sb))
+ for i := range keys {
+ keys[i] = k
+ }
+ }
+
for i := range j.sb {
if err := j.sb[i].cache(); err != nil {
return err
}
raw := format(j.sb[i].Protected, j.plcache)
- sig, err := j.methods[i].Sign(raw, keys[i])
+ sig, err := j.sb[i].method.Sign(raw, keys[i])
if err != nil {
return err
}
diff --git a/jws/jws_serialize_test.go b/jws/jws_serialize_test.go
index 61a828e..124a3b4 100644
--- a/jws/jws_serialize_test.go
+++ b/jws/jws_serialize_test.go
@@ -9,11 +9,15 @@
)
var dataRaw = struct {
+ H jose.Protected
Name string
Scopes []string
Admin bool
Data struct{ Foo, Bar int }
}{
+ H: jose.Protected{
+ "1234": "5678",
+ },
Name: "Eric",
Scopes: []string{
"user.account.info",
diff --git a/jws/jws_test.go b/jws/jws_test.go
index 8afb931..ebd7cf0 100644
--- a/jws/jws_test.go
+++ b/jws/jws_test.go
@@ -10,10 +10,13 @@
type easy []byte
func (e *easy) UnmarshalJSON(b []byte) error {
+ if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' {
+ b = b[1 : len(b)-1]
+ }
// json.Marshal encodes easy as it would a []byte, so in
// `"base64"` format.
- dst := make([]byte, base64.StdEncoding.DecodedLen(len(b)-2))
- n, err := base64.StdEncoding.Decode(dst, b[1:len(b)-1])
+ dst := make([]byte, base64.StdEncoding.DecodedLen(len(b)))
+ n, err := base64.StdEncoding.Decode(dst, b)
if err != nil {
return err
}
@@ -23,7 +26,7 @@
var _ json.Unmarshaler = (*easy)(nil)
-var easyData = easy(`"easy data!"`)
+var easyData = easy("easy data!")
func TestParseWithUnmarshaler(t *testing.T) {
j := New(easyData, SigningMethodRS512)
@@ -42,3 +45,144 @@
Error(t, easyData, *j2.payload.v.(*easy))
}
}
+
+func TestParseCompact(t *testing.T) {
+ j := New(easyData, SigningMethodRS512)
+ b, err := j.Compact(rsaPriv)
+ if err != nil {
+ t.Error(err)
+ }
+
+ j2, err := ParseCompact(b)
+ if err != nil {
+ t.Error(err)
+ }
+
+ var k easy
+ if err := k.UnmarshalJSON([]byte(j2.payload.v.(string))); err != nil {
+ t.Error(err)
+ }
+
+ if !bytes.Equal(k, easyData) {
+ Error(t, easyData, k)
+ }
+}
+
+func TestParseGeneral(t *testing.T) {
+ sm := []SigningMethod{SigningMethodRS512, SigningMethodPS384, SigningMethodPS256}
+ j := New(easyData, sm...)
+ b, err := j.General(rsaPriv)
+ if err != nil {
+ t.Error(err)
+ }
+
+ j2, err := ParseGeneral(b)
+ if err != nil {
+ t.Error(err)
+ }
+
+ for i, v := range j2.sb {
+ k := v.protected.Get("alg").(string)
+ if k != sm[i].Alg() {
+ Error(t, sm[i].Alg(), k)
+ }
+ }
+}
+
+func TestValidateMulti(t *testing.T) {
+ sm := []SigningMethod{SigningMethodRS512, SigningMethodPS384, SigningMethodPS256}
+ j := New(easyData, sm...)
+ b, err := j.General(rsaPriv)
+ if err != nil {
+ t.Error(err)
+ }
+
+ j2, err := ParseGeneral(b)
+ if err != nil {
+ t.Error(err)
+ }
+
+ keys := []interface{}{rsaPub, rsaPub, rsaPub}
+ if err := j2.ValidateMulti(keys, sm, Any); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestValidateMultiMismatchedAlgs(t *testing.T) {
+ sm := []SigningMethod{SigningMethodRS256, SigningMethodPS384, SigningMethodPS512}
+ j := New(easyData, sm...)
+ b, err := j.General(rsaPriv)
+ if err != nil {
+ t.Error(err)
+ }
+
+ j2, err := ParseGeneral(b)
+ if err != nil {
+ t.Error(err)
+ }
+
+ // Shuffle it.
+ sm = []SigningMethod{SigningMethodRS512, SigningMethodPS256, SigningMethodPS384}
+
+ keys := []interface{}{rsaPub, rsaPub, rsaPub}
+ if err := j2.ValidateMulti(keys, sm, Any); err == nil {
+ t.Error("Should NOT be nil")
+ }
+}
+
+func TestValidateMultiNotEnoughMethods(t *testing.T) {
+ sm := []SigningMethod{SigningMethodRS256, SigningMethodPS384, SigningMethodPS512}
+ j := New(easyData, sm...)
+ b, err := j.General(rsaPriv)
+ if err != nil {
+ t.Error(err)
+ }
+
+ j2, err := ParseGeneral(b)
+ if err != nil {
+ t.Error(err)
+ }
+
+ sm = sm[0 : len(sm)-1]
+
+ keys := []interface{}{rsaPub, rsaPub, rsaPub}
+ if err := j2.ValidateMulti(keys, sm, Any); err == nil {
+ t.Error("Should NOT be nil")
+ }
+}
+
+func TestValidateMultiNotEnoughKeys(t *testing.T) {
+ sm := []SigningMethod{SigningMethodRS256, SigningMethodPS384, SigningMethodPS512}
+ j := New(easyData, sm...)
+ b, err := j.General(rsaPriv)
+ if err != nil {
+ t.Error(err)
+ }
+
+ j2, err := ParseGeneral(b)
+ if err != nil {
+ t.Error(err)
+ }
+
+ keys := []interface{}{rsaPub, rsaPub}
+ if err := j2.ValidateMulti(keys, sm, Any); err == nil {
+ t.Error("Should NOT be nil")
+ }
+}
+
+func TestValidate(t *testing.T) {
+ j := New(easyData, SigningMethodPS512)
+ b, err := j.Flat(rsaPriv)
+ if err != nil {
+ t.Error(err)
+ }
+
+ j2, err := ParseFlat(b)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if err := j2.Validate(rsaPub, SigningMethodPS512); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/jws/none.go b/jws/none.go
index a1b66b2..5d0a4bf 100644
--- a/jws/none.go
+++ b/jws/none.go
@@ -8,12 +8,12 @@
)
func init() {
- crypto.RegisterHash(crypto.Hash(0), H)
+ crypto.RegisterHash(crypto.Hash(0), h)
RegisterSigningMethod(Unsecured)
}
-// H is passed to crypto.RegisterHash.
-func H() hash.Hash { return &f{Writer: nil} }
+// h is passed to crypto.RegisterHash.
+func h() hash.Hash { return &f{Writer: nil} }
type f struct{ io.Writer }
diff --git a/jws/payload.go b/jws/payload.go
index 3cc203c..34ba6d8 100644
--- a/jws/payload.go
+++ b/jws/payload.go
@@ -43,6 +43,7 @@
p.v = p.u
return err
}
+
return json.Unmarshal(b2, &p.v)
}