1
0

Usability improvements with reflection

This commit is contained in:
Sky Johnson 2025-04-26 15:56:01 -05:00
parent e80b22c45a
commit 1b0e5ea6e7
5 changed files with 92 additions and 64 deletions

View File

@ -6,27 +6,26 @@ import (
) )
// Contains asserts that a contains b // Contains asserts that a contains b
func Contains(t test, a any, b any) { func Contains(t test, container any, element any) {
if contains(a, b) { if contains(container, element) {
return return
} }
t.Errorf(twoValues, file(), "Contains", a, b) t.Errorf(formatMessage("Contains", element, container))
t.FailNow() t.FailNow()
} }
// NotContains asserts that a doesn't contain b // NotContains asserts that a doesn't contain b
func NotContains(t test, a any, b any) { func NotContains(t test, container any, element any) {
if !contains(a, b) { if !contains(container, element) {
return return
} }
t.Errorf(twoValues, file(), "NotContains", a, b) t.Errorf(formatMessage("NotContains", element, container))
t.FailNow() t.FailNow()
} }
// contains returns whether container contains the given the element // contains returns whether container contains the given element
// It works with strings, maps and slices
func contains(container any, element any) bool { func contains(container any, element any) bool {
containerValue := reflect.ValueOf(container) containerValue := reflect.ValueOf(container)
@ -37,7 +36,6 @@ func contains(container any, element any) bool {
case reflect.Map: case reflect.Map:
keys := containerValue.MapKeys() keys := containerValue.MapKeys()
for _, key := range keys { for _, key := range keys {
if key.Interface() == element { if key.Interface() == element {
return true return true
@ -58,24 +56,24 @@ func contains(container any, element any) bool {
return false return false
} }
matchingElements := 0 // Check for subsequence
for i := 0; i <= containerValue.Len()-elementLength; i++ {
for i := 0; i < containerValue.Len(); i++ { match := true
if containerValue.Index(i).Interface() == elementValue.Index(matchingElements).Interface() { for j := range elementLength {
matchingElements++ if containerValue.Index(i+j).Interface() != elementValue.Index(j).Interface() {
} else { match = false
matchingElements = 0 break
}
} }
if match {
if matchingElements == elementLength {
return true return true
} }
} }
return false return false
} }
for i := 0; i < containerValue.Len(); i++ { // Check for single element
for i := range containerValue.Len() {
if containerValue.Index(i).Interface() == element { if containerValue.Index(i).Interface() == element {
return true return true
} }

View File

@ -1,33 +1,35 @@
package assert package assert
import "reflect" import (
"reflect"
)
// Equal asserts that the two given values are equal // Equal asserts that the two given values are equal
func Equal[T comparable](t test, a T, b T) { func Equal[T comparable](t test, expected T, actual T) {
if a == b { if expected == actual {
return return
} }
t.Errorf(twoValues, file(), "Equal", a, b) t.Errorf(formatMessage("Equal", expected, actual))
t.FailNow() t.FailNow()
} }
// NotEqual asserts that the two given values are not equal // NotEqual asserts that the two given values are not equal
func NotEqual[T comparable](t test, a T, b T) { func NotEqual[T comparable](t test, expected T, actual T) {
if a != b { if expected != actual {
return return
} }
t.Errorf(twoValues, file(), "NotEqual", a, b) t.Errorf(formatMessage("NotEqual", expected, actual))
t.FailNow() t.FailNow()
} }
// DeepEqual asserts that the two given values are deeply equal // DeepEqual asserts that the two given values are deeply equal
func DeepEqual[T any](t test, a T, b T) { func DeepEqual[T any](t test, expected T, actual T) {
if reflect.DeepEqual(a, b) { if reflect.DeepEqual(expected, actual) {
return return
} }
t.Errorf(twoValues, file(), "DeepEqual", a, b) t.Errorf(formatMessage("DeepEqual", expected, actual))
t.FailNow() t.FailNow()
} }

View File

@ -1,41 +1,58 @@
package assert package assert
import ( import (
"runtime/debug" "fmt"
"path/filepath"
"runtime"
"strings" "strings"
) )
const oneValue = ` const messageFormat = `
%s %s:%d
assert.%s assert.%s
%v Expected: %v
Actual: %v
` `
const twoValues = ` const singleValueFormat = `
%s %s:%d
assert.%s assert.%s
%v Value: %v
%v
` `
// file returns the first line containing "_test.go" in the debug stack // fileInfo returns file and line information for the test caller
func file() string { func fileInfo() (string, int) {
stack := string(debug.Stack()) // Skip frames: runtime.Callers, this function, assertion function, and test helper
lines := strings.Split(stack, "\n") // This gets us to the actual test file that called the assertion
name := "" const skip = 3
_, file, line, ok := runtime.Caller(skip)
if !ok {
return "unknown_file", 0
}
for _, line := range lines { // Extract just the filename without the full path
if strings.Contains(line, "_test.go") { fileName := filepath.Base(file)
space := strings.LastIndex(line, " ")
if space != -1 { // If we're still not in a test file, search up the stack
line = line[:space] if !strings.Contains(fileName, "_test.go") {
} // Try one more frame up
_, file, line, ok = runtime.Caller(skip + 1)
name = strings.TrimSpace(line) if ok {
break fileName = filepath.Base(file)
} }
} }
return name return fileName, line
}
// formatMessage formats an error message with the file location
func formatMessage(methodName string, expected, actual any) string {
file, line := fileInfo()
return fmt.Sprintf(messageFormat, file, line, methodName, expected, actual)
}
// formatSingleValueMessage formats an error message for single-value assertions
func formatSingleValueMessage(methodName string, value any) string {
file, line := fileInfo()
return fmt.Sprintf(singleValueFormat, file, line, methodName, value)
} }

15
nil.go
View File

@ -3,22 +3,22 @@ package assert
import "reflect" import "reflect"
// Nil asserts that the given value equals nil // Nil asserts that the given value equals nil
func Nil(t test, a any) { func Nil(t test, value any) {
if isNil(a) { if isNil(value) {
return return
} }
t.Errorf(oneValue, file(), "Nil", a) t.Errorf(formatSingleValueMessage("Nil", value))
t.FailNow() t.FailNow()
} }
// NotNil asserts that the given value does not equal nil // NotNil asserts that the given value does not equal nil
func NotNil(t test, a any) { func NotNil(t test, value any) {
if !isNil(a) { if !isNil(value) {
return return
} }
t.Errorf(oneValue, file(), "NotNil", a) t.Errorf(formatSingleValueMessage("NotNil", value))
t.FailNow() t.FailNow()
} }
@ -29,8 +29,9 @@ func isNil(object any) bool {
} }
value := reflect.ValueOf(object) value := reflect.ValueOf(object)
kind := value.Kind()
switch value.Kind() { switch kind {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return value.IsNil() return value.IsNil()
} }

18
true.go
View File

@ -1,11 +1,21 @@
package assert package assert
// True asserts that the given value is true // True asserts that the given value is true
func True(t test, a bool) { func True(t test, value bool) {
Equal(t, a, true) if value {
return
}
t.Errorf(formatMessage("True", true, value))
t.FailNow()
} }
// False asserts that the given value is false // False asserts that the given value is false
func False(t test, a bool) { func False(t test, value bool) {
Equal(t, a, false) if !value {
return
}
t.Errorf(formatMessage("False", false, value))
t.FailNow()
} }