tests: added better testing coverage jws: added copycat claims
diff --git a/.gitignore b/.gitignore index daf913b..6289cdb 100644 --- a/.gitignore +++ b/.gitignore
@@ -22,3 +22,6 @@ *.exe *.test *.prof + +*.out +*.tmp \ No newline at end of file
diff --git a/.tags b/.tags index d01d39b..10d8498 100644 --- a/.tags +++ b/.tags
@@ -4,3 +4,4 @@ !_TAG_PROGRAM_NAME Exuberant Ctags // !_TAG_PROGRAM_URL http://ctags.sourceforge.net /official site/ !_TAG_PROGRAM_VERSION 5.8 // +onChange /home/eric/gopath/src/github.com/SermoDigital/jose/cover.html /^ function onChange() {$/;" f
diff --git a/base64.go b/base64.go index 6be62b8..f7275fb 100644 --- a/base64.go +++ b/base64.go
@@ -5,7 +5,7 @@ // Encoder is satisfied if the type can marshal itself into a valid // structure for a JWS. type Encoder interface { - // Base64 implies T -> JSON -> Base64 + // Base64 implies T -> JSON -> RawURLEncodingBase64 Base64() ([]byte, error) }
diff --git a/cover.html b/cover.html new file mode 100644 index 0000000..316b917 --- /dev/null +++ b/cover.html
@@ -0,0 +1,268 @@ + +<!DOCTYPE html> +<html> + <head> + <meta http-equiv="Content-Type" content="text/html; charset=utf-8"> + <style> + body { + background: black; + color: rgb(80, 80, 80); + } + body, pre, #legend span { + font-family: Menlo, monospace; + font-weight: bold; + } + #topbar { + background: black; + position: fixed; + top: 0; left: 0; right: 0; + height: 42px; + border-bottom: 1px solid rgb(80, 80, 80); + } + #content { + margin-top: 50px; + } + #nav, #legend { + float: left; + margin-left: 10px; + } + #legend { + margin-top: 12px; + } + #nav { + margin-top: 10px; + } + #legend span { + margin: 0 5px; + } + .cov0 { color: rgb(192, 0, 0) } +.cov1 { color: rgb(128, 128, 128) } +.cov2 { color: rgb(116, 140, 131) } +.cov3 { color: rgb(104, 152, 134) } +.cov4 { color: rgb(92, 164, 137) } +.cov5 { color: rgb(80, 176, 140) } +.cov6 { color: rgb(68, 188, 143) } +.cov7 { color: rgb(56, 200, 146) } +.cov8 { color: rgb(44, 212, 149) } +.cov9 { color: rgb(32, 224, 152) } +.cov10 { color: rgb(20, 236, 155) } + + </style> + </head> + <body> + <div id="topbar"> + <div id="nav"> + <select id="files"> + + <option value="file0">github.com/SermoDigital/jose/base64.go (100.0%)</option> + + <option value="file1">github.com/SermoDigital/jose/header.go (58.1%)</option> + + </select> + </div> + <div id="legend"> + <span>not tracked</span> + + <span class="cov0">not covered</span> + <span class="cov8">covered</span> + + </div> + </div> + <div id="content"> + + <pre class="file" id="file0" >package jose + +import "encoding/base64" + +// Encoder is satisfied if the type can marshal itself into a valid +// structure for a JWS. +type Encoder interface { + // Base64 implies T -> JSON -> RawURLEncodingBase64 + Base64() ([]byte, error) +} + +// Base64Decode decodes a base64-encoded byte slice. +func Base64Decode(b []byte) ([]byte, error) <span class="cov8" title="1">{ + buf := make([]byte, base64.RawURLEncoding.DecodedLen(len(b))) + n, err := base64.RawURLEncoding.Decode(buf, b) + return buf[:n], err +}</span> + +// Base64Encode encodes a byte slice. +func Base64Encode(b []byte) []byte <span class="cov8" title="1">{ + buf := make([]byte, base64.RawURLEncoding.EncodedLen(len(b))) + base64.RawURLEncoding.Encode(buf, b) + return buf +}</span> + +// EncodeEscape base64-encodes a byte slice but escapes it for JSON. +// It'll return the format: `"base64"` +func EncodeEscape(b []byte) []byte <span class="cov8" title="1">{ + buf := make([]byte, base64.RawURLEncoding.EncodedLen(len(b))+2) + buf[0] = '"' + base64.RawURLEncoding.Encode(buf[1:], b) + buf[len(buf)-1] = '"' + return buf +}</span> + +// DecodeEscaped decodes a base64-encoded byte slice straight from a JSON +// structure. It assumes it's in the format: `"base64"`, but can handle +// cases where it's not. +func DecodeEscaped(b []byte) ([]byte, error) <span class="cov8" title="1">{ + if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' </span><span class="cov8" title="1">{ + b = b[1 : len(b)-1] + }</span> + <span class="cov8" title="1">return Base64Decode(b)</span> +} +</pre> + + <pre class="file" id="file1" style="display: none">package jose + +import "encoding/json" + +// Header implements a JOSE Header with the addition of some helper +// methods, similar to net/url.Values. +type Header map[string]interface{} + +// Get retrieves the value corresponding with key from the Header. +func (h Header) Get(key string) interface{} <span class="cov8" title="1">{ + if h == nil </span><span class="cov8" title="1">{ + return nil + }</span> + <span class="cov8" title="1">return h[key]</span> +} + +// Set sets Claims[key] = val. It'll overwrite without warning. +func (h Header) Set(key string, val interface{}) <span class="cov8" title="1">{ + h[key] = val +}</span> + +// Del removes the value that corresponds with key from the Header. +func (h Header) Del(key string) <span class="cov8" title="1">{ + delete(h, key) +}</span> + +// Has returns true if a value for the given key exists inside the Header. +func (h Header) Has(key string) bool <span class="cov8" title="1">{ + _, ok := h[key] + return ok +}</span> + +// MarshalJSON implements json.Marshaler for Header. +func (h Header) MarshalJSON() ([]byte, error) <span class="cov8" title="1">{ + if h == nil || len(h) == 0 </span><span class="cov0" title="0">{ + return nil, nil + }</span> + <span class="cov8" title="1">b, err := json.Marshal(map[string]interface{}(h)) + if err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">return EncodeEscape(b), nil</span> +} + +// Base64 implements the Encoder interface. +func (h Header) Base64() ([]byte, error) <span class="cov0" title="0">{ + return h.MarshalJSON() +}</span> + +// UnmarshalJSON implements json.Unmarshaler for Header. +func (h *Header) UnmarshalJSON(b []byte) error <span class="cov8" title="1">{ + if b == nil </span><span class="cov0" title="0">{ + return nil + }</span> + + <span class="cov8" title="1">b, err := DecodeEscaped(b) + if err != nil </span><span class="cov0" title="0">{ + return err + }</span> + + // Since json.Unmarshal calls UnmarshalJSON, + // calling json.Unmarshal on *p would be infinitely recursive + // A temp variable is needed because &map[string]interface{}(*p) is + // invalid Go. + + <span class="cov8" title="1">tmp := map[string]interface{}(*h) + if err = json.Unmarshal(b, &tmp); err != nil </span><span class="cov0" title="0">{ + return err + }</span> + <span class="cov8" title="1">*h = Header(tmp) + return nil</span> +} + +// Protected Headers are base64-encoded after they're marshaled into +// JSON. +type Protected Header + +// Get retrieves the value corresponding with key from the Protected Header. +func (p Protected) Get(key string) interface{} <span class="cov0" title="0">{ + if p == nil </span><span class="cov0" title="0">{ + return nil + }</span> + <span class="cov0" title="0">return p[key]</span> +} + +// Set sets Protected[key] = val. It'll overwrite without warning. +func (p Protected) Set(key string, val interface{}) <span class="cov0" title="0">{ + p[key] = val +}</span> + +// Del removes the value that corresponds with key from the Protected Header. +func (p Protected) Del(key string) <span class="cov0" title="0">{ + delete(p, key) +}</span> + +// Has returns true if a value for the given key exists inside the Protected +// Header. +func (p Protected) Has(key string) bool <span class="cov0" title="0">{ + _, ok := p[key] + return ok +}</span> + +// MarshalJSON implements json.Marshaler for Protected. +func (p Protected) MarshalJSON() ([]byte, error) <span class="cov8" title="1">{ + b, err := json.Marshal(map[string]interface{}(p)) + if err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">return EncodeEscape(b), nil</span> +} + +// Base64 implements the Encoder interface. +func (p Protected) Base64() ([]byte, error) <span class="cov0" title="0">{ + b, err := json.Marshal(map[string]interface{}(p)) + if err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov0" title="0">return Base64Encode(b), nil</span> +} + +// UnmarshalJSON implements json.Unmarshaler for Protected. +func (p *Protected) UnmarshalJSON(b []byte) error <span class="cov8" title="1">{ + var h Header + h.UnmarshalJSON(b) + *p = Protected(h) + return nil +}</span> + +var ( + _ json.Marshaler = (Protected)(nil) + _ json.Unmarshaler = (*Protected)(nil) +) +</pre> + + </div> + </body> + <script> + (function() { + var files = document.getElementById('files'); + var visible = document.getElementById('file0'); + files.addEventListener('change', onChange, false); + function onChange() { + visible.style.display = 'none'; + visible = document.getElementById(files.value); + visible.style.display = 'block'; + window.scrollTo(0, 0); + } + })(); + </script> +</html>
diff --git a/header_test.go b/header_test.go index ead6e57..e6103f2 100644 --- a/header_test.go +++ b/header_test.go
@@ -25,3 +25,54 @@ Error(t, p["alg"], p2["alg"]) } } + +func TestMarshalHeader(t *testing.T) { + h := Header{ + "alg": "HM256", + } + + b, err := json.Marshal(h) + if err != nil { + t.Error(err) + } + + var p2 Protected + + if json.Unmarshal(b, &p2); err != nil { + t.Error(err) + } + + if p2["alg"] != h["alg"] { + Error(t, h["alg"], p2["alg"]) + } +} + +func TestBasicHeaderFunctions(t *testing.T) { + var h Header + + if v := h.Get("b"); v != nil { + Error(t, nil, v) + } + + h = Header{} + + h.Set("a", "b") + + if v := h.Get("a"); v != "b" { + Error(t, "a", v) + } + + if !h.Has("a") { + t.Error("h should have `a`") + } + + if v := h.Get("b"); v != nil { + Error(t, nil, v) + } + + h.Del("a") + + if v := h.Get("a"); v != nil { + Error(t, nil, v) + } +}
diff --git a/jws/claims.go b/jws/claims.go new file mode 100644 index 0000000..c866156 --- /dev/null +++ b/jws/claims.go
@@ -0,0 +1,189 @@ +package jws + +import ( + "encoding/json" + + "github.com/SermoDigital/jose" + "github.com/SermoDigital/jose/jwt" +) + +// Claims represents a set of JOSE Claims. +type Claims jwt.Claims + +// Get retrieves the value corresponding with key from the Claims. +func (c Claims) Get(key string) interface{} { + return jwt.Claims(c).Get(key) +} + +// Set sets Claims[key] = val. It'll overwrite without warning. +func (c Claims) Set(key string, val interface{}) { + jwt.Claims(c).Set(key, val) +} + +// Del removes the value that corresponds with key from the Claims. +func (c Claims) Del(key string) { + jwt.Claims(c).Del(key) +} + +// Has returns true if a value for the given key exists inside the Claims. +func (c Claims) Has(key string) bool { + return jwt.Claims(c).Has(key) +} + +// MarshalJSON implements json.Marshaler for Claims. +func (c Claims) MarshalJSON() ([]byte, error) { + return jwt.Claims(c).MarshalJSON() +} + +// Base64 implements the Encoder interface. +func (c Claims) Base64() ([]byte, error) { + return jwt.Claims(c).Base64() +} + +// UnmarshalJSON implements json.Unmarshaler for Claims. +func (c *Claims) UnmarshalJSON(b []byte) error { + if b == nil { + return nil + } + + b, err := jose.DecodeEscaped(b) + if err != nil { + return err + } + + // Since json.Unmarshal calls UnmarshalJSON, + // calling json.Unmarshal on *p would be infinitely recursive + // A temp variable is needed because &map[string]interface{}(*p) is + // invalid Go. + + tmp := map[string]interface{}(*c) + if err = json.Unmarshal(b, &tmp); err != nil { + return err + } + *c = Claims(tmp) + return nil +} + +// Issuer retrieves claim "iss" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.1 +func (c Claims) Issuer() (string, bool) { + return jwt.Claims(c).Issuer() +} + +// Subject retrieves claim "sub" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.2 +func (c Claims) Subject() (string, bool) { + return jwt.Claims(c).Subject() +} + +// Audience retrieves claim "aud" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.3 +func (c Claims) Audience() (interface{}, bool) { + return jwt.Claims(c).Audience() +} + +// Expiration retrieves claim "exp" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.4 +func (c Claims) Expiration() (int64, bool) { + return jwt.Claims(c).Expiration() +} + +// NotBefore retrieves claim "nbf" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.5 +func (c Claims) NotBefore() (int64, bool) { + return jwt.Claims(c).NotBefore() +} + +// IssuedAt retrieves claim "iat" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.6 +func (c Claims) IssuedAt() (int64, bool) { + return jwt.Claims(c).IssuedAt() +} + +// JWTID retrieves claim "jti" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.7 +func (c Claims) JWTID() (string, bool) { + return jwt.Claims(c).JWTID() +} + +// RemoveIssuer deletes claim "iss" from c. +func (c Claims) RemoveIssuer() { + jwt.Claims(c).RemoveIssuer() +} + +// RemoveSubject deletes claim "sub" from c. +func (c Claims) RemoveSubject() { + jwt.Claims(c).RemoveIssuer() +} + +// RemoveAudience deletes claim "aud" from c. +func (c Claims) RemoveAudience() { + jwt.Claims(c).Audience() +} + +// RemoveExpiration deletes claim "exp" from c. +func (c Claims) RemoveExpiration() { + jwt.Claims(c).RemoveExpiration() +} + +// RemoveNotBefore deletes claim "nbf" from c. +func (c Claims) RemoveNotBefore() { + jwt.Claims(c).NotBefore() +} + +// RemoveIssuedAt deletes claim "iat" from c. +func (c Claims) RemoveIssuedAt() { + jwt.Claims(c).IssuedAt() +} + +// RemoveJWTID deletes claim "jti" from c. +func (c Claims) RemoveJWTID() { + jwt.Claims(c).RemoveJWTID() +} + +// SetIssuer sets claim "iss" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.1 +func (c Claims) SetIssuer(issuer string) { + jwt.Claims(c).SetIssuer(issuer) +} + +// SetSubject sets claim "iss" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.2 +func (c Claims) SetSubject(subject string) { + jwt.Claims(c).SetSubject(subject) +} + +// SetAudience sets claim "aud" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.3 +func (c Claims) SetAudience(audience ...string) { + jwt.Claims(c).SetAudience(audience...) +} + +// SetExpiration sets claim "exp" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.4 +func (c Claims) SetExpiration(expiration int64) { + jwt.Claims(c).SetExpiration(expiration) +} + +// SetNotBefore sets claim "nbf" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.5 +func (c Claims) SetNotBefore(notBefore int64) { + jwt.Claims(c).SetNotBefore(notBefore) +} + +// SetIssuedAt sets claim "iat" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.6 +func (c Claims) SetIssuedAt(issuedAt int64) { + jwt.Claims(c).SetIssuedAt(issuedAt) +} + +// SetJWTID sets claim "jti" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.7 +func (c Claims) SetJWTID(uniqueID string) { + jwt.Claims(c).SetJWTID(uniqueID) +} + +var ( + _ json.Marshaler = (Claims)(nil) + _ json.Unmarshaler = (*Claims)(nil) +)
diff --git a/jws/cover.html b/jws/cover.html new file mode 100644 index 0000000..6357801 --- /dev/null +++ b/jws/cover.html
@@ -0,0 +1,1279 @@ + +<!DOCTYPE html> +<html> + <head> + <meta http-equiv="Content-Type" content="text/html; charset=utf-8"> + <style> + body { + background: black; + color: rgb(80, 80, 80); + } + body, pre, #legend span { + font-family: Menlo, monospace; + font-weight: bold; + } + #topbar { + background: black; + position: fixed; + top: 0; left: 0; right: 0; + height: 42px; + border-bottom: 1px solid rgb(80, 80, 80); + } + #content { + margin-top: 50px; + } + #nav, #legend { + float: left; + margin-left: 10px; + } + #legend { + margin-top: 12px; + } + #nav { + margin-top: 10px; + } + #legend span { + margin: 0 5px; + } + .cov0 { color: rgb(192, 0, 0) } +.cov1 { color: rgb(128, 128, 128) } +.cov2 { color: rgb(116, 140, 131) } +.cov3 { color: rgb(104, 152, 134) } +.cov4 { color: rgb(92, 164, 137) } +.cov5 { color: rgb(80, 176, 140) } +.cov6 { color: rgb(68, 188, 143) } +.cov7 { color: rgb(56, 200, 146) } +.cov8 { color: rgb(44, 212, 149) } +.cov9 { color: rgb(32, 224, 152) } +.cov10 { color: rgb(20, 236, 155) } + + </style> + </head> + <body> + <div id="topbar"> + <div id="nav"> + <select id="files"> + + <option value="file0">github.com/SermoDigital/jose/jws/claims.go (5.4%)</option> + + <option value="file1">github.com/SermoDigital/jose/jws/jws.go (66.3%)</option> + + <option value="file2">github.com/SermoDigital/jose/jws/jws_serialize.go (73.1%)</option> + + <option value="file3">github.com/SermoDigital/jose/jws/jws_validate.go (76.5%)</option> + + <option value="file4">github.com/SermoDigital/jose/jws/jwt.go (75.0%)</option> + + <option value="file5">github.com/SermoDigital/jose/jws/payload.go (81.2%)</option> + + <option value="file6">github.com/SermoDigital/jose/jws/rawbase64.go (100.0%)</option> + + <option value="file7">github.com/SermoDigital/jose/jws/signing_methods.go (84.6%)</option> + + </select> + </div> + <div id="legend"> + <span>not tracked</span> + + <span class="cov0">not covered</span> + <span class="cov8">covered</span> + + </div> + </div> + <div id="content"> + + <pre class="file" id="file0" >package jws + +import ( + "encoding/json" + + "github.com/SermoDigital/jose" + "github.com/SermoDigital/jose/jwt" +) + +// Claims represents a set of JOSE Claims. +type Claims jwt.Claims + +// Get retrieves the value corresponding with key from the Claims. +func (c Claims) Get(key string) interface{} <span class="cov8" title="1">{ + return jwt.Claims(c).Get(key) +}</span> + +// Set sets Claims[key] = val. It'll overwrite without warning. +func (c Claims) Set(key string, val interface{}) <span class="cov0" title="0">{ + jwt.Claims(c).Set(key, val) +}</span> + +// Del removes the value that corresponds with key from the Claims. +func (c Claims) Del(key string) <span class="cov0" title="0">{ + jwt.Claims(c).Del(key) +}</span> + +// Has returns true if a value for the given key exists inside the Claims. +func (c Claims) Has(key string) bool <span class="cov0" title="0">{ + return jwt.Claims(c).Has(key) +}</span> + +// MarshalJSON implements json.Marshaler for Claims. +func (c Claims) MarshalJSON() ([]byte, error) <span class="cov8" title="1">{ + return jwt.Claims(c).MarshalJSON() +}</span> + +// Base64 implements the Encoder interface. +func (c Claims) Base64() ([]byte, error) <span class="cov0" title="0">{ + return jwt.Claims(c).Base64() +}</span> + +// UnmarshalJSON implements json.Unmarshaler for Claims. +func (c *Claims) UnmarshalJSON(b []byte) error <span class="cov0" title="0">{ + if b == nil </span><span class="cov0" title="0">{ + return nil + }</span> + + <span class="cov0" title="0">b, err := jose.DecodeEscaped(b) + if err != nil </span><span class="cov0" title="0">{ + return err + }</span> + + // Since json.Unmarshal calls UnmarshalJSON, + // calling json.Unmarshal on *p would be infinitely recursive + // A temp variable is needed because &map[string]interface{}(*p) is + // invalid Go. + + <span class="cov0" title="0">tmp := map[string]interface{}(*c) + if err = json.Unmarshal(b, &tmp); err != nil </span><span class="cov0" title="0">{ + return err + }</span> + <span class="cov0" title="0">*c = Claims(tmp) + return nil</span> +} + +// Issuer retrieves claim "iss" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.1 +func (c Claims) Issuer() (string, bool) <span class="cov0" title="0">{ + return jwt.Claims(c).Issuer() +}</span> + +// Subject retrieves claim "sub" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.2 +func (c Claims) Subject() (string, bool) <span class="cov0" title="0">{ + return jwt.Claims(c).Subject() +}</span> + +// Audience retrieves claim "aud" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.3 +func (c Claims) Audience() (interface{}, bool) <span class="cov0" title="0">{ + return jwt.Claims(c).Audience() +}</span> + +// Expiration retrieves claim "exp" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.4 +func (c Claims) Expiration() (int64, bool) <span class="cov0" title="0">{ + return jwt.Claims(c).Expiration() +}</span> + +// NotBefore retrieves claim "nbf" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.5 +func (c Claims) NotBefore() (int64, bool) <span class="cov0" title="0">{ + return jwt.Claims(c).NotBefore() +}</span> + +// IssuedAt retrieves claim "iat" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.6 +func (c Claims) IssuedAt() (int64, bool) <span class="cov0" title="0">{ + return jwt.Claims(c).IssuedAt() +}</span> + +// JWTID retrieves claim "jti" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.7 +func (c Claims) JWTID() (string, bool) <span class="cov0" title="0">{ + return jwt.Claims(c).JWTID() +}</span> + +// RemoveIssuer deletes claim "iss" from c. +func (c Claims) RemoveIssuer() <span class="cov0" title="0">{ + jwt.Claims(c).RemoveIssuer() +}</span> + +// RemoveSubject deletes claim "sub" from c. +func (c Claims) RemoveSubject() <span class="cov0" title="0">{ + jwt.Claims(c).RemoveIssuer() +}</span> + +// RemoveAudience deletes claim "aud" from c. +func (c Claims) RemoveAudience() <span class="cov0" title="0">{ + jwt.Claims(c).Audience() +}</span> + +// RemoveExpiration deletes claim "exp" from c. +func (c Claims) RemoveExpiration() <span class="cov0" title="0">{ + jwt.Claims(c).RemoveExpiration() +}</span> + +// RemoveNotBefore deletes claim "nbf" from c. +func (c Claims) RemoveNotBefore() <span class="cov0" title="0">{ + jwt.Claims(c).NotBefore() +}</span> + +// RemoveIssuedAt deletes claim "iat" from c. +func (c Claims) RemoveIssuedAt() <span class="cov0" title="0">{ + jwt.Claims(c).IssuedAt() +}</span> + +// RemoveJWTID deletes claim "jti" from c. +func (c Claims) RemoveJWTID() <span class="cov0" title="0">{ + jwt.Claims(c).RemoveJWTID() +}</span> + +// SetIssuer sets claim "iss" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.1 +func (c Claims) SetIssuer(issuer string) <span class="cov0" title="0">{ + jwt.Claims(c).SetIssuer(issuer) +}</span> + +// SetSubject sets claim "iss" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.2 +func (c Claims) SetSubject(subject string) <span class="cov0" title="0">{ + jwt.Claims(c).SetSubject(subject) +}</span> + +// SetAudience sets claim "aud" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.3 +func (c Claims) SetAudience(audience ...string) <span class="cov0" title="0">{ + jwt.Claims(c).SetAudience(audience...) +}</span> + +// SetExpiration sets claim "exp" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.4 +func (c Claims) SetExpiration(expiration int64) <span class="cov0" title="0">{ + jwt.Claims(c).SetExpiration(expiration) +}</span> + +// SetNotBefore sets claim "nbf" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.5 +func (c Claims) SetNotBefore(notBefore int64) <span class="cov0" title="0">{ + jwt.Claims(c).SetNotBefore(notBefore) +}</span> + +// SetIssuedAt sets claim "iat" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.6 +func (c Claims) SetIssuedAt(issuedAt int64) <span class="cov0" title="0">{ + jwt.Claims(c).SetIssuedAt(issuedAt) +}</span> + +// SetJWTID sets claim "jti" per its type in +// https://tools.ietf.org/html/rfc7519#section-4.1.7 +func (c Claims) SetJWTID(uniqueID string) <span class="cov0" title="0">{ + jwt.Claims(c).SetJWTID(uniqueID) +}</span> + +var ( + _ json.Marshaler = (Claims)(nil) + _ json.Unmarshaler = (*Claims)(nil) +) +</pre> + + <pre class="file" id="file1" style="display: none">package jws + +import ( + "bytes" + "encoding/json" + + "github.com/SermoDigital/jose" + "github.com/SermoDigital/jose/crypto" +) + +// JWS implements a JWS per RFC 7515. +type JWS interface { + // Payload Returns the payload. + Payload() interface{} + + // SetPayload sets the payload with the given value. + SetPayload(interface{}) + + // Protected returns the JWS' Protected Header. + // i represents the index of the Protected Header. + // Left empty, it defaults to 0. + Protected(...int) jose.Protected + + // Header returns the JWS' unprotected Header. + // i represents the index of the Protected Header. + // Left empty, it defaults to 0. + Header(...int) jose.Header + + // Verify validates the current JWS' signature as-is. Refer to + // ValidateMulti for more information. + Verify(key interface{}, method crypto.SigningMethod) error + + // ValidateMulti validates the current JWS' signature as-is. Since it's + // meant to be called after parsing a stream of bytes into a JWS, it + // shouldn't do any internal parsing like the Sign, Flat, Compact, or + // General methods do. + VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o *SigningOpts) error + + // VerifyCallback validates the current JWS' signature as-is. It + // accepts a callback function that can be used to access header + // parameters to lookup needed information. For example, looking + // up the "kid" parameter. + // The return slice must be a slice of keys used in the verification + // of the JWS. + VerifyCallback(fn VerifyCallback, methods []crypto.SigningMethod, o *SigningOpts) error + + // General serializes the JWS into its "general" form per + // https://tools.ietf.org/html/rfc7515#section-7.2.1 + General(keys ...interface{}) ([]byte, error) + + // Flat serializes the JWS to its "flattened" form per + // https://tools.ietf.org/html/rfc7515#section-7.2.2 + Flat(key interface{}) ([]byte, error) + + // Compact serializes the JWS into its "compact" form per + // https://tools.ietf.org/html/rfc7515#section-7.1 + Compact(key interface{}) ([]byte, error) + + // IsJWT returns true if the JWS is a JWT. + IsJWT() bool +} + +// jws represents a specific jws. +type jws struct { + payload *payload + plcache rawBase64 + clean bool + + sb []sigHead + + isJWT bool +} + +// Payload returns the jws' payload. +func (j *jws) Payload() interface{} <span class="cov8" title="1">{ return j.payload.v }</span> + +// SetPayload sets the jws' raw, unexported payload. +func (j *jws) SetPayload(val interface{}) <span class="cov8" title="1">{ j.payload.v = val }</span> + +// Protected returns the JWS' Protected Header. +// i represents the index of the Protected Header. +// Left empty, it defaults to 0. +func (j *jws) Protected(i ...int) jose.Protected <span class="cov0" title="0">{ + if len(i) == 0 </span><span class="cov0" title="0">{ + return j.sb[0].protected + }</span> + <span class="cov0" title="0">return j.sb[i[0]].protected</span> +} + +// Header returns the JWS' unprotected Header. +// i represents the index of the Protected Header. +// Left empty, it defaults to 0. +func (j *jws) Header(i ...int) jose.Header <span class="cov0" title="0">{ + if len(i) == 0 </span><span class="cov0" title="0">{ + return j.sb[0].unprotected + }</span> + <span class="cov0" title="0">return j.sb[i[0]].unprotected</span> +} + +// sigHead represents the 'signatures' member of the jws' "general" +// serialization form per +// https://tools.ietf.org/html/rfc7515#section-7.2.1 +// +// It's embedded inside the "flat" structure in order to properly +// create the "flat" jws. +type sigHead struct { + Protected rawBase64 `json:"protected,omitempty"` + Unprotected rawBase64 `json:"header,omitempty"` + Signature crypto.Signature `json:"signature"` + + protected jose.Protected + unprotected jose.Header + clean bool + + method crypto.SigningMethod +} + +func (s *sigHead) unmarshal() error <span class="cov8" title="1">{ + if err := s.protected.UnmarshalJSON(s.Protected); err != nil </span><span class="cov0" title="0">{ + return err + }</span> + <span class="cov8" title="1">if err := s.unprotected.UnmarshalJSON(s.Unprotected); err != nil </span><span class="cov0" title="0">{ + return err + }</span> + <span class="cov8" title="1">return nil</span> +} + +// New creates a JWS with the provided crypto.SigningMethods. +func New(content interface{}, methods ...crypto.SigningMethod) JWS <span class="cov8" title="1">{ + sb := make([]sigHead, len(methods)) + for i := range methods </span><span class="cov8" title="1">{ + sb[i] = sigHead{ + protected: jose.Protected{ + "alg": methods[i].Alg(), + }, + unprotected: jose.Header{}, + method: methods[i], + } + }</span> + <span class="cov8" title="1">return &jws{ + payload: &payload{v: content}, + sb: sb, + }</span> +} + +func (s *sigHead) assignMethod(p jose.Protected) error <span class="cov8" title="1">{ + alg, ok := p.Get("alg").(string) + if !ok </span><span class="cov0" title="0">{ + return ErrNoAlgorithm + }</span> + + <span class="cov8" title="1">sm := GetSigningMethod(alg) + if sm == nil </span><span class="cov0" title="0">{ + return ErrNoAlgorithm + }</span> + + <span class="cov8" title="1">s.method = sm + return nil</span> +} + +type generic struct { + Payload rawBase64 `json:"payload"` + sigHead + Signatures []sigHead `json:"signatures,omitempty"` +} + +// Parse parses any of the three serialized jws forms into a physical +// jws per https://tools.ietf.org/html/rfc7515#section-5.2 +// +// It accepts a json.Unmarshaler in order to properly parse +// the payload. In order to keep the caller from having to do extra +// parsing of the payload, a json.Unmarshaler can be passed +// which will be then to unmarshal the payload however the caller +// wishes. Do note that if json.Unmarshal returns an error the +// original payload will be used as if no json.Unmarshaler was +// passed. +// +// Internally, Parse applies some heuristics and then calls either +// ParseGeneral, ParseFlat, or ParseCompact. +// It should only be called if, for whatever reason, you do not +// know which form the serialized JWT is in. +// +// It cannot parse a JWT. +func Parse(encoded []byte, u ...json.Unmarshaler) (JWS, error) <span class="cov8" title="1">{ + // Try and unmarshal into a generic struct that'll + // hopefully hold either of the two JSON serialization + // formats. + var g generic + + // Not valid JSON. Let's try compact. + if err := json.Unmarshal(encoded, &g); err != nil </span><span class="cov0" title="0">{ + return ParseCompact(encoded, u...) + }</span> + + <span class="cov8" title="1">if g.Signatures == nil </span><span class="cov8" title="1">{ + return g.parseFlat(u...) + }</span> + <span class="cov0" title="0">return g.parseGeneral(u...)</span> +} + +// ParseGeneral parses a jws serialized into its "general" form per +// https://tools.ietf.org/html/rfc7515#section-7.2.1 +// into a physical jws per +// https://tools.ietf.org/html/rfc7515#section-5.2 +// +// For information on the json.Unmarshaler parameter, see Parse. +func ParseGeneral(encoded []byte, u ...json.Unmarshaler) (JWS, error) <span class="cov8" title="1">{ + var g generic + if err := json.Unmarshal(encoded, &g); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">return g.parseGeneral(u...)</span> +} + +func (g *generic) parseGeneral(u ...json.Unmarshaler) (JWS, error) <span class="cov8" title="1">{ + + var p payload + if len(u) > 0 </span><span class="cov0" title="0">{ + p.u = u[0] + }</span> + + <span class="cov8" title="1">if err := p.UnmarshalJSON(g.Payload); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + <span class="cov8" title="1">for i := range g.Signatures </span><span class="cov8" title="1">{ + if err := g.Signatures[i].unmarshal(); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">if err := checkHeaders(jose.Header(g.Signatures[i].protected), g.Signatures[i].unprotected); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + <span class="cov8" title="1">if err := g.Signatures[i].assignMethod(g.Signatures[i].protected); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + <span class="cov8" title="1">g.clean = true</span> + } + + <span class="cov8" title="1">return &jws{ + payload: &p, + plcache: g.Payload, + clean: true, + sb: g.Signatures, + }, nil</span> +} + +// ParseFlat parses a jws serialized into its "flat" form per +// https://tools.ietf.org/html/rfc7515#section-7.2.2 +// into a physical jws per +// https://tools.ietf.org/html/rfc7515#section-5.2 +// +// For information on the json.Unmarshaler parameter, see Parse. +func ParseFlat(encoded []byte, u ...json.Unmarshaler) (JWS, error) <span class="cov8" title="1">{ + var g generic + if err := json.Unmarshal(encoded, &g); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">return g.parseFlat(u...)</span> +} + +func (g *generic) parseFlat(u ...json.Unmarshaler) (JWS, error) <span class="cov8" title="1">{ + + var p payload + if len(u) > 0 </span><span class="cov8" title="1">{ + p.u = u[0] + }</span> + + <span class="cov8" title="1">if err := p.UnmarshalJSON(g.Payload); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + <span class="cov8" title="1">if err := g.sigHead.unmarshal(); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">g.sigHead.clean = true + + if err := checkHeaders(jose.Header(g.sigHead.protected), g.sigHead.unprotected); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + <span class="cov8" title="1">if err := g.sigHead.assignMethod(g.sigHead.protected); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + <span class="cov8" title="1">return &jws{ + payload: &p, + plcache: g.Payload, + clean: true, + sb: []sigHead{g.sigHead}, + }, nil</span> +} + +// ParseCompact parses a jws serialized into its "compact" form per +// https://tools.ietf.org/html/rfc7515#section-7.1 +// into a physical jws per +// https://tools.ietf.org/html/rfc7515#section-5.2 +// +// For information on the json.Unmarshaler parameter, see Parse. +func ParseCompact(encoded []byte, u ...json.Unmarshaler) (JWS, error) <span class="cov8" title="1">{ + return parseCompact(encoded, false) +}</span> + +func parseCompact(encoded []byte, jwt bool) (*jws, error) <span class="cov8" title="1">{ + + // This section loosely follows + // https://tools.ietf.org/html/rfc7519#section-7.2 + // because it's used to parse _both_ jws and JWTs. + + parts := bytes.Split(encoded, []byte{'.'}) + if len(parts) != 3 </span><span class="cov0" title="0">{ + return nil, ErrNotCompact + }</span> + + <span class="cov8" title="1">var p jose.Protected + if err := p.UnmarshalJSON(parts[0]); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + <span class="cov8" title="1">s := sigHead{ + Protected: parts[0], + protected: p, + Signature: parts[2], + clean: true, + } + + if err := s.assignMethod(p); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + <span class="cov8" title="1">j := jws{ + payload: &payload{}, + plcache: parts[1], + sb: []sigHead{s}, + isJWT: jwt, + } + + if err := j.payload.UnmarshalJSON(parts[1]); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + <span class="cov8" title="1">j.clean = true + + if err := j.sb[0].Signature.UnmarshalJSON(parts[2]); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + // https://tools.ietf.org/html/rfc7519#section-7.2.8 + <span class="cov8" title="1">cty, ok := p.Get("cty").(string) + if ok && cty == "JWT" </span><span class="cov0" title="0">{ + return &j, ErrHoldsJWE + }</span> + <span class="cov8" title="1">return &j, nil</span> +} + +// IgnoreDupes should be set to true if the internal duplicate header key check +// should ignore duplicate Header keys instead of reporting an error when +// duplicate Header keys are found. +// +// Note: +// Duplicate Header keys are defined in +// https://tools.ietf.org/html/rfc7515#section-5.2 +// meaning keys that both the protected and unprotected +// Headers possess. +var IgnoreDupes bool + +// checkHeaders returns an error per the constraints described in +// IgnoreDupes' comment. +func checkHeaders(a, b jose.Header) error <span class="cov8" title="1">{ + if len(a)+len(b) == 0 </span><span class="cov0" title="0">{ + return ErrTwoEmptyHeaders + }</span> + <span class="cov8" title="1">for key := range a </span><span class="cov8" title="1">{ + if b.Has(key) && !IgnoreDupes </span><span class="cov0" title="0">{ + return ErrDuplicateHeaderParameter + }</span> + } + <span class="cov8" title="1">return nil</span> +} + +var _ JWS = (*jws)(nil) +</pre> + + <pre class="file" id="file2" style="display: none">package jws + +import ( + "bytes" + "encoding/json" +) + +// Flat serializes the JWS to its "flattened" form per +// https://tools.ietf.org/html/rfc7515#section-7.2.2 +func (j *jws) Flat(key interface{}) ([]byte, error) <span class="cov8" title="1">{ + if len(j.sb) < 1 </span><span class="cov0" title="0">{ + return nil, ErrNotEnoughMethods + }</span> + <span class="cov8" title="1">if err := j.sign(key); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">return json.Marshal(struct { + Payload rawBase64 `json:"payload"` + sigHead + }{ + Payload: j.plcache, + sigHead: j.sb[0], + })</span> +} + +// General serializes the JWS into its "general" form per +// https://tools.ietf.org/html/rfc7515#section-7.2.1 +// +// If only one key is passed it's used for all the provided +// crypto.SigningMethods. Otherwise, len(keys) must equal the number +// of crypto.SigningMethods added. +func (j *jws) General(keys ...interface{}) ([]byte, error) <span class="cov8" title="1">{ + if err := j.sign(keys...); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">return json.Marshal(struct { + Payload rawBase64 `json:"payload"` + Signatures []sigHead `json:"signatures"` + }{ + Payload: j.plcache, + Signatures: j.sb, + })</span> +} + +// Compact serializes the JWS into its "compact" form per +// https://tools.ietf.org/html/rfc7515#section-7.1 +func (j *jws) Compact(key interface{}) ([]byte, error) <span class="cov8" title="1">{ + if len(j.sb) < 1 </span><span class="cov0" title="0">{ + return nil, ErrNotEnoughMethods + }</span> + + <span class="cov8" title="1">if err := j.sign(key); err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + + <span class="cov8" title="1">sig, err := j.sb[0].Signature.Base64() + if err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">return format( + j.sb[0].Protected, + j.plcache, + sig, + ), nil</span> +} + +// sign signs each index of j's sb member. +func (j *jws) sign(keys ...interface{}) error <span class="cov8" title="1">{ + if err := j.cache(); err != nil </span><span class="cov0" title="0">{ + return err + }</span> + + <span class="cov8" title="1">if len(keys) < 1 || + len(keys) > 1 && len(keys) != len(j.sb) </span><span class="cov0" title="0">{ + return ErrNotEnoughKeys + }</span> + + <span class="cov8" title="1">if len(keys) == 1 </span><span class="cov8" title="1">{ + k := keys[0] + keys = make([]interface{}, len(j.sb)) + for i := range keys </span><span class="cov8" title="1">{ + keys[i] = k + }</span> + } + + <span class="cov8" title="1">for i := range j.sb </span><span class="cov8" title="1">{ + if err := j.sb[i].cache(); err != nil </span><span class="cov0" title="0">{ + return err + }</span> + + <span class="cov8" title="1">raw := format(j.sb[i].Protected, j.plcache) + sig, err := j.sb[i].method.Sign(raw, keys[i]) + if err != nil </span><span class="cov0" title="0">{ + return err + }</span> + <span class="cov8" title="1">j.sb[i].Signature = sig</span> + } + + <span class="cov8" title="1">return nil</span> +} + +// cache marshals the payload, but only if it's changed since the last cache. +func (j *jws) cache() error <span class="cov8" title="1">{ + if !j.clean </span><span class="cov8" title="1">{ + var err error + j.plcache, err = j.payload.Base64() + j.clean = err == nil + return err + }</span> + <span class="cov0" title="0">return nil</span> +} + +// cache marshals the protected and unprotected headers, but only if +// they've changed since their last cache. +func (s *sigHead) cache() error <span class="cov8" title="1">{ + if !s.clean </span><span class="cov8" title="1">{ + var err error + + s.Protected, err = s.protected.Base64() + if err != nil </span><span class="cov0" title="0">{ + goto err_return</span> + } + + <span class="cov8" title="1">s.Unprotected, err = s.unprotected.Base64() + if err != nil </span><span class="cov0" title="0">{ + goto err_return</span> + } + + <span class="cov8" title="1">err_return: + s.clean = err == nil + return err</span> + } + <span class="cov0" title="0">return nil</span> +} + +// format formats a slice of bytes in the order given, joining +// them with a period. +func format(a ...[]byte) []byte <span class="cov8" title="1">{ + return bytes.Join(a, []byte{'.'}) +}</span> +</pre> + + <pre class="file" id="file3" style="display: none">package jws + +import ( + "fmt" + + "github.com/SermoDigital/jose/crypto" +) + +// VerifyCallback is a callback function that can be used to access header +// parameters to lookup needed information. For example, looking +// up the "kid" parameter. +// The return slice must be a slice of keys used in the verification +// of the JWS. +type VerifyCallback func(JWS) ([]interface{}, error) + +// VerifyCallback validates the current JWS' signature as-is. It +// accepts a callback function that can be used to access header +// parameters to lookup needed information. For example, looking +// up the "kid" parameter. +// The return slice must be a slice of keys used in the verification +// of the JWS. +func (j *jws) VerifyCallback(fn VerifyCallback, methods []crypto.SigningMethod, o *SigningOpts) error <span class="cov8" title="1">{ + keys, err := fn(j) + if err != nil </span><span class="cov8" title="1">{ + return err + }</span> + <span class="cov8" title="1">return j.VerifyMulti(keys, methods, o)</span> +} + +// IsMultiError returns true if the given error is type MultiError. +func IsMultiError(err error) bool <span class="cov0" title="0">{ + _, ok := err.(MultiError) + return ok +}</span> + +// MultiError is a slice of errors. +type MultiError []error + +func (m MultiError) sanityCheck() error <span class="cov8" title="1">{ + if m == nil </span><span class="cov8" title="1">{ + return nil + }</span> + <span class="cov8" title="1">return m</span> +} + +// Errors implements the error interface. +func (m MultiError) Error() string <span class="cov0" title="0">{ + s, n := "", 0 + for _, e := range m </span><span class="cov0" title="0">{ + if e != nil </span><span class="cov0" title="0">{ + if n == 0 </span><span class="cov0" title="0">{ + s = e.Error() + }</span> + <span class="cov0" title="0">n++</span> + } + } + <span class="cov0" title="0">switch n </span>{ + <span class="cov0" title="0">case 0: + return "(0 errors)"</span> + <span class="cov0" title="0">case 1: + return s</span> + <span class="cov0" title="0">case 2: + return s + " (and 1 other error)"</span> + } + <span class="cov0" title="0">return fmt.Sprintf("%s (and %d other errors)", s, n-1)</span> +} + +// Any means any of the JWS signatures need to verify. +// Refer to verifyMulti for more information. +const Any int = 0 + +// VerifyMulti verifies the current JWS as-is. Since it's meant to be +// called after parsing a stream of bytes into a JWS, it doesn't do any +// internal parsing like the Sign, Flat, Compact, or General methods do. +func (j *jws) VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o *SigningOpts) error <span class="cov8" title="1">{ + + // Catch a simple mistake. Parameter o is irrelevant in this scenario. + if len(keys) == 1 && + len(methods) == 1 && + len(j.sb) == 1 </span><span class="cov8" title="1">{ + return j.Verify(keys[0], methods[0]) + }</span> + + <span class="cov8" title="1">if len(j.sb) != len(methods) </span><span class="cov8" title="1">{ + return ErrNotEnoughMethods + }</span> + + <span class="cov8" title="1">if len(keys) < 1 || + len(keys) > 1 && len(keys) != len(j.sb) </span><span class="cov8" title="1">{ + return ErrNotEnoughKeys + }</span> + + // TODO do this better. + <span class="cov8" title="1">if len(keys) == 1 </span><span class="cov8" title="1">{ + k := keys[0] + keys = make([]interface{}, len(methods)) + for i := range keys </span><span class="cov8" title="1">{ + keys[i] = k + }</span> + } + + <span class="cov8" title="1">var o2 SigningOpts + if o == nil </span><span class="cov8" title="1">{ + o = &SigningOpts{} + }</span> + + <span class="cov8" title="1">var m MultiError + for i := range j.sb </span><span class="cov8" title="1">{ + err := j.sb[i].verify(j.plcache, keys[i], methods[i]) + if err != nil </span><span class="cov8" title="1">{ + m = append(m, err) + }</span><span class="cov8" title="1"> else { + o2.Inc() + if o.Needs(i) </span><span class="cov8" title="1">{ + o2.Append(i) + }</span> + } + } + + <span class="cov8" title="1">if err := o.Validate(&o2); err != nil </span><span class="cov8" title="1">{ + return err + }</span> + <span class="cov8" title="1">return m.sanityCheck()</span> +} + +// SigningOpts is a struct which holds options for validating +// JWS signatures. +// Number represents the cumulative which signatures need to verify +// in order for the JWS to be considered valid. +// Leave 'Number' empty or set it to the constant 'Any' if any number of +// valid signatures (greater than one) should verify the JWS. +// +// Use the indices of the signatures that need to verify in order +// for the JWS to be considered valid if specific signatures need +// to verify in order for the JWS to be considered valid. +// +// Note: +// The JWS spec requires *at least* one +// signature to verify in order for the JWS to be considered valid. +type SigningOpts struct { + // Minimum of signatures which need to verify. + Number int + + // Indices of specific signatures which need to verify. + Indices []int + ptr int + + _ struct{} +} + +// Append appends x to s's Indices member. +func (s *SigningOpts) Append(x int) <span class="cov8" title="1">{ + s.Indices = append(s.Indices, x) +}</span> + +// Needs returns true if x resides inside s's Indices member +// for the given index. If true, it increments s's internal +// index. It's used to match two SigningOpts Indices members. +func (s *SigningOpts) Needs(x int) bool <span class="cov8" title="1">{ + if s.ptr < len(s.Indices) && + s.Indices[s.ptr] == x </span><span class="cov8" title="1">{ + s.ptr++ + return true + }</span> + <span class="cov8" title="1">return false</span> +} + +// Inc increments s's Number member by one. +func (s *SigningOpts) Inc() <span class="cov8" title="1">{ s.Number++ }</span> + +// Validate returns any errors found while validating the +// provided SigningOpts. The receiver validates the parameter `have`. +// It'll return an error if the passed SigningOpts' Number member is less +// than s's or if the passed SigningOpts' Indices slice isn't equal to s's. +func (s *SigningOpts) Validate(have *SigningOpts) error <span class="cov8" title="1">{ + if have.Number < s.Number || + (s.Indices != nil && + !eq(s.Indices, have.Indices)) </span><span class="cov8" title="1">{ + return ErrNotEnoughValidSignatures + }</span> + <span class="cov8" title="1">return nil</span> +} + +func eq(a, b []int) bool <span class="cov8" title="1">{ + if a == nil && b == nil </span><span class="cov0" title="0">{ + return true + }</span> + <span class="cov8" title="1">if a == nil || b == nil || len(a) != len(b) </span><span class="cov0" title="0">{ + return false + }</span> + <span class="cov8" title="1">for i := range a </span><span class="cov8" title="1">{ + if a[i] != b[i] </span><span class="cov0" title="0">{ + return false + }</span> + } + <span class="cov8" title="1">return true</span> +} + +// Verify verifies the current JWS as-is. Refer to verifyMulti +// for more information. +func (j *jws) Verify(key interface{}, method crypto.SigningMethod) error <span class="cov8" title="1">{ + if len(j.sb) < 1 </span><span class="cov8" title="1">{ + return ErrCannotValidate + }</span> + <span class="cov8" title="1">return j.sb[0].verify(j.plcache, key, method)</span> +} + +func (s *sigHead) verify(pl []byte, key interface{}, method crypto.SigningMethod) error <span class="cov8" title="1">{ + if s.method != method </span><span class="cov8" title="1">{ + return ErrMismatchedAlgorithms + }</span> + <span class="cov8" title="1">return method.Verify(format(s.Protected, pl), s.Signature, key)</span> +} +</pre> + + <pre class="file" id="file4" style="display: none">package jws + +import ( + "time" + + "github.com/SermoDigital/jose/crypto" + "github.com/SermoDigital/jose/jwt" +) + +// NewJWT creates a new JWT with the given claims. +func NewJWT(claims Claims, method crypto.SigningMethod) jwt.JWT <span class="cov8" title="1">{ + j := New(claims, method).(*jws) + j.isJWT = true + return j +}</span> + +// Serialize helps implements jwt.JWT. +func (j *jws) Serialize(key interface{}) ([]byte, error) <span class="cov8" title="1">{ + if j.isJWT </span><span class="cov8" title="1">{ + return j.Compact(key) + }</span> + <span class="cov0" title="0">return nil, ErrIsNotJWT</span> +} + +// Claims helps implements jwt.JWT. +func (j *jws) Claims() jwt.Claims <span class="cov8" title="1">{ + if j.isJWT </span><span class="cov8" title="1">{ + if c, ok := j.payload.v.(Claims); ok </span><span class="cov8" title="1">{ + return jwt.Claims(c) + }</span> + } + <span class="cov0" title="0">return nil</span> +} + +// ParseJWT parses a serialized jwt.JWT into a physical jwt.JWT. +// If its payload isn't a set of claims (or able to be coerced into +// a set of claims) it'll return an error stating the +// JWT isn't a JWT. +func ParseJWT(encoded []byte) (jwt.JWT, error) <span class="cov8" title="1">{ + t, err := parseCompact(encoded, true) + if err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">c, ok := t.Payload().(map[string]interface{}) + if !ok </span><span class="cov0" title="0">{ + return nil, ErrIsNotJWT + }</span> + <span class="cov8" title="1">t.SetPayload(Claims(c)) + return t, nil</span> +} + +// IsJWT returns true if the JWS is a JWT. +func (j *jws) IsJWT() bool <span class="cov0" title="0">{ return j.isJWT }</span> + +func (j *jws) Validate(key interface{}, m crypto.SigningMethod, v ...*jwt.Validator) error <span class="cov8" title="1">{ + if j.isJWT </span><span class="cov8" title="1">{ + if err := j.Verify(key, m); err != nil </span><span class="cov0" title="0">{ + return err + }</span> + <span class="cov8" title="1">var v1 jwt.Validator + if len(v) > 0 </span><span class="cov8" title="1">{ + v1 = *v[0] + }</span> + <span class="cov8" title="1">c, ok := j.payload.v.(Claims) + if ok </span><span class="cov8" title="1">{ + if err := v1.Validate(j); err != nil </span><span class="cov0" title="0">{ + return err + }</span> + <span class="cov8" title="1">return jwt.Claims(c).Validate(time.Now().Unix(), v1.EXP, v1.NBF)</span> + } + } + <span class="cov0" title="0">return ErrIsNotJWT</span> +} + +// Conv converts a func(Claims) error to type jwt.ValidateFunc. +func Conv(fn func(Claims) error) jwt.ValidateFunc <span class="cov8" title="1">{ + if fn == nil </span><span class="cov0" title="0">{ + return nil + }</span> + <span class="cov8" title="1">return func(c jwt.Claims) error </span><span class="cov8" title="1">{ + return fn(Claims(c)) + }</span> +} + +// NewValidator returns a pointer to a jwt.Validator structure containing +// the info to be used in the validation of a JWT. +func NewValidator(c Claims, exp, nbf int64, fn func(Claims) error) *jwt.Validator <span class="cov8" title="1">{ + return &jwt.Validator{ + Expected: jwt.Claims(c), + EXP: exp, + NBF: nbf, + Fn: Conv(fn), + } +}</span> + +var _ jwt.JWT = (*jws)(nil) +</pre> + + <pre class="file" id="file5" style="display: none">package jws + +import ( + "encoding/json" + + "github.com/SermoDigital/jose" +) + +// payload represents the payload of a JWS. +type payload struct { + v interface{} + u json.Unmarshaler + _ struct{} +} + +// MarshalJSON implements json.Marshaler for payload. +func (p *payload) MarshalJSON() ([]byte, error) <span class="cov8" title="1">{ + b, err := json.Marshal(p.v) + if err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">return jose.EncodeEscape(b), nil</span> +} + +// Base64 implements jose.Encoder. +func (p *payload) Base64() ([]byte, error) <span class="cov8" title="1">{ + b, err := json.Marshal(p.v) + if err != nil </span><span class="cov0" title="0">{ + return nil, err + }</span> + <span class="cov8" title="1">return jose.Base64Encode(b), nil</span> +} + +// MarshalJSON implements json.Unmarshaler for payload. +func (p *payload) UnmarshalJSON(b []byte) error <span class="cov8" title="1">{ + b2, err := jose.DecodeEscaped(b) + if err != nil </span><span class="cov0" title="0">{ + return err + }</span> + + <span class="cov8" title="1">if p.u != nil </span><span class="cov8" title="1">{ + err := p.u.UnmarshalJSON(b2) + p.v = p.u + return err + }</span> + + <span class="cov8" title="1">return json.Unmarshal(b2, &p.v)</span> +} + +var ( + _ json.Marshaler = (*payload)(nil) + _ json.Unmarshaler = (*payload)(nil) + _ jose.Encoder = (*payload)(nil) +) +</pre> + + <pre class="file" id="file6" style="display: none">package jws + +import "encoding/json" + +type rawBase64 []byte + +// MarshalJSON implements json.Marshaler for rawBase64. +func (r rawBase64) MarshalJSON() ([]byte, error) <span class="cov8" title="1">{ + buf := make([]byte, len(r)+2) + buf[0] = '"' + copy(buf[1:], r) + buf[len(buf)-1] = '"' + return buf, nil +}</span> + +// MarshalJSON implements json.Unmarshaler for rawBase64. +func (r *rawBase64) UnmarshalJSON(b []byte) error <span class="cov8" title="1">{ + if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' </span><span class="cov8" title="1">{ + b = b[1 : len(b)-1] + }</span> + <span class="cov8" title="1">*r = rawBase64(b) + return nil</span> +} + +var ( + _ json.Marshaler = (rawBase64)(nil) + _ json.Unmarshaler = (*rawBase64)(nil) +) +</pre> + + <pre class="file" id="file7" style="display: none">package jws + +import ( + "sync" + + "github.com/SermoDigital/jose/crypto" +) + +var ( + mu = &sync.RWMutex{} + + signingMethods = map[string]crypto.SigningMethod{ + crypto.SigningMethodES256.Alg(): crypto.SigningMethodES256, + crypto.SigningMethodES384.Alg(): crypto.SigningMethodES384, + crypto.SigningMethodES512.Alg(): crypto.SigningMethodES512, + + crypto.SigningMethodPS256.Alg(): crypto.SigningMethodPS256, + crypto.SigningMethodPS384.Alg(): crypto.SigningMethodPS384, + crypto.SigningMethodPS512.Alg(): crypto.SigningMethodPS512, + + crypto.SigningMethodRS256.Alg(): crypto.SigningMethodRS256, + crypto.SigningMethodRS384.Alg(): crypto.SigningMethodRS384, + crypto.SigningMethodRS512.Alg(): crypto.SigningMethodRS512, + + crypto.SigningMethodHS256.Alg(): crypto.SigningMethodHS256, + crypto.SigningMethodHS384.Alg(): crypto.SigningMethodHS384, + crypto.SigningMethodHS512.Alg(): crypto.SigningMethodHS512, + + crypto.Unsecured.Alg(): crypto.Unsecured, + } +) + +// RegisterSigningMethod registers the crypto.SigningMethod in the global map. +// This is typically done inside the caller's init function. +func RegisterSigningMethod(sm crypto.SigningMethod) <span class="cov8" title="1">{ + if GetSigningMethod(sm.Alg()) != nil </span><span class="cov0" title="0">{ + panic("jose/jws: cannot duplicate signing methods")</span> + } + + <span class="cov8" title="1">if !sm.Hasher().Available() </span><span class="cov0" title="0">{ + panic("jose/jws: specific hash is unavailable")</span> + } + + <span class="cov8" title="1">mu.Lock() + signingMethods[sm.Alg()] = sm + mu.Unlock()</span> +} + +// RemoveSigningMethod removes the crypto.SigningMethod from the global map. +func RemoveSigningMethod(sm crypto.SigningMethod) <span class="cov8" title="1">{ + mu.Lock() + delete(signingMethods, sm.Alg()) + mu.Unlock() +}</span> + +// GetSigningMethod retrieves a crypto.SigningMethod from the global map. +func GetSigningMethod(alg string) crypto.SigningMethod <span class="cov8" title="1">{ + mu.RLock() + defer mu.RUnlock() + return signingMethods[alg] +}</span> +</pre> + + </div> + </body> + <script> + (function() { + var files = document.getElementById('files'); + var visible = document.getElementById('file0'); + files.addEventListener('change', onChange, false); + function onChange() { + visible.style.display = 'none'; + visible = document.getElementById(files.value); + visible.style.display = 'block'; + window.scrollTo(0, 0); + } + })(); + </script> +</html>
diff --git a/jws/errors.go b/jws/errors.go index 6f6de54..6120bc1 100644 --- a/jws/errors.go +++ b/jws/errors.go
@@ -52,4 +52,8 @@ // ErrHoldsJWE means the given JWS holds a JWE inside its payload. ErrHoldsJWE = errors.New("JWS holds JWE") + + // ErrNotEnoughValidSignatures means the JWS did not meet the required + // number of signatures. + ErrNotEnoughValidSignatures = errors.New("not enough valid signatures in the JWS") )
diff --git a/jws/jws.go b/jws/jws.go index e448caa..6673c9c 100644 --- a/jws/jws.go +++ b/jws/jws.go
@@ -179,10 +179,12 @@ // ParseGeneral, ParseFlat, or ParseCompact. // It should only be called if, for whatever reason, you do not // know which form the serialized JWT is in. +// +// It cannot parse a JWT. func Parse(encoded []byte, u ...json.Unmarshaler) (JWS, error) { // Try and unmarshal into a generic struct that'll // hopefully hold either of the two JSON serialization - // formats.s + // formats. var g generic // Not valid JSON. Let's try compact. @@ -317,7 +319,9 @@ } s := sigHead{ + Protected: parts[0], protected: p, + Signature: parts[2], clean: true, } @@ -327,6 +331,7 @@ j := jws{ payload: &payload{}, + plcache: parts[1], sb: []sigHead{s}, isJWT: jwt, } @@ -373,3 +378,5 @@ } return nil } + +var _ JWS = (*jws)(nil)
diff --git a/jws/jws_test.go b/jws/jws_test.go index 483d604..7f27500 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go
@@ -4,6 +4,7 @@ "bytes" "encoding/base64" "encoding/json" + "errors" "math/rand" "testing" @@ -121,6 +122,30 @@ } } +func TestVerifyMultiOneKey(t *testing.T) { + sm := []crypto.SigningMethod{ + crypto.SigningMethodRS256, + crypto.SigningMethodPS384, + crypto.SigningMethodPS512, + } + + j := New(easyData, sm...) + b, err := j.General(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseGeneral(b) + if err != nil { + t.Error(err) + } + + keys := []interface{}{rsaPub} + if err := j2.VerifyMulti(keys, sm, nil); err != nil { + t.Error(err) + } +} + func TestVerifyMultiMismatchedAlgs(t *testing.T) { sm := []crypto.SigningMethod{ crypto.SigningMethodRS256, @@ -205,6 +230,64 @@ } } +func TestVerifyMultiSigningOpts(t *testing.T) { + sm := []crypto.SigningMethod{ + crypto.SigningMethodRS256, + crypto.SigningMethodPS384, + crypto.SigningMethodPS512, + } + + j := New(easyData, sm...) + b, err := j.General(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseGeneral(b) + if err != nil { + t.Error(err) + } + + o := SigningOpts{ + Number: 3, + Indices: []int{0, 1, 2}, + } + + keys := []interface{}{rsaPub, rsaPub, rsaPub} + if err := j2.VerifyMulti(keys, sm, &o); err != nil { + t.Error(err) + } +} + +func TestVerifyMultiSigningOptsErr(t *testing.T) { + sm := []crypto.SigningMethod{ + crypto.SigningMethodRS256, + crypto.SigningMethodPS384, + crypto.SigningMethodPS512, + } + + j := New(easyData, sm...) + b, err := j.General(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseGeneral(b) + if err != nil { + t.Error(err) + } + + o := SigningOpts{ + Number: 4, + Indices: []int{0, 1, 2, 3}, + } + + keys := []interface{}{rsaPub, rsaPub, rsaPub} + if err := j2.VerifyMulti(keys, sm, &o); err == nil { + t.Error("Should not be nil!") + } +} + func TestVerify(t *testing.T) { j := New(easyData, crypto.SigningMethodPS512) b, err := j.Flat(rsaPriv) @@ -221,3 +304,62 @@ t.Error(err) } } + +func TestVerifyCallback(t *testing.T) { + j := New(easyData, crypto.SigningMethodPS512) + b, err := j.Flat(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseFlat(b) + if err != nil { + t.Error(err) + } + + cb := func(j JWS) ([]interface{}, error) { + return []interface{}{rsaPub}, nil + } + + if err := j2.VerifyCallback(cb, []crypto.SigningMethod{crypto.SigningMethodPS512}, nil); err != nil { + t.Error(err) + } +} + +func TestVerifyCallbackErr(t *testing.T) { + j := New(easyData, crypto.SigningMethodPS512) + b, err := j.Flat(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseFlat(b) + if err != nil { + t.Error(err) + } + + cb := func(j JWS) ([]interface{}, error) { + return nil, errors.New("k") + } + + if err := j2.VerifyCallback(cb, []crypto.SigningMethod{crypto.SigningMethodPS512}, nil); err == nil { + t.Error("Should not be nil!") + } +} + +func TestVerifyNoSBs(t *testing.T) { + j := New(easyData, crypto.SigningMethodPS512) + b, err := j.Flat(rsaPriv) + if err != nil { + t.Error(err) + } + + j2, err := ParseFlat(b) + if err != nil { + t.Error(err) + } + j2.(*jws).sb = nil + if err := j2.Verify(rsaPub, crypto.SigningMethodPS512); err != ErrCannotValidate { + Error(t, ErrCannotValidate, err) + } +}
diff --git a/jws/jws_validate.go b/jws/jws_validate.go index 453eb8f..1948880 100644 --- a/jws/jws_validate.go +++ b/jws/jws_validate.go
@@ -1,7 +1,6 @@ package jws import ( - "errors" "fmt" "github.com/SermoDigital/jose/crypto" @@ -174,12 +173,10 @@ // It'll return an error if the passed SigningOpts' Number member is less // than s's or if the passed SigningOpts' Indices slice isn't equal to s's. func (s *SigningOpts) Validate(have *SigningOpts) error { - if have.Number < s.Number { - return errors.New("TODO 2") - } - if s.Indices != nil && - !eq(s.Indices, have.Indices) { - return errors.New("TODO 3") + if have.Number < s.Number || + (s.Indices != nil && + !eq(s.Indices, have.Indices)) { + return ErrNotEnoughValidSignatures } return nil }
diff --git a/jws/jwt.go b/jws/jwt.go index c6051dc..08e8301 100644 --- a/jws/jwt.go +++ b/jws/jwt.go
@@ -7,9 +7,6 @@ "github.com/SermoDigital/jose/jwt" ) -// Claims represents a set of JOSE Claims. -type Claims jwt.Claims - // NewJWT creates a new JWT with the given claims. func NewJWT(claims Claims, method crypto.SigningMethod) jwt.JWT { j := New(claims, method).(*jws) @@ -64,7 +61,6 @@ if len(v) > 0 { v1 = *v[0] } - c, ok := j.payload.v.(Claims) if ok { if err := v1.Validate(j); err != nil { @@ -86,13 +82,14 @@ } } -// NewOpts returns a pointer to a jwt.Validator structure containing +// NewValidator returns a pointer to a jwt.Validator structure containing // the info to be used in the validation of a JWT. -func NewOpts(c Claims, exp, nbf int64) *jwt.Validator { +func NewValidator(c Claims, exp, nbf int64, fn func(Claims) error) *jwt.Validator { return &jwt.Validator{ Expected: jwt.Claims(c), EXP: exp, NBF: nbf, + Fn: Conv(fn), } }
diff --git a/jws/jwt_test.go b/jws/jwt_test.go index aac4051..91662a0 100644 --- a/jws/jwt_test.go +++ b/jws/jwt_test.go
@@ -1,7 +1,9 @@ package jws import ( + "errors" "testing" + "time" "github.com/SermoDigital/jose/crypto" ) @@ -39,4 +41,37 @@ w.Claims().Get("scopes").([]string)[0] != "user.account.info" { Error(t, claims, w.Claims()) } + + if err := w.Validate(rsaPub, crypto.SigningMethodRS512); err != nil { + t.Error(err) + } +} + +func TestJWTValidator(t *testing.T) { + j := NewJWT(claims, crypto.SigningMethodRS512) + j.Claims().SetIssuer("example.com") + + b, err := j.Serialize(rsaPriv) + if err != nil { + t.Error(err) + } + + w, err := ParseJWT(b) + if err != nil { + t.Error(err) + } + + d := time.Now().Add(1 * time.Hour).Unix() + fn := func(c Claims) error { + if c.Get("name") != "Eric" && + c.Get("admin") != true && + c.Get("scopes").([]string)[0] != "user.account.info" { + return errors.New("invalid") + } + return nil + } + v := NewValidator(Claims{"iss": "example.com"}, d, d, fn) + if err := w.Validate(rsaPub, crypto.SigningMethodRS512, v); err != nil { + t.Error(err) + } }
diff --git a/jwt/claims.go b/jwt/claims.go index 8700d48..3461c58 100644 --- a/jwt/claims.go +++ b/jwt/claims.go
@@ -60,16 +60,16 @@ if c == nil || len(c) == 0 { return nil, nil } - b, err := json.Marshal(map[string]interface{}(c)) - if err != nil { - return nil, err - } - return jose.EncodeEscape(b), nil + return json.Marshal(map[string]interface{}(c)) } // Base64 implements the Encoder interface. func (c Claims) Base64() ([]byte, error) { - return c.MarshalJSON() + b, err := c.MarshalJSON() + if err != nil { + return nil, err + } + return jose.Base64Encode(b), nil } // UnmarshalJSON implements json.Unmarshaler for Claims.
diff --git a/jwt/jwt.go b/jwt/jwt.go index 1b83c39..172e2f7 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go
@@ -2,6 +2,7 @@ import ( "errors" + "fmt" "github.com/SermoDigital/jose/crypto" ) @@ -59,23 +60,24 @@ func (v *Validator) Validate(j JWT) error { if iss, ok := v.Expected.Issuer(); ok && j.Claims().Get("iss") != iss { - return errors.New("TODO 12") + fmt.Println(iss, j.Claims().Get("iss")) + return errors.New("TODO 1") } if sub, ok := v.Expected.Subject(); ok && j.Claims().Get("sub") != sub { - return errors.New("TODO 12") + return errors.New("TODO 2") } if iat, ok := v.Expected.IssuedAt(); ok && j.Claims().Get("iat") != iat { - return errors.New("TODO 12") + return errors.New("TODO 3") } if jti, ok := v.Expected.JWTID(); ok && j.Claims().Get("jti") != jti { - return errors.New("TODO 12") + return errors.New("TODO 4") } if aud, ok := v.Expected.Audience(); ok && !eq(j.Claims().Get("aud"), aud) { - return errors.New("TODO 12") + return errors.New("TODO 5") } if v.Fn != nil {