jwt: Fix Audience validation to match spec more closely
diff --git a/jwt/claims.go b/jwt/claims.go index a44f207..d4f94b8 100644 --- a/jwt/claims.go +++ b/jwt/claims.go
@@ -84,7 +84,7 @@ // Since json.Unmarshal calls UnmarshalJSON, // calling json.Unmarshal on *p would be infinitely recursive // A temp variable is needed because &map[string]interface{}(*p) is - // invalid Go. + // invalid Go. (Address of unaddressable object and all that...) tmp := map[string]interface{}(*c) if err = json.Unmarshal(b, &tmp); err != nil { @@ -111,6 +111,8 @@ // Audience retrieves claim "aud" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.3 func (c Claims) Audience() ([]string, bool) { + // Audience claim must be stringy. That is, it may be one string + // or multiple strings but it should not be anything else. E.g. an int. switch t := c.Get("aud").(type) { case string: return []string{t}, true
diff --git a/jwt/eq.go b/jwt/eq.go index a7a37a9..3113269 100644 --- a/jwt/eq.go +++ b/jwt/eq.go
@@ -1,52 +1,47 @@ package jwt -import "reflect" +func verifyPrincipals(pcpls, auds []string) bool { + // "Each principal intended to process the JWT MUST + // identify itself with a value in the audience claim." + // - https://tools.ietf.org/html/rfc7519#section-4.1.3 -// eq returns true if the two types are either strings -// or comparable slices. -func eq(a, b interface{}) bool { - t1 := reflect.TypeOf(a) - t2 := reflect.TypeOf(b) - - if t1.Kind() == t2.Kind() { - switch t1.Kind() { - case reflect.Slice: - return eqSlice(a, b) - case reflect.String: - return reflect.ValueOf(a).String() == - reflect.ValueOf(b).String() - } - } - return false -} - -// eqSlice returns true if the two interfaces are both slices -// and are equal. For example: https://play.golang.org/p/5VLMwNE3i- -func eqSlice(a, b interface{}) bool { - if a == nil || b == nil { - return false - } - - v1 := reflect.ValueOf(a) - v2 := reflect.ValueOf(b) - - if v1.Kind() != reflect.Slice || - v2.Kind() != reflect.Slice { - return false - } - - if v1.Len() == v2.Len() && v1.Len() > 0 { - for i := 0; i < v1.Len() && i < v2.Len(); i++ { - k1 := v1.Index(i) - k2 := v2.Index(i) - if k1.Type().Comparable() && - k2.Type().Comparable() && - k1.CanInterface() && k2.CanInterface() && - k1.Interface() != k2.Interface() { - return false + found := -1 + for i, p := range pcpls { + for _, v := range auds { + if p == v { + found++ + break } } - return true + if found != i { + return false + } } - return false + return true +} + +// ValidAudience returns true iff: +// - a and b are strings and a == b +// - a is string, b is []string and a is in b +// - a is []string, b is []string and all of a is in b +// - a is []string, b is string and len(a) == 1 and a[0] == b +func ValidAudience(a, b interface{}) bool { + s1, ok := a.(string) + if ok { + if s2, ok := b.(string); ok { + return s1 == s2 + } + a2, ok := b.([]string) + return ok && verifyPrincipals([]string{s1}, a2) + } + + a1, ok := a.([]string) + if !ok { + return false + } + if a2, ok := b.([]string); ok { + return verifyPrincipals(a1, a2) + } + s2, ok := b.(string) + return ok && len(a1) == 1 && a1[0] == s2 }
diff --git a/jwt/eq_test.go b/jwt/eq_test.go new file mode 100644 index 0000000..5f9d4fd --- /dev/null +++ b/jwt/eq_test.go
@@ -0,0 +1,26 @@ +package jwt_test + +import ( + "testing" + + "github.com/SermoDigital/jose/jwt" +) + +func TestValidAudience(t *testing.T) { + tests := [...]struct { + a interface{} + b interface{} + v bool + }{ + 0: {"https://www.google.com", "https://www.google.com", true}, + 1: {[]string{"example.com", "google.com"}, []string{"example.com"}, false}, + 2: {500, 43, false}, + 3: {"google.com", "facebook.com", false}, + 4: {[]string{"example.com"}, []string{"example.com", "foo.com"}, true}, + } + for i, v := range tests { + if x := jwt.ValidAudience(v.a, v.b); x != v.v { + t.Fatalf("#%d: wanted %t, got %t", i, v.v, x) + } + } +}
diff --git a/jwt/jwt.go b/jwt/jwt.go index 18f5758..d29c43a 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go
@@ -42,13 +42,7 @@ NBF time.Duration // NBFLeeway Fn ValidateFunc // See ValidateFunc for more information. - _ struct{} -} - -var defaultClaims = []string{ - "iss", "sub", "aud", - "exp", "nbf", "iat", - "jti", + _ struct{} // Require explicitly-named struct fields. } // Validate validates the JWT based on the expected claims in v. @@ -75,7 +69,8 @@ } if aud, ok := v.Expected.Audience(); ok { - if aud2, _ := j.Claims().Audience(); !eq(aud, aud2) { + aud2, ok := j.Claims().Audience() + if !ok || !ValidAudience(aud, aud2) { return ErrInvalidAUDClaim } }