package jwt

import (
	"encoding/json"
	"reflect"
	"time"

	"github.com/SermoDigital/jose"
)

// Claims implements a set of JOSE Claims with the addition of some helper
// methods, similar to net/url.Values.
type Claims map[string]interface{}

// Validate validates the Claims per the claims found in
// https://tools.ietf.org/html/rfc7519#section-4.1
func (c Claims) Validate(now time.Time, expLeeway, nbfLeeway time.Duration) error {
	if exp, ok := c.Expiration(); ok {
		if now.After(exp.Add(expLeeway)) {
			return ErrTokenIsExpired
		}
	}

	if nbf, ok := c.NotBefore(); ok {
		if !now.After(nbf.Add(-nbfLeeway)) {
			return ErrTokenNotYetValid
		}
	}
	return nil
}

// Get retrieves the value corresponding with key from the Claims.
func (c Claims) Get(key string) interface{} {
	if c == nil {
		return nil
	}
	return c[key]
}

// Set sets Claims[key] = val. It'll overwrite without warning.
func (c Claims) Set(key string, val interface{}) {
	c[key] = val
}

// Del removes the value that corresponds with key from the Claims.
func (c Claims) Del(key string) {
	delete(c, key)
}

// Has returns true if a value for the given key exists inside the Claims.
func (c Claims) Has(key string) bool {
	_, ok := c[key]
	return ok
}

// MarshalJSON implements json.Marshaler for Claims.
func (c Claims) MarshalJSON() ([]byte, error) {
	if c == nil || len(c) == 0 {
		return nil, nil
	}
	return json.Marshal(map[string]interface{}(c))
}

// Base64 implements the jose.Encoder interface.
func (c Claims) Base64() ([]byte, error) {
	b, err := c.MarshalJSON()
	if err != nil {
		return nil, err
	}
	return jose.Base64Encode(b), nil
}

// UnmarshalJSON implements json.Unmarshaler for Claims.
func (c *Claims) UnmarshalJSON(b []byte) error {
	if b == nil {
		return nil
	}

	b, err := jose.DecodeEscaped(b)
	if err != nil {
		return err
	}

	// Since json.Unmarshal calls UnmarshalJSON,
	// calling json.Unmarshal on *p would be infinitely recursive
	// A temp variable is needed because &map[string]interface{}(*p) is
	// invalid Go. (Address of unaddressable object and all that...)

	tmp := map[string]interface{}(*c)
	if err = json.Unmarshal(b, &tmp); err != nil {
		return err
	}
	*c = Claims(tmp)
	return nil
}

// Issuer retrieves claim "iss" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.1
func (c Claims) Issuer() (string, bool) {
	v, ok := c.Get("iss").(string)
	return v, ok
}

// Subject retrieves claim "sub" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.2
func (c Claims) Subject() (string, bool) {
	v, ok := c.Get("sub").(string)
	return v, ok
}

// Audience retrieves claim "aud" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.3
func (c Claims) Audience() ([]string, bool) {
	// Audience claim must be stringy. That is, it may be one string
	// or multiple strings but it should not be anything else. E.g. an int.
	switch t := c.Get("aud").(type) {
	case string:
		return []string{t}, true
	case []string:
		return t, true
	case []interface{}:
		return stringify(t...)
	case interface{}:
		return stringify(t)
	}
	return nil, false
}

func stringify(a ...interface{}) ([]string, bool) {
	if len(a) == 0 {
		return nil, false
	}

	s := make([]string, len(a))
	for i := range a {
		str, ok := a[i].(string)
		if !ok {
			return nil, false
		}
		s[i] = str
	}
	return s, true
}

// Expiration retrieves claim "exp" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.4
func (c Claims) Expiration() (time.Time, bool) {
	return c.GetTime("exp")
}

// NotBefore retrieves claim "nbf" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.5
func (c Claims) NotBefore() (time.Time, bool) {
	return c.GetTime("nbf")
}

// IssuedAt retrieves claim "iat" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.6
func (c Claims) IssuedAt() (time.Time, bool) {
	return c.GetTime("iat")
}

// JWTID retrieves claim "jti" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.7
func (c Claims) JWTID() (string, bool) {
	v, ok := c.Get("jti").(string)
	return v, ok
}

// RemoveIssuer deletes claim "iss" from c.
func (c Claims) RemoveIssuer() { c.Del("iss") }

// RemoveSubject deletes claim "sub" from c.
func (c Claims) RemoveSubject() { c.Del("sub") }

// RemoveAudience deletes claim "aud" from c.
func (c Claims) RemoveAudience() { c.Del("aud") }

// RemoveExpiration deletes claim "exp" from c.
func (c Claims) RemoveExpiration() { c.Del("exp") }

// RemoveNotBefore deletes claim "nbf" from c.
func (c Claims) RemoveNotBefore() { c.Del("nbf") }

// RemoveIssuedAt deletes claim "iat" from c.
func (c Claims) RemoveIssuedAt() { c.Del("iat") }

// RemoveJWTID deletes claim "jti" from c.
func (c Claims) RemoveJWTID() { c.Del("jti") }

// SetIssuer sets claim "iss" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.1
func (c Claims) SetIssuer(issuer string) {
	c.Set("iss", issuer)
}

// SetSubject sets claim "iss" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.2
func (c Claims) SetSubject(subject string) {
	c.Set("sub", subject)
}

// SetAudience sets claim "aud" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.3
func (c Claims) SetAudience(audience ...string) {
	if len(audience) == 1 {
		c.Set("aud", audience[0])
	} else {
		c.Set("aud", audience)
	}
}

// SetExpiration sets claim "exp" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.4
func (c Claims) SetExpiration(expiration time.Time) {
	c.SetTime("exp", expiration)
}

// SetNotBefore sets claim "nbf" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.5
func (c Claims) SetNotBefore(notBefore time.Time) {
	c.SetTime("nbf", notBefore)
}

// SetIssuedAt sets claim "iat" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.6
func (c Claims) SetIssuedAt(issuedAt time.Time) {
	c.SetTime("iat", issuedAt)
}

// SetJWTID sets claim "jti" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.7
func (c Claims) SetJWTID(uniqueID string) {
	c.Set("jti", uniqueID)
}

// zero pre-allocs the zero-time value
var zero = time.Time{}

// GetTime returns a UNIX time for the given key.
//
// It converts an int, int32, int64, uint, uint32, uint64 or float64 value
// into a UNIX time (epoch seconds). float32 does not have sufficient
// precision to store a UNIX time.
//
// Numeric values parsed from JSON will always be stored as float64 since
// Claims is a map[string]interface{}. However, internally the values may be
// stored directly in the claims map as different types.
func (c Claims) GetTime(key string) (time.Time, bool) {
	x := c.Get(key)
	if x == nil {
		return zero, false
	}
	v := reflect.ValueOf(x)
	switch v.Kind() {
	case reflect.Int, reflect.Int32, reflect.Int64:
		return time.Unix(v.Int(), 0), true
	case reflect.Uint, reflect.Uint32, reflect.Uint64:
		return time.Unix(int64(v.Uint()), 0), true
	case reflect.Float64:
		return time.Unix(int64(v.Float()), 0), true
	default:
		return zero, false
	}
}

// SetTime stores a UNIX time for the given key.
func (c Claims) SetTime(key string, t time.Time) {
	c.Set(key, t.Unix())
}

var (
	_ json.Marshaler   = (Claims)(nil)
	_ json.Unmarshaler = (*Claims)(nil)
)
