diff --git a/contains.go b/contains.go index 84437ae..d892d05 100644 --- a/contains.go +++ b/contains.go @@ -6,27 +6,26 @@ import ( ) // Contains asserts that a contains b -func Contains(t test, a any, b any) { - if contains(a, b) { +func Contains(t test, container any, element any) { + if contains(container, element) { return } - t.Errorf(twoValues, file(), "Contains", a, b) + t.Errorf(formatMessage("Contains", element, container)) t.FailNow() } // NotContains asserts that a doesn't contain b -func NotContains(t test, a any, b any) { - if !contains(a, b) { +func NotContains(t test, container any, element any) { + if !contains(container, element) { return } - t.Errorf(twoValues, file(), "NotContains", a, b) + t.Errorf(formatMessage("NotContains", element, container)) t.FailNow() } -// contains returns whether container contains the given the element -// It works with strings, maps and slices +// contains returns whether container contains the given element func contains(container any, element any) bool { containerValue := reflect.ValueOf(container) @@ -37,7 +36,6 @@ func contains(container any, element any) bool { case reflect.Map: keys := containerValue.MapKeys() - for _, key := range keys { if key.Interface() == element { return true @@ -58,24 +56,24 @@ func contains(container any, element any) bool { return false } - matchingElements := 0 - - for i := 0; i < containerValue.Len(); i++ { - if containerValue.Index(i).Interface() == elementValue.Index(matchingElements).Interface() { - matchingElements++ - } else { - matchingElements = 0 + // Check for subsequence + for i := 0; i <= containerValue.Len()-elementLength; i++ { + match := true + for j := range elementLength { + if containerValue.Index(i+j).Interface() != elementValue.Index(j).Interface() { + match = false + break + } } - - if matchingElements == elementLength { + if match { return true } } - return false } - for i := 0; i < containerValue.Len(); i++ { + // Check for single element + for i := range containerValue.Len() { if containerValue.Index(i).Interface() == element { return true } diff --git a/equal.go b/equal.go index 2f71404..7fad37e 100644 --- a/equal.go +++ b/equal.go @@ -1,33 +1,35 @@ package assert -import "reflect" +import ( + "reflect" +) // Equal asserts that the two given values are equal -func Equal[T comparable](t test, a T, b T) { - if a == b { +func Equal[T comparable](t test, expected T, actual T) { + if expected == actual { return } - t.Errorf(twoValues, file(), "Equal", a, b) + t.Errorf(formatMessage("Equal", expected, actual)) t.FailNow() } // NotEqual asserts that the two given values are not equal -func NotEqual[T comparable](t test, a T, b T) { - if a != b { +func NotEqual[T comparable](t test, expected T, actual T) { + if expected != actual { return } - t.Errorf(twoValues, file(), "NotEqual", a, b) + t.Errorf(formatMessage("NotEqual", expected, actual)) t.FailNow() } // DeepEqual asserts that the two given values are deeply equal -func DeepEqual[T any](t test, a T, b T) { - if reflect.DeepEqual(a, b) { +func DeepEqual[T any](t test, expected T, actual T) { + if reflect.DeepEqual(expected, actual) { return } - t.Errorf(twoValues, file(), "DeepEqual", a, b) + t.Errorf(formatMessage("DeepEqual", expected, actual)) t.FailNow() } diff --git a/errors.go b/errors.go index f73f53e..51249db 100644 --- a/errors.go +++ b/errors.go @@ -1,41 +1,58 @@ package assert import ( - "runtime/debug" + "fmt" + "path/filepath" + "runtime" "strings" ) -const oneValue = ` -%s +const messageFormat = ` +%s:%d assert.%s - %v + Expected: %v + Actual: %v ` -const twoValues = ` -%s +const singleValueFormat = ` +%s:%d assert.%s - %v - %v + Value: %v ` -// file returns the first line containing "_test.go" in the debug stack -func file() string { - stack := string(debug.Stack()) - lines := strings.Split(stack, "\n") - name := "" +// fileInfo returns file and line information for the test caller +func fileInfo() (string, int) { + // Skip frames: runtime.Callers, this function, assertion function, and test helper + // This gets us to the actual test file that called the assertion + const skip = 3 + _, file, line, ok := runtime.Caller(skip) + if !ok { + return "unknown_file", 0 + } - for _, line := range lines { - if strings.Contains(line, "_test.go") { - space := strings.LastIndex(line, " ") + // Extract just the filename without the full path + fileName := filepath.Base(file) - if space != -1 { - line = line[:space] - } - - name = strings.TrimSpace(line) - break + // If we're still not in a test file, search up the stack + if !strings.Contains(fileName, "_test.go") { + // Try one more frame up + _, file, line, ok = runtime.Caller(skip + 1) + if ok { + fileName = filepath.Base(file) } } - return name + return fileName, line +} + +// formatMessage formats an error message with the file location +func formatMessage(methodName string, expected, actual any) string { + file, line := fileInfo() + return fmt.Sprintf(messageFormat, file, line, methodName, expected, actual) +} + +// formatSingleValueMessage formats an error message for single-value assertions +func formatSingleValueMessage(methodName string, value any) string { + file, line := fileInfo() + return fmt.Sprintf(singleValueFormat, file, line, methodName, value) } diff --git a/nil.go b/nil.go index 684ac3a..91a195f 100644 --- a/nil.go +++ b/nil.go @@ -3,22 +3,22 @@ package assert import "reflect" // Nil asserts that the given value equals nil -func Nil(t test, a any) { - if isNil(a) { +func Nil(t test, value any) { + if isNil(value) { return } - t.Errorf(oneValue, file(), "Nil", a) + t.Errorf(formatSingleValueMessage("Nil", value)) t.FailNow() } // NotNil asserts that the given value does not equal nil -func NotNil(t test, a any) { - if !isNil(a) { +func NotNil(t test, value any) { + if !isNil(value) { return } - t.Errorf(oneValue, file(), "NotNil", a) + t.Errorf(formatSingleValueMessage("NotNil", value)) t.FailNow() } @@ -29,8 +29,9 @@ func isNil(object any) bool { } value := reflect.ValueOf(object) + kind := value.Kind() - switch value.Kind() { + switch kind { case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: return value.IsNil() } diff --git a/true.go b/true.go index d6b8cbb..ece20a2 100644 --- a/true.go +++ b/true.go @@ -1,11 +1,21 @@ package assert // True asserts that the given value is true -func True(t test, a bool) { - Equal(t, a, true) +func True(t test, value bool) { + if value { + return + } + + t.Errorf(formatMessage("True", true, value)) + t.FailNow() } // False asserts that the given value is false -func False(t test, a bool) { - Equal(t, a, false) +func False(t test, value bool) { + if !value { + return + } + + t.Errorf(formatMessage("False", false, value)) + t.FailNow() }