WithTransform() supports transform function with explicit types
diff --git a/matchers.go b/matchers.go index 872de9e..ba41103 100644 --- a/matchers.go +++ b/matchers.go
@@ -371,13 +371,11 @@ } //WithTransform applies the `transform` to the actual value and matches it against `matcher`. -// var plus1 = func(i interface{}) interface{} { return i.(int) + 1 } +//The given transform must be a function of one parameter that returns one value. +// var plus1 = func(i int) int { return i + 1 } // Expect(1).To(WithTransform(plus1, Equal(2)) // //And(), Or(), Not() and WithTransform() allow matchers to be composed into complex expressions. -func WithTransform(transform func(interface{}) interface{}, matcher types.GomegaMatcher) types.GomegaMatcher { - return &matchers.WithTransformMatcher{ - Transform: transform, - Matcher: matcher, - } +func WithTransform(transform interface{}, matcher types.GomegaMatcher) types.GomegaMatcher { + return matchers.NewWithTransformMatcher(transform, matcher) }
diff --git a/matchers/with_transform.go b/matchers/with_transform.go index c7a650a..430223e 100644 --- a/matchers/with_transform.go +++ b/matchers/with_transform.go
@@ -1,25 +1,61 @@ package matchers -import "github.com/onsi/gomega/types" +import ( + "fmt" + "github.com/onsi/gomega/types" + "reflect" +) type WithTransformMatcher struct { // input - Transform func(interface{}) interface{} + Transform interface{} // must be a function of one parameter that returns one value Matcher types.GomegaMatcher + // cached value + transformArgType reflect.Type + // state transformedValue interface{} } +func NewWithTransformMatcher(transform interface{}, matcher types.GomegaMatcher) *WithTransformMatcher { + if transform == nil { + panic("transform function cannot be nil") + } + txType := reflect.TypeOf(transform) + if txType.NumIn() != 1 { + panic("transform function must have 1 argument") + } + if txType.NumOut() != 1 { + panic("transform function must have 1 return value") + } + + return &WithTransformMatcher{ + Transform: transform, + Matcher: matcher, + transformArgType: reflect.TypeOf(transform).In(0), + } +} + func (m *WithTransformMatcher) Match(actual interface{}) (bool, error) { - m.transformedValue = m.Transform(actual) + // return error if actual's type is incompatible with Transform function's argument type + actualType := reflect.TypeOf(actual) + if !actualType.AssignableTo(m.transformArgType) { + return false, fmt.Errorf("Transform function expects '%s' but we have '%s'", m.transformArgType, actualType) + } + + // call the Transform function with `actual` + fn := reflect.ValueOf(m.Transform) + result := fn.Call([]reflect.Value{reflect.ValueOf(actual)}) + m.transformedValue = result[0].Interface() // expect exactly one value + return m.Matcher.Match(m.transformedValue) } -func (m *WithTransformMatcher) FailureMessage(actual interface{}) (message string) { +func (m *WithTransformMatcher) FailureMessage(_ interface{}) (message string) { return m.Matcher.FailureMessage(m.transformedValue) } -func (m *WithTransformMatcher) NegatedFailureMessage(actual interface{}) (message string) { +func (m *WithTransformMatcher) NegatedFailureMessage(_ interface{}) (message string) { return m.Matcher.NegatedFailureMessage(m.transformedValue) }
diff --git a/matchers/with_transform_test.go b/matchers/with_transform_test.go index ced53bd..4d1a363 100644 --- a/matchers/with_transform_test.go +++ b/matchers/with_transform_test.go
@@ -1,18 +1,56 @@ package matchers_test import ( + "errors" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("WithTransformMatcher", func() { - var plus1 = func(i interface{}) interface{} { return i.(int) + 1 } + var plus1 = func(i int) int { return i + 1 } + + Context("Panic if transform function invalid", func() { + panicsWithTransformer := func(transform interface{}) { + ExpectWithOffset(1, func() { WithTransform(transform, nil) }).To(Panic()) + } + It("nil", func() { + panicsWithTransformer(nil) + }) + Context("Invalid number of args, but correct return value count", func() { + It("zero", func() { + panicsWithTransformer(func() int { return 5 }) + }) + It("two", func() { + panicsWithTransformer(func(i, j int) int { return 5 }) + }) + }) + Context("Invalid number of return values, but correct number of arguments", func() { + It("zero", func() { + panicsWithTransformer(func(i int) {}) + }) + It("two", func() { + panicsWithTransformer(func(i int) (int, int) { return 5, 6 }) + }) + }) + }) It("works with positive cases", func() { Expect(1).To(WithTransform(plus1, Equal(2))) Expect(1).To(WithTransform(plus1, WithTransform(plus1, Equal(3)))) Expect(1).To(WithTransform(plus1, And(Equal(2), BeNumerically(">", 1)))) + + // transform expects custom type + type S struct { + A int + B string + } + transformer := func(s S) string { return s.B } + Expect(S{1, "hi"}).To(WithTransform(transformer, Equal("hi"))) + + // transform expects interface + errString := func(e error) string { return e.Error() } + Expect(errors.New("abc")).To(WithTransform(errString, Equal("abc"))) }) It("works with negative cases", func() { @@ -25,7 +63,7 @@ It("gives a descriptive message", func() { m := WithTransform(plus1, Equal(3)) Expect(m.Match(1)).To(BeFalse()) - Expect(m.FailureMessage(input)).To(Equal("Expected\n <int>: 2\nto equal\n <int>: 3")) + Expect(m.FailureMessage(1)).To(Equal("Expected\n <int>: 2\nto equal\n <int>: 3")) }) }) @@ -33,7 +71,16 @@ It("gives a descriptive message", func() { m := Not(WithTransform(plus1, Equal(3))) Expect(m.Match(2)).To(BeFalse()) - Expect(m.FailureMessage(input)).To(Equal("Expected\n <int>: 3\nnot to equal\n <int>: 3")) + Expect(m.FailureMessage(2)).To(Equal("Expected\n <int>: 3\nnot to equal\n <int>: 3")) + }) + }) + + Context("actual value is incompatible with transform function's argument type", func() { + It("gracefully fails if transform cannot be performed", func() { + m := WithTransform(plus1, Equal(3)) + result, err := m.Match("hi") // give it a string but transform expects int; doesn't panic + Expect(result).To(BeFalse()) + Expect(err).To(MatchError("Transform function expects 'int' but we have 'string'")) }) }) })