| package jws |
| |
| import ( |
| "bytes" |
| "encoding/json" |
| ) |
| |
| // Flat serializes the JWS to its "flattened" form per |
| // https://tools.ietf.org/html/rfc7515#section-7.2.2 |
| func (j *jws) Flat(key interface{}) ([]byte, error) { |
| if len(j.sb) < 1 { |
| return nil, ErrNotEnoughMethods |
| } |
| if err := j.sign(key); err != nil { |
| return nil, err |
| } |
| return json.Marshal(struct { |
| Payload rawBase64 `json:"payload"` |
| sigHead |
| }{ |
| Payload: j.plcache, |
| sigHead: j.sb[0], |
| }) |
| } |
| |
| // General serializes the JWS into its "general" form per |
| // https://tools.ietf.org/html/rfc7515#section-7.2.1 |
| // |
| // If only one key is passed it's used for all the provided |
| // crypto.SigningMethods. Otherwise, len(keys) must equal the number |
| // of crypto.SigningMethods added. |
| func (j *jws) General(keys ...interface{}) ([]byte, error) { |
| if err := j.sign(keys...); err != nil { |
| return nil, err |
| } |
| return json.Marshal(struct { |
| Payload rawBase64 `json:"payload"` |
| Signatures []sigHead `json:"signatures"` |
| }{ |
| Payload: j.plcache, |
| Signatures: j.sb, |
| }) |
| } |
| |
| // Compact serializes the JWS into its "compact" form per |
| // https://tools.ietf.org/html/rfc7515#section-7.1 |
| func (j *jws) Compact(key interface{}) ([]byte, error) { |
| if len(j.sb) < 1 { |
| return nil, ErrNotEnoughMethods |
| } |
| |
| if err := j.sign(key); err != nil { |
| return nil, err |
| } |
| |
| sig, err := j.sb[0].Signature.Base64() |
| if err != nil { |
| return nil, err |
| } |
| return format( |
| j.sb[0].Protected, |
| j.plcache, |
| sig, |
| ), nil |
| } |
| |
| // sign signs each index of j's sb member. |
| func (j *jws) sign(keys ...interface{}) error { |
| if err := j.cache(); err != nil { |
| return err |
| } |
| |
| if len(keys) < 1 || |
| len(keys) > 1 && len(keys) != len(j.sb) { |
| return ErrNotEnoughKeys |
| } |
| |
| if len(keys) == 1 { |
| k := keys[0] |
| keys = make([]interface{}, len(j.sb)) |
| for i := range keys { |
| keys[i] = k |
| } |
| } |
| |
| for i := range j.sb { |
| if err := j.sb[i].cache(); err != nil { |
| return err |
| } |
| |
| raw := format(j.sb[i].Protected, j.plcache) |
| sig, err := j.sb[i].method.Sign(raw, keys[i]) |
| if err != nil { |
| return err |
| } |
| j.sb[i].Signature = sig |
| } |
| |
| return nil |
| } |
| |
| // cache marshals the payload, but only if it's changed since the last cache. |
| func (j *jws) cache() (err error) { |
| if !j.clean { |
| j.plcache, err = j.payload.Base64() |
| j.clean = err == nil |
| } |
| return err |
| } |
| |
| // cache marshals the protected and unprotected headers, but only if |
| // they've changed since their last cache. |
| func (s *sigHead) cache() (err error) { |
| if !s.clean { |
| s.Protected, err = s.protected.Base64() |
| if err != nil { |
| return err |
| } |
| s.Unprotected, err = s.unprotected.Base64() |
| if err != nil { |
| return err |
| } |
| } |
| s.clean = true |
| return nil |
| } |
| |
| // format formats a slice of bytes in the order given, joining |
| // them with a period. |
| func format(a ...[]byte) []byte { |
| return bytes.Join(a, []byte{'.'}) |
| } |