110 lines
2.4 KiB
Go
110 lines
2.4 KiB
Go
package luajit
|
|
|
|
import "testing"
|
|
|
|
func TestGoFunctions(t *testing.T) {
|
|
for _, f := range factories {
|
|
t.Run(f.name, func(t *testing.T) {
|
|
L := f.new()
|
|
if L == nil {
|
|
t.Fatal("Failed to create Lua state")
|
|
}
|
|
defer L.Close()
|
|
defer L.Cleanup()
|
|
|
|
addFunc := func(s *State) int {
|
|
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
|
|
return 1
|
|
}
|
|
|
|
if err := L.RegisterGoFunction("add", addFunc); err != nil {
|
|
t.Fatalf("Failed to register function: %v", err)
|
|
}
|
|
|
|
// Test basic function call
|
|
if err := L.DoString("result = add(40, 2)"); err != nil {
|
|
t.Fatalf("Failed to call function: %v", err)
|
|
}
|
|
|
|
L.GetGlobal("result")
|
|
if result := L.ToNumber(-1); result != 42 {
|
|
t.Errorf("got %v, want 42", result)
|
|
}
|
|
L.Pop(1)
|
|
|
|
// Test multiple return values
|
|
multiFunc := func(s *State) int {
|
|
s.PushString("hello")
|
|
s.PushNumber(42)
|
|
s.PushBoolean(true)
|
|
return 3
|
|
}
|
|
|
|
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
|
|
t.Fatalf("Failed to register multi function: %v", err)
|
|
}
|
|
|
|
code := `
|
|
a, b, c = multi()
|
|
result = (a == "hello" and b == 42 and c == true)
|
|
`
|
|
|
|
if err := L.DoString(code); err != nil {
|
|
t.Fatalf("Failed to call multi function: %v", err)
|
|
}
|
|
|
|
L.GetGlobal("result")
|
|
if !L.ToBoolean(-1) {
|
|
t.Error("Multiple return values test failed")
|
|
}
|
|
L.Pop(1)
|
|
|
|
// Test error handling
|
|
errFunc := func(s *State) int {
|
|
s.PushString("test error")
|
|
return -1
|
|
}
|
|
|
|
if err := L.RegisterGoFunction("err", errFunc); err != nil {
|
|
t.Fatalf("Failed to register error function: %v", err)
|
|
}
|
|
|
|
if err := L.DoString("err()"); err == nil {
|
|
t.Error("Expected error from error function")
|
|
}
|
|
|
|
// Test unregistering
|
|
L.UnregisterGoFunction("add")
|
|
if err := L.DoString("add(1, 2)"); err == nil {
|
|
t.Error("Expected error calling unregistered function")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStackSafety(t *testing.T) {
|
|
L := NewSafe()
|
|
if L == nil {
|
|
t.Fatal("Failed to create Lua state")
|
|
}
|
|
defer L.Close()
|
|
defer L.Cleanup()
|
|
|
|
// Test stack overflow protection
|
|
overflowFunc := func(s *State) int {
|
|
for i := 0; i < 100; i++ {
|
|
s.PushNumber(float64(i))
|
|
}
|
|
s.PushString("done")
|
|
return 101
|
|
}
|
|
|
|
if err := L.RegisterGoFunction("overflow", overflowFunc); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := L.DoString("overflow()"); err != nil {
|
|
t.Logf("Got expected error: %v", err)
|
|
}
|
|
}
|