jwt: Fix Audience validation to match spec more closely
diff --git a/jwt/claims.go b/jwt/claims.go
index a44f207..d4f94b8 100644
--- a/jwt/claims.go
+++ b/jwt/claims.go
@@ -84,7 +84,7 @@
// Since json.Unmarshal calls UnmarshalJSON,
// calling json.Unmarshal on *p would be infinitely recursive
// A temp variable is needed because &map[string]interface{}(*p) is
- // invalid Go.
+ // invalid Go. (Address of unaddressable object and all that...)
tmp := map[string]interface{}(*c)
if err = json.Unmarshal(b, &tmp); err != nil {
@@ -111,6 +111,8 @@
// Audience retrieves claim "aud" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.3
func (c Claims) Audience() ([]string, bool) {
+ // Audience claim must be stringy. That is, it may be one string
+ // or multiple strings but it should not be anything else. E.g. an int.
switch t := c.Get("aud").(type) {
case string:
return []string{t}, true
diff --git a/jwt/eq.go b/jwt/eq.go
index a7a37a9..3113269 100644
--- a/jwt/eq.go
+++ b/jwt/eq.go
@@ -1,52 +1,47 @@
package jwt
-import "reflect"
+func verifyPrincipals(pcpls, auds []string) bool {
+ // "Each principal intended to process the JWT MUST
+ // identify itself with a value in the audience claim."
+ // - https://tools.ietf.org/html/rfc7519#section-4.1.3
-// eq returns true if the two types are either strings
-// or comparable slices.
-func eq(a, b interface{}) bool {
- t1 := reflect.TypeOf(a)
- t2 := reflect.TypeOf(b)
-
- if t1.Kind() == t2.Kind() {
- switch t1.Kind() {
- case reflect.Slice:
- return eqSlice(a, b)
- case reflect.String:
- return reflect.ValueOf(a).String() ==
- reflect.ValueOf(b).String()
- }
- }
- return false
-}
-
-// eqSlice returns true if the two interfaces are both slices
-// and are equal. For example: https://play.golang.org/p/5VLMwNE3i-
-func eqSlice(a, b interface{}) bool {
- if a == nil || b == nil {
- return false
- }
-
- v1 := reflect.ValueOf(a)
- v2 := reflect.ValueOf(b)
-
- if v1.Kind() != reflect.Slice ||
- v2.Kind() != reflect.Slice {
- return false
- }
-
- if v1.Len() == v2.Len() && v1.Len() > 0 {
- for i := 0; i < v1.Len() && i < v2.Len(); i++ {
- k1 := v1.Index(i)
- k2 := v2.Index(i)
- if k1.Type().Comparable() &&
- k2.Type().Comparable() &&
- k1.CanInterface() && k2.CanInterface() &&
- k1.Interface() != k2.Interface() {
- return false
+ found := -1
+ for i, p := range pcpls {
+ for _, v := range auds {
+ if p == v {
+ found++
+ break
}
}
- return true
+ if found != i {
+ return false
+ }
}
- return false
+ return true
+}
+
+// ValidAudience returns true iff:
+// - a and b are strings and a == b
+// - a is string, b is []string and a is in b
+// - a is []string, b is []string and all of a is in b
+// - a is []string, b is string and len(a) == 1 and a[0] == b
+func ValidAudience(a, b interface{}) bool {
+ s1, ok := a.(string)
+ if ok {
+ if s2, ok := b.(string); ok {
+ return s1 == s2
+ }
+ a2, ok := b.([]string)
+ return ok && verifyPrincipals([]string{s1}, a2)
+ }
+
+ a1, ok := a.([]string)
+ if !ok {
+ return false
+ }
+ if a2, ok := b.([]string); ok {
+ return verifyPrincipals(a1, a2)
+ }
+ s2, ok := b.(string)
+ return ok && len(a1) == 1 && a1[0] == s2
}
diff --git a/jwt/eq_test.go b/jwt/eq_test.go
new file mode 100644
index 0000000..5f9d4fd
--- /dev/null
+++ b/jwt/eq_test.go
@@ -0,0 +1,26 @@
+package jwt_test
+
+import (
+ "testing"
+
+ "github.com/SermoDigital/jose/jwt"
+)
+
+func TestValidAudience(t *testing.T) {
+ tests := [...]struct {
+ a interface{}
+ b interface{}
+ v bool
+ }{
+ 0: {"https://www.google.com", "https://www.google.com", true},
+ 1: {[]string{"example.com", "google.com"}, []string{"example.com"}, false},
+ 2: {500, 43, false},
+ 3: {"google.com", "facebook.com", false},
+ 4: {[]string{"example.com"}, []string{"example.com", "foo.com"}, true},
+ }
+ for i, v := range tests {
+ if x := jwt.ValidAudience(v.a, v.b); x != v.v {
+ t.Fatalf("#%d: wanted %t, got %t", i, v.v, x)
+ }
+ }
+}
diff --git a/jwt/jwt.go b/jwt/jwt.go
index 18f5758..d29c43a 100644
--- a/jwt/jwt.go
+++ b/jwt/jwt.go
@@ -42,13 +42,7 @@
NBF time.Duration // NBFLeeway
Fn ValidateFunc // See ValidateFunc for more information.
- _ struct{}
-}
-
-var defaultClaims = []string{
- "iss", "sub", "aud",
- "exp", "nbf", "iat",
- "jti",
+ _ struct{} // Require explicitly-named struct fields.
}
// Validate validates the JWT based on the expected claims in v.
@@ -75,7 +69,8 @@
}
if aud, ok := v.Expected.Audience(); ok {
- if aud2, _ := j.Claims().Audience(); !eq(aud, aud2) {
+ aud2, ok := j.Claims().Audience()
+ if !ok || !ValidAudience(aud, aud2) {
return ErrInvalidAUDClaim
}
}