BIG changes; no "safe" mode, function updates, etc
This commit is contained in:
parent
7c79616cac
commit
4dc266201f
|
@ -1,148 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
type benchCase struct {
|
|
||||||
name string
|
|
||||||
code string
|
|
||||||
}
|
|
||||||
|
|
||||||
var cases = []benchCase{
|
|
||||||
{
|
|
||||||
name: "Simple Addition",
|
|
||||||
code: `return 1 + 1`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Loop Sum",
|
|
||||||
code: `
|
|
||||||
local sum = 0
|
|
||||||
for i = 1, 1000 do
|
|
||||||
sum = sum + i
|
|
||||||
end
|
|
||||||
return sum
|
|
||||||
`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Function Call",
|
|
||||||
code: `
|
|
||||||
local result = 0
|
|
||||||
for i = 1, 100 do
|
|
||||||
result = result + i
|
|
||||||
end
|
|
||||||
return result
|
|
||||||
`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Table Creation",
|
|
||||||
code: `
|
|
||||||
local t = {}
|
|
||||||
for i = 1, 100 do
|
|
||||||
t[i] = i * 2
|
|
||||||
end
|
|
||||||
return t[50]
|
|
||||||
`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "String Operations",
|
|
||||||
code: `
|
|
||||||
local s = "hello"
|
|
||||||
for i = 1, 10 do
|
|
||||||
s = s .. " world"
|
|
||||||
end
|
|
||||||
return #s
|
|
||||||
`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func runBenchmark(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) {
|
|
||||||
start := time.Now()
|
|
||||||
deadline := start.Add(duration)
|
|
||||||
var ops int64
|
|
||||||
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
if err := L.DoString(code); err != nil {
|
|
||||||
fmt.Printf("Error executing code: %v\n", err)
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
L.Pop(1)
|
|
||||||
ops++
|
|
||||||
}
|
|
||||||
|
|
||||||
return time.Since(start), ops
|
|
||||||
}
|
|
||||||
|
|
||||||
func runBytecodeTest(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) {
|
|
||||||
// First compile the bytecode
|
|
||||||
bytecode, err := L.CompileBytecode(code, "bench")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Error compiling bytecode: %v\n", err)
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
deadline := start.Add(duration)
|
|
||||||
var ops int64
|
|
||||||
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
if err := L.LoadBytecode(bytecode, "bench"); err != nil {
|
|
||||||
fmt.Printf("Error executing bytecode: %v\n", err)
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
ops++
|
|
||||||
}
|
|
||||||
|
|
||||||
return time.Since(start), ops
|
|
||||||
}
|
|
||||||
|
|
||||||
func benchmarkCase(newState func() *luajit.State, bc benchCase) {
|
|
||||||
fmt.Printf("\n%s:\n", bc.name)
|
|
||||||
|
|
||||||
// Direct execution benchmark
|
|
||||||
L := newState()
|
|
||||||
if L == nil {
|
|
||||||
fmt.Printf(" Failed to create Lua state\n")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
execTime, ops := runBenchmark(L, bc.code, 2*time.Second)
|
|
||||||
L.Close()
|
|
||||||
if ops > 0 {
|
|
||||||
opsPerSec := float64(ops) / execTime.Seconds()
|
|
||||||
fmt.Printf(" Direct: %.0f ops/sec\n", opsPerSec)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bytecode execution benchmark
|
|
||||||
L = newState()
|
|
||||||
if L == nil {
|
|
||||||
fmt.Printf(" Failed to create Lua state\n")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
execTime, ops = runBytecodeTest(L, bc.code, 2*time.Second)
|
|
||||||
L.Close()
|
|
||||||
if ops > 0 {
|
|
||||||
opsPerSec := float64(ops) / execTime.Seconds()
|
|
||||||
fmt.Printf(" Bytecode: %.0f ops/sec\n", opsPerSec)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
modes := []struct {
|
|
||||||
name string
|
|
||||||
newState func() *luajit.State
|
|
||||||
}{
|
|
||||||
{"Safe", luajit.NewSafe},
|
|
||||||
{"Unsafe", luajit.New},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, mode := range modes {
|
|
||||||
fmt.Printf("\n=== %s Mode ===\n", mode.name)
|
|
||||||
|
|
||||||
for _, c := range cases {
|
|
||||||
benchmarkCase(mode.newState, c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
26
bytecode.go
26
bytecode.go
|
@ -51,7 +51,8 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *State) compileBytecodeUnsafe(code string, name string) ([]byte, error) {
|
// CompileBytecode compiles a Lua chunk to bytecode without executing it
|
||||||
|
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
|
||||||
// First load the string but don't execute it
|
// First load the string but don't execute it
|
||||||
ccode := C.CString(code)
|
ccode := C.CString(code)
|
||||||
defer C.free(unsafe.Pointer(ccode))
|
defer C.free(unsafe.Pointer(ccode))
|
||||||
|
@ -94,7 +95,8 @@ func (s *State) compileBytecodeUnsafe(code string, name string) ([]byte, error)
|
||||||
return bytecode, nil
|
return bytecode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error {
|
// LoadBytecode loads precompiled bytecode and executes it
|
||||||
|
func (s *State) LoadBytecode(bytecode []byte, name string) error {
|
||||||
cname := C.CString(name)
|
cname := C.CString(name)
|
||||||
defer C.free(unsafe.Pointer(cname))
|
defer C.free(unsafe.Pointer(cname))
|
||||||
|
|
||||||
|
@ -125,26 +127,6 @@ func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompileBytecode compiles a Lua chunk to bytecode without executing it
|
|
||||||
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
|
|
||||||
if s.safeStack {
|
|
||||||
return stackGuardValue[[]byte](s, func() ([]byte, error) {
|
|
||||||
return s.compileBytecodeUnsafe(code, name)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return s.compileBytecodeUnsafe(code, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadBytecode loads precompiled bytecode and executes it
|
|
||||||
func (s *State) LoadBytecode(bytecode []byte, name string) error {
|
|
||||||
if s.safeStack {
|
|
||||||
return stackGuardErr(s, func() error {
|
|
||||||
return s.loadBytecodeUnsafe(bytecode, name)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return s.loadBytecodeUnsafe(bytecode, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to compile and immediately load/execute bytecode
|
// Helper function to compile and immediately load/execute bytecode
|
||||||
func (s *State) CompileAndLoad(code string, name string) error {
|
func (s *State) CompileAndLoad(code string, name string) error {
|
||||||
bytecode, err := s.CompileBytecode(code, name)
|
bytecode, err := s.CompileBytecode(code, name)
|
||||||
|
|
162
bytecode_test.go
162
bytecode_test.go
|
@ -28,82 +28,70 @@ func TestBytecodeCompilation(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range factories {
|
for _, tt := range tests {
|
||||||
for _, tt := range tests {
|
L := New()
|
||||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
if L == nil {
|
||||||
L := f.new()
|
t.Fatal("Failed to create Lua state")
|
||||||
if L == nil {
|
}
|
||||||
t.Fatal("Failed to create Lua state")
|
defer L.Close()
|
||||||
}
|
|
||||||
defer L.Close()
|
|
||||||
|
|
||||||
bytecode, err := L.CompileBytecode(tt.code, "test")
|
bytecode, err := L.CompileBytecode(tt.code, "test")
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.wantErr {
|
if !tt.wantErr {
|
||||||
if len(bytecode) == 0 {
|
if len(bytecode) == 0 {
|
||||||
t.Error("CompileBytecode() returned empty bytecode")
|
t.Error("CompileBytecode() returned empty bytecode")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBytecodeExecution(t *testing.T) {
|
func TestBytecodeExecution(t *testing.T) {
|
||||||
for _, f := range factories {
|
L := New()
|
||||||
t.Run(f.name, func(t *testing.T) {
|
if L == nil {
|
||||||
L := f.new()
|
t.Fatal("Failed to create Lua state")
|
||||||
if L == nil {
|
}
|
||||||
t.Fatal("Failed to create Lua state")
|
defer L.Close()
|
||||||
}
|
|
||||||
defer L.Close()
|
|
||||||
|
|
||||||
// Compile some test code
|
// Compile some test code
|
||||||
code := `
|
code := `
|
||||||
function add(a, b)
|
function add(a, b)
|
||||||
return a + b
|
return a + b
|
||||||
end
|
end
|
||||||
result = add(40, 2)
|
result = add(40, 2)
|
||||||
`
|
`
|
||||||
|
|
||||||
bytecode, err := L.CompileBytecode(code, "test")
|
bytecode, err := L.CompileBytecode(code, "test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CompileBytecode() error = %v", err)
|
t.Fatalf("CompileBytecode() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load and execute the bytecode
|
// Load and execute the bytecode
|
||||||
if err := L.LoadBytecode(bytecode, "test"); err != nil {
|
if err := L.LoadBytecode(bytecode, "test"); err != nil {
|
||||||
t.Fatalf("LoadBytecode() error = %v", err)
|
t.Fatalf("LoadBytecode() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the result
|
// Verify the result
|
||||||
L.GetGlobal("result")
|
L.GetGlobal("result")
|
||||||
if result := L.ToNumber(-1); result != 42 {
|
if result := L.ToNumber(-1); result != 42 {
|
||||||
t.Errorf("got result = %v, want 42", result)
|
t.Errorf("got result = %v, want 42", result)
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInvalidBytecode(t *testing.T) {
|
func TestInvalidBytecode(t *testing.T) {
|
||||||
for _, f := range factories {
|
L := New()
|
||||||
t.Run(f.name, func(t *testing.T) {
|
if L == nil {
|
||||||
L := f.new()
|
t.Fatal("Failed to create Lua state")
|
||||||
if L == nil {
|
}
|
||||||
t.Fatal("Failed to create Lua state")
|
defer L.Close()
|
||||||
}
|
|
||||||
defer L.Close()
|
|
||||||
|
|
||||||
// Test with invalid bytecode
|
// Test with invalid bytecode
|
||||||
invalidBytecode := []byte("this is not valid bytecode")
|
invalidBytecode := []byte("this is not valid bytecode")
|
||||||
if err := L.LoadBytecode(invalidBytecode, "test"); err == nil {
|
if err := L.LoadBytecode(invalidBytecode, "test"); err == nil {
|
||||||
t.Error("LoadBytecode() expected error with invalid bytecode")
|
t.Error("LoadBytecode() expected error with invalid bytecode")
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,39 +128,35 @@ func TestBytecodeRoundTrip(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range factories {
|
for _, tt := range tests {
|
||||||
for _, tt := range tests {
|
// First state for compilation
|
||||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
L1 := New()
|
||||||
// First state for compilation
|
if L1 == nil {
|
||||||
L1 := f.new()
|
t.Fatal("Failed to create first Lua state")
|
||||||
if L1 == nil {
|
}
|
||||||
t.Fatal("Failed to create first Lua state")
|
defer L1.Close()
|
||||||
}
|
|
||||||
defer L1.Close()
|
|
||||||
|
|
||||||
// Compile the code
|
// Compile the code
|
||||||
bytecode, err := L1.CompileBytecode(tt.code, "test")
|
bytecode, err := L1.CompileBytecode(tt.code, "test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CompileBytecode() error = %v", err)
|
t.Fatalf("CompileBytecode() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second state for execution
|
// Second state for execution
|
||||||
L2 := f.new()
|
L2 := New()
|
||||||
if L2 == nil {
|
if L2 == nil {
|
||||||
t.Fatal("Failed to create second Lua state")
|
t.Fatal("Failed to create second Lua state")
|
||||||
}
|
}
|
||||||
defer L2.Close()
|
defer L2.Close()
|
||||||
|
|
||||||
// Load and execute the bytecode
|
// Load and execute the bytecode
|
||||||
if err := L2.LoadBytecode(bytecode, "test"); err != nil {
|
if err := L2.LoadBytecode(bytecode, "test"); err != nil {
|
||||||
t.Fatalf("LoadBytecode() error = %v", err)
|
t.Fatalf("LoadBytecode() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the check function
|
// Run the check function
|
||||||
if err := tt.check(L2); err != nil {
|
if err := tt.check(L2); err != nil {
|
||||||
t.Errorf("check failed: %v", err)
|
t.Errorf("check failed: %v", err)
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ var (
|
||||||
|
|
||||||
//export goFunctionWrapper
|
//export goFunctionWrapper
|
||||||
func goFunctionWrapper(L *C.lua_State) C.int {
|
func goFunctionWrapper(L *C.lua_State) C.int {
|
||||||
state := &State{L: L, safeStack: true}
|
state := &State{L: L}
|
||||||
|
|
||||||
// Get upvalue using standard Lua 5.1 macro
|
// Get upvalue using standard Lua 5.1 macro
|
||||||
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
|
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
|
||||||
|
@ -54,7 +54,6 @@ func goFunctionWrapper(L *C.lua_State) C.int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *State) PushGoFunction(fn GoFunction) error {
|
func (s *State) PushGoFunction(fn GoFunction) error {
|
||||||
// Push lightuserdata as upvalue and create closure
|
|
||||||
ptr := C.malloc(1)
|
ptr := C.malloc(1)
|
||||||
if ptr == nil {
|
if ptr == nil {
|
||||||
return fmt.Errorf("failed to allocate memory for function pointer")
|
return fmt.Errorf("failed to allocate memory for function pointer")
|
||||||
|
|
|
@ -1,89 +1,87 @@
|
||||||
package luajit
|
package luajit
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
func TestGoFunctions(t *testing.T) {
|
func TestGoFunctions(t *testing.T) {
|
||||||
for _, f := range factories {
|
L := New()
|
||||||
t.Run(f.name, func(t *testing.T) {
|
if L == nil {
|
||||||
L := f.new()
|
t.Fatal("Failed to create Lua state")
|
||||||
if L == nil {
|
}
|
||||||
t.Fatal("Failed to create Lua state")
|
defer L.Close()
|
||||||
}
|
defer L.Cleanup()
|
||||||
defer L.Close()
|
|
||||||
defer L.Cleanup()
|
|
||||||
|
|
||||||
addFunc := func(s *State) int {
|
addFunc := func(s *State) int {
|
||||||
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
|
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := L.RegisterGoFunction("add", addFunc); err != nil {
|
if err := L.RegisterGoFunction("add", addFunc); err != nil {
|
||||||
t.Fatalf("Failed to register function: %v", err)
|
t.Fatalf("Failed to register function: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test basic function call
|
// Test basic function call
|
||||||
if err := L.DoString("result = add(40, 2)"); err != nil {
|
if err := L.DoString("result = add(40, 2)"); err != nil {
|
||||||
t.Fatalf("Failed to call function: %v", err)
|
t.Fatalf("Failed to call function: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
L.GetGlobal("result")
|
L.GetGlobal("result")
|
||||||
if result := L.ToNumber(-1); result != 42 {
|
if result := L.ToNumber(-1); result != 42 {
|
||||||
t.Errorf("got %v, want 42", result)
|
t.Errorf("got %v, want 42", result)
|
||||||
}
|
}
|
||||||
L.Pop(1)
|
L.Pop(1)
|
||||||
|
|
||||||
// Test multiple return values
|
// Test multiple return values
|
||||||
multiFunc := func(s *State) int {
|
multiFunc := func(s *State) int {
|
||||||
s.PushString("hello")
|
s.PushString("hello")
|
||||||
s.PushNumber(42)
|
s.PushNumber(42)
|
||||||
s.PushBoolean(true)
|
s.PushBoolean(true)
|
||||||
return 3
|
return 3
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
|
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
|
||||||
t.Fatalf("Failed to register multi function: %v", err)
|
t.Fatalf("Failed to register multi function: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
code := `
|
code := `
|
||||||
a, b, c = multi()
|
a, b, c = multi()
|
||||||
result = (a == "hello" and b == 42 and c == true)
|
result = (a == "hello" and b == 42 and c == true)
|
||||||
`
|
`
|
||||||
|
|
||||||
if err := L.DoString(code); err != nil {
|
if err := L.DoString(code); err != nil {
|
||||||
t.Fatalf("Failed to call multi function: %v", err)
|
t.Fatalf("Failed to call multi function: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
L.GetGlobal("result")
|
L.GetGlobal("result")
|
||||||
if !L.ToBoolean(-1) {
|
if !L.ToBoolean(-1) {
|
||||||
t.Error("Multiple return values test failed")
|
t.Error("Multiple return values test failed")
|
||||||
}
|
}
|
||||||
L.Pop(1)
|
L.Pop(1)
|
||||||
|
|
||||||
// Test error handling
|
// Test error handling
|
||||||
errFunc := func(s *State) int {
|
errFunc := func(s *State) int {
|
||||||
s.PushString("test error")
|
s.PushString("test error")
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := L.RegisterGoFunction("err", errFunc); err != nil {
|
if err := L.RegisterGoFunction("err", errFunc); err != nil {
|
||||||
t.Fatalf("Failed to register error function: %v", err)
|
t.Fatalf("Failed to register error function: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := L.DoString("err()"); err == nil {
|
if err := L.DoString("err()"); err == nil {
|
||||||
t.Error("Expected error from error function")
|
t.Error("Expected error from error function")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test unregistering
|
// Test unregistering
|
||||||
L.UnregisterGoFunction("add")
|
L.UnregisterGoFunction("add")
|
||||||
if err := L.DoString("add(1, 2)"); err == nil {
|
if err := L.DoString("add(1, 2)"); err == nil {
|
||||||
t.Error("Expected error calling unregistered function")
|
t.Error("Expected error calling unregistered function")
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStackSafety(t *testing.T) {
|
func TestStackSafety(t *testing.T) {
|
||||||
L := NewSafe()
|
L := New()
|
||||||
if L == nil {
|
if L == nil {
|
||||||
t.Fatal("Failed to create Lua state")
|
t.Fatal("Failed to create Lua state")
|
||||||
}
|
}
|
||||||
|
|
101
table.go
101
table.go
|
@ -22,57 +22,20 @@ type TableValue interface {
|
||||||
|
|
||||||
func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) }
|
func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) }
|
||||||
|
|
||||||
|
// PushTable pushes a Go map onto the Lua stack as a table
|
||||||
|
func (s *State) PushTable(table map[string]interface{}) error {
|
||||||
|
s.NewTable()
|
||||||
|
for k, v := range table {
|
||||||
|
if err := s.PushValue(v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.SetField(-2, k)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ToTable converts a Lua table to a Go map
|
// ToTable converts a Lua table to a Go map
|
||||||
func (s *State) ToTable(index int) (map[string]interface{}, error) {
|
func (s *State) ToTable(index int) (map[string]interface{}, error) {
|
||||||
if s.safeStack {
|
|
||||||
return stackGuardValue[map[string]interface{}](s, func() (map[string]interface{}, error) {
|
|
||||||
if !s.IsTable(index) {
|
|
||||||
return nil, fmt.Errorf("not a table at index %d", index)
|
|
||||||
}
|
|
||||||
return s.toTableUnsafe(index)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if !s.IsTable(index) {
|
|
||||||
return nil, fmt.Errorf("not a table at index %d", index)
|
|
||||||
}
|
|
||||||
return s.toTableUnsafe(index)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) pushTableSafe(table map[string]interface{}) error {
|
|
||||||
size := 2
|
|
||||||
if err := s.checkStack(size); err != nil {
|
|
||||||
return fmt.Errorf("insufficient stack space: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.NewTable()
|
|
||||||
for k, v := range table {
|
|
||||||
if err := s.pushValueSafe(v); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.SetField(-2, k)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) pushTableUnsafe(table map[string]interface{}) error {
|
|
||||||
s.NewTable()
|
|
||||||
for k, v := range table {
|
|
||||||
if err := s.pushValueUnsafe(v); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.SetField(-2, k)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) toTableSafe(index int) (map[string]interface{}, error) {
|
|
||||||
if err := s.checkStack(2); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s.toTableUnsafe(index)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
|
|
||||||
absIdx := s.absIndex(index)
|
absIdx := s.absIndex(index)
|
||||||
table := make(map[string]interface{})
|
table := make(map[string]interface{})
|
||||||
|
|
||||||
|
@ -111,7 +74,7 @@ func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
|
||||||
key = fmt.Sprintf("%g", s.ToNumber(-2))
|
key = fmt.Sprintf("%g", s.ToNumber(-2))
|
||||||
}
|
}
|
||||||
|
|
||||||
value, err := s.toValueUnsafe(-1)
|
value, err := s.ToValue(-1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Pop(1)
|
s.Pop(1)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -133,45 +96,15 @@ func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
|
||||||
|
|
||||||
// NewTable creates a new table and pushes it onto the stack
|
// NewTable creates a new table and pushes it onto the stack
|
||||||
func (s *State) NewTable() {
|
func (s *State) NewTable() {
|
||||||
if s.safeStack {
|
|
||||||
if err := s.checkStack(1); err != nil {
|
|
||||||
// Since we can't return an error, we'll push nil instead
|
|
||||||
s.PushNil()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
C.lua_createtable(s.L, 0, 0)
|
C.lua_createtable(s.L, 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTable sets a table field with cached absolute index
|
// SetTable sets a table field
|
||||||
func (s *State) SetTable(index int) {
|
func (s *State) SetTable(index int) {
|
||||||
absIdx := index
|
C.lua_settable(s.L, C.int(index))
|
||||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
|
||||||
absIdx = s.GetTop() + index + 1
|
|
||||||
}
|
|
||||||
C.lua_settable(s.L, C.int(absIdx))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTable gets a table field with cached absolute index
|
// GetTable gets a table field
|
||||||
func (s *State) GetTable(index int) {
|
func (s *State) GetTable(index int) {
|
||||||
absIdx := index
|
C.lua_gettable(s.L, C.int(index))
|
||||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
|
||||||
absIdx = s.GetTop() + index + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.safeStack {
|
|
||||||
if err := s.checkStack(1); err != nil {
|
|
||||||
s.PushNil()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
C.lua_gettable(s.L, C.int(absIdx))
|
|
||||||
}
|
|
||||||
|
|
||||||
// PushTable pushes a Go map onto the Lua stack as a table with stack checking
|
|
||||||
func (s *State) PushTable(table map[string]interface{}) error {
|
|
||||||
if s.safeStack {
|
|
||||||
return s.pushTableSafe(table)
|
|
||||||
}
|
|
||||||
return s.pushTableUnsafe(table)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,28 +34,24 @@ func TestTableOperations(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range factories {
|
for _, tt := range tests {
|
||||||
for _, tt := range tests {
|
L := New()
|
||||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
if L == nil {
|
||||||
L := f.new()
|
t.Fatal("Failed to create Lua state")
|
||||||
if L == nil {
|
}
|
||||||
t.Fatal("Failed to create Lua state")
|
defer L.Close()
|
||||||
}
|
|
||||||
defer L.Close()
|
|
||||||
|
|
||||||
if err := L.PushTable(tt.data); err != nil {
|
if err := L.PushTable(tt.data); err != nil {
|
||||||
t.Fatalf("PushTable() error = %v", err)
|
t.Fatalf("PushTable() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := L.ToTable(-1)
|
got, err := L.ToTable(-1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ToTable() error = %v", err)
|
t.Fatalf("ToTable() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tablesEqual(got, tt.data) {
|
if !tablesEqual(got, tt.data) {
|
||||||
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
|
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
321
wrapper.go
321
wrapper.go
|
@ -10,49 +10,82 @@ package luajit
|
||||||
#include <lauxlib.h>
|
#include <lauxlib.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
|
||||||
static int do_string(lua_State *L, const char *s) {
|
// Simple wrapper around luaL_loadstring
|
||||||
int status = luaL_loadstring(L, s);
|
static int load_chunk(lua_State *L, const char *s) {
|
||||||
if (status) return status;
|
return luaL_loadstring(L, s);
|
||||||
return lua_pcall(L, 0, LUA_MULTRET, 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Direct wrapper around lua_pcall
|
||||||
|
static int protected_call(lua_State *L, int nargs, int nresults, int errfunc) {
|
||||||
|
return lua_pcall(L, nargs, nresults, errfunc);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combined load and execute with no results
|
||||||
|
static int do_string(lua_State *L, const char *s) {
|
||||||
|
return luaL_dostring(L, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combined load and execute file
|
||||||
static int do_file(lua_State *L, const char *filename) {
|
static int do_file(lua_State *L, const char *filename) {
|
||||||
int status = luaL_loadfile(L, filename);
|
return luaL_dofile(L, filename);
|
||||||
if (status) return status;
|
}
|
||||||
return lua_pcall(L, 0, LUA_MULTRET, 0);
|
|
||||||
|
// Execute string with multiple returns
|
||||||
|
static int execute_string(lua_State *L, const char *s) {
|
||||||
|
int base = lua_gettop(L); // Save stack position
|
||||||
|
int status = luaL_loadstring(L, s);
|
||||||
|
if (status) return -status; // Return negative status for load errors
|
||||||
|
|
||||||
|
status = lua_pcall(L, 0, LUA_MULTRET, 0);
|
||||||
|
if (status) return -status; // Return negative status for runtime errors
|
||||||
|
|
||||||
|
return lua_gettop(L) - base; // Return number of results
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get absolute stack index (converts negative indices)
|
||||||
|
static int get_abs_index(lua_State *L, int idx) {
|
||||||
|
if (idx > 0 || idx <= LUA_REGISTRYINDEX) return idx;
|
||||||
|
return lua_gettop(L) + idx + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stack manipulation helpers
|
||||||
|
static int check_stack(lua_State *L, int n) {
|
||||||
|
return lua_checkstack(L, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void remove_stack(lua_State *L, int idx) {
|
||||||
|
lua_remove(L, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_field_helper(lua_State *L, int idx, const char *k) {
|
||||||
|
lua_getfield(L, idx, k);
|
||||||
|
return lua_type(L, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void set_field_helper(lua_State *L, int idx, const char *k) {
|
||||||
|
lua_setfield(L, idx, k);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
// State represents a Lua state with configurable stack safety
|
// State represents a Lua state
|
||||||
type State struct {
|
type State struct {
|
||||||
L *C.lua_State
|
L *C.lua_State
|
||||||
safeStack bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSafe creates a new Lua state with full stack safety guarantees
|
// New creates a new Lua state
|
||||||
func NewSafe() *State {
|
|
||||||
L := C.luaL_newstate()
|
|
||||||
if L == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
C.luaL_openlibs(L)
|
|
||||||
return &State{L: L, safeStack: true}
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new Lua state with minimal stack checking
|
|
||||||
func New() *State {
|
func New() *State {
|
||||||
L := C.luaL_newstate()
|
L := C.luaL_newstate()
|
||||||
if L == nil {
|
if L == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
C.luaL_openlibs(L)
|
C.luaL_openlibs(L)
|
||||||
return &State{L: L, safeStack: false}
|
return &State{L: L}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the Lua state
|
// Close closes the Lua state
|
||||||
|
@ -63,65 +96,28 @@ func (s *State) Close() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoString executes a Lua string with appropriate stack management
|
// DoString executes a Lua string.
|
||||||
func (s *State) DoString(str string) error {
|
func (s *State) DoString(str string) error {
|
||||||
cstr := C.CString(str)
|
// Save initial stack size
|
||||||
defer C.free(unsafe.Pointer(cstr))
|
top := s.GetTop()
|
||||||
|
|
||||||
if s.safeStack {
|
// Load the string
|
||||||
return stackGuardErr(s, func() error {
|
if err := s.LoadString(str); err != nil {
|
||||||
// Save the current stack size
|
return err
|
||||||
initialTop := s.GetTop()
|
|
||||||
|
|
||||||
// Execute the string
|
|
||||||
status := C.do_string(s.L, cstr)
|
|
||||||
if status != 0 {
|
|
||||||
// In case of error, get error message from stack
|
|
||||||
errMsg := s.ToString(-1)
|
|
||||||
s.SetTop(initialTop) // Restore stack
|
|
||||||
return &LuaError{
|
|
||||||
Code: int(status),
|
|
||||||
Message: errMsg,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return values are now on the stack above initialTop
|
|
||||||
// We don't pop them as they may be needed by the caller
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
status := C.do_string(s.L, cstr)
|
// Execute and check for errors
|
||||||
if status != 0 {
|
if err := s.Call(0, 0); err != nil {
|
||||||
return &LuaError{
|
return err
|
||||||
Code: int(status),
|
|
||||||
Message: s.ToString(-1),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Restore stack to initial size to clean up any leftovers
|
||||||
|
s.SetTop(top)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PushValue pushes a Go value onto the stack
|
// PushValue pushes a Go value onto the stack
|
||||||
func (s *State) PushValue(v interface{}) error {
|
func (s *State) PushValue(v interface{}) error {
|
||||||
if s.safeStack {
|
|
||||||
return stackGuardErr(s, func() error {
|
|
||||||
if err := s.checkStack(1); err != nil {
|
|
||||||
return fmt.Errorf("pushing value: %w", err)
|
|
||||||
}
|
|
||||||
return s.pushValueUnsafe(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return s.pushValueUnsafe(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) pushValueSafe(v interface{}) error {
|
|
||||||
if err := s.checkStack(1); err != nil {
|
|
||||||
return fmt.Errorf("pushing value: %w", err)
|
|
||||||
}
|
|
||||||
return s.pushValueUnsafe(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) pushValueUnsafe(v interface{}) error {
|
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
s.PushNil()
|
s.PushNil()
|
||||||
|
@ -144,7 +140,7 @@ func (s *State) pushValueUnsafe(v interface{}) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.pushTableUnsafe(v)
|
return s.PushTable(v)
|
||||||
case []float64:
|
case []float64:
|
||||||
s.NewTable()
|
s.NewTable()
|
||||||
for i, elem := range v {
|
for i, elem := range v {
|
||||||
|
@ -156,7 +152,7 @@ func (s *State) pushValueUnsafe(v interface{}) error {
|
||||||
s.NewTable()
|
s.NewTable()
|
||||||
for i, elem := range v {
|
for i, elem := range v {
|
||||||
s.PushNumber(float64(i + 1))
|
s.PushNumber(float64(i + 1))
|
||||||
if err := s.pushValueUnsafe(elem); err != nil {
|
if err := s.PushValue(elem); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.SetTable(-3)
|
s.SetTable(-3)
|
||||||
|
@ -169,15 +165,6 @@ func (s *State) pushValueUnsafe(v interface{}) error {
|
||||||
|
|
||||||
// ToValue converts a Lua value to a Go value
|
// ToValue converts a Lua value to a Go value
|
||||||
func (s *State) ToValue(index int) (interface{}, error) {
|
func (s *State) ToValue(index int) (interface{}, error) {
|
||||||
if s.safeStack {
|
|
||||||
return stackGuardValue[interface{}](s, func() (interface{}, error) {
|
|
||||||
return s.toValueUnsafe(index)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return s.toValueUnsafe(index)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) toValueUnsafe(index int) (interface{}, error) {
|
|
||||||
switch s.GetType(index) {
|
switch s.GetType(index) {
|
||||||
case TypeNil:
|
case TypeNil:
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -191,7 +178,7 @@ func (s *State) toValueUnsafe(index int) (interface{}, error) {
|
||||||
if !s.IsTable(index) {
|
if !s.IsTable(index) {
|
||||||
return nil, fmt.Errorf("not a table at index %d", index)
|
return nil, fmt.Errorf("not a table at index %d", index)
|
||||||
}
|
}
|
||||||
return s.toTableUnsafe(index)
|
return s.ToTable(index)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
|
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
|
||||||
}
|
}
|
||||||
|
@ -244,53 +231,29 @@ func (s *State) absIndex(index int) int {
|
||||||
return s.GetTop() + index + 1
|
return s.GetTop() + index + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetField sets a field in a table at the given index with cached absolute index
|
// SetField sets a field in a table at the given index
|
||||||
func (s *State) SetField(index int, key string) {
|
func (s *State) SetField(index int, key string) {
|
||||||
absIdx := index
|
|
||||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
|
||||||
absIdx = s.GetTop() + index + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
cstr := C.CString(key)
|
cstr := C.CString(key)
|
||||||
defer C.free(unsafe.Pointer(cstr))
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
C.lua_setfield(s.L, C.int(absIdx), cstr)
|
C.lua_setfield(s.L, C.int(index), cstr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetField gets a field from a table with cached absolute index
|
// GetField gets a field from a table
|
||||||
func (s *State) GetField(index int, key string) {
|
func (s *State) GetField(index int, key string) {
|
||||||
absIdx := index
|
|
||||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
|
||||||
absIdx = s.GetTop() + index + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.safeStack {
|
|
||||||
if err := s.checkStack(1); err != nil {
|
|
||||||
s.PushNil()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cstr := C.CString(key)
|
cstr := C.CString(key)
|
||||||
defer C.free(unsafe.Pointer(cstr))
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
C.lua_getfield(s.L, C.int(absIdx), cstr)
|
C.lua_getfield(s.L, C.int(index), cstr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGlobal gets a global variable and pushes it onto the stack
|
// GetGlobal gets a global variable and pushes it onto the stack
|
||||||
func (s *State) GetGlobal(name string) {
|
func (s *State) GetGlobal(name string) {
|
||||||
if s.safeStack {
|
cname := C.CString(name)
|
||||||
if err := s.checkStack(1); err != nil {
|
defer C.free(unsafe.Pointer(cname))
|
||||||
s.PushNil()
|
C.get_field_helper(s.L, C.LUA_GLOBALSINDEX, cname)
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cstr := C.CString(name)
|
|
||||||
defer C.free(unsafe.Pointer(cstr))
|
|
||||||
C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetGlobal sets a global variable from the value at the top of the stack
|
// SetGlobal sets a global variable from the value at the top of the stack
|
||||||
func (s *State) SetGlobal(name string) {
|
func (s *State) SetGlobal(name string) {
|
||||||
// SetGlobal doesn't need stack space checking as it pops the value
|
|
||||||
cstr := C.CString(name)
|
cstr := C.CString(name)
|
||||||
defer C.free(unsafe.Pointer(cstr))
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
||||||
|
@ -299,9 +262,6 @@ func (s *State) SetGlobal(name string) {
|
||||||
// Remove removes element with cached absolute index
|
// Remove removes element with cached absolute index
|
||||||
func (s *State) Remove(index int) {
|
func (s *State) Remove(index int) {
|
||||||
absIdx := index
|
absIdx := index
|
||||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
|
||||||
absIdx = s.GetTop() + index + 1
|
|
||||||
}
|
|
||||||
C.lua_remove(s.L, C.int(absIdx))
|
C.lua_remove(s.L, C.int(absIdx))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -310,14 +270,6 @@ func (s *State) DoFile(filename string) error {
|
||||||
cfilename := C.CString(filename)
|
cfilename := C.CString(filename)
|
||||||
defer C.free(unsafe.Pointer(cfilename))
|
defer C.free(unsafe.Pointer(cfilename))
|
||||||
|
|
||||||
if s.safeStack {
|
|
||||||
return stackGuardErr(s, func() error {
|
|
||||||
return s.safeCall(func() C.int {
|
|
||||||
return C.do_file(s.L, cfilename)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
status := C.do_file(s.L, cfilename)
|
status := C.do_file(s.L, cfilename)
|
||||||
if status != 0 {
|
if status != 0 {
|
||||||
return &LuaError{
|
return &LuaError{
|
||||||
|
@ -328,24 +280,28 @@ func (s *State) DoFile(filename string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPackagePath sets the Lua package.path
|
||||||
func (s *State) SetPackagePath(path string) error {
|
func (s *State) SetPackagePath(path string) error {
|
||||||
path = filepath.ToSlash(path)
|
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
|
||||||
if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil {
|
cmd := fmt.Sprintf(`package.path = %q`, path)
|
||||||
return fmt.Errorf("setting package.path: %w", err)
|
return s.DoString(cmd)
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddPackagePath adds a path to package.path
|
||||||
func (s *State) AddPackagePath(path string) error {
|
func (s *State) AddPackagePath(path string) error {
|
||||||
path = filepath.ToSlash(path)
|
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
|
||||||
if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil {
|
cmd := fmt.Sprintf(`package.path = package.path .. ";%s"`, path)
|
||||||
return fmt.Errorf("adding to package.path: %w", err)
|
return s.DoString(cmd)
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Call executes a function on the stack with the given number of arguments and results.
|
||||||
|
// The function and arguments should already be on the stack in the correct order
|
||||||
|
// (function first, then args from left to right).
|
||||||
func (s *State) Call(nargs, nresults int) error {
|
func (s *State) Call(nargs, nresults int) error {
|
||||||
status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
|
if !s.IsFunction(-nargs - 1) {
|
||||||
|
return fmt.Errorf("attempt to call a non-function")
|
||||||
|
}
|
||||||
|
status := C.protected_call(s.L, C.int(nargs), C.int(nresults), 0)
|
||||||
if status != 0 {
|
if status != 0 {
|
||||||
err := &LuaError{
|
err := &LuaError{
|
||||||
Code: int(status),
|
Code: int(status),
|
||||||
|
@ -356,3 +312,94 @@ func (s *State) Call(nargs, nresults int) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadString loads but does not execute a string of Lua code.
|
||||||
|
// The compiled code chunk is left on the stack.
|
||||||
|
func (s *State) LoadString(str string) error {
|
||||||
|
cstr := C.CString(str)
|
||||||
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
|
|
||||||
|
status := C.load_chunk(s.L, cstr)
|
||||||
|
if status != 0 {
|
||||||
|
err := &LuaError{
|
||||||
|
Code: int(status),
|
||||||
|
Message: s.ToString(-1),
|
||||||
|
}
|
||||||
|
s.Pop(1)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.IsFunction(-1) {
|
||||||
|
s.Pop(1)
|
||||||
|
return fmt.Errorf("failed to load function")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteString executes a string of Lua code and returns the number of results.
|
||||||
|
// The results are left on the stack.
|
||||||
|
func (s *State) ExecuteString(str string) (int, error) {
|
||||||
|
base := s.GetTop()
|
||||||
|
|
||||||
|
// First load the string
|
||||||
|
if err := s.LoadString(str); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now execute it
|
||||||
|
if err := s.Call(0, C.LUA_MULTRET); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.GetTop() - base, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteStringResult executes a Lua string and returns its first result as a Go value.
|
||||||
|
// It's a convenience wrapper around ExecuteString for the common case of wanting
|
||||||
|
// a single return value. The stack is restored to its original state after execution.
|
||||||
|
func (s *State) ExecuteStringResult(code string) (interface{}, error) {
|
||||||
|
top := s.GetTop()
|
||||||
|
defer s.SetTop(top) // Restore stack when we're done
|
||||||
|
|
||||||
|
nresults, err := s.ExecuteString(code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("execution error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nresults == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the result
|
||||||
|
result, err := s.ToValue(-nresults) // Get first result
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error converting result: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoStringResult executes a Lua string and expects a single return value.
|
||||||
|
// Unlike ExecuteStringResult, this function specifically expects exactly one
|
||||||
|
// return value and will return an error if the code returns 0 or multiple values.
|
||||||
|
func (s *State) DoStringResult(code string) (interface{}, error) {
|
||||||
|
top := s.GetTop()
|
||||||
|
defer s.SetTop(top) // Restore stack when we're done
|
||||||
|
|
||||||
|
nresults, err := s.ExecuteString(code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("execution error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nresults != 1 {
|
||||||
|
return nil, fmt.Errorf("expected 1 return value, got %d", nresults)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the result
|
||||||
|
result, err := s.ToValue(-1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error converting result: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
237
wrapper_bench_test.go
Normal file
237
wrapper_bench_test.go
Normal file
|
@ -0,0 +1,237 @@
|
||||||
|
package luajit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var benchCases = []struct {
|
||||||
|
name string
|
||||||
|
code string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SimpleAddition",
|
||||||
|
code: `return 1 + 1`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LoopSum",
|
||||||
|
code: `
|
||||||
|
local sum = 0
|
||||||
|
for i = 1, 1000 do
|
||||||
|
sum = sum + i
|
||||||
|
end
|
||||||
|
return sum
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "FunctionCall",
|
||||||
|
code: `
|
||||||
|
local result = 0
|
||||||
|
for i = 1, 100 do
|
||||||
|
result = result + i
|
||||||
|
end
|
||||||
|
return result
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TableCreation",
|
||||||
|
code: `
|
||||||
|
local t = {}
|
||||||
|
for i = 1, 100 do
|
||||||
|
t[i] = i * 2
|
||||||
|
end
|
||||||
|
return t[50]
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "StringOperations",
|
||||||
|
code: `
|
||||||
|
local s = "hello"
|
||||||
|
for i = 1, 10 do
|
||||||
|
s = s .. " world"
|
||||||
|
end
|
||||||
|
return #s
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLuaDirectExecution(b *testing.B) {
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
b.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
// First verify we can execute the code
|
||||||
|
if err := L.DoString(bc.code); err != nil {
|
||||||
|
b.Fatalf("Failed to execute test code: %v", err)
|
||||||
|
}
|
||||||
|
L.Pop(1) // Clean up the result
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Execute string and get result
|
||||||
|
nresults, err := L.ExecuteString(bc.code)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to execute code: %v", err)
|
||||||
|
}
|
||||||
|
L.Pop(nresults) // Clean up any results
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLuaBytecodeExecution(b *testing.B) {
|
||||||
|
// First compile all bytecode
|
||||||
|
bytecodes := make(map[string][]byte)
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
b.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
bytecode, err := L.CompileBytecode(bc.code, bc.name)
|
||||||
|
if err != nil {
|
||||||
|
L.Close()
|
||||||
|
b.Fatalf("Error compiling bytecode for %s: %v", bc.name, err)
|
||||||
|
}
|
||||||
|
bytecodes[bc.name] = bytecode
|
||||||
|
L.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
b.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
bytecode := bytecodes[bc.name]
|
||||||
|
|
||||||
|
// First verify we can execute the bytecode
|
||||||
|
if err := L.LoadBytecode(bytecode, bc.name); err != nil {
|
||||||
|
b.Fatalf("Failed to execute test bytecode: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.SetBytes(int64(len(bytecode))) // Track bytecode size in benchmarks
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := L.LoadBytecode(bytecode, bc.name); err != nil {
|
||||||
|
b.Fatalf("Error executing bytecode: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkTableOperations(b *testing.B) {
|
||||||
|
testData := map[string]interface{}{
|
||||||
|
"number": 42.0,
|
||||||
|
"string": "hello",
|
||||||
|
"bool": true,
|
||||||
|
"nested": map[string]interface{}{
|
||||||
|
"value": 123.0,
|
||||||
|
"array": []float64{1.1, 2.2, 3.3},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("PushTable", func(b *testing.B) {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
b.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
// First verify we can push the table
|
||||||
|
if err := L.PushTable(testData); err != nil {
|
||||||
|
b.Fatalf("Failed to push initial table: %v", err)
|
||||||
|
}
|
||||||
|
L.Pop(1)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := L.PushTable(testData); err != nil {
|
||||||
|
b.Fatalf("Failed to push table: %v", err)
|
||||||
|
}
|
||||||
|
L.Pop(1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ToTable", func(b *testing.B) {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
b.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
// Keep a table on the stack for repeated conversions
|
||||||
|
if err := L.PushTable(testData); err != nil {
|
||||||
|
b.Fatalf("Failed to push initial table: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := L.ToTable(-1); err != nil {
|
||||||
|
b.Fatalf("Failed to convert table: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkValueConversion(b *testing.B) {
|
||||||
|
testValues := []struct {
|
||||||
|
name string
|
||||||
|
value interface{}
|
||||||
|
}{
|
||||||
|
{"Number", 42.0},
|
||||||
|
{"String", "hello world"},
|
||||||
|
{"Boolean", true},
|
||||||
|
{"Nil", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tv := range testValues {
|
||||||
|
b.Run("Push"+tv.name, func(b *testing.B) {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
b.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
// First verify we can push the value
|
||||||
|
if err := L.PushValue(tv.value); err != nil {
|
||||||
|
b.Fatalf("Failed to push initial value: %v", err)
|
||||||
|
}
|
||||||
|
L.Pop(1)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := L.PushValue(tv.value); err != nil {
|
||||||
|
b.Fatalf("Failed to push value: %v", err)
|
||||||
|
}
|
||||||
|
L.Pop(1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("To"+tv.name, func(b *testing.B) {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
b.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
// Keep a value on the stack for repeated conversions
|
||||||
|
if err := L.PushValue(tv.value); err != nil {
|
||||||
|
b.Fatalf("Failed to push initial value: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := L.ToValue(-1); err != nil {
|
||||||
|
b.Fatalf("Failed to convert value: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
611
wrapper_test.go
611
wrapper_test.go
|
@ -7,25 +7,151 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type stateFactory struct {
|
|
||||||
name string
|
|
||||||
new func() *State
|
|
||||||
}
|
|
||||||
|
|
||||||
var factories = []stateFactory{
|
|
||||||
{"unsafe", New},
|
|
||||||
{"safe", NewSafe},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
for _, f := range factories {
|
L := New()
|
||||||
t.Run(f.name, func(t *testing.T) {
|
if L == nil {
|
||||||
L := f.new()
|
t.Fatal("Failed to create Lua state")
|
||||||
if L == nil {
|
}
|
||||||
t.Fatal("Failed to create Lua state")
|
defer L.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid function",
|
||||||
|
code: "function add(a, b) return a + b end",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid expression",
|
||||||
|
code: "return 1 + 1",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "syntax error",
|
||||||
|
code: "function bad syntax",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
err := L.LoadString(tt.code)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("LoadString() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantErr {
|
||||||
|
// Verify the function is on the stack
|
||||||
|
if L.GetTop() != 1 {
|
||||||
|
t.Error("LoadString() did not leave exactly one value on stack")
|
||||||
}
|
}
|
||||||
defer L.Close()
|
if !L.IsFunction(-1) {
|
||||||
})
|
t.Error("LoadString() did not leave a function on the stack")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code string
|
||||||
|
wantResults int
|
||||||
|
checkResults func(*State) error
|
||||||
|
wantErr bool
|
||||||
|
wantStackSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no results",
|
||||||
|
code: "local x = 1",
|
||||||
|
wantResults: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single result",
|
||||||
|
code: "return 42",
|
||||||
|
wantResults: 1,
|
||||||
|
checkResults: func(L *State) error {
|
||||||
|
if n := L.ToNumber(-1); n != 42 {
|
||||||
|
return fmt.Errorf("got %v, want 42", n)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple results",
|
||||||
|
code: "return 1, 'test', true",
|
||||||
|
wantResults: 3,
|
||||||
|
checkResults: func(L *State) error {
|
||||||
|
if n := L.ToNumber(-3); n != 1 {
|
||||||
|
return fmt.Errorf("first result: got %v, want 1", n)
|
||||||
|
}
|
||||||
|
if s := L.ToString(-2); s != "test" {
|
||||||
|
return fmt.Errorf("second result: got %v, want 'test'", s)
|
||||||
|
}
|
||||||
|
if b := L.ToBoolean(-1); !b {
|
||||||
|
return fmt.Errorf("third result: got %v, want true", b)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "syntax error",
|
||||||
|
code: "this is not valid lua",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "runtime error",
|
||||||
|
code: "error('test error')",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
// Record initial stack size
|
||||||
|
initialStack := L.GetTop()
|
||||||
|
|
||||||
|
results, err := L.ExecuteString(tt.code)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ExecuteString() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
if results != tt.wantResults {
|
||||||
|
t.Errorf("ExecuteString() returned %d results, want %d", results, tt.wantResults)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.checkResults != nil {
|
||||||
|
if err := tt.checkResults(L); err != nil {
|
||||||
|
t.Errorf("Result check failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify stack size matches expected results
|
||||||
|
if got := L.GetTop() - initialStack; got != tt.wantResults {
|
||||||
|
t.Errorf("Stack size grew by %d, want %d", got, tt.wantResults)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,20 +167,22 @@ func TestDoString(t *testing.T) {
|
||||||
{"runtime error", "error('test error')", true},
|
{"runtime error", "error('test error')", true},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range factories {
|
for _, tt := range tests {
|
||||||
for _, tt := range tests {
|
L := New()
|
||||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
if L == nil {
|
||||||
L := f.new()
|
t.Fatal("Failed to create Lua state")
|
||||||
if L == nil {
|
}
|
||||||
t.Fatal("Failed to create Lua state")
|
defer L.Close()
|
||||||
}
|
|
||||||
defer L.Close()
|
|
||||||
|
|
||||||
err := L.DoString(tt.code)
|
initialStack := L.GetTop()
|
||||||
if (err != nil) != tt.wantErr {
|
err := L.DoString(tt.code)
|
||||||
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
|
if (err != nil) != tt.wantErr {
|
||||||
}
|
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
})
|
}
|
||||||
|
|
||||||
|
// Verify stack is unchanged
|
||||||
|
if finalStack := L.GetTop(); finalStack != initialStack {
|
||||||
|
t.Errorf("Stack size changed from %d to %d", initialStack, finalStack)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -107,95 +235,171 @@ func TestPushAndGetValues(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range factories {
|
for _, v := range values {
|
||||||
for _, v := range values {
|
L := New()
|
||||||
t.Run(f.name+"/"+v.name, func(t *testing.T) {
|
if L == nil {
|
||||||
L := f.new()
|
t.Fatal("Failed to create Lua state")
|
||||||
if L == nil {
|
}
|
||||||
t.Fatal("Failed to create Lua state")
|
defer L.Close()
|
||||||
}
|
|
||||||
defer L.Close()
|
|
||||||
|
|
||||||
v.push(L)
|
v.push(L)
|
||||||
if err := v.check(L); err != nil {
|
if err := v.check(L); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStackManipulation(t *testing.T) {
|
func TestStackManipulation(t *testing.T) {
|
||||||
for _, f := range factories {
|
L := New()
|
||||||
t.Run(f.name, func(t *testing.T) {
|
if L == nil {
|
||||||
L := f.new()
|
t.Fatal("Failed to create Lua state")
|
||||||
if L == nil {
|
}
|
||||||
t.Fatal("Failed to create Lua state")
|
defer L.Close()
|
||||||
}
|
|
||||||
defer L.Close()
|
|
||||||
|
|
||||||
// Push values
|
// Push values
|
||||||
values := []string{"first", "second", "third"}
|
values := []string{"first", "second", "third"}
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
L.PushString(v)
|
L.PushString(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check size
|
// Check size
|
||||||
if top := L.GetTop(); top != len(values) {
|
if top := L.GetTop(); top != len(values) {
|
||||||
t.Errorf("stack size = %d, want %d", top, len(values))
|
t.Errorf("stack size = %d, want %d", top, len(values))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pop one value
|
// Pop one value
|
||||||
L.Pop(1)
|
L.Pop(1)
|
||||||
|
|
||||||
// Check new top
|
// Check new top
|
||||||
if str := L.ToString(-1); str != "second" {
|
if str := L.ToString(-1); str != "second" {
|
||||||
t.Errorf("top value = %q, want 'second'", str)
|
t.Errorf("top value = %q, want 'second'", str)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check new size
|
// Check new size
|
||||||
if top := L.GetTop(); top != len(values)-1 {
|
if top := L.GetTop(); top != len(values)-1 {
|
||||||
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
|
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGlobals(t *testing.T) {
|
func TestGlobals(t *testing.T) {
|
||||||
for _, f := range factories {
|
L := New()
|
||||||
t.Run(f.name, func(t *testing.T) {
|
if L == nil {
|
||||||
L := f.new()
|
t.Fatal("Failed to create Lua state")
|
||||||
if L == nil {
|
}
|
||||||
t.Fatal("Failed to create Lua state")
|
defer L.Close()
|
||||||
}
|
|
||||||
defer L.Close()
|
|
||||||
|
|
||||||
// Test via Lua
|
// Test via Lua
|
||||||
if err := L.DoString(`globalVar = "test"`); err != nil {
|
if err := L.DoString(`globalVar = "test"`); err != nil {
|
||||||
t.Fatalf("DoString error: %v", err)
|
t.Fatalf("DoString error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the global
|
// Get the global
|
||||||
L.GetGlobal("globalVar")
|
L.GetGlobal("globalVar")
|
||||||
if str := L.ToString(-1); str != "test" {
|
if str := L.ToString(-1); str != "test" {
|
||||||
t.Errorf("global value = %q, want 'test'", str)
|
t.Errorf("global value = %q, want 'test'", str)
|
||||||
}
|
}
|
||||||
L.Pop(1)
|
L.Pop(1)
|
||||||
|
|
||||||
// Set and get via API
|
// Set and get via API
|
||||||
L.PushNumber(42)
|
L.PushNumber(42)
|
||||||
L.SetGlobal("testNum")
|
L.SetGlobal("testNum")
|
||||||
|
|
||||||
L.GetGlobal("testNum")
|
L.GetGlobal("testNum")
|
||||||
if num := L.ToNumber(-1); num != 42 {
|
if num := L.ToNumber(-1); num != 42 {
|
||||||
t.Errorf("global number = %f, want 42", num)
|
t.Errorf("global number = %f, want 42", num)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCall(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
funcName string // Add explicit function name field
|
||||||
|
setup string
|
||||||
|
args []interface{}
|
||||||
|
nresults int
|
||||||
|
checkStack func(*State) error
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
funcName: "add",
|
||||||
|
setup: "function add(a, b) return a + b end",
|
||||||
|
args: []interface{}{float64(40), float64(2)},
|
||||||
|
nresults: 1,
|
||||||
|
checkStack: func(L *State) error {
|
||||||
|
if n := L.ToNumber(-1); n != 42 {
|
||||||
|
return fmt.Errorf("got %v, want 42", n)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
funcName: "multi",
|
||||||
|
setup: "function multi() return 1, 'test', true end",
|
||||||
|
args: []interface{}{},
|
||||||
|
nresults: 3,
|
||||||
|
checkStack: func(L *State) error {
|
||||||
|
if n := L.ToNumber(-3); n != 1 {
|
||||||
|
return fmt.Errorf("first result: got %v, want 1", n)
|
||||||
|
}
|
||||||
|
if s := L.ToString(-2); s != "test" {
|
||||||
|
return fmt.Errorf("second result: got %v, want 'test'", s)
|
||||||
|
}
|
||||||
|
if b := L.ToBoolean(-1); !b {
|
||||||
|
return fmt.Errorf("third result: got %v, want true", b)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
funcName: "err",
|
||||||
|
setup: "function err() error('test error') end",
|
||||||
|
args: []interface{}{},
|
||||||
|
nresults: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
// Setup function
|
||||||
|
if err := L.DoString(tt.setup); err != nil {
|
||||||
|
t.Fatalf("Setup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get function
|
||||||
|
L.GetGlobal(tt.funcName)
|
||||||
|
if !L.IsFunction(-1) {
|
||||||
|
t.Fatal("Failed to get function")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push arguments
|
||||||
|
for _, arg := range tt.args {
|
||||||
|
if err := L.PushValue(arg); err != nil {
|
||||||
|
t.Fatalf("Failed to push argument: %v", err)
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
|
// Call function
|
||||||
|
err := L.Call(len(tt.args), tt.nresults)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Call() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil && tt.checkStack != nil {
|
||||||
|
if err := tt.checkStack(L); err != nil {
|
||||||
|
t.Errorf("Stack check failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDoFile(t *testing.T) {
|
func TestDoFile(t *testing.T) {
|
||||||
L := NewSafe()
|
L := New()
|
||||||
defer L.Close()
|
defer L.Close()
|
||||||
|
|
||||||
// Create test file
|
// Create test file
|
||||||
|
@ -223,7 +427,7 @@ func TestDoFile(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequireAndPackagePath(t *testing.T) {
|
func TestRequireAndPackagePath(t *testing.T) {
|
||||||
L := NewSafe()
|
L := New()
|
||||||
defer L.Close()
|
defer L.Close()
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
|
@ -260,7 +464,7 @@ func TestRequireAndPackagePath(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetPackagePath(t *testing.T) {
|
func TestSetPackagePath(t *testing.T) {
|
||||||
L := NewSafe()
|
L := New()
|
||||||
defer L.Close()
|
defer L.Close()
|
||||||
|
|
||||||
customPath := "./custom/?.lua"
|
customPath := "./custom/?.lua"
|
||||||
|
@ -273,4 +477,217 @@ func TestSetPackagePath(t *testing.T) {
|
||||||
if path := L.ToString(-1); path != customPath {
|
if path := L.ToString(-1); path != customPath {
|
||||||
t.Errorf("Expected package.path=%q, got %q", customPath, path)
|
t.Errorf("Expected package.path=%q, got %q", customPath, path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the old path is completely replaced
|
||||||
|
initialPath := L.ToString(-1)
|
||||||
|
anotherPath := "./another/?.lua"
|
||||||
|
if err := L.SetPackagePath(anotherPath); err != nil {
|
||||||
|
t.Fatalf("Second SetPackagePath failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
L.GetGlobal("package")
|
||||||
|
L.GetField(-1, "path")
|
||||||
|
if path := L.ToString(-1); path != anotherPath {
|
||||||
|
t.Errorf("Expected package.path=%q, got %q", anotherPath, path)
|
||||||
|
}
|
||||||
|
if path := L.ToString(-1); path == initialPath {
|
||||||
|
t.Error("SetPackagePath did not replace the old path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStackDebug(t *testing.T) {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
t.Log("Testing LoadString:")
|
||||||
|
initialTop := L.GetTop()
|
||||||
|
t.Logf("Initial stack size: %d", initialTop)
|
||||||
|
|
||||||
|
err := L.LoadString("return 42")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("LoadString failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
afterLoad := L.GetTop()
|
||||||
|
t.Logf("Stack size after load: %d", afterLoad)
|
||||||
|
t.Logf("Type of top element: %s", L.GetType(-1))
|
||||||
|
|
||||||
|
if L.IsFunction(-1) {
|
||||||
|
t.Log("Top element is a function")
|
||||||
|
} else {
|
||||||
|
t.Log("Top element is NOT a function")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up after LoadString test
|
||||||
|
L.SetTop(0)
|
||||||
|
|
||||||
|
t.Log("\nTesting ExecuteString:")
|
||||||
|
if err := L.DoString("function test() return 1, 'hello', true end"); err != nil {
|
||||||
|
t.Errorf("DoString failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
beforeExec := L.GetTop()
|
||||||
|
t.Logf("Stack size before execute: %d", beforeExec)
|
||||||
|
|
||||||
|
nresults, err := L.ExecuteString("return test()")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ExecuteString failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
afterExec := L.GetTop()
|
||||||
|
t.Logf("Stack size after execute: %d", afterExec)
|
||||||
|
t.Logf("Reported number of results: %d", nresults)
|
||||||
|
|
||||||
|
// Print each stack element
|
||||||
|
for i := 1; i <= afterExec; i++ {
|
||||||
|
t.Logf("Stack[-%d] type: %s", i, L.GetType(-i))
|
||||||
|
}
|
||||||
|
|
||||||
|
if afterExec != nresults {
|
||||||
|
t.Errorf("Stack size (%d) doesn't match number of results (%d)", afterExec, nresults)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateRendering(t *testing.T) {
|
||||||
|
L := New()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
defer L.Cleanup()
|
||||||
|
|
||||||
|
// Create a simple render.template function
|
||||||
|
renderFunc := func(s *State) int {
|
||||||
|
// Template will be at index 1, data at index 2
|
||||||
|
data, err := s.ToTable(2)
|
||||||
|
if err != nil {
|
||||||
|
s.PushString(fmt.Sprintf("failed to get data table: %v", err))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push data back as global for template access
|
||||||
|
if err := s.PushTable(data); err != nil {
|
||||||
|
s.PushString(fmt.Sprintf("failed to push data table: %v", err))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
s.SetGlobal("data")
|
||||||
|
|
||||||
|
// Template processing code
|
||||||
|
luaCode := `
|
||||||
|
local result = {}
|
||||||
|
if data.user.logged_in then
|
||||||
|
table.insert(result, '<div class="profile">')
|
||||||
|
table.insert(result, string.format(' <h1>Welcome, %s!</h1>', tostring(data.user.name)))
|
||||||
|
table.insert(result, ' <div class="user-info">')
|
||||||
|
table.insert(result, string.format(' <p>Email: %s</p>', tostring(data.user.email)))
|
||||||
|
table.insert(result, string.format(' <p>Member since: %s</p>', tostring(data.user.joined_date)))
|
||||||
|
table.insert(result, ' </div>')
|
||||||
|
if data.user.is_admin then
|
||||||
|
table.insert(result, ' <div class="admin-panel">')
|
||||||
|
table.insert(result, ' <h2>Admin Controls</h2>')
|
||||||
|
table.insert(result, ' <!-- admin content -->')
|
||||||
|
table.insert(result, ' </div>')
|
||||||
|
end
|
||||||
|
table.insert(result, '</div>')
|
||||||
|
else
|
||||||
|
table.insert(result, '<div class="profile">')
|
||||||
|
table.insert(result, ' <h1>Please log in to view your profile</h1>')
|
||||||
|
table.insert(result, '</div>')
|
||||||
|
end
|
||||||
|
return table.concat(result, '\n')`
|
||||||
|
|
||||||
|
result, err := s.DoStringResult(luaCode)
|
||||||
|
if err != nil {
|
||||||
|
s.PushString(fmt.Sprintf("template execution failed: %v", err))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push the string result
|
||||||
|
if str, ok := result.(string); ok {
|
||||||
|
s.PushString(str)
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
s.PushString(fmt.Sprintf("expected string result, got %T", result))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create render table and add template function
|
||||||
|
L.NewTable()
|
||||||
|
if err := L.PushGoFunction(renderFunc); err != nil {
|
||||||
|
t.Fatalf("Failed to create render function: %v", err)
|
||||||
|
}
|
||||||
|
L.SetField(-2, "template")
|
||||||
|
L.SetGlobal("render")
|
||||||
|
|
||||||
|
// Test with logged in admin user
|
||||||
|
testCode := `
|
||||||
|
local data = {
|
||||||
|
user = {
|
||||||
|
logged_in = true,
|
||||||
|
name = "John Doe",
|
||||||
|
email = "john@example.com",
|
||||||
|
joined_date = "2024-02-09",
|
||||||
|
is_admin = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return render.template("test.html", data)
|
||||||
|
`
|
||||||
|
|
||||||
|
result, err := L.DoStringResult(testCode)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to execute test: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
str, ok := result.(string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected string result, got %T", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedResult := `<div class="profile">
|
||||||
|
<h1>Welcome, John Doe!</h1>
|
||||||
|
<div class="user-info">
|
||||||
|
<p>Email: john@example.com</p>
|
||||||
|
<p>Member since: 2024-02-09</p>
|
||||||
|
</div>
|
||||||
|
<div class="admin-panel">
|
||||||
|
<h2>Admin Controls</h2>
|
||||||
|
<!-- admin content -->
|
||||||
|
</div>
|
||||||
|
</div>`
|
||||||
|
|
||||||
|
if str != expectedResult {
|
||||||
|
t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with logged out user
|
||||||
|
testCode = `
|
||||||
|
local data = {
|
||||||
|
user = {
|
||||||
|
logged_in = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return render.template("test.html", data)
|
||||||
|
`
|
||||||
|
|
||||||
|
result, err = L.DoStringResult(testCode)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to execute logged out test: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
str, ok = result.(string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected string result, got %T", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedResult = `<div class="profile">
|
||||||
|
<h1>Please log in to view your profile</h1>
|
||||||
|
</div>`
|
||||||
|
|
||||||
|
if str != expectedResult {
|
||||||
|
t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user