LuaJIT-to-Go/wrapper_test.go

277 lines
5.6 KiB
Go

package luajit
import (
"fmt"
"os"
"path/filepath"
"testing"
)
type stateFactory struct {
name string
new func() *State
}
var factories = []stateFactory{
{"unsafe", New},
{"safe", NewSafe},
}
func TestNew(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
})
}
}
func TestDoString(t *testing.T) {
tests := []struct {
name string
code string
wantErr bool
}{
{"simple addition", "return 1 + 1", false},
{"set global", "test = 42", false},
{"syntax error", "this is not valid lua", true},
{"runtime error", "error('test error')", true},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
err := L.DoString(tt.code)
if (err != nil) != tt.wantErr {
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
}
func TestPushAndGetValues(t *testing.T) {
values := []struct {
name string
push func(*State)
check func(*State) error
}{
{
name: "string",
push: func(L *State) { L.PushString("hello") },
check: func(L *State) error {
if got := L.ToString(-1); got != "hello" {
return fmt.Errorf("got %q, want %q", got, "hello")
}
return nil
},
},
{
name: "number",
push: func(L *State) { L.PushNumber(42.5) },
check: func(L *State) error {
if got := L.ToNumber(-1); got != 42.5 {
return fmt.Errorf("got %f, want %f", got, 42.5)
}
return nil
},
},
{
name: "boolean",
push: func(L *State) { L.PushBoolean(true) },
check: func(L *State) error {
if got := L.ToBoolean(-1); !got {
return fmt.Errorf("got %v, want true", got)
}
return nil
},
},
{
name: "nil",
push: func(L *State) { L.PushNil() },
check: func(L *State) error {
if typ := L.GetType(-1); typ != TypeNil {
return fmt.Errorf("got type %v, want TypeNil", typ)
}
return nil
},
},
}
for _, f := range factories {
for _, v := range values {
t.Run(f.name+"/"+v.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
v.push(L)
if err := v.check(L); err != nil {
t.Error(err)
}
})
}
}
}
func TestStackManipulation(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Push values
values := []string{"first", "second", "third"}
for _, v := range values {
L.PushString(v)
}
// Check size
if top := L.GetTop(); top != len(values) {
t.Errorf("stack size = %d, want %d", top, len(values))
}
// Pop one value
L.Pop(1)
// Check new top
if str := L.ToString(-1); str != "second" {
t.Errorf("top value = %q, want 'second'", str)
}
// Check new size
if top := L.GetTop(); top != len(values)-1 {
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
}
})
}
}
func TestGlobals(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Test via Lua
if err := L.DoString(`globalVar = "test"`); err != nil {
t.Fatalf("DoString error: %v", err)
}
// Get the global
L.GetGlobal("globalVar")
if str := L.ToString(-1); str != "test" {
t.Errorf("global value = %q, want 'test'", str)
}
L.Pop(1)
// Set and get via API
L.PushNumber(42)
L.SetGlobal("testNum")
L.GetGlobal("testNum")
if num := L.ToNumber(-1); num != 42 {
t.Errorf("global number = %f, want 42", num)
}
})
}
}
func TestDoFile(t *testing.T) {
L := NewSafe()
defer L.Close()
// Create test file
content := []byte(`
function add(a, b)
return a + b
end
result = add(40, 2)
`)
tmpDir := t.TempDir()
filename := filepath.Join(tmpDir, "test.lua")
if err := os.WriteFile(filename, content, 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
if err := L.DoFile(filename); err != nil {
t.Fatalf("DoFile failed: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("Expected result=42, got %v", result)
}
}
func TestRequireAndPackagePath(t *testing.T) {
L := NewSafe()
defer L.Close()
tmpDir := t.TempDir()
// Create module file
moduleContent := []byte(`
local M = {}
function M.multiply(a, b)
return a * b
end
return M
`)
if err := os.WriteFile(filepath.Join(tmpDir, "mathmod.lua"), moduleContent, 0644); err != nil {
t.Fatalf("Failed to create module file: %v", err)
}
// Add module path and test require
if err := L.AddPackagePath(filepath.Join(tmpDir, "?.lua")); err != nil {
t.Fatalf("AddPackagePath failed: %v", err)
}
if err := L.DoString(`
local math = require("mathmod")
result = math.multiply(6, 7)
`); err != nil {
t.Fatalf("Failed to require module: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("Expected result=42, got %v", result)
}
}
func TestSetPackagePath(t *testing.T) {
L := NewSafe()
defer L.Close()
customPath := "./custom/?.lua"
if err := L.SetPackagePath(customPath); err != nil {
t.Fatalf("SetPackagePath failed: %v", err)
}
L.GetGlobal("package")
L.GetField(-1, "path")
if path := L.ToString(-1); path != customPath {
t.Errorf("Expected package.path=%q, got %q", customPath, path)
}
}