BIG changes; no "safe" mode, function updates, etc

This commit is contained in:
Sky Johnson 2025-02-12 19:17:11 -06:00
parent 7c79616cac
commit 4dc266201f
10 changed files with 1105 additions and 660 deletions

View File

@ -1,148 +0,0 @@
package main
import (
"fmt"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
type benchCase struct {
name string
code string
}
var cases = []benchCase{
{
name: "Simple Addition",
code: `return 1 + 1`,
},
{
name: "Loop Sum",
code: `
local sum = 0
for i = 1, 1000 do
sum = sum + i
end
return sum
`,
},
{
name: "Function Call",
code: `
local result = 0
for i = 1, 100 do
result = result + i
end
return result
`,
},
{
name: "Table Creation",
code: `
local t = {}
for i = 1, 100 do
t[i] = i * 2
end
return t[50]
`,
},
{
name: "String Operations",
code: `
local s = "hello"
for i = 1, 10 do
s = s .. " world"
end
return #s
`,
},
}
func runBenchmark(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) {
start := time.Now()
deadline := start.Add(duration)
var ops int64
for time.Now().Before(deadline) {
if err := L.DoString(code); err != nil {
fmt.Printf("Error executing code: %v\n", err)
return 0, 0
}
L.Pop(1)
ops++
}
return time.Since(start), ops
}
func runBytecodeTest(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) {
// First compile the bytecode
bytecode, err := L.CompileBytecode(code, "bench")
if err != nil {
fmt.Printf("Error compiling bytecode: %v\n", err)
return 0, 0
}
start := time.Now()
deadline := start.Add(duration)
var ops int64
for time.Now().Before(deadline) {
if err := L.LoadBytecode(bytecode, "bench"); err != nil {
fmt.Printf("Error executing bytecode: %v\n", err)
return 0, 0
}
ops++
}
return time.Since(start), ops
}
func benchmarkCase(newState func() *luajit.State, bc benchCase) {
fmt.Printf("\n%s:\n", bc.name)
// Direct execution benchmark
L := newState()
if L == nil {
fmt.Printf(" Failed to create Lua state\n")
return
}
execTime, ops := runBenchmark(L, bc.code, 2*time.Second)
L.Close()
if ops > 0 {
opsPerSec := float64(ops) / execTime.Seconds()
fmt.Printf(" Direct: %.0f ops/sec\n", opsPerSec)
}
// Bytecode execution benchmark
L = newState()
if L == nil {
fmt.Printf(" Failed to create Lua state\n")
return
}
execTime, ops = runBytecodeTest(L, bc.code, 2*time.Second)
L.Close()
if ops > 0 {
opsPerSec := float64(ops) / execTime.Seconds()
fmt.Printf(" Bytecode: %.0f ops/sec\n", opsPerSec)
}
}
func main() {
modes := []struct {
name string
newState func() *luajit.State
}{
{"Safe", luajit.NewSafe},
{"Unsafe", luajit.New},
}
for _, mode := range modes {
fmt.Printf("\n=== %s Mode ===\n", mode.name)
for _, c := range cases {
benchmarkCase(mode.newState, c)
}
}
}

View File

@ -51,7 +51,8 @@ import (
"unsafe"
)
func (s *State) compileBytecodeUnsafe(code string, name string) ([]byte, error) {
// CompileBytecode compiles a Lua chunk to bytecode without executing it
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
// First load the string but don't execute it
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
@ -94,7 +95,8 @@ func (s *State) compileBytecodeUnsafe(code string, name string) ([]byte, error)
return bytecode, nil
}
func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error {
// LoadBytecode loads precompiled bytecode and executes it
func (s *State) LoadBytecode(bytecode []byte, name string) error {
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
@ -125,26 +127,6 @@ func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error {
return nil
}
// CompileBytecode compiles a Lua chunk to bytecode without executing it
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
if s.safeStack {
return stackGuardValue[[]byte](s, func() ([]byte, error) {
return s.compileBytecodeUnsafe(code, name)
})
}
return s.compileBytecodeUnsafe(code, name)
}
// LoadBytecode loads precompiled bytecode and executes it
func (s *State) LoadBytecode(bytecode []byte, name string) error {
if s.safeStack {
return stackGuardErr(s, func() error {
return s.loadBytecodeUnsafe(bytecode, name)
})
}
return s.loadBytecodeUnsafe(bytecode, name)
}
// Helper function to compile and immediately load/execute bytecode
func (s *State) CompileAndLoad(code string, name string) error {
bytecode, err := s.CompileBytecode(code, name)

View File

@ -28,82 +28,70 @@ func TestBytecodeCompilation(t *testing.T) {
},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
bytecode, err := L.CompileBytecode(tt.code, "test")
if (err != nil) != tt.wantErr {
t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr)
return
}
bytecode, err := L.CompileBytecode(tt.code, "test")
if (err != nil) != tt.wantErr {
t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if len(bytecode) == 0 {
t.Error("CompileBytecode() returned empty bytecode")
}
}
})
if !tt.wantErr {
if len(bytecode) == 0 {
t.Error("CompileBytecode() returned empty bytecode")
}
}
}
}
func TestBytecodeExecution(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Compile some test code
code := `
function add(a, b)
return a + b
end
result = add(40, 2)
`
// Compile some test code
code := `
function add(a, b)
return a + b
end
result = add(40, 2)
`
bytecode, err := L.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode() error = %v", err)
}
bytecode, err := L.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode() error = %v", err)
}
// Load and execute the bytecode
if err := L.LoadBytecode(bytecode, "test"); err != nil {
t.Fatalf("LoadBytecode() error = %v", err)
}
// Load and execute the bytecode
if err := L.LoadBytecode(bytecode, "test"); err != nil {
t.Fatalf("LoadBytecode() error = %v", err)
}
// Verify the result
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("got result = %v, want 42", result)
}
})
// Verify the result
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("got result = %v, want 42", result)
}
}
func TestInvalidBytecode(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Test with invalid bytecode
invalidBytecode := []byte("this is not valid bytecode")
if err := L.LoadBytecode(invalidBytecode, "test"); err == nil {
t.Error("LoadBytecode() expected error with invalid bytecode")
}
})
// Test with invalid bytecode
invalidBytecode := []byte("this is not valid bytecode")
if err := L.LoadBytecode(invalidBytecode, "test"); err == nil {
t.Error("LoadBytecode() expected error with invalid bytecode")
}
}
@ -140,39 +128,35 @@ func TestBytecodeRoundTrip(t *testing.T) {
},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
// First state for compilation
L1 := f.new()
if L1 == nil {
t.Fatal("Failed to create first Lua state")
}
defer L1.Close()
for _, tt := range tests {
// First state for compilation
L1 := New()
if L1 == nil {
t.Fatal("Failed to create first Lua state")
}
defer L1.Close()
// Compile the code
bytecode, err := L1.CompileBytecode(tt.code, "test")
if err != nil {
t.Fatalf("CompileBytecode() error = %v", err)
}
// Compile the code
bytecode, err := L1.CompileBytecode(tt.code, "test")
if err != nil {
t.Fatalf("CompileBytecode() error = %v", err)
}
// Second state for execution
L2 := f.new()
if L2 == nil {
t.Fatal("Failed to create second Lua state")
}
defer L2.Close()
// Second state for execution
L2 := New()
if L2 == nil {
t.Fatal("Failed to create second Lua state")
}
defer L2.Close()
// Load and execute the bytecode
if err := L2.LoadBytecode(bytecode, "test"); err != nil {
t.Fatalf("LoadBytecode() error = %v", err)
}
// Load and execute the bytecode
if err := L2.LoadBytecode(bytecode, "test"); err != nil {
t.Fatalf("LoadBytecode() error = %v", err)
}
// Run the check function
if err := tt.check(L2); err != nil {
t.Errorf("check failed: %v", err)
}
})
// Run the check function
if err := tt.check(L2); err != nil {
t.Errorf("check failed: %v", err)
}
}
}

View File

@ -31,7 +31,7 @@ var (
//export goFunctionWrapper
func goFunctionWrapper(L *C.lua_State) C.int {
state := &State{L: L, safeStack: true}
state := &State{L: L}
// Get upvalue using standard Lua 5.1 macro
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
@ -54,7 +54,6 @@ func goFunctionWrapper(L *C.lua_State) C.int {
}
func (s *State) PushGoFunction(fn GoFunction) error {
// Push lightuserdata as upvalue and create closure
ptr := C.malloc(1)
if ptr == nil {
return fmt.Errorf("failed to allocate memory for function pointer")

View File

@ -1,89 +1,87 @@
package luajit
import "testing"
import (
"testing"
)
func TestGoFunctions(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
addFunc := func(s *State) int {
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
return 1
}
addFunc := func(s *State) int {
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
return 1
}
if err := L.RegisterGoFunction("add", addFunc); err != nil {
t.Fatalf("Failed to register function: %v", err)
}
if err := L.RegisterGoFunction("add", addFunc); err != nil {
t.Fatalf("Failed to register function: %v", err)
}
// Test basic function call
if err := L.DoString("result = add(40, 2)"); err != nil {
t.Fatalf("Failed to call function: %v", err)
}
// Test basic function call
if err := L.DoString("result = add(40, 2)"); err != nil {
t.Fatalf("Failed to call function: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("got %v, want 42", result)
}
L.Pop(1)
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("got %v, want 42", result)
}
L.Pop(1)
// Test multiple return values
multiFunc := func(s *State) int {
s.PushString("hello")
s.PushNumber(42)
s.PushBoolean(true)
return 3
}
// Test multiple return values
multiFunc := func(s *State) int {
s.PushString("hello")
s.PushNumber(42)
s.PushBoolean(true)
return 3
}
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
t.Fatalf("Failed to register multi function: %v", err)
}
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
t.Fatalf("Failed to register multi function: %v", err)
}
code := `
code := `
a, b, c = multi()
result = (a == "hello" and b == 42 and c == true)
`
if err := L.DoString(code); err != nil {
t.Fatalf("Failed to call multi function: %v", err)
}
if err := L.DoString(code); err != nil {
t.Fatalf("Failed to call multi function: %v", err)
}
L.GetGlobal("result")
if !L.ToBoolean(-1) {
t.Error("Multiple return values test failed")
}
L.Pop(1)
L.GetGlobal("result")
if !L.ToBoolean(-1) {
t.Error("Multiple return values test failed")
}
L.Pop(1)
// Test error handling
errFunc := func(s *State) int {
s.PushString("test error")
return -1
}
// Test error handling
errFunc := func(s *State) int {
s.PushString("test error")
return -1
}
if err := L.RegisterGoFunction("err", errFunc); err != nil {
t.Fatalf("Failed to register error function: %v", err)
}
if err := L.RegisterGoFunction("err", errFunc); err != nil {
t.Fatalf("Failed to register error function: %v", err)
}
if err := L.DoString("err()"); err == nil {
t.Error("Expected error from error function")
}
if err := L.DoString("err()"); err == nil {
t.Error("Expected error from error function")
}
// Test unregistering
L.UnregisterGoFunction("add")
if err := L.DoString("add(1, 2)"); err == nil {
t.Error("Expected error calling unregistered function")
}
})
// Test unregistering
L.UnregisterGoFunction("add")
if err := L.DoString("add(1, 2)"); err == nil {
t.Error("Expected error calling unregistered function")
}
}
func TestStackSafety(t *testing.T) {
L := NewSafe()
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}

101
table.go
View File

@ -22,57 +22,20 @@ type TableValue interface {
func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) }
// PushTable pushes a Go map onto the Lua stack as a table
func (s *State) PushTable(table map[string]interface{}) error {
s.NewTable()
for k, v := range table {
if err := s.PushValue(v); err != nil {
return err
}
s.SetField(-2, k)
}
return nil
}
// ToTable converts a Lua table to a Go map
func (s *State) ToTable(index int) (map[string]interface{}, error) {
if s.safeStack {
return stackGuardValue[map[string]interface{}](s, func() (map[string]interface{}, error) {
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.toTableUnsafe(index)
})
}
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.toTableUnsafe(index)
}
func (s *State) pushTableSafe(table map[string]interface{}) error {
size := 2
if err := s.checkStack(size); err != nil {
return fmt.Errorf("insufficient stack space: %w", err)
}
s.NewTable()
for k, v := range table {
if err := s.pushValueSafe(v); err != nil {
return err
}
s.SetField(-2, k)
}
return nil
}
func (s *State) pushTableUnsafe(table map[string]interface{}) error {
s.NewTable()
for k, v := range table {
if err := s.pushValueUnsafe(v); err != nil {
return err
}
s.SetField(-2, k)
}
return nil
}
func (s *State) toTableSafe(index int) (map[string]interface{}, error) {
if err := s.checkStack(2); err != nil {
return nil, err
}
return s.toTableUnsafe(index)
}
func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
absIdx := s.absIndex(index)
table := make(map[string]interface{})
@ -111,7 +74,7 @@ func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
key = fmt.Sprintf("%g", s.ToNumber(-2))
}
value, err := s.toValueUnsafe(-1)
value, err := s.ToValue(-1)
if err != nil {
s.Pop(1)
return nil, err
@ -133,45 +96,15 @@ func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
// NewTable creates a new table and pushes it onto the stack
func (s *State) NewTable() {
if s.safeStack {
if err := s.checkStack(1); err != nil {
// Since we can't return an error, we'll push nil instead
s.PushNil()
return
}
}
C.lua_createtable(s.L, 0, 0)
}
// SetTable sets a table field with cached absolute index
// SetTable sets a table field
func (s *State) SetTable(index int) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
C.lua_settable(s.L, C.int(absIdx))
C.lua_settable(s.L, C.int(index))
}
// GetTable gets a table field with cached absolute index
// GetTable gets a table field
func (s *State) GetTable(index int) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
C.lua_gettable(s.L, C.int(absIdx))
}
// PushTable pushes a Go map onto the Lua stack as a table with stack checking
func (s *State) PushTable(table map[string]interface{}) error {
if s.safeStack {
return s.pushTableSafe(table)
}
return s.pushTableUnsafe(table)
C.lua_gettable(s.L, C.int(index))
}

View File

@ -34,28 +34,24 @@ func TestTableOperations(t *testing.T) {
},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
if err := L.PushTable(tt.data); err != nil {
t.Fatalf("PushTable() error = %v", err)
}
if err := L.PushTable(tt.data); err != nil {
t.Fatalf("PushTable() error = %v", err)
}
got, err := L.ToTable(-1)
if err != nil {
t.Fatalf("ToTable() error = %v", err)
}
got, err := L.ToTable(-1)
if err != nil {
t.Fatalf("ToTable() error = %v", err)
}
if !tablesEqual(got, tt.data) {
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
}
})
if !tablesEqual(got, tt.data) {
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
}
}
}

View File

@ -10,49 +10,82 @@ package luajit
#include <lauxlib.h>
#include <stdlib.h>
static int do_string(lua_State *L, const char *s) {
int status = luaL_loadstring(L, s);
if (status) return status;
return lua_pcall(L, 0, LUA_MULTRET, 0);
// Simple wrapper around luaL_loadstring
static int load_chunk(lua_State *L, const char *s) {
return luaL_loadstring(L, s);
}
// Direct wrapper around lua_pcall
static int protected_call(lua_State *L, int nargs, int nresults, int errfunc) {
return lua_pcall(L, nargs, nresults, errfunc);
}
// Combined load and execute with no results
static int do_string(lua_State *L, const char *s) {
return luaL_dostring(L, s);
}
// Combined load and execute file
static int do_file(lua_State *L, const char *filename) {
int status = luaL_loadfile(L, filename);
if (status) return status;
return lua_pcall(L, 0, LUA_MULTRET, 0);
return luaL_dofile(L, filename);
}
// Execute string with multiple returns
static int execute_string(lua_State *L, const char *s) {
int base = lua_gettop(L); // Save stack position
int status = luaL_loadstring(L, s);
if (status) return -status; // Return negative status for load errors
status = lua_pcall(L, 0, LUA_MULTRET, 0);
if (status) return -status; // Return negative status for runtime errors
return lua_gettop(L) - base; // Return number of results
}
// Get absolute stack index (converts negative indices)
static int get_abs_index(lua_State *L, int idx) {
if (idx > 0 || idx <= LUA_REGISTRYINDEX) return idx;
return lua_gettop(L) + idx + 1;
}
// Stack manipulation helpers
static int check_stack(lua_State *L, int n) {
return lua_checkstack(L, n);
}
static void remove_stack(lua_State *L, int idx) {
lua_remove(L, idx);
}
static int get_field_helper(lua_State *L, int idx, const char *k) {
lua_getfield(L, idx, k);
return lua_type(L, -1);
}
static void set_field_helper(lua_State *L, int idx, const char *k) {
lua_setfield(L, idx, k);
}
*/
import "C"
import (
"fmt"
"path/filepath"
"strings"
"unsafe"
)
// State represents a Lua state with configurable stack safety
// State represents a Lua state
type State struct {
L *C.lua_State
safeStack bool
L *C.lua_State
}
// NewSafe creates a new Lua state with full stack safety guarantees
func NewSafe() *State {
L := C.luaL_newstate()
if L == nil {
return nil
}
C.luaL_openlibs(L)
return &State{L: L, safeStack: true}
}
// New creates a new Lua state with minimal stack checking
// New creates a new Lua state
func New() *State {
L := C.luaL_newstate()
if L == nil {
return nil
}
C.luaL_openlibs(L)
return &State{L: L, safeStack: false}
return &State{L: L}
}
// Close closes the Lua state
@ -63,65 +96,28 @@ func (s *State) Close() {
}
}
// DoString executes a Lua string with appropriate stack management
// DoString executes a Lua string.
func (s *State) DoString(str string) error {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
// Save initial stack size
top := s.GetTop()
if s.safeStack {
return stackGuardErr(s, func() error {
// Save the current stack size
initialTop := s.GetTop()
// Execute the string
status := C.do_string(s.L, cstr)
if status != 0 {
// In case of error, get error message from stack
errMsg := s.ToString(-1)
s.SetTop(initialTop) // Restore stack
return &LuaError{
Code: int(status),
Message: errMsg,
}
}
// Return values are now on the stack above initialTop
// We don't pop them as they may be needed by the caller
return nil
})
// Load the string
if err := s.LoadString(str); err != nil {
return err
}
status := C.do_string(s.L, cstr)
if status != 0 {
return &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
// Execute and check for errors
if err := s.Call(0, 0); err != nil {
return err
}
// Restore stack to initial size to clean up any leftovers
s.SetTop(top)
return nil
}
// PushValue pushes a Go value onto the stack
func (s *State) PushValue(v interface{}) error {
if s.safeStack {
return stackGuardErr(s, func() error {
if err := s.checkStack(1); err != nil {
return fmt.Errorf("pushing value: %w", err)
}
return s.pushValueUnsafe(v)
})
}
return s.pushValueUnsafe(v)
}
func (s *State) pushValueSafe(v interface{}) error {
if err := s.checkStack(1); err != nil {
return fmt.Errorf("pushing value: %w", err)
}
return s.pushValueUnsafe(v)
}
func (s *State) pushValueUnsafe(v interface{}) error {
switch v := v.(type) {
case nil:
s.PushNil()
@ -144,7 +140,7 @@ func (s *State) pushValueUnsafe(v interface{}) error {
}
return nil
}
return s.pushTableUnsafe(v)
return s.PushTable(v)
case []float64:
s.NewTable()
for i, elem := range v {
@ -156,7 +152,7 @@ func (s *State) pushValueUnsafe(v interface{}) error {
s.NewTable()
for i, elem := range v {
s.PushNumber(float64(i + 1))
if err := s.pushValueUnsafe(elem); err != nil {
if err := s.PushValue(elem); err != nil {
return err
}
s.SetTable(-3)
@ -169,15 +165,6 @@ func (s *State) pushValueUnsafe(v interface{}) error {
// ToValue converts a Lua value to a Go value
func (s *State) ToValue(index int) (interface{}, error) {
if s.safeStack {
return stackGuardValue[interface{}](s, func() (interface{}, error) {
return s.toValueUnsafe(index)
})
}
return s.toValueUnsafe(index)
}
func (s *State) toValueUnsafe(index int) (interface{}, error) {
switch s.GetType(index) {
case TypeNil:
return nil, nil
@ -191,7 +178,7 @@ func (s *State) toValueUnsafe(index int) (interface{}, error) {
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.toTableUnsafe(index)
return s.ToTable(index)
default:
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
}
@ -244,53 +231,29 @@ func (s *State) absIndex(index int) int {
return s.GetTop() + index + 1
}
// SetField sets a field in a table at the given index with cached absolute index
// SetField sets a field in a table at the given index
func (s *State) SetField(index int, key string) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.int(absIdx), cstr)
C.lua_setfield(s.L, C.int(index), cstr)
}
// GetField gets a field from a table with cached absolute index
// GetField gets a field from a table
func (s *State) GetField(index int, key string) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_getfield(s.L, C.int(absIdx), cstr)
C.lua_getfield(s.L, C.int(index), cstr)
}
// GetGlobal gets a global variable and pushes it onto the stack
func (s *State) GetGlobal(name string) {
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.get_field_helper(s.L, C.LUA_GLOBALSINDEX, cname)
}
// SetGlobal sets a global variable from the value at the top of the stack
func (s *State) SetGlobal(name string) {
// SetGlobal doesn't need stack space checking as it pops the value
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
@ -299,9 +262,6 @@ func (s *State) SetGlobal(name string) {
// Remove removes element with cached absolute index
func (s *State) Remove(index int) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
C.lua_remove(s.L, C.int(absIdx))
}
@ -310,14 +270,6 @@ func (s *State) DoFile(filename string) error {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
if s.safeStack {
return stackGuardErr(s, func() error {
return s.safeCall(func() C.int {
return C.do_file(s.L, cfilename)
})
})
}
status := C.do_file(s.L, cfilename)
if status != 0 {
return &LuaError{
@ -328,24 +280,28 @@ func (s *State) DoFile(filename string) error {
return nil
}
// SetPackagePath sets the Lua package.path
func (s *State) SetPackagePath(path string) error {
path = filepath.ToSlash(path)
if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil {
return fmt.Errorf("setting package.path: %w", err)
}
return nil
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
cmd := fmt.Sprintf(`package.path = %q`, path)
return s.DoString(cmd)
}
// AddPackagePath adds a path to package.path
func (s *State) AddPackagePath(path string) error {
path = filepath.ToSlash(path)
if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil {
return fmt.Errorf("adding to package.path: %w", err)
}
return nil
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
cmd := fmt.Sprintf(`package.path = package.path .. ";%s"`, path)
return s.DoString(cmd)
}
// Call executes a function on the stack with the given number of arguments and results.
// The function and arguments should already be on the stack in the correct order
// (function first, then args from left to right).
func (s *State) Call(nargs, nresults int) error {
status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
if !s.IsFunction(-nargs - 1) {
return fmt.Errorf("attempt to call a non-function")
}
status := C.protected_call(s.L, C.int(nargs), C.int(nresults), 0)
if status != 0 {
err := &LuaError{
Code: int(status),
@ -356,3 +312,94 @@ func (s *State) Call(nargs, nresults int) error {
}
return nil
}
// LoadString loads but does not execute a string of Lua code.
// The compiled code chunk is left on the stack.
func (s *State) LoadString(str string) error {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
status := C.load_chunk(s.L, cstr)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1)
return err
}
if !s.IsFunction(-1) {
s.Pop(1)
return fmt.Errorf("failed to load function")
}
return nil
}
// ExecuteString executes a string of Lua code and returns the number of results.
// The results are left on the stack.
func (s *State) ExecuteString(str string) (int, error) {
base := s.GetTop()
// First load the string
if err := s.LoadString(str); err != nil {
return 0, err
}
// Now execute it
if err := s.Call(0, C.LUA_MULTRET); err != nil {
return 0, err
}
return s.GetTop() - base, nil
}
// ExecuteStringResult executes a Lua string and returns its first result as a Go value.
// It's a convenience wrapper around ExecuteString for the common case of wanting
// a single return value. The stack is restored to its original state after execution.
func (s *State) ExecuteStringResult(code string) (interface{}, error) {
top := s.GetTop()
defer s.SetTop(top) // Restore stack when we're done
nresults, err := s.ExecuteString(code)
if err != nil {
return nil, fmt.Errorf("execution error: %w", err)
}
if nresults == 0 {
return nil, nil
}
// Get the result
result, err := s.ToValue(-nresults) // Get first result
if err != nil {
return nil, fmt.Errorf("error converting result: %w", err)
}
return result, nil
}
// DoStringResult executes a Lua string and expects a single return value.
// Unlike ExecuteStringResult, this function specifically expects exactly one
// return value and will return an error if the code returns 0 or multiple values.
func (s *State) DoStringResult(code string) (interface{}, error) {
top := s.GetTop()
defer s.SetTop(top) // Restore stack when we're done
nresults, err := s.ExecuteString(code)
if err != nil {
return nil, fmt.Errorf("execution error: %w", err)
}
if nresults != 1 {
return nil, fmt.Errorf("expected 1 return value, got %d", nresults)
}
// Get the result
result, err := s.ToValue(-1)
if err != nil {
return nil, fmt.Errorf("error converting result: %w", err)
}
return result, nil
}

237
wrapper_bench_test.go Normal file
View File

@ -0,0 +1,237 @@
package luajit
import (
"testing"
)
var benchCases = []struct {
name string
code string
}{
{
name: "SimpleAddition",
code: `return 1 + 1`,
},
{
name: "LoopSum",
code: `
local sum = 0
for i = 1, 1000 do
sum = sum + i
end
return sum
`,
},
{
name: "FunctionCall",
code: `
local result = 0
for i = 1, 100 do
result = result + i
end
return result
`,
},
{
name: "TableCreation",
code: `
local t = {}
for i = 1, 100 do
t[i] = i * 2
end
return t[50]
`,
},
{
name: "StringOperations",
code: `
local s = "hello"
for i = 1, 10 do
s = s .. " world"
end
return #s
`,
},
}
func BenchmarkLuaDirectExecution(b *testing.B) {
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
// First verify we can execute the code
if err := L.DoString(bc.code); err != nil {
b.Fatalf("Failed to execute test code: %v", err)
}
L.Pop(1) // Clean up the result
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Execute string and get result
nresults, err := L.ExecuteString(bc.code)
if err != nil {
b.Fatalf("Failed to execute code: %v", err)
}
L.Pop(nresults) // Clean up any results
}
})
}
}
func BenchmarkLuaBytecodeExecution(b *testing.B) {
// First compile all bytecode
bytecodes := make(map[string][]byte)
for _, bc := range benchCases {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
bytecode, err := L.CompileBytecode(bc.code, bc.name)
if err != nil {
L.Close()
b.Fatalf("Error compiling bytecode for %s: %v", bc.name, err)
}
bytecodes[bc.name] = bytecode
L.Close()
}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
bytecode := bytecodes[bc.name]
// First verify we can execute the bytecode
if err := L.LoadBytecode(bytecode, bc.name); err != nil {
b.Fatalf("Failed to execute test bytecode: %v", err)
}
b.ResetTimer()
b.SetBytes(int64(len(bytecode))) // Track bytecode size in benchmarks
for i := 0; i < b.N; i++ {
if err := L.LoadBytecode(bytecode, bc.name); err != nil {
b.Fatalf("Error executing bytecode: %v", err)
}
}
})
}
}
func BenchmarkTableOperations(b *testing.B) {
testData := map[string]interface{}{
"number": 42.0,
"string": "hello",
"bool": true,
"nested": map[string]interface{}{
"value": 123.0,
"array": []float64{1.1, 2.2, 3.3},
},
}
b.Run("PushTable", func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
// First verify we can push the table
if err := L.PushTable(testData); err != nil {
b.Fatalf("Failed to push initial table: %v", err)
}
L.Pop(1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := L.PushTable(testData); err != nil {
b.Fatalf("Failed to push table: %v", err)
}
L.Pop(1)
}
})
b.Run("ToTable", func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
// Keep a table on the stack for repeated conversions
if err := L.PushTable(testData); err != nil {
b.Fatalf("Failed to push initial table: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := L.ToTable(-1); err != nil {
b.Fatalf("Failed to convert table: %v", err)
}
}
})
}
func BenchmarkValueConversion(b *testing.B) {
testValues := []struct {
name string
value interface{}
}{
{"Number", 42.0},
{"String", "hello world"},
{"Boolean", true},
{"Nil", nil},
}
for _, tv := range testValues {
b.Run("Push"+tv.name, func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
// First verify we can push the value
if err := L.PushValue(tv.value); err != nil {
b.Fatalf("Failed to push initial value: %v", err)
}
L.Pop(1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := L.PushValue(tv.value); err != nil {
b.Fatalf("Failed to push value: %v", err)
}
L.Pop(1)
}
})
b.Run("To"+tv.name, func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
// Keep a value on the stack for repeated conversions
if err := L.PushValue(tv.value); err != nil {
b.Fatalf("Failed to push initial value: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := L.ToValue(-1); err != nil {
b.Fatalf("Failed to convert value: %v", err)
}
}
})
}
}

View File

@ -7,25 +7,151 @@ import (
"testing"
)
type stateFactory struct {
name string
new func() *State
}
var factories = []stateFactory{
{"unsafe", New},
{"safe", NewSafe},
}
func TestNew(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
}
func TestLoadString(t *testing.T) {
tests := []struct {
name string
code string
wantErr bool
}{
{
name: "valid function",
code: "function add(a, b) return a + b end",
wantErr: false,
},
{
name: "valid expression",
code: "return 1 + 1",
wantErr: false,
},
{
name: "syntax error",
code: "function bad syntax",
wantErr: true,
},
}
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
err := L.LoadString(tt.code)
if (err != nil) != tt.wantErr {
t.Errorf("LoadString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
// Verify the function is on the stack
if L.GetTop() != 1 {
t.Error("LoadString() did not leave exactly one value on stack")
}
defer L.Close()
})
if !L.IsFunction(-1) {
t.Error("LoadString() did not leave a function on the stack")
}
}
}
}
func TestExecuteString(t *testing.T) {
tests := []struct {
name string
code string
wantResults int
checkResults func(*State) error
wantErr bool
wantStackSize int
}{
{
name: "no results",
code: "local x = 1",
wantResults: 0,
wantErr: false,
},
{
name: "single result",
code: "return 42",
wantResults: 1,
checkResults: func(L *State) error {
if n := L.ToNumber(-1); n != 42 {
return fmt.Errorf("got %v, want 42", n)
}
return nil
},
wantErr: false,
},
{
name: "multiple results",
code: "return 1, 'test', true",
wantResults: 3,
checkResults: func(L *State) error {
if n := L.ToNumber(-3); n != 1 {
return fmt.Errorf("first result: got %v, want 1", n)
}
if s := L.ToString(-2); s != "test" {
return fmt.Errorf("second result: got %v, want 'test'", s)
}
if b := L.ToBoolean(-1); !b {
return fmt.Errorf("third result: got %v, want true", b)
}
return nil
},
wantErr: false,
},
{
name: "syntax error",
code: "this is not valid lua",
wantErr: true,
},
{
name: "runtime error",
code: "error('test error')",
wantErr: true,
},
}
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Record initial stack size
initialStack := L.GetTop()
results, err := L.ExecuteString(tt.code)
if (err != nil) != tt.wantErr {
t.Errorf("ExecuteString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
if results != tt.wantResults {
t.Errorf("ExecuteString() returned %d results, want %d", results, tt.wantResults)
}
if tt.checkResults != nil {
if err := tt.checkResults(L); err != nil {
t.Errorf("Result check failed: %v", err)
}
}
// Verify stack size matches expected results
if got := L.GetTop() - initialStack; got != tt.wantResults {
t.Errorf("Stack size grew by %d, want %d", got, tt.wantResults)
}
}
}
}
@ -41,20 +167,22 @@ func TestDoString(t *testing.T) {
{"runtime error", "error('test error')", true},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
err := L.DoString(tt.code)
if (err != nil) != tt.wantErr {
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
}
})
initialStack := L.GetTop()
err := L.DoString(tt.code)
if (err != nil) != tt.wantErr {
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
}
// Verify stack is unchanged
if finalStack := L.GetTop(); finalStack != initialStack {
t.Errorf("Stack size changed from %d to %d", initialStack, finalStack)
}
}
}
@ -107,95 +235,171 @@ func TestPushAndGetValues(t *testing.T) {
},
}
for _, f := range factories {
for _, v := range values {
t.Run(f.name+"/"+v.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
for _, v := range values {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
v.push(L)
if err := v.check(L); err != nil {
t.Error(err)
}
})
v.push(L)
if err := v.check(L); err != nil {
t.Error(err)
}
}
}
func TestStackManipulation(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Push values
values := []string{"first", "second", "third"}
for _, v := range values {
L.PushString(v)
}
// Push values
values := []string{"first", "second", "third"}
for _, v := range values {
L.PushString(v)
}
// Check size
if top := L.GetTop(); top != len(values) {
t.Errorf("stack size = %d, want %d", top, len(values))
}
// Check size
if top := L.GetTop(); top != len(values) {
t.Errorf("stack size = %d, want %d", top, len(values))
}
// Pop one value
L.Pop(1)
// Pop one value
L.Pop(1)
// Check new top
if str := L.ToString(-1); str != "second" {
t.Errorf("top value = %q, want 'second'", str)
}
// Check new top
if str := L.ToString(-1); str != "second" {
t.Errorf("top value = %q, want 'second'", str)
}
// Check new size
if top := L.GetTop(); top != len(values)-1 {
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
}
})
// Check new size
if top := L.GetTop(); top != len(values)-1 {
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
}
}
func TestGlobals(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Test via Lua
if err := L.DoString(`globalVar = "test"`); err != nil {
t.Fatalf("DoString error: %v", err)
}
// Test via Lua
if err := L.DoString(`globalVar = "test"`); err != nil {
t.Fatalf("DoString error: %v", err)
}
// Get the global
L.GetGlobal("globalVar")
if str := L.ToString(-1); str != "test" {
t.Errorf("global value = %q, want 'test'", str)
}
L.Pop(1)
// Get the global
L.GetGlobal("globalVar")
if str := L.ToString(-1); str != "test" {
t.Errorf("global value = %q, want 'test'", str)
}
L.Pop(1)
// Set and get via API
L.PushNumber(42)
L.SetGlobal("testNum")
// Set and get via API
L.PushNumber(42)
L.SetGlobal("testNum")
L.GetGlobal("testNum")
if num := L.ToNumber(-1); num != 42 {
t.Errorf("global number = %f, want 42", num)
L.GetGlobal("testNum")
if num := L.ToNumber(-1); num != 42 {
t.Errorf("global number = %f, want 42", num)
}
}
func TestCall(t *testing.T) {
tests := []struct {
funcName string // Add explicit function name field
setup string
args []interface{}
nresults int
checkStack func(*State) error
wantErr bool
}{
{
funcName: "add",
setup: "function add(a, b) return a + b end",
args: []interface{}{float64(40), float64(2)},
nresults: 1,
checkStack: func(L *State) error {
if n := L.ToNumber(-1); n != 42 {
return fmt.Errorf("got %v, want 42", n)
}
return nil
},
},
{
funcName: "multi",
setup: "function multi() return 1, 'test', true end",
args: []interface{}{},
nresults: 3,
checkStack: func(L *State) error {
if n := L.ToNumber(-3); n != 1 {
return fmt.Errorf("first result: got %v, want 1", n)
}
if s := L.ToString(-2); s != "test" {
return fmt.Errorf("second result: got %v, want 'test'", s)
}
if b := L.ToBoolean(-1); !b {
return fmt.Errorf("third result: got %v, want true", b)
}
return nil
},
},
{
funcName: "err",
setup: "function err() error('test error') end",
args: []interface{}{},
nresults: 0,
wantErr: true,
},
}
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Setup function
if err := L.DoString(tt.setup); err != nil {
t.Fatalf("Setup failed: %v", err)
}
// Get function
L.GetGlobal(tt.funcName)
if !L.IsFunction(-1) {
t.Fatal("Failed to get function")
}
// Push arguments
for _, arg := range tt.args {
if err := L.PushValue(arg); err != nil {
t.Fatalf("Failed to push argument: %v", err)
}
})
}
// Call function
err := L.Call(len(tt.args), tt.nresults)
if (err != nil) != tt.wantErr {
t.Errorf("Call() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil && tt.checkStack != nil {
if err := tt.checkStack(L); err != nil {
t.Errorf("Stack check failed: %v", err)
}
}
}
}
func TestDoFile(t *testing.T) {
L := NewSafe()
L := New()
defer L.Close()
// Create test file
@ -223,7 +427,7 @@ func TestDoFile(t *testing.T) {
}
func TestRequireAndPackagePath(t *testing.T) {
L := NewSafe()
L := New()
defer L.Close()
tmpDir := t.TempDir()
@ -260,7 +464,7 @@ func TestRequireAndPackagePath(t *testing.T) {
}
func TestSetPackagePath(t *testing.T) {
L := NewSafe()
L := New()
defer L.Close()
customPath := "./custom/?.lua"
@ -273,4 +477,217 @@ func TestSetPackagePath(t *testing.T) {
if path := L.ToString(-1); path != customPath {
t.Errorf("Expected package.path=%q, got %q", customPath, path)
}
// Test that the old path is completely replaced
initialPath := L.ToString(-1)
anotherPath := "./another/?.lua"
if err := L.SetPackagePath(anotherPath); err != nil {
t.Fatalf("Second SetPackagePath failed: %v", err)
}
L.GetGlobal("package")
L.GetField(-1, "path")
if path := L.ToString(-1); path != anotherPath {
t.Errorf("Expected package.path=%q, got %q", anotherPath, path)
}
if path := L.ToString(-1); path == initialPath {
t.Error("SetPackagePath did not replace the old path")
}
}
func TestStackDebug(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
t.Log("Testing LoadString:")
initialTop := L.GetTop()
t.Logf("Initial stack size: %d", initialTop)
err := L.LoadString("return 42")
if err != nil {
t.Errorf("LoadString failed: %v", err)
}
afterLoad := L.GetTop()
t.Logf("Stack size after load: %d", afterLoad)
t.Logf("Type of top element: %s", L.GetType(-1))
if L.IsFunction(-1) {
t.Log("Top element is a function")
} else {
t.Log("Top element is NOT a function")
}
// Clean up after LoadString test
L.SetTop(0)
t.Log("\nTesting ExecuteString:")
if err := L.DoString("function test() return 1, 'hello', true end"); err != nil {
t.Errorf("DoString failed: %v", err)
}
beforeExec := L.GetTop()
t.Logf("Stack size before execute: %d", beforeExec)
nresults, err := L.ExecuteString("return test()")
if err != nil {
t.Errorf("ExecuteString failed: %v", err)
}
afterExec := L.GetTop()
t.Logf("Stack size after execute: %d", afterExec)
t.Logf("Reported number of results: %d", nresults)
// Print each stack element
for i := 1; i <= afterExec; i++ {
t.Logf("Stack[-%d] type: %s", i, L.GetType(-i))
}
if afterExec != nresults {
t.Errorf("Stack size (%d) doesn't match number of results (%d)", afterExec, nresults)
}
}
func TestTemplateRendering(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
// Create a simple render.template function
renderFunc := func(s *State) int {
// Template will be at index 1, data at index 2
data, err := s.ToTable(2)
if err != nil {
s.PushString(fmt.Sprintf("failed to get data table: %v", err))
return -1
}
// Push data back as global for template access
if err := s.PushTable(data); err != nil {
s.PushString(fmt.Sprintf("failed to push data table: %v", err))
return -1
}
s.SetGlobal("data")
// Template processing code
luaCode := `
local result = {}
if data.user.logged_in then
table.insert(result, '<div class="profile">')
table.insert(result, string.format(' <h1>Welcome, %s!</h1>', tostring(data.user.name)))
table.insert(result, ' <div class="user-info">')
table.insert(result, string.format(' <p>Email: %s</p>', tostring(data.user.email)))
table.insert(result, string.format(' <p>Member since: %s</p>', tostring(data.user.joined_date)))
table.insert(result, ' </div>')
if data.user.is_admin then
table.insert(result, ' <div class="admin-panel">')
table.insert(result, ' <h2>Admin Controls</h2>')
table.insert(result, ' <!-- admin content -->')
table.insert(result, ' </div>')
end
table.insert(result, '</div>')
else
table.insert(result, '<div class="profile">')
table.insert(result, ' <h1>Please log in to view your profile</h1>')
table.insert(result, '</div>')
end
return table.concat(result, '\n')`
result, err := s.DoStringResult(luaCode)
if err != nil {
s.PushString(fmt.Sprintf("template execution failed: %v", err))
return -1
}
// Push the string result
if str, ok := result.(string); ok {
s.PushString(str)
return 1
}
s.PushString(fmt.Sprintf("expected string result, got %T", result))
return -1
}
// Create render table and add template function
L.NewTable()
if err := L.PushGoFunction(renderFunc); err != nil {
t.Fatalf("Failed to create render function: %v", err)
}
L.SetField(-2, "template")
L.SetGlobal("render")
// Test with logged in admin user
testCode := `
local data = {
user = {
logged_in = true,
name = "John Doe",
email = "john@example.com",
joined_date = "2024-02-09",
is_admin = true
}
}
return render.template("test.html", data)
`
result, err := L.DoStringResult(testCode)
if err != nil {
t.Fatalf("Failed to execute test: %v", err)
}
str, ok := result.(string)
if !ok {
t.Fatalf("Expected string result, got %T", result)
}
expectedResult := `<div class="profile">
<h1>Welcome, John Doe!</h1>
<div class="user-info">
<p>Email: john@example.com</p>
<p>Member since: 2024-02-09</p>
</div>
<div class="admin-panel">
<h2>Admin Controls</h2>
<!-- admin content -->
</div>
</div>`
if str != expectedResult {
t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str)
}
// Test with logged out user
testCode = `
local data = {
user = {
logged_in = false
}
}
return render.template("test.html", data)
`
result, err = L.DoStringResult(testCode)
if err != nil {
t.Fatalf("Failed to execute logged out test: %v", err)
}
str, ok = result.(string)
if !ok {
t.Fatalf("Expected string result, got %T", result)
}
expectedResult = `<div class="profile">
<h1>Please log in to view your profile</h1>
</div>`
if str != expectedResult {
t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str)
}
}