jws: added JWS validation and Parsing + tests
diff --git a/.tags b/.tags new file mode 100644 index 0000000..d01d39b --- /dev/null +++ b/.tags
@@ -0,0 +1,6 @@ +!_TAG_FILE_FORMAT 2 /extended format; --format=1 will not append ;" to lines/ +!_TAG_FILE_SORTED 1 /0=unsorted, 1=sorted, 2=foldcase/ +!_TAG_PROGRAM_AUTHOR Darren Hiebert /dhiebert@users.sourceforge.net/ +!_TAG_PROGRAM_NAME Exuberant Ctags // +!_TAG_PROGRAM_URL http://ctags.sourceforge.net /official site/ +!_TAG_PROGRAM_VERSION 5.8 //
diff --git a/base64.go b/base64.go index 1956791..6be62b8 100644 --- a/base64.go +++ b/base64.go
@@ -5,7 +5,7 @@ // Encoder is satisfied if the type can marshal itself into a valid // structure for a JWS. type Encoder interface { - // Base64 implies -> JSON -> Base64 + // Base64 implies T -> JSON -> Base64 Base64() ([]byte, error) }
diff --git a/header.go b/header.go index ef58feb..b5fc9e8 100644 --- a/header.go +++ b/header.go
@@ -47,6 +47,30 @@ return h.MarshalJSON() } +// UnmarshalJSON implements json.Unmarshaler for Header. +func (h *Header) UnmarshalJSON(b []byte) error { + if b == nil { + return nil + } + + b, err := DecodeEscaped(b) + if err != nil { + return err + } + + // Since json.Unmarshal calls UnmarshalJSON, + // calling json.Unmarshal on *p would be infinitely recursive + // A temp variable is needed because &map[string]interface{}(*p) is + // invalid Go. + + tmp := map[string]interface{}(*h) + if err = json.Unmarshal(b, &tmp); err != nil { + return err + } + *h = Header(tmp) + return nil +} + // Protected Headers are base64-encoded after they're marshaled into // JSON. type Protected Header @@ -96,20 +120,9 @@ // UnmarshalJSON implements json.Unmarshaler for Protected. func (p *Protected) UnmarshalJSON(b []byte) error { - b, err := DecodeEscaped(b) - if err != nil { - return err - } - - // Since json.Unmarshal calls UnmarshalJSON, - // calling json.Unmarshal on *p would be infinitely recursive - // A temp variable is needed because &Header(*p) is invalid Go. - - tmp := map[string]interface{}(*p) - if err = json.Unmarshal(b, &tmp); err != nil { - return err - } - *p = Protected(tmp) + var h Header + h.UnmarshalJSON(b) + *p = Protected(h) return nil }
diff --git a/jws/ecdsa.go b/jws/ecdsa.go index 21a6328..ce46b47 100644 --- a/jws/ecdsa.go +++ b/jws/ecdsa.go
@@ -104,8 +104,10 @@ func (m *SigningMethodECDSA) Hasher() crypto.Hash { return m.Hash } // MarshalJSON is in case somebody decides to place SigningMethodECDSA -// inside the Claims in the "alg" portion. In order to keep things sane, -// marshalling this will do the same thing as jws.SetProtected("alg", m.Alg()) +// inside the Header, presumably because they (wrongly) decided it was a good +// idea to use the SigningMethod itself instead of the SigningMethod's Alg +// method. In order to keep things sane, marshalling this will simply +// return the JSON-compatible representation of m.Alg(). func (m *SigningMethodECDSA) MarshalJSON() ([]byte, error) { return []byte(`"` + m.Alg() + `"`), nil }
diff --git a/jws/errors.go b/jws/errors.go index 6c792a2..2d48368 100644 --- a/jws/errors.go +++ b/jws/errors.go
@@ -14,4 +14,39 @@ // ErrCouldNotUnmarshal is returned when Parse's json.Unmarshaler // parameter returns an error. ErrCouldNotUnmarshal = errors.New("custom unmarshal failed") + + // ErrNotCompact signals that the provided potential JWS is not + // in its compact representation. + ErrNotCompact = errors.New("not a compact JWS") + + // ErrDuplicateHeaderParameter signals that there are duplicate parameters + // in the provided Headers. + ErrDuplicateHeaderParameter = errors.New("duplicate parameters in the JOSE Header") + + // ErrTwoEmptyHeaders is returned if both Headers are empty. + ErrTwoEmptyHeaders = errors.New("both headers cannot be empty") + + // ErrNotEnoughKeys is returned when not enough keys are provided for + // the given SigningMethods. + ErrNotEnoughKeys = errors.New("not enough keys (for given methods)") + + // ErrDidNotValidate means the given JWT did not properly validate + ErrDidNotValidate = errors.New("did not validate") + + // ErrNoAlgorithm means no algorithm ("alg") was found in the Protected + // Header. + ErrNoAlgorithm = errors.New("no algorithm found") + + // ErrAlgorithmDoesntExist means the algorithm asked for cannot be + // found inside the signingMethod cache. + ErrAlgorithmDoesntExist = errors.New("algorithm doesn't exist") + + // ErrMismatchedAlgorithms means the algorithm inside the JWT was + // different than the algorithm the caller wanted to use. + ErrMismatchedAlgorithms = errors.New("mismatched algorithms") + + // ErrCannotValidate means the JWS cannot be validated for various + // reasons. For example, if there aren't any signatures/payloads/headers + // to actually validate. + ErrCannotValidate = errors.New("cannot validate") )
diff --git a/jws/jws.go b/jws/jws.go index de0f355..4147d37 100644 --- a/jws/jws.go +++ b/jws/jws.go
@@ -1,8 +1,9 @@ package jws import ( + "bytes" "encoding/json" - "fmt" + "sort" "github.com/SermoDigital/jose" ) @@ -13,8 +14,7 @@ plcache rawBase64 clean bool - sb []sigHead - methods []SigningMethod + sb []sigHead } // sigHead represents the 'signatures' member of the JWS' "general" @@ -31,6 +31,18 @@ protected jose.Protected `json:"-"` unprotected jose.Header `json:"-"` clean bool `json:"-"` + + method SigningMethod +} + +func (s *sigHead) unmarshal() error { + if err := s.protected.UnmarshalJSON(s.Protected); err != nil { + return err + } + if err := s.unprotected.UnmarshalJSON(s.Unprotected); err != nil { + return err + } + return nil } // New creates a new JWS with the provided SigningMethods. @@ -41,16 +53,31 @@ protected: jose.Protected{ "alg": methods[i].Alg(), }, - unprotected: make(jose.Header), + unprotected: jose.Header{}, + method: methods[i], } } return &JWS{ payload: &payload{v: content}, sb: sb, - methods: methods, } } +func (s *sigHead) assignMethod(p jose.Protected) error { + alg, ok := p.Get("alg").(string) + if !ok { + return ErrNoAlgorithm + } + + sm := GetSigningMethod(alg) + if sm == nil { + return ErrNoAlgorithm + } + + s.method = sm + return nil +} + type generic struct { Payload rawBase64 `json:"payload"` sigHead @@ -61,13 +88,9 @@ // JWS per https://tools.ietf.org/html/rfc7515#section-5.2 // // It accepts a json.Unmarshaler in order to properly parse -// the payload. The reason for this is sometimes the payload -// might implement the json.Marshaler interface, and since -// the JWS' payload member is an interface{}, a simple -// json.Unmarshal call cannot magically identify the original -// type. So, in order to keep the caller from having to do extra -// parsing of the payload, the a json.Unmarshaler can be passed -// which will be called to unmarshal the payload however the caller +// the payload. In order to keep the caller from having to do extra +// parsing of the payload, a json.Unmarshaler can be passed +// which will be then to unmarshal the payload however the caller // wishes. Do note that if json.Unmarshal returns an error the // original payload will be used as if no json.Unmarshaler was // passed. @@ -109,29 +132,36 @@ func (g *generic) parseGeneral(u ...json.Unmarshaler) (*JWS, error) { - var ( - p payload - err error - ) - + var p payload if len(u) > 0 { - if k := u[0]; k.UnmarshalJSON(g.Payload) != nil { - p.v = u - err = ErrCouldNotUnmarshal - } + p.u = u[0] } - if err != nil { - fmt.Println(string(g.Payload)) - if err := json.Unmarshal(g.Payload, &p); err != nil { + if err := p.UnmarshalJSON(g.Payload); err != nil { + return nil, err + } + + for i := range g.Signatures { + if err := g.Signatures[i].unmarshal(); err != nil { return nil, err } + if err := checkHeaders(jose.Header(g.Signatures[i].protected), g.Signatures[i].unprotected); err != nil { + return nil, err + } + + if err := g.Signatures[i].assignMethod(g.Signatures[i].protected); err != nil { + return nil, err + } + + g.clean = true } return &JWS{ payload: &p, + plcache: g.Payload, + clean: true, sb: g.Signatures, - }, err + }, nil } // ParseFlat parses a JWS serialized into its "flat" form per @@ -159,8 +189,23 @@ return nil, err } + if err := g.sigHead.unmarshal(); err != nil { + return nil, err + } + g.sigHead.clean = true + + if err := checkHeaders(jose.Header(g.sigHead.protected), g.sigHead.unprotected); err != nil { + return nil, err + } + + if err := g.sigHead.assignMethod(g.sigHead.protected); err != nil { + return nil, err + } + return &JWS{ payload: &p, + plcache: g.Payload, + clean: true, sb: []sigHead{g.sigHead}, }, nil } @@ -168,8 +213,136 @@ // ParseCompact parses a JWS serialized into its "compact" form per // https://tools.ietf.org/html/rfc7515#section-7.1 // into a physical JWS per -// https://tools.ietf.org/html/rfc7515#section-5.2// +// https://tools.ietf.org/html/rfc7515#section-5.2 +// // For information on the json.Unmarshaler parameter, see Parse. func ParseCompact(encoded []byte, u ...json.Unmarshaler) (*JWS, error) { - return nil, nil + + parts := bytes.Split(encoded, []byte{'.'}) + if len(parts) != 3 { + return nil, ErrNotCompact + } + + var p jose.Protected + if err := p.UnmarshalJSON(parts[0]); err != nil { + return nil, err + } + + s := sigHead{ + protected: p, + clean: true, + } + + if err := s.assignMethod(p); err != nil { + return nil, err + } + + j := JWS{ + payload: &payload{}, + sb: []sigHead{s}, + } + + if err := j.payload.UnmarshalJSON(parts[1]); err != nil { + return nil, err + } + + j.clean = true + + if err := j.sb[0].Signature.UnmarshalJSON(parts[2]); err != nil { + return nil, err + } + + return &j, nil +} + +// IgnoreDupes should be set to true if the internal duplicate header key check +// should ignore duplicate Header keys instead of reporting an error when +// duplicate Header keys are found. +// +// Note: Duplicate Header keys are defined in +// https://tools.ietf.org/html/rfc7515#section-5.2 +// meaning keys that both the protected and unprotected +// Headers possess. +var IgnoreDupes bool + +// checkHeaders returns an error per the constraints described in +// IgnoreDupes' comment. +func checkHeaders(a, b jose.Header) error { + if len(a)+len(b) == 0 { + return ErrTwoEmptyHeaders + } + for key := range a { + if b.Has(key) && !IgnoreDupes { + return ErrDuplicateHeaderParameter + } + } + return nil +} + +const Any int = -1 + +// ValidateMulti validates the current JWS as-is. Since it's meant to be +// called after parsing a stream of bytes into a JWS, it doesn't do any +// internal parsing like the Sign, Flat, Compact, or General methods do. +// idx represents which signatures need to validate +// in order for the JWS to be considered valid. +// Use the constant `Any` (-1) if _any_ should validate the JWS. Otherwise, +// use the indexes of the signatures that need to validate in order +// for the JWS to be considered valid. +// Note: if idx is omitted it defaults to requiring _all_ +// signatures validate, and the JWS spec required _at least_ one +// signature to validate in order for the JWS to be considered +// valid. +func (j *JWS) ValidateMulti(keys []interface{}, methods []SigningMethod, idx ...int) error { + + if len(j.sb) != len(methods) { + return ErrNotEnoughMethods + } + + if len(keys) < 1 || + len(keys) > 1 && len(keys) != len(j.sb) { + return ErrNotEnoughKeys + } + + if len(keys) == 1 { + k := keys[0] + keys = make([]interface{}, len(methods)) + for i := range keys { + keys[i] = k + } + } + + any := len(idx) == 1 && idx[0] == Any + if !any { + sort.Ints(idx) + } + + rp := 0 + for i := range j.sb { + if j.sb[i].validate(j.plcache, keys[i], methods[i]) == nil && + any || (rp < len(idx) && idx[rp] == i) { + rp++ + } + } + + if rp < len(idx) { + return ErrDidNotValidate + } + return nil +} + +// Validate validates the current JWS as-is. Refer to ValidateMulti +// for more information. +func (j *JWS) Validate(key interface{}, method SigningMethod) error { + if len(j.sb) < 1 { + return ErrCannotValidate + } + return j.sb[0].validate(j.plcache, key, method) +} + +func (s *sigHead) validate(pl []byte, key interface{}, method SigningMethod) error { + if s.method != method { + return ErrMismatchedAlgorithms + } + return method.Verify(format(s.Protected, pl), s.Signature, key) }
diff --git a/jws/jws_serialize.go b/jws/jws_serialize.go index a79c59c..46ecd38 100644 --- a/jws/jws_serialize.go +++ b/jws/jws_serialize.go
@@ -11,7 +11,6 @@ if len(j.sb) < 1 { return nil, ErrNotEnoughMethods } - if err := j.sign(key); err != nil { return nil, err } @@ -26,6 +25,10 @@ // General serializes the JWS into its "general" form per // https://tools.ietf.org/html/rfc7515#section-7.2.1 +// +// If only one key is passed it's used for all the provided +// SigningMethods. Otherwise, len(keys) must equal the number +// of SigningMethods added. func (j *JWS) General(keys ...interface{}) ([]byte, error) { if err := j.sign(keys...); err != nil { return nil, err @@ -67,13 +70,26 @@ return err } + if len(keys) < 1 || + len(keys) > 1 && len(keys) != len(j.sb) { + return ErrNotEnoughKeys + } + + if len(keys) == 1 { + k := keys[0] + keys = make([]interface{}, len(j.sb)) + for i := range keys { + keys[i] = k + } + } + for i := range j.sb { if err := j.sb[i].cache(); err != nil { return err } raw := format(j.sb[i].Protected, j.plcache) - sig, err := j.methods[i].Sign(raw, keys[i]) + sig, err := j.sb[i].method.Sign(raw, keys[i]) if err != nil { return err }
diff --git a/jws/jws_serialize_test.go b/jws/jws_serialize_test.go index 61a828e..124a3b4 100644 --- a/jws/jws_serialize_test.go +++ b/jws/jws_serialize_test.go
@@ -9,11 +9,15 @@ ) var dataRaw = struct { + H jose.Protected Name string Scopes []string Admin bool Data struct{ Foo, Bar int } }{ + H: jose.Protected{ + "1234": "5678", + }, Name: "Eric", Scopes: []string{ "user.account.info",
diff --git a/jws/jws_test.go b/jws/jws_test.go index 8afb931..ebd7cf0 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go
@@ -10,10 +10,13 @@ type easy []byte func (e *easy) UnmarshalJSON(b []byte) error { + if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' { + b = b[1 : len(b)-1] + } // json.Marshal encodes easy as it would a []byte, so in // `"base64"` format. - dst := make([]byte, base64.StdEncoding.DecodedLen(len(b)-2)) - n, err := base64.StdEncoding.Decode(dst, b[1:len(b)-1]) + dst := make([]byte, base64.StdEncoding.DecodedLen(len(b))) + n, err := base64.StdEncoding.Decode(dst, b) if err != nil { return err } @@ -23,7 +26,7 @@ var _ json.Unmarshaler = (*easy)(nil) -var easyData = easy(`"easy data!"`) +var easyData = easy("easy data!") func TestParseWithUnmarshaler(t *testing.T) { j := New(easyData, SigningMethodRS512) @@ -42,3 +45,144 @@ Error(t, easyData, *j2.payload.v.(*easy)) } } + +func TestParseCompact(t *testing.T) { + j := New(easyData, SigningMethodRS512) + b, err := j.Compact(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseCompact(b) + if err != nil { + t.Error(err) + } + + var k easy + if err := k.UnmarshalJSON([]byte(j2.payload.v.(string))); err != nil { + t.Error(err) + } + + if !bytes.Equal(k, easyData) { + Error(t, easyData, k) + } +} + +func TestParseGeneral(t *testing.T) { + sm := []SigningMethod{SigningMethodRS512, SigningMethodPS384, SigningMethodPS256} + j := New(easyData, sm...) + b, err := j.General(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseGeneral(b) + if err != nil { + t.Error(err) + } + + for i, v := range j2.sb { + k := v.protected.Get("alg").(string) + if k != sm[i].Alg() { + Error(t, sm[i].Alg(), k) + } + } +} + +func TestValidateMulti(t *testing.T) { + sm := []SigningMethod{SigningMethodRS512, SigningMethodPS384, SigningMethodPS256} + j := New(easyData, sm...) + b, err := j.General(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseGeneral(b) + if err != nil { + t.Error(err) + } + + keys := []interface{}{rsaPub, rsaPub, rsaPub} + if err := j2.ValidateMulti(keys, sm, Any); err != nil { + t.Error(err) + } +} + +func TestValidateMultiMismatchedAlgs(t *testing.T) { + sm := []SigningMethod{SigningMethodRS256, SigningMethodPS384, SigningMethodPS512} + j := New(easyData, sm...) + b, err := j.General(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseGeneral(b) + if err != nil { + t.Error(err) + } + + // Shuffle it. + sm = []SigningMethod{SigningMethodRS512, SigningMethodPS256, SigningMethodPS384} + + keys := []interface{}{rsaPub, rsaPub, rsaPub} + if err := j2.ValidateMulti(keys, sm, Any); err == nil { + t.Error("Should NOT be nil") + } +} + +func TestValidateMultiNotEnoughMethods(t *testing.T) { + sm := []SigningMethod{SigningMethodRS256, SigningMethodPS384, SigningMethodPS512} + j := New(easyData, sm...) + b, err := j.General(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseGeneral(b) + if err != nil { + t.Error(err) + } + + sm = sm[0 : len(sm)-1] + + keys := []interface{}{rsaPub, rsaPub, rsaPub} + if err := j2.ValidateMulti(keys, sm, Any); err == nil { + t.Error("Should NOT be nil") + } +} + +func TestValidateMultiNotEnoughKeys(t *testing.T) { + sm := []SigningMethod{SigningMethodRS256, SigningMethodPS384, SigningMethodPS512} + j := New(easyData, sm...) + b, err := j.General(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseGeneral(b) + if err != nil { + t.Error(err) + } + + keys := []interface{}{rsaPub, rsaPub} + if err := j2.ValidateMulti(keys, sm, Any); err == nil { + t.Error("Should NOT be nil") + } +} + +func TestValidate(t *testing.T) { + j := New(easyData, SigningMethodPS512) + b, err := j.Flat(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseFlat(b) + if err != nil { + t.Error(err) + } + + if err := j2.Validate(rsaPub, SigningMethodPS512); err != nil { + t.Error(err) + } +}
diff --git a/jws/none.go b/jws/none.go index a1b66b2..5d0a4bf 100644 --- a/jws/none.go +++ b/jws/none.go
@@ -8,12 +8,12 @@ ) func init() { - crypto.RegisterHash(crypto.Hash(0), H) + crypto.RegisterHash(crypto.Hash(0), h) RegisterSigningMethod(Unsecured) } -// H is passed to crypto.RegisterHash. -func H() hash.Hash { return &f{Writer: nil} } +// h is passed to crypto.RegisterHash. +func h() hash.Hash { return &f{Writer: nil} } type f struct{ io.Writer }
diff --git a/jws/payload.go b/jws/payload.go index 3cc203c..34ba6d8 100644 --- a/jws/payload.go +++ b/jws/payload.go
@@ -43,6 +43,7 @@ p.v = p.u return err } + return json.Unmarshal(b2, &p.v) }