BIG changes; no "safe" mode, function updates, etc
This commit is contained in:
parent
7c79616cac
commit
4dc266201f
|
@ -1,148 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
type benchCase struct {
|
||||
name string
|
||||
code string
|
||||
}
|
||||
|
||||
var cases = []benchCase{
|
||||
{
|
||||
name: "Simple Addition",
|
||||
code: `return 1 + 1`,
|
||||
},
|
||||
{
|
||||
name: "Loop Sum",
|
||||
code: `
|
||||
local sum = 0
|
||||
for i = 1, 1000 do
|
||||
sum = sum + i
|
||||
end
|
||||
return sum
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Function Call",
|
||||
code: `
|
||||
local result = 0
|
||||
for i = 1, 100 do
|
||||
result = result + i
|
||||
end
|
||||
return result
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Table Creation",
|
||||
code: `
|
||||
local t = {}
|
||||
for i = 1, 100 do
|
||||
t[i] = i * 2
|
||||
end
|
||||
return t[50]
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "String Operations",
|
||||
code: `
|
||||
local s = "hello"
|
||||
for i = 1, 10 do
|
||||
s = s .. " world"
|
||||
end
|
||||
return #s
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
func runBenchmark(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) {
|
||||
start := time.Now()
|
||||
deadline := start.Add(duration)
|
||||
var ops int64
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
if err := L.DoString(code); err != nil {
|
||||
fmt.Printf("Error executing code: %v\n", err)
|
||||
return 0, 0
|
||||
}
|
||||
L.Pop(1)
|
||||
ops++
|
||||
}
|
||||
|
||||
return time.Since(start), ops
|
||||
}
|
||||
|
||||
func runBytecodeTest(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) {
|
||||
// First compile the bytecode
|
||||
bytecode, err := L.CompileBytecode(code, "bench")
|
||||
if err != nil {
|
||||
fmt.Printf("Error compiling bytecode: %v\n", err)
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
deadline := start.Add(duration)
|
||||
var ops int64
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
if err := L.LoadBytecode(bytecode, "bench"); err != nil {
|
||||
fmt.Printf("Error executing bytecode: %v\n", err)
|
||||
return 0, 0
|
||||
}
|
||||
ops++
|
||||
}
|
||||
|
||||
return time.Since(start), ops
|
||||
}
|
||||
|
||||
func benchmarkCase(newState func() *luajit.State, bc benchCase) {
|
||||
fmt.Printf("\n%s:\n", bc.name)
|
||||
|
||||
// Direct execution benchmark
|
||||
L := newState()
|
||||
if L == nil {
|
||||
fmt.Printf(" Failed to create Lua state\n")
|
||||
return
|
||||
}
|
||||
execTime, ops := runBenchmark(L, bc.code, 2*time.Second)
|
||||
L.Close()
|
||||
if ops > 0 {
|
||||
opsPerSec := float64(ops) / execTime.Seconds()
|
||||
fmt.Printf(" Direct: %.0f ops/sec\n", opsPerSec)
|
||||
}
|
||||
|
||||
// Bytecode execution benchmark
|
||||
L = newState()
|
||||
if L == nil {
|
||||
fmt.Printf(" Failed to create Lua state\n")
|
||||
return
|
||||
}
|
||||
execTime, ops = runBytecodeTest(L, bc.code, 2*time.Second)
|
||||
L.Close()
|
||||
if ops > 0 {
|
||||
opsPerSec := float64(ops) / execTime.Seconds()
|
||||
fmt.Printf(" Bytecode: %.0f ops/sec\n", opsPerSec)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
modes := []struct {
|
||||
name string
|
||||
newState func() *luajit.State
|
||||
}{
|
||||
{"Safe", luajit.NewSafe},
|
||||
{"Unsafe", luajit.New},
|
||||
}
|
||||
|
||||
for _, mode := range modes {
|
||||
fmt.Printf("\n=== %s Mode ===\n", mode.name)
|
||||
|
||||
for _, c := range cases {
|
||||
benchmarkCase(mode.newState, c)
|
||||
}
|
||||
}
|
||||
}
|
26
bytecode.go
26
bytecode.go
|
@ -51,7 +51,8 @@ import (
|
|||
"unsafe"
|
||||
)
|
||||
|
||||
func (s *State) compileBytecodeUnsafe(code string, name string) ([]byte, error) {
|
||||
// CompileBytecode compiles a Lua chunk to bytecode without executing it
|
||||
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
|
||||
// First load the string but don't execute it
|
||||
ccode := C.CString(code)
|
||||
defer C.free(unsafe.Pointer(ccode))
|
||||
|
@ -94,7 +95,8 @@ func (s *State) compileBytecodeUnsafe(code string, name string) ([]byte, error)
|
|||
return bytecode, nil
|
||||
}
|
||||
|
||||
func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error {
|
||||
// LoadBytecode loads precompiled bytecode and executes it
|
||||
func (s *State) LoadBytecode(bytecode []byte, name string) error {
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
|
@ -125,26 +127,6 @@ func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// CompileBytecode compiles a Lua chunk to bytecode without executing it
|
||||
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
|
||||
if s.safeStack {
|
||||
return stackGuardValue[[]byte](s, func() ([]byte, error) {
|
||||
return s.compileBytecodeUnsafe(code, name)
|
||||
})
|
||||
}
|
||||
return s.compileBytecodeUnsafe(code, name)
|
||||
}
|
||||
|
||||
// LoadBytecode loads precompiled bytecode and executes it
|
||||
func (s *State) LoadBytecode(bytecode []byte, name string) error {
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
return s.loadBytecodeUnsafe(bytecode, name)
|
||||
})
|
||||
}
|
||||
return s.loadBytecodeUnsafe(bytecode, name)
|
||||
}
|
||||
|
||||
// Helper function to compile and immediately load/execute bytecode
|
||||
func (s *State) CompileAndLoad(code string, name string) error {
|
||||
bytecode, err := s.CompileBytecode(code, name)
|
||||
|
|
|
@ -28,10 +28,8 @@ func TestBytecodeCompilation(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, tt := range tests {
|
||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
@ -48,15 +46,11 @@ func TestBytecodeCompilation(t *testing.T) {
|
|||
t.Error("CompileBytecode() returned empty bytecode")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytecodeExecution(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
@ -85,14 +79,10 @@ func TestBytecodeExecution(t *testing.T) {
|
|||
if result := L.ToNumber(-1); result != 42 {
|
||||
t.Errorf("got result = %v, want 42", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidBytecode(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
@ -103,8 +93,6 @@ func TestInvalidBytecode(t *testing.T) {
|
|||
if err := L.LoadBytecode(invalidBytecode, "test"); err == nil {
|
||||
t.Error("LoadBytecode() expected error with invalid bytecode")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytecodeRoundTrip(t *testing.T) {
|
||||
|
@ -140,11 +128,9 @@ func TestBytecodeRoundTrip(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, tt := range tests {
|
||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||
// First state for compilation
|
||||
L1 := f.new()
|
||||
L1 := New()
|
||||
if L1 == nil {
|
||||
t.Fatal("Failed to create first Lua state")
|
||||
}
|
||||
|
@ -157,7 +143,7 @@ func TestBytecodeRoundTrip(t *testing.T) {
|
|||
}
|
||||
|
||||
// Second state for execution
|
||||
L2 := f.new()
|
||||
L2 := New()
|
||||
if L2 == nil {
|
||||
t.Fatal("Failed to create second Lua state")
|
||||
}
|
||||
|
@ -172,7 +158,5 @@ func TestBytecodeRoundTrip(t *testing.T) {
|
|||
if err := tt.check(L2); err != nil {
|
||||
t.Errorf("check failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ var (
|
|||
|
||||
//export goFunctionWrapper
|
||||
func goFunctionWrapper(L *C.lua_State) C.int {
|
||||
state := &State{L: L, safeStack: true}
|
||||
state := &State{L: L}
|
||||
|
||||
// Get upvalue using standard Lua 5.1 macro
|
||||
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
|
||||
|
@ -54,7 +54,6 @@ func goFunctionWrapper(L *C.lua_State) C.int {
|
|||
}
|
||||
|
||||
func (s *State) PushGoFunction(fn GoFunction) error {
|
||||
// Push lightuserdata as upvalue and create closure
|
||||
ptr := C.malloc(1)
|
||||
if ptr == nil {
|
||||
return fmt.Errorf("failed to allocate memory for function pointer")
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
package luajit
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGoFunctions(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
@ -78,12 +78,10 @@ func TestGoFunctions(t *testing.T) {
|
|||
if err := L.DoString("add(1, 2)"); err == nil {
|
||||
t.Error("Expected error calling unregistered function")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStackSafety(t *testing.T) {
|
||||
L := NewSafe()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
|
101
table.go
101
table.go
|
@ -22,57 +22,20 @@ type TableValue interface {
|
|||
|
||||
func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) }
|
||||
|
||||
// PushTable pushes a Go map onto the Lua stack as a table
|
||||
func (s *State) PushTable(table map[string]interface{}) error {
|
||||
s.NewTable()
|
||||
for k, v := range table {
|
||||
if err := s.PushValue(v); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetField(-2, k)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToTable converts a Lua table to a Go map
|
||||
func (s *State) ToTable(index int) (map[string]interface{}, error) {
|
||||
if s.safeStack {
|
||||
return stackGuardValue[map[string]interface{}](s, func() (map[string]interface{}, error) {
|
||||
if !s.IsTable(index) {
|
||||
return nil, fmt.Errorf("not a table at index %d", index)
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
})
|
||||
}
|
||||
if !s.IsTable(index) {
|
||||
return nil, fmt.Errorf("not a table at index %d", index)
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
}
|
||||
|
||||
func (s *State) pushTableSafe(table map[string]interface{}) error {
|
||||
size := 2
|
||||
if err := s.checkStack(size); err != nil {
|
||||
return fmt.Errorf("insufficient stack space: %w", err)
|
||||
}
|
||||
|
||||
s.NewTable()
|
||||
for k, v := range table {
|
||||
if err := s.pushValueSafe(v); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetField(-2, k)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) pushTableUnsafe(table map[string]interface{}) error {
|
||||
s.NewTable()
|
||||
for k, v := range table {
|
||||
if err := s.pushValueUnsafe(v); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetField(-2, k)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) toTableSafe(index int) (map[string]interface{}, error) {
|
||||
if err := s.checkStack(2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
}
|
||||
|
||||
func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
|
||||
absIdx := s.absIndex(index)
|
||||
table := make(map[string]interface{})
|
||||
|
||||
|
@ -111,7 +74,7 @@ func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
|
|||
key = fmt.Sprintf("%g", s.ToNumber(-2))
|
||||
}
|
||||
|
||||
value, err := s.toValueUnsafe(-1)
|
||||
value, err := s.ToValue(-1)
|
||||
if err != nil {
|
||||
s.Pop(1)
|
||||
return nil, err
|
||||
|
@ -133,45 +96,15 @@ func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
|
|||
|
||||
// NewTable creates a new table and pushes it onto the stack
|
||||
func (s *State) NewTable() {
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
// Since we can't return an error, we'll push nil instead
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
C.lua_createtable(s.L, 0, 0)
|
||||
}
|
||||
|
||||
// SetTable sets a table field with cached absolute index
|
||||
// SetTable sets a table field
|
||||
func (s *State) SetTable(index int) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
C.lua_settable(s.L, C.int(absIdx))
|
||||
C.lua_settable(s.L, C.int(index))
|
||||
}
|
||||
|
||||
// GetTable gets a table field with cached absolute index
|
||||
// GetTable gets a table field
|
||||
func (s *State) GetTable(index int) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
C.lua_gettable(s.L, C.int(absIdx))
|
||||
}
|
||||
|
||||
// PushTable pushes a Go map onto the Lua stack as a table with stack checking
|
||||
func (s *State) PushTable(table map[string]interface{}) error {
|
||||
if s.safeStack {
|
||||
return s.pushTableSafe(table)
|
||||
}
|
||||
return s.pushTableUnsafe(table)
|
||||
C.lua_gettable(s.L, C.int(index))
|
||||
}
|
||||
|
|
|
@ -34,10 +34,8 @@ func TestTableOperations(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, tt := range tests {
|
||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
@ -55,8 +53,6 @@ func TestTableOperations(t *testing.T) {
|
|||
if !tablesEqual(got, tt.data) {
|
||||
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
317
wrapper.go
317
wrapper.go
|
@ -10,49 +10,82 @@ package luajit
|
|||
#include <lauxlib.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
static int do_string(lua_State *L, const char *s) {
|
||||
int status = luaL_loadstring(L, s);
|
||||
if (status) return status;
|
||||
return lua_pcall(L, 0, LUA_MULTRET, 0);
|
||||
// Simple wrapper around luaL_loadstring
|
||||
static int load_chunk(lua_State *L, const char *s) {
|
||||
return luaL_loadstring(L, s);
|
||||
}
|
||||
|
||||
// Direct wrapper around lua_pcall
|
||||
static int protected_call(lua_State *L, int nargs, int nresults, int errfunc) {
|
||||
return lua_pcall(L, nargs, nresults, errfunc);
|
||||
}
|
||||
|
||||
// Combined load and execute with no results
|
||||
static int do_string(lua_State *L, const char *s) {
|
||||
return luaL_dostring(L, s);
|
||||
}
|
||||
|
||||
// Combined load and execute file
|
||||
static int do_file(lua_State *L, const char *filename) {
|
||||
int status = luaL_loadfile(L, filename);
|
||||
if (status) return status;
|
||||
return lua_pcall(L, 0, LUA_MULTRET, 0);
|
||||
return luaL_dofile(L, filename);
|
||||
}
|
||||
|
||||
// Execute string with multiple returns
|
||||
static int execute_string(lua_State *L, const char *s) {
|
||||
int base = lua_gettop(L); // Save stack position
|
||||
int status = luaL_loadstring(L, s);
|
||||
if (status) return -status; // Return negative status for load errors
|
||||
|
||||
status = lua_pcall(L, 0, LUA_MULTRET, 0);
|
||||
if (status) return -status; // Return negative status for runtime errors
|
||||
|
||||
return lua_gettop(L) - base; // Return number of results
|
||||
}
|
||||
|
||||
// Get absolute stack index (converts negative indices)
|
||||
static int get_abs_index(lua_State *L, int idx) {
|
||||
if (idx > 0 || idx <= LUA_REGISTRYINDEX) return idx;
|
||||
return lua_gettop(L) + idx + 1;
|
||||
}
|
||||
|
||||
// Stack manipulation helpers
|
||||
static int check_stack(lua_State *L, int n) {
|
||||
return lua_checkstack(L, n);
|
||||
}
|
||||
|
||||
static void remove_stack(lua_State *L, int idx) {
|
||||
lua_remove(L, idx);
|
||||
}
|
||||
|
||||
static int get_field_helper(lua_State *L, int idx, const char *k) {
|
||||
lua_getfield(L, idx, k);
|
||||
return lua_type(L, -1);
|
||||
}
|
||||
|
||||
static void set_field_helper(lua_State *L, int idx, const char *k) {
|
||||
lua_setfield(L, idx, k);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// State represents a Lua state with configurable stack safety
|
||||
// State represents a Lua state
|
||||
type State struct {
|
||||
L *C.lua_State
|
||||
safeStack bool
|
||||
}
|
||||
|
||||
// NewSafe creates a new Lua state with full stack safety guarantees
|
||||
func NewSafe() *State {
|
||||
L := C.luaL_newstate()
|
||||
if L == nil {
|
||||
return nil
|
||||
}
|
||||
C.luaL_openlibs(L)
|
||||
return &State{L: L, safeStack: true}
|
||||
}
|
||||
|
||||
// New creates a new Lua state with minimal stack checking
|
||||
// New creates a new Lua state
|
||||
func New() *State {
|
||||
L := C.luaL_newstate()
|
||||
if L == nil {
|
||||
return nil
|
||||
}
|
||||
C.luaL_openlibs(L)
|
||||
return &State{L: L, safeStack: false}
|
||||
return &State{L: L}
|
||||
}
|
||||
|
||||
// Close closes the Lua state
|
||||
|
@ -63,65 +96,28 @@ func (s *State) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
// DoString executes a Lua string with appropriate stack management
|
||||
// DoString executes a Lua string.
|
||||
func (s *State) DoString(str string) error {
|
||||
cstr := C.CString(str)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
// Save initial stack size
|
||||
top := s.GetTop()
|
||||
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
// Save the current stack size
|
||||
initialTop := s.GetTop()
|
||||
|
||||
// Execute the string
|
||||
status := C.do_string(s.L, cstr)
|
||||
if status != 0 {
|
||||
// In case of error, get error message from stack
|
||||
errMsg := s.ToString(-1)
|
||||
s.SetTop(initialTop) // Restore stack
|
||||
return &LuaError{
|
||||
Code: int(status),
|
||||
Message: errMsg,
|
||||
}
|
||||
// Load the string
|
||||
if err := s.LoadString(str); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return values are now on the stack above initialTop
|
||||
// We don't pop them as they may be needed by the caller
|
||||
return nil
|
||||
})
|
||||
// Execute and check for errors
|
||||
if err := s.Call(0, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status := C.do_string(s.L, cstr)
|
||||
if status != 0 {
|
||||
return &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
}
|
||||
// Restore stack to initial size to clean up any leftovers
|
||||
s.SetTop(top)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushValue pushes a Go value onto the stack
|
||||
func (s *State) PushValue(v interface{}) error {
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
return fmt.Errorf("pushing value: %w", err)
|
||||
}
|
||||
return s.pushValueUnsafe(v)
|
||||
})
|
||||
}
|
||||
return s.pushValueUnsafe(v)
|
||||
}
|
||||
|
||||
func (s *State) pushValueSafe(v interface{}) error {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
return fmt.Errorf("pushing value: %w", err)
|
||||
}
|
||||
return s.pushValueUnsafe(v)
|
||||
}
|
||||
|
||||
func (s *State) pushValueUnsafe(v interface{}) error {
|
||||
switch v := v.(type) {
|
||||
case nil:
|
||||
s.PushNil()
|
||||
|
@ -144,7 +140,7 @@ func (s *State) pushValueUnsafe(v interface{}) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
return s.pushTableUnsafe(v)
|
||||
return s.PushTable(v)
|
||||
case []float64:
|
||||
s.NewTable()
|
||||
for i, elem := range v {
|
||||
|
@ -156,7 +152,7 @@ func (s *State) pushValueUnsafe(v interface{}) error {
|
|||
s.NewTable()
|
||||
for i, elem := range v {
|
||||
s.PushNumber(float64(i + 1))
|
||||
if err := s.pushValueUnsafe(elem); err != nil {
|
||||
if err := s.PushValue(elem); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetTable(-3)
|
||||
|
@ -169,15 +165,6 @@ func (s *State) pushValueUnsafe(v interface{}) error {
|
|||
|
||||
// ToValue converts a Lua value to a Go value
|
||||
func (s *State) ToValue(index int) (interface{}, error) {
|
||||
if s.safeStack {
|
||||
return stackGuardValue[interface{}](s, func() (interface{}, error) {
|
||||
return s.toValueUnsafe(index)
|
||||
})
|
||||
}
|
||||
return s.toValueUnsafe(index)
|
||||
}
|
||||
|
||||
func (s *State) toValueUnsafe(index int) (interface{}, error) {
|
||||
switch s.GetType(index) {
|
||||
case TypeNil:
|
||||
return nil, nil
|
||||
|
@ -191,7 +178,7 @@ func (s *State) toValueUnsafe(index int) (interface{}, error) {
|
|||
if !s.IsTable(index) {
|
||||
return nil, fmt.Errorf("not a table at index %d", index)
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
return s.ToTable(index)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
|
||||
}
|
||||
|
@ -244,53 +231,29 @@ func (s *State) absIndex(index int) int {
|
|||
return s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
// SetField sets a field in a table at the given index with cached absolute index
|
||||
// SetField sets a field in a table at the given index
|
||||
func (s *State) SetField(index int, key string) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
cstr := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_setfield(s.L, C.int(absIdx), cstr)
|
||||
C.lua_setfield(s.L, C.int(index), cstr)
|
||||
}
|
||||
|
||||
// GetField gets a field from a table with cached absolute index
|
||||
// GetField gets a field from a table
|
||||
func (s *State) GetField(index int, key string) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cstr := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_getfield(s.L, C.int(absIdx), cstr)
|
||||
C.lua_getfield(s.L, C.int(index), cstr)
|
||||
}
|
||||
|
||||
// GetGlobal gets a global variable and pushes it onto the stack
|
||||
func (s *State) GetGlobal(name string) {
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
cstr := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
C.get_field_helper(s.L, C.LUA_GLOBALSINDEX, cname)
|
||||
}
|
||||
|
||||
// SetGlobal sets a global variable from the value at the top of the stack
|
||||
func (s *State) SetGlobal(name string) {
|
||||
// SetGlobal doesn't need stack space checking as it pops the value
|
||||
cstr := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
||||
|
@ -299,9 +262,6 @@ func (s *State) SetGlobal(name string) {
|
|||
// Remove removes element with cached absolute index
|
||||
func (s *State) Remove(index int) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
C.lua_remove(s.L, C.int(absIdx))
|
||||
}
|
||||
|
||||
|
@ -310,14 +270,6 @@ func (s *State) DoFile(filename string) error {
|
|||
cfilename := C.CString(filename)
|
||||
defer C.free(unsafe.Pointer(cfilename))
|
||||
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
return s.safeCall(func() C.int {
|
||||
return C.do_file(s.L, cfilename)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
status := C.do_file(s.L, cfilename)
|
||||
if status != 0 {
|
||||
return &LuaError{
|
||||
|
@ -328,24 +280,28 @@ func (s *State) DoFile(filename string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SetPackagePath sets the Lua package.path
|
||||
func (s *State) SetPackagePath(path string) error {
|
||||
path = filepath.ToSlash(path)
|
||||
if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil {
|
||||
return fmt.Errorf("setting package.path: %w", err)
|
||||
}
|
||||
return nil
|
||||
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
|
||||
cmd := fmt.Sprintf(`package.path = %q`, path)
|
||||
return s.DoString(cmd)
|
||||
}
|
||||
|
||||
// AddPackagePath adds a path to package.path
|
||||
func (s *State) AddPackagePath(path string) error {
|
||||
path = filepath.ToSlash(path)
|
||||
if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil {
|
||||
return fmt.Errorf("adding to package.path: %w", err)
|
||||
}
|
||||
return nil
|
||||
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
|
||||
cmd := fmt.Sprintf(`package.path = package.path .. ";%s"`, path)
|
||||
return s.DoString(cmd)
|
||||
}
|
||||
|
||||
// Call executes a function on the stack with the given number of arguments and results.
|
||||
// The function and arguments should already be on the stack in the correct order
|
||||
// (function first, then args from left to right).
|
||||
func (s *State) Call(nargs, nresults int) error {
|
||||
status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
|
||||
if !s.IsFunction(-nargs - 1) {
|
||||
return fmt.Errorf("attempt to call a non-function")
|
||||
}
|
||||
status := C.protected_call(s.L, C.int(nargs), C.int(nresults), 0)
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
|
@ -356,3 +312,94 @@ func (s *State) Call(nargs, nresults int) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadString loads but does not execute a string of Lua code.
|
||||
// The compiled code chunk is left on the stack.
|
||||
func (s *State) LoadString(str string) error {
|
||||
cstr := C.CString(str)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
|
||||
status := C.load_chunk(s.L, cstr)
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1)
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.IsFunction(-1) {
|
||||
s.Pop(1)
|
||||
return fmt.Errorf("failed to load function")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteString executes a string of Lua code and returns the number of results.
|
||||
// The results are left on the stack.
|
||||
func (s *State) ExecuteString(str string) (int, error) {
|
||||
base := s.GetTop()
|
||||
|
||||
// First load the string
|
||||
if err := s.LoadString(str); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Now execute it
|
||||
if err := s.Call(0, C.LUA_MULTRET); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return s.GetTop() - base, nil
|
||||
}
|
||||
|
||||
// ExecuteStringResult executes a Lua string and returns its first result as a Go value.
|
||||
// It's a convenience wrapper around ExecuteString for the common case of wanting
|
||||
// a single return value. The stack is restored to its original state after execution.
|
||||
func (s *State) ExecuteStringResult(code string) (interface{}, error) {
|
||||
top := s.GetTop()
|
||||
defer s.SetTop(top) // Restore stack when we're done
|
||||
|
||||
nresults, err := s.ExecuteString(code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execution error: %w", err)
|
||||
}
|
||||
|
||||
if nresults == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Get the result
|
||||
result, err := s.ToValue(-nresults) // Get first result
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting result: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DoStringResult executes a Lua string and expects a single return value.
|
||||
// Unlike ExecuteStringResult, this function specifically expects exactly one
|
||||
// return value and will return an error if the code returns 0 or multiple values.
|
||||
func (s *State) DoStringResult(code string) (interface{}, error) {
|
||||
top := s.GetTop()
|
||||
defer s.SetTop(top) // Restore stack when we're done
|
||||
|
||||
nresults, err := s.ExecuteString(code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execution error: %w", err)
|
||||
}
|
||||
|
||||
if nresults != 1 {
|
||||
return nil, fmt.Errorf("expected 1 return value, got %d", nresults)
|
||||
}
|
||||
|
||||
// Get the result
|
||||
result, err := s.ToValue(-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting result: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
237
wrapper_bench_test.go
Normal file
237
wrapper_bench_test.go
Normal file
|
@ -0,0 +1,237 @@
|
|||
package luajit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
var benchCases = []struct {
|
||||
name string
|
||||
code string
|
||||
}{
|
||||
{
|
||||
name: "SimpleAddition",
|
||||
code: `return 1 + 1`,
|
||||
},
|
||||
{
|
||||
name: "LoopSum",
|
||||
code: `
|
||||
local sum = 0
|
||||
for i = 1, 1000 do
|
||||
sum = sum + i
|
||||
end
|
||||
return sum
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "FunctionCall",
|
||||
code: `
|
||||
local result = 0
|
||||
for i = 1, 100 do
|
||||
result = result + i
|
||||
end
|
||||
return result
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "TableCreation",
|
||||
code: `
|
||||
local t = {}
|
||||
for i = 1, 100 do
|
||||
t[i] = i * 2
|
||||
end
|
||||
return t[50]
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "StringOperations",
|
||||
code: `
|
||||
local s = "hello"
|
||||
for i = 1, 10 do
|
||||
s = s .. " world"
|
||||
end
|
||||
return #s
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
func BenchmarkLuaDirectExecution(b *testing.B) {
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
L := New()
|
||||
if L == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// First verify we can execute the code
|
||||
if err := L.DoString(bc.code); err != nil {
|
||||
b.Fatalf("Failed to execute test code: %v", err)
|
||||
}
|
||||
L.Pop(1) // Clean up the result
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Execute string and get result
|
||||
nresults, err := L.ExecuteString(bc.code)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to execute code: %v", err)
|
||||
}
|
||||
L.Pop(nresults) // Clean up any results
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLuaBytecodeExecution(b *testing.B) {
|
||||
// First compile all bytecode
|
||||
bytecodes := make(map[string][]byte)
|
||||
for _, bc := range benchCases {
|
||||
L := New()
|
||||
if L == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
bytecode, err := L.CompileBytecode(bc.code, bc.name)
|
||||
if err != nil {
|
||||
L.Close()
|
||||
b.Fatalf("Error compiling bytecode for %s: %v", bc.name, err)
|
||||
}
|
||||
bytecodes[bc.name] = bytecode
|
||||
L.Close()
|
||||
}
|
||||
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
L := New()
|
||||
if L == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
bytecode := bytecodes[bc.name]
|
||||
|
||||
// First verify we can execute the bytecode
|
||||
if err := L.LoadBytecode(bytecode, bc.name); err != nil {
|
||||
b.Fatalf("Failed to execute test bytecode: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.SetBytes(int64(len(bytecode))) // Track bytecode size in benchmarks
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := L.LoadBytecode(bytecode, bc.name); err != nil {
|
||||
b.Fatalf("Error executing bytecode: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTableOperations(b *testing.B) {
|
||||
testData := map[string]interface{}{
|
||||
"number": 42.0,
|
||||
"string": "hello",
|
||||
"bool": true,
|
||||
"nested": map[string]interface{}{
|
||||
"value": 123.0,
|
||||
"array": []float64{1.1, 2.2, 3.3},
|
||||
},
|
||||
}
|
||||
|
||||
b.Run("PushTable", func(b *testing.B) {
|
||||
L := New()
|
||||
if L == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// First verify we can push the table
|
||||
if err := L.PushTable(testData); err != nil {
|
||||
b.Fatalf("Failed to push initial table: %v", err)
|
||||
}
|
||||
L.Pop(1)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := L.PushTable(testData); err != nil {
|
||||
b.Fatalf("Failed to push table: %v", err)
|
||||
}
|
||||
L.Pop(1)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ToTable", func(b *testing.B) {
|
||||
L := New()
|
||||
if L == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Keep a table on the stack for repeated conversions
|
||||
if err := L.PushTable(testData); err != nil {
|
||||
b.Fatalf("Failed to push initial table: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := L.ToTable(-1); err != nil {
|
||||
b.Fatalf("Failed to convert table: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkValueConversion(b *testing.B) {
|
||||
testValues := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
}{
|
||||
{"Number", 42.0},
|
||||
{"String", "hello world"},
|
||||
{"Boolean", true},
|
||||
{"Nil", nil},
|
||||
}
|
||||
|
||||
for _, tv := range testValues {
|
||||
b.Run("Push"+tv.name, func(b *testing.B) {
|
||||
L := New()
|
||||
if L == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// First verify we can push the value
|
||||
if err := L.PushValue(tv.value); err != nil {
|
||||
b.Fatalf("Failed to push initial value: %v", err)
|
||||
}
|
||||
L.Pop(1)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := L.PushValue(tv.value); err != nil {
|
||||
b.Fatalf("Failed to push value: %v", err)
|
||||
}
|
||||
L.Pop(1)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("To"+tv.name, func(b *testing.B) {
|
||||
L := New()
|
||||
if L == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Keep a value on the stack for repeated conversions
|
||||
if err := L.PushValue(tv.value); err != nil {
|
||||
b.Fatalf("Failed to push initial value: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := L.ToValue(-1); err != nil {
|
||||
b.Fatalf("Failed to convert value: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
487
wrapper_test.go
487
wrapper_test.go
|
@ -7,25 +7,151 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
type stateFactory struct {
|
||||
name string
|
||||
new func() *State
|
||||
}
|
||||
|
||||
var factories = []stateFactory{
|
||||
{"unsafe", New},
|
||||
{"safe", NewSafe},
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid function",
|
||||
code: "function add(a, b) return a + b end",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid expression",
|
||||
code: "return 1 + 1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "syntax error",
|
||||
code: "function bad syntax",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
err := L.LoadString(tt.code)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadString() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
// Verify the function is on the stack
|
||||
if L.GetTop() != 1 {
|
||||
t.Error("LoadString() did not leave exactly one value on stack")
|
||||
}
|
||||
if !L.IsFunction(-1) {
|
||||
t.Error("LoadString() did not leave a function on the stack")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
wantResults int
|
||||
checkResults func(*State) error
|
||||
wantErr bool
|
||||
wantStackSize int
|
||||
}{
|
||||
{
|
||||
name: "no results",
|
||||
code: "local x = 1",
|
||||
wantResults: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "single result",
|
||||
code: "return 42",
|
||||
wantResults: 1,
|
||||
checkResults: func(L *State) error {
|
||||
if n := L.ToNumber(-1); n != 42 {
|
||||
return fmt.Errorf("got %v, want 42", n)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple results",
|
||||
code: "return 1, 'test', true",
|
||||
wantResults: 3,
|
||||
checkResults: func(L *State) error {
|
||||
if n := L.ToNumber(-3); n != 1 {
|
||||
return fmt.Errorf("first result: got %v, want 1", n)
|
||||
}
|
||||
if s := L.ToString(-2); s != "test" {
|
||||
return fmt.Errorf("second result: got %v, want 'test'", s)
|
||||
}
|
||||
if b := L.ToBoolean(-1); !b {
|
||||
return fmt.Errorf("third result: got %v, want true", b)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "syntax error",
|
||||
code: "this is not valid lua",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "runtime error",
|
||||
code: "error('test error')",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Record initial stack size
|
||||
initialStack := L.GetTop()
|
||||
|
||||
results, err := L.ExecuteString(tt.code)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ExecuteString() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if results != tt.wantResults {
|
||||
t.Errorf("ExecuteString() returned %d results, want %d", results, tt.wantResults)
|
||||
}
|
||||
|
||||
if tt.checkResults != nil {
|
||||
if err := tt.checkResults(L); err != nil {
|
||||
t.Errorf("Result check failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify stack size matches expected results
|
||||
if got := L.GetTop() - initialStack; got != tt.wantResults {
|
||||
t.Errorf("Stack size grew by %d, want %d", got, tt.wantResults)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -41,20 +167,22 @@ func TestDoString(t *testing.T) {
|
|||
{"runtime error", "error('test error')", true},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, tt := range tests {
|
||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
initialStack := L.GetTop()
|
||||
err := L.DoString(tt.code)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
|
||||
// Verify stack is unchanged
|
||||
if finalStack := L.GetTop(); finalStack != initialStack {
|
||||
t.Errorf("Stack size changed from %d to %d", initialStack, finalStack)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -107,10 +235,8 @@ func TestPushAndGetValues(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, v := range values {
|
||||
t.Run(f.name+"/"+v.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
@ -120,15 +246,11 @@ func TestPushAndGetValues(t *testing.T) {
|
|||
if err := v.check(L); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStackManipulation(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
@ -157,14 +279,10 @@ func TestStackManipulation(t *testing.T) {
|
|||
if top := L.GetTop(); top != len(values)-1 {
|
||||
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobals(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
@ -190,12 +308,98 @@ func TestGlobals(t *testing.T) {
|
|||
if num := L.ToNumber(-1); num != 42 {
|
||||
t.Errorf("global number = %f, want 42", num)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCall(t *testing.T) {
|
||||
tests := []struct {
|
||||
funcName string // Add explicit function name field
|
||||
setup string
|
||||
args []interface{}
|
||||
nresults int
|
||||
checkStack func(*State) error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
funcName: "add",
|
||||
setup: "function add(a, b) return a + b end",
|
||||
args: []interface{}{float64(40), float64(2)},
|
||||
nresults: 1,
|
||||
checkStack: func(L *State) error {
|
||||
if n := L.ToNumber(-1); n != 42 {
|
||||
return fmt.Errorf("got %v, want 42", n)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
funcName: "multi",
|
||||
setup: "function multi() return 1, 'test', true end",
|
||||
args: []interface{}{},
|
||||
nresults: 3,
|
||||
checkStack: func(L *State) error {
|
||||
if n := L.ToNumber(-3); n != 1 {
|
||||
return fmt.Errorf("first result: got %v, want 1", n)
|
||||
}
|
||||
if s := L.ToString(-2); s != "test" {
|
||||
return fmt.Errorf("second result: got %v, want 'test'", s)
|
||||
}
|
||||
if b := L.ToBoolean(-1); !b {
|
||||
return fmt.Errorf("third result: got %v, want true", b)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
funcName: "err",
|
||||
setup: "function err() error('test error') end",
|
||||
args: []interface{}{},
|
||||
nresults: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Setup function
|
||||
if err := L.DoString(tt.setup); err != nil {
|
||||
t.Fatalf("Setup failed: %v", err)
|
||||
}
|
||||
|
||||
// Get function
|
||||
L.GetGlobal(tt.funcName)
|
||||
if !L.IsFunction(-1) {
|
||||
t.Fatal("Failed to get function")
|
||||
}
|
||||
|
||||
// Push arguments
|
||||
for _, arg := range tt.args {
|
||||
if err := L.PushValue(arg); err != nil {
|
||||
t.Fatalf("Failed to push argument: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Call function
|
||||
err := L.Call(len(tt.args), tt.nresults)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Call() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if err == nil && tt.checkStack != nil {
|
||||
if err := tt.checkStack(L); err != nil {
|
||||
t.Errorf("Stack check failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoFile(t *testing.T) {
|
||||
L := NewSafe()
|
||||
L := New()
|
||||
defer L.Close()
|
||||
|
||||
// Create test file
|
||||
|
@ -223,7 +427,7 @@ func TestDoFile(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRequireAndPackagePath(t *testing.T) {
|
||||
L := NewSafe()
|
||||
L := New()
|
||||
defer L.Close()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
@ -260,7 +464,7 @@ func TestRequireAndPackagePath(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSetPackagePath(t *testing.T) {
|
||||
L := NewSafe()
|
||||
L := New()
|
||||
defer L.Close()
|
||||
|
||||
customPath := "./custom/?.lua"
|
||||
|
@ -273,4 +477,217 @@ func TestSetPackagePath(t *testing.T) {
|
|||
if path := L.ToString(-1); path != customPath {
|
||||
t.Errorf("Expected package.path=%q, got %q", customPath, path)
|
||||
}
|
||||
|
||||
// Test that the old path is completely replaced
|
||||
initialPath := L.ToString(-1)
|
||||
anotherPath := "./another/?.lua"
|
||||
if err := L.SetPackagePath(anotherPath); err != nil {
|
||||
t.Fatalf("Second SetPackagePath failed: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("package")
|
||||
L.GetField(-1, "path")
|
||||
if path := L.ToString(-1); path != anotherPath {
|
||||
t.Errorf("Expected package.path=%q, got %q", anotherPath, path)
|
||||
}
|
||||
if path := L.ToString(-1); path == initialPath {
|
||||
t.Error("SetPackagePath did not replace the old path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStackDebug(t *testing.T) {
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
t.Log("Testing LoadString:")
|
||||
initialTop := L.GetTop()
|
||||
t.Logf("Initial stack size: %d", initialTop)
|
||||
|
||||
err := L.LoadString("return 42")
|
||||
if err != nil {
|
||||
t.Errorf("LoadString failed: %v", err)
|
||||
}
|
||||
|
||||
afterLoad := L.GetTop()
|
||||
t.Logf("Stack size after load: %d", afterLoad)
|
||||
t.Logf("Type of top element: %s", L.GetType(-1))
|
||||
|
||||
if L.IsFunction(-1) {
|
||||
t.Log("Top element is a function")
|
||||
} else {
|
||||
t.Log("Top element is NOT a function")
|
||||
}
|
||||
|
||||
// Clean up after LoadString test
|
||||
L.SetTop(0)
|
||||
|
||||
t.Log("\nTesting ExecuteString:")
|
||||
if err := L.DoString("function test() return 1, 'hello', true end"); err != nil {
|
||||
t.Errorf("DoString failed: %v", err)
|
||||
}
|
||||
|
||||
beforeExec := L.GetTop()
|
||||
t.Logf("Stack size before execute: %d", beforeExec)
|
||||
|
||||
nresults, err := L.ExecuteString("return test()")
|
||||
if err != nil {
|
||||
t.Errorf("ExecuteString failed: %v", err)
|
||||
}
|
||||
|
||||
afterExec := L.GetTop()
|
||||
t.Logf("Stack size after execute: %d", afterExec)
|
||||
t.Logf("Reported number of results: %d", nresults)
|
||||
|
||||
// Print each stack element
|
||||
for i := 1; i <= afterExec; i++ {
|
||||
t.Logf("Stack[-%d] type: %s", i, L.GetType(-i))
|
||||
}
|
||||
|
||||
if afterExec != nresults {
|
||||
t.Errorf("Stack size (%d) doesn't match number of results (%d)", afterExec, nresults)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateRendering(t *testing.T) {
|
||||
L := New()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
|
||||
// Create a simple render.template function
|
||||
renderFunc := func(s *State) int {
|
||||
// Template will be at index 1, data at index 2
|
||||
data, err := s.ToTable(2)
|
||||
if err != nil {
|
||||
s.PushString(fmt.Sprintf("failed to get data table: %v", err))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Push data back as global for template access
|
||||
if err := s.PushTable(data); err != nil {
|
||||
s.PushString(fmt.Sprintf("failed to push data table: %v", err))
|
||||
return -1
|
||||
}
|
||||
s.SetGlobal("data")
|
||||
|
||||
// Template processing code
|
||||
luaCode := `
|
||||
local result = {}
|
||||
if data.user.logged_in then
|
||||
table.insert(result, '<div class="profile">')
|
||||
table.insert(result, string.format(' <h1>Welcome, %s!</h1>', tostring(data.user.name)))
|
||||
table.insert(result, ' <div class="user-info">')
|
||||
table.insert(result, string.format(' <p>Email: %s</p>', tostring(data.user.email)))
|
||||
table.insert(result, string.format(' <p>Member since: %s</p>', tostring(data.user.joined_date)))
|
||||
table.insert(result, ' </div>')
|
||||
if data.user.is_admin then
|
||||
table.insert(result, ' <div class="admin-panel">')
|
||||
table.insert(result, ' <h2>Admin Controls</h2>')
|
||||
table.insert(result, ' <!-- admin content -->')
|
||||
table.insert(result, ' </div>')
|
||||
end
|
||||
table.insert(result, '</div>')
|
||||
else
|
||||
table.insert(result, '<div class="profile">')
|
||||
table.insert(result, ' <h1>Please log in to view your profile</h1>')
|
||||
table.insert(result, '</div>')
|
||||
end
|
||||
return table.concat(result, '\n')`
|
||||
|
||||
result, err := s.DoStringResult(luaCode)
|
||||
if err != nil {
|
||||
s.PushString(fmt.Sprintf("template execution failed: %v", err))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Push the string result
|
||||
if str, ok := result.(string); ok {
|
||||
s.PushString(str)
|
||||
return 1
|
||||
}
|
||||
|
||||
s.PushString(fmt.Sprintf("expected string result, got %T", result))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Create render table and add template function
|
||||
L.NewTable()
|
||||
if err := L.PushGoFunction(renderFunc); err != nil {
|
||||
t.Fatalf("Failed to create render function: %v", err)
|
||||
}
|
||||
L.SetField(-2, "template")
|
||||
L.SetGlobal("render")
|
||||
|
||||
// Test with logged in admin user
|
||||
testCode := `
|
||||
local data = {
|
||||
user = {
|
||||
logged_in = true,
|
||||
name = "John Doe",
|
||||
email = "john@example.com",
|
||||
joined_date = "2024-02-09",
|
||||
is_admin = true
|
||||
}
|
||||
}
|
||||
return render.template("test.html", data)
|
||||
`
|
||||
|
||||
result, err := L.DoStringResult(testCode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute test: %v", err)
|
||||
}
|
||||
|
||||
str, ok := result.(string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected string result, got %T", result)
|
||||
}
|
||||
|
||||
expectedResult := `<div class="profile">
|
||||
<h1>Welcome, John Doe!</h1>
|
||||
<div class="user-info">
|
||||
<p>Email: john@example.com</p>
|
||||
<p>Member since: 2024-02-09</p>
|
||||
</div>
|
||||
<div class="admin-panel">
|
||||
<h2>Admin Controls</h2>
|
||||
<!-- admin content -->
|
||||
</div>
|
||||
</div>`
|
||||
|
||||
if str != expectedResult {
|
||||
t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str)
|
||||
}
|
||||
|
||||
// Test with logged out user
|
||||
testCode = `
|
||||
local data = {
|
||||
user = {
|
||||
logged_in = false
|
||||
}
|
||||
}
|
||||
return render.template("test.html", data)
|
||||
`
|
||||
|
||||
result, err = L.DoStringResult(testCode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute logged out test: %v", err)
|
||||
}
|
||||
|
||||
str, ok = result.(string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected string result, got %T", result)
|
||||
}
|
||||
|
||||
expectedResult = `<div class="profile">
|
||||
<h1>Please log in to view your profile</h1>
|
||||
</div>`
|
||||
|
||||
if str != expectedResult {
|
||||
t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user