1
0

Usability improvements with reflection

This commit is contained in:
Sky Johnson 2025-04-26 15:56:01 -05:00
parent e80b22c45a
commit 1b0e5ea6e7
5 changed files with 92 additions and 64 deletions

View File

@ -6,27 +6,26 @@ import (
)
// Contains asserts that a contains b
func Contains(t test, a any, b any) {
if contains(a, b) {
func Contains(t test, container any, element any) {
if contains(container, element) {
return
}
t.Errorf(twoValues, file(), "Contains", a, b)
t.Errorf(formatMessage("Contains", element, container))
t.FailNow()
}
// NotContains asserts that a doesn't contain b
func NotContains(t test, a any, b any) {
if !contains(a, b) {
func NotContains(t test, container any, element any) {
if !contains(container, element) {
return
}
t.Errorf(twoValues, file(), "NotContains", a, b)
t.Errorf(formatMessage("NotContains", element, container))
t.FailNow()
}
// contains returns whether container contains the given the element
// It works with strings, maps and slices
// contains returns whether container contains the given element
func contains(container any, element any) bool {
containerValue := reflect.ValueOf(container)
@ -37,7 +36,6 @@ func contains(container any, element any) bool {
case reflect.Map:
keys := containerValue.MapKeys()
for _, key := range keys {
if key.Interface() == element {
return true
@ -58,24 +56,24 @@ func contains(container any, element any) bool {
return false
}
matchingElements := 0
for i := 0; i < containerValue.Len(); i++ {
if containerValue.Index(i).Interface() == elementValue.Index(matchingElements).Interface() {
matchingElements++
} else {
matchingElements = 0
// Check for subsequence
for i := 0; i <= containerValue.Len()-elementLength; i++ {
match := true
for j := range elementLength {
if containerValue.Index(i+j).Interface() != elementValue.Index(j).Interface() {
match = false
break
}
}
if matchingElements == elementLength {
if match {
return true
}
}
return false
}
for i := 0; i < containerValue.Len(); i++ {
// Check for single element
for i := range containerValue.Len() {
if containerValue.Index(i).Interface() == element {
return true
}

View File

@ -1,33 +1,35 @@
package assert
import "reflect"
import (
"reflect"
)
// Equal asserts that the two given values are equal
func Equal[T comparable](t test, a T, b T) {
if a == b {
func Equal[T comparable](t test, expected T, actual T) {
if expected == actual {
return
}
t.Errorf(twoValues, file(), "Equal", a, b)
t.Errorf(formatMessage("Equal", expected, actual))
t.FailNow()
}
// NotEqual asserts that the two given values are not equal
func NotEqual[T comparable](t test, a T, b T) {
if a != b {
func NotEqual[T comparable](t test, expected T, actual T) {
if expected != actual {
return
}
t.Errorf(twoValues, file(), "NotEqual", a, b)
t.Errorf(formatMessage("NotEqual", expected, actual))
t.FailNow()
}
// DeepEqual asserts that the two given values are deeply equal
func DeepEqual[T any](t test, a T, b T) {
if reflect.DeepEqual(a, b) {
func DeepEqual[T any](t test, expected T, actual T) {
if reflect.DeepEqual(expected, actual) {
return
}
t.Errorf(twoValues, file(), "DeepEqual", a, b)
t.Errorf(formatMessage("DeepEqual", expected, actual))
t.FailNow()
}

View File

@ -1,41 +1,58 @@
package assert
import (
"runtime/debug"
"fmt"
"path/filepath"
"runtime"
"strings"
)
const oneValue = `
%s
const messageFormat = `
%s:%d
assert.%s
%v
Expected: %v
Actual: %v
`
const twoValues = `
%s
const singleValueFormat = `
%s:%d
assert.%s
%v
%v
Value: %v
`
// file returns the first line containing "_test.go" in the debug stack
func file() string {
stack := string(debug.Stack())
lines := strings.Split(stack, "\n")
name := ""
// fileInfo returns file and line information for the test caller
func fileInfo() (string, int) {
// Skip frames: runtime.Callers, this function, assertion function, and test helper
// This gets us to the actual test file that called the assertion
const skip = 3
_, file, line, ok := runtime.Caller(skip)
if !ok {
return "unknown_file", 0
}
for _, line := range lines {
if strings.Contains(line, "_test.go") {
space := strings.LastIndex(line, " ")
// Extract just the filename without the full path
fileName := filepath.Base(file)
if space != -1 {
line = line[:space]
}
name = strings.TrimSpace(line)
break
// If we're still not in a test file, search up the stack
if !strings.Contains(fileName, "_test.go") {
// Try one more frame up
_, file, line, ok = runtime.Caller(skip + 1)
if ok {
fileName = filepath.Base(file)
}
}
return name
return fileName, line
}
// formatMessage formats an error message with the file location
func formatMessage(methodName string, expected, actual any) string {
file, line := fileInfo()
return fmt.Sprintf(messageFormat, file, line, methodName, expected, actual)
}
// formatSingleValueMessage formats an error message for single-value assertions
func formatSingleValueMessage(methodName string, value any) string {
file, line := fileInfo()
return fmt.Sprintf(singleValueFormat, file, line, methodName, value)
}

15
nil.go
View File

@ -3,22 +3,22 @@ package assert
import "reflect"
// Nil asserts that the given value equals nil
func Nil(t test, a any) {
if isNil(a) {
func Nil(t test, value any) {
if isNil(value) {
return
}
t.Errorf(oneValue, file(), "Nil", a)
t.Errorf(formatSingleValueMessage("Nil", value))
t.FailNow()
}
// NotNil asserts that the given value does not equal nil
func NotNil(t test, a any) {
if !isNil(a) {
func NotNil(t test, value any) {
if !isNil(value) {
return
}
t.Errorf(oneValue, file(), "NotNil", a)
t.Errorf(formatSingleValueMessage("NotNil", value))
t.FailNow()
}
@ -29,8 +29,9 @@ func isNil(object any) bool {
}
value := reflect.ValueOf(object)
kind := value.Kind()
switch value.Kind() {
switch kind {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return value.IsNil()
}

18
true.go
View File

@ -1,11 +1,21 @@
package assert
// True asserts that the given value is true
func True(t test, a bool) {
Equal(t, a, true)
func True(t test, value bool) {
if value {
return
}
t.Errorf(formatMessage("True", true, value))
t.FailNow()
}
// False asserts that the given value is false
func False(t test, a bool) {
Equal(t, a, false)
func False(t test, value bool) {
if !value {
return
}
t.Errorf(formatMessage("False", false, value))
t.FailNow()
}