From 4dc266201f0b0cfccc60817a7c3cdb4794d1f495 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Wed, 12 Feb 2025 19:17:11 -0600 Subject: [PATCH] BIG changes; no "safe" mode, function updates, etc --- benchmark/main.go | 148 ---------- bytecode.go | 26 +- bytecode_test.go | 162 +++++------ functions.go | 3 +- functions_test.go | 122 +++++---- table.go | 101 ++----- table_test.go | 34 ++- wrapper.go | 321 ++++++++++++---------- wrapper_bench_test.go | 237 ++++++++++++++++ wrapper_test.go | 611 +++++++++++++++++++++++++++++++++++------- 10 files changed, 1105 insertions(+), 660 deletions(-) delete mode 100644 benchmark/main.go create mode 100644 wrapper_bench_test.go diff --git a/benchmark/main.go b/benchmark/main.go deleted file mode 100644 index 0398ffe..0000000 --- a/benchmark/main.go +++ /dev/null @@ -1,148 +0,0 @@ -package main - -import ( - "fmt" - "time" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -type benchCase struct { - name string - code string -} - -var cases = []benchCase{ - { - name: "Simple Addition", - code: `return 1 + 1`, - }, - { - name: "Loop Sum", - code: ` - local sum = 0 - for i = 1, 1000 do - sum = sum + i - end - return sum - `, - }, - { - name: "Function Call", - code: ` - local result = 0 - for i = 1, 100 do - result = result + i - end - return result - `, - }, - { - name: "Table Creation", - code: ` - local t = {} - for i = 1, 100 do - t[i] = i * 2 - end - return t[50] - `, - }, - { - name: "String Operations", - code: ` - local s = "hello" - for i = 1, 10 do - s = s .. " world" - end - return #s - `, - }, -} - -func runBenchmark(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) { - start := time.Now() - deadline := start.Add(duration) - var ops int64 - - for time.Now().Before(deadline) { - if err := L.DoString(code); err != nil { - fmt.Printf("Error executing code: %v\n", err) - return 0, 0 - } - L.Pop(1) - ops++ - } - - return time.Since(start), ops -} - -func runBytecodeTest(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) { - // First compile the bytecode - bytecode, err := L.CompileBytecode(code, "bench") - if err != nil { - fmt.Printf("Error compiling bytecode: %v\n", err) - return 0, 0 - } - - start := time.Now() - deadline := start.Add(duration) - var ops int64 - - for time.Now().Before(deadline) { - if err := L.LoadBytecode(bytecode, "bench"); err != nil { - fmt.Printf("Error executing bytecode: %v\n", err) - return 0, 0 - } - ops++ - } - - return time.Since(start), ops -} - -func benchmarkCase(newState func() *luajit.State, bc benchCase) { - fmt.Printf("\n%s:\n", bc.name) - - // Direct execution benchmark - L := newState() - if L == nil { - fmt.Printf(" Failed to create Lua state\n") - return - } - execTime, ops := runBenchmark(L, bc.code, 2*time.Second) - L.Close() - if ops > 0 { - opsPerSec := float64(ops) / execTime.Seconds() - fmt.Printf(" Direct: %.0f ops/sec\n", opsPerSec) - } - - // Bytecode execution benchmark - L = newState() - if L == nil { - fmt.Printf(" Failed to create Lua state\n") - return - } - execTime, ops = runBytecodeTest(L, bc.code, 2*time.Second) - L.Close() - if ops > 0 { - opsPerSec := float64(ops) / execTime.Seconds() - fmt.Printf(" Bytecode: %.0f ops/sec\n", opsPerSec) - } -} - -func main() { - modes := []struct { - name string - newState func() *luajit.State - }{ - {"Safe", luajit.NewSafe}, - {"Unsafe", luajit.New}, - } - - for _, mode := range modes { - fmt.Printf("\n=== %s Mode ===\n", mode.name) - - for _, c := range cases { - benchmarkCase(mode.newState, c) - } - } -} diff --git a/bytecode.go b/bytecode.go index d318c6f..14c579f 100644 --- a/bytecode.go +++ b/bytecode.go @@ -51,7 +51,8 @@ import ( "unsafe" ) -func (s *State) compileBytecodeUnsafe(code string, name string) ([]byte, error) { +// CompileBytecode compiles a Lua chunk to bytecode without executing it +func (s *State) CompileBytecode(code string, name string) ([]byte, error) { // First load the string but don't execute it ccode := C.CString(code) defer C.free(unsafe.Pointer(ccode)) @@ -94,7 +95,8 @@ func (s *State) compileBytecodeUnsafe(code string, name string) ([]byte, error) return bytecode, nil } -func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error { +// LoadBytecode loads precompiled bytecode and executes it +func (s *State) LoadBytecode(bytecode []byte, name string) error { cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) @@ -125,26 +127,6 @@ func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error { return nil } -// CompileBytecode compiles a Lua chunk to bytecode without executing it -func (s *State) CompileBytecode(code string, name string) ([]byte, error) { - if s.safeStack { - return stackGuardValue[[]byte](s, func() ([]byte, error) { - return s.compileBytecodeUnsafe(code, name) - }) - } - return s.compileBytecodeUnsafe(code, name) -} - -// LoadBytecode loads precompiled bytecode and executes it -func (s *State) LoadBytecode(bytecode []byte, name string) error { - if s.safeStack { - return stackGuardErr(s, func() error { - return s.loadBytecodeUnsafe(bytecode, name) - }) - } - return s.loadBytecodeUnsafe(bytecode, name) -} - // Helper function to compile and immediately load/execute bytecode func (s *State) CompileAndLoad(code string, name string) error { bytecode, err := s.CompileBytecode(code, name) diff --git a/bytecode_test.go b/bytecode_test.go index cc3b9f4..4e2822c 100644 --- a/bytecode_test.go +++ b/bytecode_test.go @@ -28,82 +28,70 @@ func TestBytecodeCompilation(t *testing.T) { }, } - for _, f := range factories { - for _, tt := range tests { - t.Run(f.name+"/"+tt.name, func(t *testing.T) { - L := f.new() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() + for _, tt := range tests { + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() - bytecode, err := L.CompileBytecode(tt.code, "test") - if (err != nil) != tt.wantErr { - t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr) - return - } + bytecode, err := L.CompileBytecode(tt.code, "test") + if (err != nil) != tt.wantErr { + t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr) + return + } - if !tt.wantErr { - if len(bytecode) == 0 { - t.Error("CompileBytecode() returned empty bytecode") - } - } - }) + if !tt.wantErr { + if len(bytecode) == 0 { + t.Error("CompileBytecode() returned empty bytecode") + } } } } func TestBytecodeExecution(t *testing.T) { - for _, f := range factories { - t.Run(f.name, func(t *testing.T) { - L := f.new() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() - // Compile some test code - code := ` - function add(a, b) - return a + b - end - result = add(40, 2) - ` + // Compile some test code + code := ` + function add(a, b) + return a + b + end + result = add(40, 2) + ` - bytecode, err := L.CompileBytecode(code, "test") - if err != nil { - t.Fatalf("CompileBytecode() error = %v", err) - } + bytecode, err := L.CompileBytecode(code, "test") + if err != nil { + t.Fatalf("CompileBytecode() error = %v", err) + } - // Load and execute the bytecode - if err := L.LoadBytecode(bytecode, "test"); err != nil { - t.Fatalf("LoadBytecode() error = %v", err) - } + // Load and execute the bytecode + if err := L.LoadBytecode(bytecode, "test"); err != nil { + t.Fatalf("LoadBytecode() error = %v", err) + } - // Verify the result - L.GetGlobal("result") - if result := L.ToNumber(-1); result != 42 { - t.Errorf("got result = %v, want 42", result) - } - }) + // Verify the result + L.GetGlobal("result") + if result := L.ToNumber(-1); result != 42 { + t.Errorf("got result = %v, want 42", result) } } func TestInvalidBytecode(t *testing.T) { - for _, f := range factories { - t.Run(f.name, func(t *testing.T) { - L := f.new() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() - // Test with invalid bytecode - invalidBytecode := []byte("this is not valid bytecode") - if err := L.LoadBytecode(invalidBytecode, "test"); err == nil { - t.Error("LoadBytecode() expected error with invalid bytecode") - } - }) + // Test with invalid bytecode + invalidBytecode := []byte("this is not valid bytecode") + if err := L.LoadBytecode(invalidBytecode, "test"); err == nil { + t.Error("LoadBytecode() expected error with invalid bytecode") } } @@ -140,39 +128,35 @@ func TestBytecodeRoundTrip(t *testing.T) { }, } - for _, f := range factories { - for _, tt := range tests { - t.Run(f.name+"/"+tt.name, func(t *testing.T) { - // First state for compilation - L1 := f.new() - if L1 == nil { - t.Fatal("Failed to create first Lua state") - } - defer L1.Close() + for _, tt := range tests { + // First state for compilation + L1 := New() + if L1 == nil { + t.Fatal("Failed to create first Lua state") + } + defer L1.Close() - // Compile the code - bytecode, err := L1.CompileBytecode(tt.code, "test") - if err != nil { - t.Fatalf("CompileBytecode() error = %v", err) - } + // Compile the code + bytecode, err := L1.CompileBytecode(tt.code, "test") + if err != nil { + t.Fatalf("CompileBytecode() error = %v", err) + } - // Second state for execution - L2 := f.new() - if L2 == nil { - t.Fatal("Failed to create second Lua state") - } - defer L2.Close() + // Second state for execution + L2 := New() + if L2 == nil { + t.Fatal("Failed to create second Lua state") + } + defer L2.Close() - // Load and execute the bytecode - if err := L2.LoadBytecode(bytecode, "test"); err != nil { - t.Fatalf("LoadBytecode() error = %v", err) - } + // Load and execute the bytecode + if err := L2.LoadBytecode(bytecode, "test"); err != nil { + t.Fatalf("LoadBytecode() error = %v", err) + } - // Run the check function - if err := tt.check(L2); err != nil { - t.Errorf("check failed: %v", err) - } - }) + // Run the check function + if err := tt.check(L2); err != nil { + t.Errorf("check failed: %v", err) } } } diff --git a/functions.go b/functions.go index c9e297e..2770174 100644 --- a/functions.go +++ b/functions.go @@ -31,7 +31,7 @@ var ( //export goFunctionWrapper func goFunctionWrapper(L *C.lua_State) C.int { - state := &State{L: L, safeStack: true} + state := &State{L: L} // Get upvalue using standard Lua 5.1 macro ptr := C.lua_touserdata(L, C.get_upvalue_index(1)) @@ -54,7 +54,6 @@ func goFunctionWrapper(L *C.lua_State) C.int { } func (s *State) PushGoFunction(fn GoFunction) error { - // Push lightuserdata as upvalue and create closure ptr := C.malloc(1) if ptr == nil { return fmt.Errorf("failed to allocate memory for function pointer") diff --git a/functions_test.go b/functions_test.go index 1a18bad..9a57245 100644 --- a/functions_test.go +++ b/functions_test.go @@ -1,89 +1,87 @@ package luajit -import "testing" +import ( + "testing" +) func TestGoFunctions(t *testing.T) { - for _, f := range factories { - t.Run(f.name, func(t *testing.T) { - L := f.new() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - defer L.Cleanup() + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + defer L.Cleanup() - addFunc := func(s *State) int { - s.PushNumber(s.ToNumber(1) + s.ToNumber(2)) - return 1 - } + addFunc := func(s *State) int { + s.PushNumber(s.ToNumber(1) + s.ToNumber(2)) + return 1 + } - if err := L.RegisterGoFunction("add", addFunc); err != nil { - t.Fatalf("Failed to register function: %v", err) - } + if err := L.RegisterGoFunction("add", addFunc); err != nil { + t.Fatalf("Failed to register function: %v", err) + } - // Test basic function call - if err := L.DoString("result = add(40, 2)"); err != nil { - t.Fatalf("Failed to call function: %v", err) - } + // Test basic function call + if err := L.DoString("result = add(40, 2)"); err != nil { + t.Fatalf("Failed to call function: %v", err) + } - L.GetGlobal("result") - if result := L.ToNumber(-1); result != 42 { - t.Errorf("got %v, want 42", result) - } - L.Pop(1) + L.GetGlobal("result") + if result := L.ToNumber(-1); result != 42 { + t.Errorf("got %v, want 42", result) + } + L.Pop(1) - // Test multiple return values - multiFunc := func(s *State) int { - s.PushString("hello") - s.PushNumber(42) - s.PushBoolean(true) - return 3 - } + // Test multiple return values + multiFunc := func(s *State) int { + s.PushString("hello") + s.PushNumber(42) + s.PushBoolean(true) + return 3 + } - if err := L.RegisterGoFunction("multi", multiFunc); err != nil { - t.Fatalf("Failed to register multi function: %v", err) - } + if err := L.RegisterGoFunction("multi", multiFunc); err != nil { + t.Fatalf("Failed to register multi function: %v", err) + } - code := ` + code := ` a, b, c = multi() result = (a == "hello" and b == 42 and c == true) ` - if err := L.DoString(code); err != nil { - t.Fatalf("Failed to call multi function: %v", err) - } + if err := L.DoString(code); err != nil { + t.Fatalf("Failed to call multi function: %v", err) + } - L.GetGlobal("result") - if !L.ToBoolean(-1) { - t.Error("Multiple return values test failed") - } - L.Pop(1) + L.GetGlobal("result") + if !L.ToBoolean(-1) { + t.Error("Multiple return values test failed") + } + L.Pop(1) - // Test error handling - errFunc := func(s *State) int { - s.PushString("test error") - return -1 - } + // Test error handling + errFunc := func(s *State) int { + s.PushString("test error") + return -1 + } - if err := L.RegisterGoFunction("err", errFunc); err != nil { - t.Fatalf("Failed to register error function: %v", err) - } + if err := L.RegisterGoFunction("err", errFunc); err != nil { + t.Fatalf("Failed to register error function: %v", err) + } - if err := L.DoString("err()"); err == nil { - t.Error("Expected error from error function") - } + if err := L.DoString("err()"); err == nil { + t.Error("Expected error from error function") + } - // Test unregistering - L.UnregisterGoFunction("add") - if err := L.DoString("add(1, 2)"); err == nil { - t.Error("Expected error calling unregistered function") - } - }) + // Test unregistering + L.UnregisterGoFunction("add") + if err := L.DoString("add(1, 2)"); err == nil { + t.Error("Expected error calling unregistered function") } } func TestStackSafety(t *testing.T) { - L := NewSafe() + L := New() if L == nil { t.Fatal("Failed to create Lua state") } diff --git a/table.go b/table.go index 9f74265..57286a6 100644 --- a/table.go +++ b/table.go @@ -22,57 +22,20 @@ type TableValue interface { func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) } +// PushTable pushes a Go map onto the Lua stack as a table +func (s *State) PushTable(table map[string]interface{}) error { + s.NewTable() + for k, v := range table { + if err := s.PushValue(v); err != nil { + return err + } + s.SetField(-2, k) + } + return nil +} + // ToTable converts a Lua table to a Go map func (s *State) ToTable(index int) (map[string]interface{}, error) { - if s.safeStack { - return stackGuardValue[map[string]interface{}](s, func() (map[string]interface{}, error) { - if !s.IsTable(index) { - return nil, fmt.Errorf("not a table at index %d", index) - } - return s.toTableUnsafe(index) - }) - } - if !s.IsTable(index) { - return nil, fmt.Errorf("not a table at index %d", index) - } - return s.toTableUnsafe(index) -} - -func (s *State) pushTableSafe(table map[string]interface{}) error { - size := 2 - if err := s.checkStack(size); err != nil { - return fmt.Errorf("insufficient stack space: %w", err) - } - - s.NewTable() - for k, v := range table { - if err := s.pushValueSafe(v); err != nil { - return err - } - s.SetField(-2, k) - } - return nil -} - -func (s *State) pushTableUnsafe(table map[string]interface{}) error { - s.NewTable() - for k, v := range table { - if err := s.pushValueUnsafe(v); err != nil { - return err - } - s.SetField(-2, k) - } - return nil -} - -func (s *State) toTableSafe(index int) (map[string]interface{}, error) { - if err := s.checkStack(2); err != nil { - return nil, err - } - return s.toTableUnsafe(index) -} - -func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) { absIdx := s.absIndex(index) table := make(map[string]interface{}) @@ -111,7 +74,7 @@ func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) { key = fmt.Sprintf("%g", s.ToNumber(-2)) } - value, err := s.toValueUnsafe(-1) + value, err := s.ToValue(-1) if err != nil { s.Pop(1) return nil, err @@ -133,45 +96,15 @@ func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) { // NewTable creates a new table and pushes it onto the stack func (s *State) NewTable() { - if s.safeStack { - if err := s.checkStack(1); err != nil { - // Since we can't return an error, we'll push nil instead - s.PushNil() - return - } - } C.lua_createtable(s.L, 0, 0) } -// SetTable sets a table field with cached absolute index +// SetTable sets a table field func (s *State) SetTable(index int) { - absIdx := index - if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) { - absIdx = s.GetTop() + index + 1 - } - C.lua_settable(s.L, C.int(absIdx)) + C.lua_settable(s.L, C.int(index)) } -// GetTable gets a table field with cached absolute index +// GetTable gets a table field func (s *State) GetTable(index int) { - absIdx := index - if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) { - absIdx = s.GetTop() + index + 1 - } - - if s.safeStack { - if err := s.checkStack(1); err != nil { - s.PushNil() - return - } - } - C.lua_gettable(s.L, C.int(absIdx)) -} - -// PushTable pushes a Go map onto the Lua stack as a table with stack checking -func (s *State) PushTable(table map[string]interface{}) error { - if s.safeStack { - return s.pushTableSafe(table) - } - return s.pushTableUnsafe(table) + C.lua_gettable(s.L, C.int(index)) } diff --git a/table_test.go b/table_test.go index 70995fc..4d0d9e8 100644 --- a/table_test.go +++ b/table_test.go @@ -34,28 +34,24 @@ func TestTableOperations(t *testing.T) { }, } - for _, f := range factories { - for _, tt := range tests { - t.Run(f.name+"/"+tt.name, func(t *testing.T) { - L := f.new() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() + for _, tt := range tests { + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() - if err := L.PushTable(tt.data); err != nil { - t.Fatalf("PushTable() error = %v", err) - } + if err := L.PushTable(tt.data); err != nil { + t.Fatalf("PushTable() error = %v", err) + } - got, err := L.ToTable(-1) - if err != nil { - t.Fatalf("ToTable() error = %v", err) - } + got, err := L.ToTable(-1) + if err != nil { + t.Fatalf("ToTable() error = %v", err) + } - if !tablesEqual(got, tt.data) { - t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data) - } - }) + if !tablesEqual(got, tt.data) { + t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data) } } } diff --git a/wrapper.go b/wrapper.go index f4098c4..eb2a3f4 100644 --- a/wrapper.go +++ b/wrapper.go @@ -10,49 +10,82 @@ package luajit #include #include -static int do_string(lua_State *L, const char *s) { - int status = luaL_loadstring(L, s); - if (status) return status; - return lua_pcall(L, 0, LUA_MULTRET, 0); +// Simple wrapper around luaL_loadstring +static int load_chunk(lua_State *L, const char *s) { + return luaL_loadstring(L, s); } +// Direct wrapper around lua_pcall +static int protected_call(lua_State *L, int nargs, int nresults, int errfunc) { + return lua_pcall(L, nargs, nresults, errfunc); +} + +// Combined load and execute with no results +static int do_string(lua_State *L, const char *s) { + return luaL_dostring(L, s); +} + +// Combined load and execute file static int do_file(lua_State *L, const char *filename) { - int status = luaL_loadfile(L, filename); - if (status) return status; - return lua_pcall(L, 0, LUA_MULTRET, 0); + return luaL_dofile(L, filename); +} + +// Execute string with multiple returns +static int execute_string(lua_State *L, const char *s) { + int base = lua_gettop(L); // Save stack position + int status = luaL_loadstring(L, s); + if (status) return -status; // Return negative status for load errors + + status = lua_pcall(L, 0, LUA_MULTRET, 0); + if (status) return -status; // Return negative status for runtime errors + + return lua_gettop(L) - base; // Return number of results +} + +// Get absolute stack index (converts negative indices) +static int get_abs_index(lua_State *L, int idx) { + if (idx > 0 || idx <= LUA_REGISTRYINDEX) return idx; + return lua_gettop(L) + idx + 1; +} + +// Stack manipulation helpers +static int check_stack(lua_State *L, int n) { + return lua_checkstack(L, n); +} + +static void remove_stack(lua_State *L, int idx) { + lua_remove(L, idx); +} + +static int get_field_helper(lua_State *L, int idx, const char *k) { + lua_getfield(L, idx, k); + return lua_type(L, -1); +} + +static void set_field_helper(lua_State *L, int idx, const char *k) { + lua_setfield(L, idx, k); } */ import "C" import ( "fmt" - "path/filepath" + "strings" "unsafe" ) -// State represents a Lua state with configurable stack safety +// State represents a Lua state type State struct { - L *C.lua_State - safeStack bool + L *C.lua_State } -// NewSafe creates a new Lua state with full stack safety guarantees -func NewSafe() *State { - L := C.luaL_newstate() - if L == nil { - return nil - } - C.luaL_openlibs(L) - return &State{L: L, safeStack: true} -} - -// New creates a new Lua state with minimal stack checking +// New creates a new Lua state func New() *State { L := C.luaL_newstate() if L == nil { return nil } C.luaL_openlibs(L) - return &State{L: L, safeStack: false} + return &State{L: L} } // Close closes the Lua state @@ -63,65 +96,28 @@ func (s *State) Close() { } } -// DoString executes a Lua string with appropriate stack management +// DoString executes a Lua string. func (s *State) DoString(str string) error { - cstr := C.CString(str) - defer C.free(unsafe.Pointer(cstr)) + // Save initial stack size + top := s.GetTop() - if s.safeStack { - return stackGuardErr(s, func() error { - // Save the current stack size - initialTop := s.GetTop() - - // Execute the string - status := C.do_string(s.L, cstr) - if status != 0 { - // In case of error, get error message from stack - errMsg := s.ToString(-1) - s.SetTop(initialTop) // Restore stack - return &LuaError{ - Code: int(status), - Message: errMsg, - } - } - - // Return values are now on the stack above initialTop - // We don't pop them as they may be needed by the caller - return nil - }) + // Load the string + if err := s.LoadString(str); err != nil { + return err } - status := C.do_string(s.L, cstr) - if status != 0 { - return &LuaError{ - Code: int(status), - Message: s.ToString(-1), - } + // Execute and check for errors + if err := s.Call(0, 0); err != nil { + return err } + + // Restore stack to initial size to clean up any leftovers + s.SetTop(top) return nil } // PushValue pushes a Go value onto the stack func (s *State) PushValue(v interface{}) error { - if s.safeStack { - return stackGuardErr(s, func() error { - if err := s.checkStack(1); err != nil { - return fmt.Errorf("pushing value: %w", err) - } - return s.pushValueUnsafe(v) - }) - } - return s.pushValueUnsafe(v) -} - -func (s *State) pushValueSafe(v interface{}) error { - if err := s.checkStack(1); err != nil { - return fmt.Errorf("pushing value: %w", err) - } - return s.pushValueUnsafe(v) -} - -func (s *State) pushValueUnsafe(v interface{}) error { switch v := v.(type) { case nil: s.PushNil() @@ -144,7 +140,7 @@ func (s *State) pushValueUnsafe(v interface{}) error { } return nil } - return s.pushTableUnsafe(v) + return s.PushTable(v) case []float64: s.NewTable() for i, elem := range v { @@ -156,7 +152,7 @@ func (s *State) pushValueUnsafe(v interface{}) error { s.NewTable() for i, elem := range v { s.PushNumber(float64(i + 1)) - if err := s.pushValueUnsafe(elem); err != nil { + if err := s.PushValue(elem); err != nil { return err } s.SetTable(-3) @@ -169,15 +165,6 @@ func (s *State) pushValueUnsafe(v interface{}) error { // ToValue converts a Lua value to a Go value func (s *State) ToValue(index int) (interface{}, error) { - if s.safeStack { - return stackGuardValue[interface{}](s, func() (interface{}, error) { - return s.toValueUnsafe(index) - }) - } - return s.toValueUnsafe(index) -} - -func (s *State) toValueUnsafe(index int) (interface{}, error) { switch s.GetType(index) { case TypeNil: return nil, nil @@ -191,7 +178,7 @@ func (s *State) toValueUnsafe(index int) (interface{}, error) { if !s.IsTable(index) { return nil, fmt.Errorf("not a table at index %d", index) } - return s.toTableUnsafe(index) + return s.ToTable(index) default: return nil, fmt.Errorf("unsupported type: %s", s.GetType(index)) } @@ -244,53 +231,29 @@ func (s *State) absIndex(index int) int { return s.GetTop() + index + 1 } -// SetField sets a field in a table at the given index with cached absolute index +// SetField sets a field in a table at the given index func (s *State) SetField(index int, key string) { - absIdx := index - if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) { - absIdx = s.GetTop() + index + 1 - } - cstr := C.CString(key) defer C.free(unsafe.Pointer(cstr)) - C.lua_setfield(s.L, C.int(absIdx), cstr) + C.lua_setfield(s.L, C.int(index), cstr) } -// GetField gets a field from a table with cached absolute index +// GetField gets a field from a table func (s *State) GetField(index int, key string) { - absIdx := index - if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) { - absIdx = s.GetTop() + index + 1 - } - - if s.safeStack { - if err := s.checkStack(1); err != nil { - s.PushNil() - return - } - } - cstr := C.CString(key) defer C.free(unsafe.Pointer(cstr)) - C.lua_getfield(s.L, C.int(absIdx), cstr) + C.lua_getfield(s.L, C.int(index), cstr) } // GetGlobal gets a global variable and pushes it onto the stack func (s *State) GetGlobal(name string) { - if s.safeStack { - if err := s.checkStack(1); err != nil { - s.PushNil() - return - } - } - cstr := C.CString(name) - defer C.free(unsafe.Pointer(cstr)) - C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr) + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + C.get_field_helper(s.L, C.LUA_GLOBALSINDEX, cname) } // SetGlobal sets a global variable from the value at the top of the stack func (s *State) SetGlobal(name string) { - // SetGlobal doesn't need stack space checking as it pops the value cstr := C.CString(name) defer C.free(unsafe.Pointer(cstr)) C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr) @@ -299,9 +262,6 @@ func (s *State) SetGlobal(name string) { // Remove removes element with cached absolute index func (s *State) Remove(index int) { absIdx := index - if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) { - absIdx = s.GetTop() + index + 1 - } C.lua_remove(s.L, C.int(absIdx)) } @@ -310,14 +270,6 @@ func (s *State) DoFile(filename string) error { cfilename := C.CString(filename) defer C.free(unsafe.Pointer(cfilename)) - if s.safeStack { - return stackGuardErr(s, func() error { - return s.safeCall(func() C.int { - return C.do_file(s.L, cfilename) - }) - }) - } - status := C.do_file(s.L, cfilename) if status != 0 { return &LuaError{ @@ -328,24 +280,28 @@ func (s *State) DoFile(filename string) error { return nil } +// SetPackagePath sets the Lua package.path func (s *State) SetPackagePath(path string) error { - path = filepath.ToSlash(path) - if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil { - return fmt.Errorf("setting package.path: %w", err) - } - return nil + path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths + cmd := fmt.Sprintf(`package.path = %q`, path) + return s.DoString(cmd) } +// AddPackagePath adds a path to package.path func (s *State) AddPackagePath(path string) error { - path = filepath.ToSlash(path) - if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil { - return fmt.Errorf("adding to package.path: %w", err) - } - return nil + path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths + cmd := fmt.Sprintf(`package.path = package.path .. ";%s"`, path) + return s.DoString(cmd) } +// Call executes a function on the stack with the given number of arguments and results. +// The function and arguments should already be on the stack in the correct order +// (function first, then args from left to right). func (s *State) Call(nargs, nresults int) error { - status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0) + if !s.IsFunction(-nargs - 1) { + return fmt.Errorf("attempt to call a non-function") + } + status := C.protected_call(s.L, C.int(nargs), C.int(nresults), 0) if status != 0 { err := &LuaError{ Code: int(status), @@ -356,3 +312,94 @@ func (s *State) Call(nargs, nresults int) error { } return nil } + +// LoadString loads but does not execute a string of Lua code. +// The compiled code chunk is left on the stack. +func (s *State) LoadString(str string) error { + cstr := C.CString(str) + defer C.free(unsafe.Pointer(cstr)) + + status := C.load_chunk(s.L, cstr) + if status != 0 { + err := &LuaError{ + Code: int(status), + Message: s.ToString(-1), + } + s.Pop(1) + return err + } + + if !s.IsFunction(-1) { + s.Pop(1) + return fmt.Errorf("failed to load function") + } + return nil +} + +// ExecuteString executes a string of Lua code and returns the number of results. +// The results are left on the stack. +func (s *State) ExecuteString(str string) (int, error) { + base := s.GetTop() + + // First load the string + if err := s.LoadString(str); err != nil { + return 0, err + } + + // Now execute it + if err := s.Call(0, C.LUA_MULTRET); err != nil { + return 0, err + } + + return s.GetTop() - base, nil +} + +// ExecuteStringResult executes a Lua string and returns its first result as a Go value. +// It's a convenience wrapper around ExecuteString for the common case of wanting +// a single return value. The stack is restored to its original state after execution. +func (s *State) ExecuteStringResult(code string) (interface{}, error) { + top := s.GetTop() + defer s.SetTop(top) // Restore stack when we're done + + nresults, err := s.ExecuteString(code) + if err != nil { + return nil, fmt.Errorf("execution error: %w", err) + } + + if nresults == 0 { + return nil, nil + } + + // Get the result + result, err := s.ToValue(-nresults) // Get first result + if err != nil { + return nil, fmt.Errorf("error converting result: %w", err) + } + + return result, nil +} + +// DoStringResult executes a Lua string and expects a single return value. +// Unlike ExecuteStringResult, this function specifically expects exactly one +// return value and will return an error if the code returns 0 or multiple values. +func (s *State) DoStringResult(code string) (interface{}, error) { + top := s.GetTop() + defer s.SetTop(top) // Restore stack when we're done + + nresults, err := s.ExecuteString(code) + if err != nil { + return nil, fmt.Errorf("execution error: %w", err) + } + + if nresults != 1 { + return nil, fmt.Errorf("expected 1 return value, got %d", nresults) + } + + // Get the result + result, err := s.ToValue(-1) + if err != nil { + return nil, fmt.Errorf("error converting result: %w", err) + } + + return result, nil +} diff --git a/wrapper_bench_test.go b/wrapper_bench_test.go new file mode 100644 index 0000000..6316ae6 --- /dev/null +++ b/wrapper_bench_test.go @@ -0,0 +1,237 @@ +package luajit + +import ( + "testing" +) + +var benchCases = []struct { + name string + code string +}{ + { + name: "SimpleAddition", + code: `return 1 + 1`, + }, + { + name: "LoopSum", + code: ` + local sum = 0 + for i = 1, 1000 do + sum = sum + i + end + return sum + `, + }, + { + name: "FunctionCall", + code: ` + local result = 0 + for i = 1, 100 do + result = result + i + end + return result + `, + }, + { + name: "TableCreation", + code: ` + local t = {} + for i = 1, 100 do + t[i] = i * 2 + end + return t[50] + `, + }, + { + name: "StringOperations", + code: ` + local s = "hello" + for i = 1, 10 do + s = s .. " world" + end + return #s + `, + }, +} + +func BenchmarkLuaDirectExecution(b *testing.B) { + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + L := New() + if L == nil { + b.Fatal("Failed to create Lua state") + } + defer L.Close() + + // First verify we can execute the code + if err := L.DoString(bc.code); err != nil { + b.Fatalf("Failed to execute test code: %v", err) + } + L.Pop(1) // Clean up the result + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Execute string and get result + nresults, err := L.ExecuteString(bc.code) + if err != nil { + b.Fatalf("Failed to execute code: %v", err) + } + L.Pop(nresults) // Clean up any results + } + }) + } +} + +func BenchmarkLuaBytecodeExecution(b *testing.B) { + // First compile all bytecode + bytecodes := make(map[string][]byte) + for _, bc := range benchCases { + L := New() + if L == nil { + b.Fatal("Failed to create Lua state") + } + bytecode, err := L.CompileBytecode(bc.code, bc.name) + if err != nil { + L.Close() + b.Fatalf("Error compiling bytecode for %s: %v", bc.name, err) + } + bytecodes[bc.name] = bytecode + L.Close() + } + + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + L := New() + if L == nil { + b.Fatal("Failed to create Lua state") + } + defer L.Close() + + bytecode := bytecodes[bc.name] + + // First verify we can execute the bytecode + if err := L.LoadBytecode(bytecode, bc.name); err != nil { + b.Fatalf("Failed to execute test bytecode: %v", err) + } + + b.ResetTimer() + b.SetBytes(int64(len(bytecode))) // Track bytecode size in benchmarks + + for i := 0; i < b.N; i++ { + if err := L.LoadBytecode(bytecode, bc.name); err != nil { + b.Fatalf("Error executing bytecode: %v", err) + } + } + }) + } +} + +func BenchmarkTableOperations(b *testing.B) { + testData := map[string]interface{}{ + "number": 42.0, + "string": "hello", + "bool": true, + "nested": map[string]interface{}{ + "value": 123.0, + "array": []float64{1.1, 2.2, 3.3}, + }, + } + + b.Run("PushTable", func(b *testing.B) { + L := New() + if L == nil { + b.Fatal("Failed to create Lua state") + } + defer L.Close() + + // First verify we can push the table + if err := L.PushTable(testData); err != nil { + b.Fatalf("Failed to push initial table: %v", err) + } + L.Pop(1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := L.PushTable(testData); err != nil { + b.Fatalf("Failed to push table: %v", err) + } + L.Pop(1) + } + }) + + b.Run("ToTable", func(b *testing.B) { + L := New() + if L == nil { + b.Fatal("Failed to create Lua state") + } + defer L.Close() + + // Keep a table on the stack for repeated conversions + if err := L.PushTable(testData); err != nil { + b.Fatalf("Failed to push initial table: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := L.ToTable(-1); err != nil { + b.Fatalf("Failed to convert table: %v", err) + } + } + }) +} + +func BenchmarkValueConversion(b *testing.B) { + testValues := []struct { + name string + value interface{} + }{ + {"Number", 42.0}, + {"String", "hello world"}, + {"Boolean", true}, + {"Nil", nil}, + } + + for _, tv := range testValues { + b.Run("Push"+tv.name, func(b *testing.B) { + L := New() + if L == nil { + b.Fatal("Failed to create Lua state") + } + defer L.Close() + + // First verify we can push the value + if err := L.PushValue(tv.value); err != nil { + b.Fatalf("Failed to push initial value: %v", err) + } + L.Pop(1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := L.PushValue(tv.value); err != nil { + b.Fatalf("Failed to push value: %v", err) + } + L.Pop(1) + } + }) + + b.Run("To"+tv.name, func(b *testing.B) { + L := New() + if L == nil { + b.Fatal("Failed to create Lua state") + } + defer L.Close() + + // Keep a value on the stack for repeated conversions + if err := L.PushValue(tv.value); err != nil { + b.Fatalf("Failed to push initial value: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := L.ToValue(-1); err != nil { + b.Fatalf("Failed to convert value: %v", err) + } + } + }) + } +} diff --git a/wrapper_test.go b/wrapper_test.go index d0f888c..d7ce9e2 100644 --- a/wrapper_test.go +++ b/wrapper_test.go @@ -7,25 +7,151 @@ import ( "testing" ) -type stateFactory struct { - name string - new func() *State -} - -var factories = []stateFactory{ - {"unsafe", New}, - {"safe", NewSafe}, -} - func TestNew(t *testing.T) { - for _, f := range factories { - t.Run(f.name, func(t *testing.T) { - L := f.new() - if L == nil { - t.Fatal("Failed to create Lua state") + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() +} + +func TestLoadString(t *testing.T) { + tests := []struct { + name string + code string + wantErr bool + }{ + { + name: "valid function", + code: "function add(a, b) return a + b end", + wantErr: false, + }, + { + name: "valid expression", + code: "return 1 + 1", + wantErr: false, + }, + { + name: "syntax error", + code: "function bad syntax", + wantErr: true, + }, + } + + for _, tt := range tests { + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + + err := L.LoadString(tt.code) + if (err != nil) != tt.wantErr { + t.Errorf("LoadString() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // Verify the function is on the stack + if L.GetTop() != 1 { + t.Error("LoadString() did not leave exactly one value on stack") } - defer L.Close() - }) + if !L.IsFunction(-1) { + t.Error("LoadString() did not leave a function on the stack") + } + } + } +} + +func TestExecuteString(t *testing.T) { + tests := []struct { + name string + code string + wantResults int + checkResults func(*State) error + wantErr bool + wantStackSize int + }{ + { + name: "no results", + code: "local x = 1", + wantResults: 0, + wantErr: false, + }, + { + name: "single result", + code: "return 42", + wantResults: 1, + checkResults: func(L *State) error { + if n := L.ToNumber(-1); n != 42 { + return fmt.Errorf("got %v, want 42", n) + } + return nil + }, + wantErr: false, + }, + { + name: "multiple results", + code: "return 1, 'test', true", + wantResults: 3, + checkResults: func(L *State) error { + if n := L.ToNumber(-3); n != 1 { + return fmt.Errorf("first result: got %v, want 1", n) + } + if s := L.ToString(-2); s != "test" { + return fmt.Errorf("second result: got %v, want 'test'", s) + } + if b := L.ToBoolean(-1); !b { + return fmt.Errorf("third result: got %v, want true", b) + } + return nil + }, + wantErr: false, + }, + { + name: "syntax error", + code: "this is not valid lua", + wantErr: true, + }, + { + name: "runtime error", + code: "error('test error')", + wantErr: true, + }, + } + + for _, tt := range tests { + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + + // Record initial stack size + initialStack := L.GetTop() + + results, err := L.ExecuteString(tt.code) + if (err != nil) != tt.wantErr { + t.Errorf("ExecuteString() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if err == nil { + if results != tt.wantResults { + t.Errorf("ExecuteString() returned %d results, want %d", results, tt.wantResults) + } + + if tt.checkResults != nil { + if err := tt.checkResults(L); err != nil { + t.Errorf("Result check failed: %v", err) + } + } + + // Verify stack size matches expected results + if got := L.GetTop() - initialStack; got != tt.wantResults { + t.Errorf("Stack size grew by %d, want %d", got, tt.wantResults) + } + } } } @@ -41,20 +167,22 @@ func TestDoString(t *testing.T) { {"runtime error", "error('test error')", true}, } - for _, f := range factories { - for _, tt := range tests { - t.Run(f.name+"/"+tt.name, func(t *testing.T) { - L := f.new() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() + for _, tt := range tests { + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() - err := L.DoString(tt.code) - if (err != nil) != tt.wantErr { - t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr) - } - }) + initialStack := L.GetTop() + err := L.DoString(tt.code) + if (err != nil) != tt.wantErr { + t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr) + } + + // Verify stack is unchanged + if finalStack := L.GetTop(); finalStack != initialStack { + t.Errorf("Stack size changed from %d to %d", initialStack, finalStack) } } } @@ -107,95 +235,171 @@ func TestPushAndGetValues(t *testing.T) { }, } - for _, f := range factories { - for _, v := range values { - t.Run(f.name+"/"+v.name, func(t *testing.T) { - L := f.new() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() + for _, v := range values { + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() - v.push(L) - if err := v.check(L); err != nil { - t.Error(err) - } - }) + v.push(L) + if err := v.check(L); err != nil { + t.Error(err) } } } func TestStackManipulation(t *testing.T) { - for _, f := range factories { - t.Run(f.name, func(t *testing.T) { - L := f.new() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() - // Push values - values := []string{"first", "second", "third"} - for _, v := range values { - L.PushString(v) - } + // Push values + values := []string{"first", "second", "third"} + for _, v := range values { + L.PushString(v) + } - // Check size - if top := L.GetTop(); top != len(values) { - t.Errorf("stack size = %d, want %d", top, len(values)) - } + // Check size + if top := L.GetTop(); top != len(values) { + t.Errorf("stack size = %d, want %d", top, len(values)) + } - // Pop one value - L.Pop(1) + // Pop one value + L.Pop(1) - // Check new top - if str := L.ToString(-1); str != "second" { - t.Errorf("top value = %q, want 'second'", str) - } + // Check new top + if str := L.ToString(-1); str != "second" { + t.Errorf("top value = %q, want 'second'", str) + } - // Check new size - if top := L.GetTop(); top != len(values)-1 { - t.Errorf("stack size after pop = %d, want %d", top, len(values)-1) - } - }) + // Check new size + if top := L.GetTop(); top != len(values)-1 { + t.Errorf("stack size after pop = %d, want %d", top, len(values)-1) } } func TestGlobals(t *testing.T) { - for _, f := range factories { - t.Run(f.name, func(t *testing.T) { - L := f.new() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() - // Test via Lua - if err := L.DoString(`globalVar = "test"`); err != nil { - t.Fatalf("DoString error: %v", err) - } + // Test via Lua + if err := L.DoString(`globalVar = "test"`); err != nil { + t.Fatalf("DoString error: %v", err) + } - // Get the global - L.GetGlobal("globalVar") - if str := L.ToString(-1); str != "test" { - t.Errorf("global value = %q, want 'test'", str) - } - L.Pop(1) + // Get the global + L.GetGlobal("globalVar") + if str := L.ToString(-1); str != "test" { + t.Errorf("global value = %q, want 'test'", str) + } + L.Pop(1) - // Set and get via API - L.PushNumber(42) - L.SetGlobal("testNum") + // Set and get via API + L.PushNumber(42) + L.SetGlobal("testNum") - L.GetGlobal("testNum") - if num := L.ToNumber(-1); num != 42 { - t.Errorf("global number = %f, want 42", num) + L.GetGlobal("testNum") + if num := L.ToNumber(-1); num != 42 { + t.Errorf("global number = %f, want 42", num) + } +} + +func TestCall(t *testing.T) { + tests := []struct { + funcName string // Add explicit function name field + setup string + args []interface{} + nresults int + checkStack func(*State) error + wantErr bool + }{ + { + funcName: "add", + setup: "function add(a, b) return a + b end", + args: []interface{}{float64(40), float64(2)}, + nresults: 1, + checkStack: func(L *State) error { + if n := L.ToNumber(-1); n != 42 { + return fmt.Errorf("got %v, want 42", n) + } + return nil + }, + }, + { + funcName: "multi", + setup: "function multi() return 1, 'test', true end", + args: []interface{}{}, + nresults: 3, + checkStack: func(L *State) error { + if n := L.ToNumber(-3); n != 1 { + return fmt.Errorf("first result: got %v, want 1", n) + } + if s := L.ToString(-2); s != "test" { + return fmt.Errorf("second result: got %v, want 'test'", s) + } + if b := L.ToBoolean(-1); !b { + return fmt.Errorf("third result: got %v, want true", b) + } + return nil + }, + }, + { + funcName: "err", + setup: "function err() error('test error') end", + args: []interface{}{}, + nresults: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + + // Setup function + if err := L.DoString(tt.setup); err != nil { + t.Fatalf("Setup failed: %v", err) + } + + // Get function + L.GetGlobal(tt.funcName) + if !L.IsFunction(-1) { + t.Fatal("Failed to get function") + } + + // Push arguments + for _, arg := range tt.args { + if err := L.PushValue(arg); err != nil { + t.Fatalf("Failed to push argument: %v", err) } - }) + } + + // Call function + err := L.Call(len(tt.args), tt.nresults) + if (err != nil) != tt.wantErr { + t.Errorf("Call() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if err == nil && tt.checkStack != nil { + if err := tt.checkStack(L); err != nil { + t.Errorf("Stack check failed: %v", err) + } + } } } func TestDoFile(t *testing.T) { - L := NewSafe() + L := New() defer L.Close() // Create test file @@ -223,7 +427,7 @@ func TestDoFile(t *testing.T) { } func TestRequireAndPackagePath(t *testing.T) { - L := NewSafe() + L := New() defer L.Close() tmpDir := t.TempDir() @@ -260,7 +464,7 @@ func TestRequireAndPackagePath(t *testing.T) { } func TestSetPackagePath(t *testing.T) { - L := NewSafe() + L := New() defer L.Close() customPath := "./custom/?.lua" @@ -273,4 +477,217 @@ func TestSetPackagePath(t *testing.T) { if path := L.ToString(-1); path != customPath { t.Errorf("Expected package.path=%q, got %q", customPath, path) } + + // Test that the old path is completely replaced + initialPath := L.ToString(-1) + anotherPath := "./another/?.lua" + if err := L.SetPackagePath(anotherPath); err != nil { + t.Fatalf("Second SetPackagePath failed: %v", err) + } + + L.GetGlobal("package") + L.GetField(-1, "path") + if path := L.ToString(-1); path != anotherPath { + t.Errorf("Expected package.path=%q, got %q", anotherPath, path) + } + if path := L.ToString(-1); path == initialPath { + t.Error("SetPackagePath did not replace the old path") + } +} + +func TestStackDebug(t *testing.T) { + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + + t.Log("Testing LoadString:") + initialTop := L.GetTop() + t.Logf("Initial stack size: %d", initialTop) + + err := L.LoadString("return 42") + if err != nil { + t.Errorf("LoadString failed: %v", err) + } + + afterLoad := L.GetTop() + t.Logf("Stack size after load: %d", afterLoad) + t.Logf("Type of top element: %s", L.GetType(-1)) + + if L.IsFunction(-1) { + t.Log("Top element is a function") + } else { + t.Log("Top element is NOT a function") + } + + // Clean up after LoadString test + L.SetTop(0) + + t.Log("\nTesting ExecuteString:") + if err := L.DoString("function test() return 1, 'hello', true end"); err != nil { + t.Errorf("DoString failed: %v", err) + } + + beforeExec := L.GetTop() + t.Logf("Stack size before execute: %d", beforeExec) + + nresults, err := L.ExecuteString("return test()") + if err != nil { + t.Errorf("ExecuteString failed: %v", err) + } + + afterExec := L.GetTop() + t.Logf("Stack size after execute: %d", afterExec) + t.Logf("Reported number of results: %d", nresults) + + // Print each stack element + for i := 1; i <= afterExec; i++ { + t.Logf("Stack[-%d] type: %s", i, L.GetType(-i)) + } + + if afterExec != nresults { + t.Errorf("Stack size (%d) doesn't match number of results (%d)", afterExec, nresults) + } +} + +func TestTemplateRendering(t *testing.T) { + L := New() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + defer L.Cleanup() + + // Create a simple render.template function + renderFunc := func(s *State) int { + // Template will be at index 1, data at index 2 + data, err := s.ToTable(2) + if err != nil { + s.PushString(fmt.Sprintf("failed to get data table: %v", err)) + return -1 + } + + // Push data back as global for template access + if err := s.PushTable(data); err != nil { + s.PushString(fmt.Sprintf("failed to push data table: %v", err)) + return -1 + } + s.SetGlobal("data") + + // Template processing code + luaCode := ` + local result = {} + if data.user.logged_in then + table.insert(result, '
') + table.insert(result, string.format('

Welcome, %s!

', tostring(data.user.name))) + table.insert(result, ' ') + if data.user.is_admin then + table.insert(result, '
') + table.insert(result, '

Admin Controls

') + table.insert(result, ' ') + table.insert(result, '
') + end + table.insert(result, '
') + else + table.insert(result, '
') + table.insert(result, '

Please log in to view your profile

') + table.insert(result, '
') + end + return table.concat(result, '\n')` + + result, err := s.DoStringResult(luaCode) + if err != nil { + s.PushString(fmt.Sprintf("template execution failed: %v", err)) + return -1 + } + + // Push the string result + if str, ok := result.(string); ok { + s.PushString(str) + return 1 + } + + s.PushString(fmt.Sprintf("expected string result, got %T", result)) + return -1 + } + + // Create render table and add template function + L.NewTable() + if err := L.PushGoFunction(renderFunc); err != nil { + t.Fatalf("Failed to create render function: %v", err) + } + L.SetField(-2, "template") + L.SetGlobal("render") + + // Test with logged in admin user + testCode := ` + local data = { + user = { + logged_in = true, + name = "John Doe", + email = "john@example.com", + joined_date = "2024-02-09", + is_admin = true + } + } + return render.template("test.html", data) + ` + + result, err := L.DoStringResult(testCode) + if err != nil { + t.Fatalf("Failed to execute test: %v", err) + } + + str, ok := result.(string) + if !ok { + t.Fatalf("Expected string result, got %T", result) + } + + expectedResult := `
+

Welcome, John Doe!

+ +
+

Admin Controls

+ +
+
` + + if str != expectedResult { + t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str) + } + + // Test with logged out user + testCode = ` + local data = { + user = { + logged_in = false + } + } + return render.template("test.html", data) + ` + + result, err = L.DoStringResult(testCode) + if err != nil { + t.Fatalf("Failed to execute logged out test: %v", err) + } + + str, ok = result.(string) + if !ok { + t.Fatalf("Expected string result, got %T", result) + } + + expectedResult = `
+

Please log in to view your profile

+
` + + if str != expectedResult { + t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str) + } }