From 143b9333c6752db725dd19fa336143eb4517f660 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Wed, 26 Feb 2025 07:00:01 -0600 Subject: [PATCH] Wrapper rewrite --- DOCS.md | 245 +++++++++++--- LICENSE | 2 +- README.md | 142 ++++---- bench/bench_test.go | 430 +++++++++++++++++++++++++ bench/ezbench_test.go | 133 ++++++++ bytecode.go | 145 +++++---- bytecode_test.go | 162 ---------- example/main.go | 70 ++++ example/script.lua | 35 ++ example/utils.lua | 19 ++ functions.go | 34 +- functions_test.go | 107 ------- stack.go | 89 +----- table.go | 105 +++--- table_test.go | 93 ------ tests/bytecode_test.go | 443 +++++++++++++++++++++++++ tests/functions_test.go | 178 +++++++++++ tests/stack_test.go | 53 +++ tests/table_test.go | 246 ++++++++++++++ tests/wrapper_test.go | 473 +++++++++++++++++++++++++++ types.go | 2 +- wrapper.go | 582 +++++++++++++++++---------------- wrapper_bench_test.go | 237 -------------- wrapper_test.go | 693 ---------------------------------------- 24 files changed, 2863 insertions(+), 1855 deletions(-) create mode 100644 bench/bench_test.go create mode 100644 bench/ezbench_test.go delete mode 100644 bytecode_test.go create mode 100644 example/main.go create mode 100644 example/script.lua create mode 100644 example/utils.lua delete mode 100644 functions_test.go delete mode 100644 table_test.go create mode 100644 tests/bytecode_test.go create mode 100644 tests/functions_test.go create mode 100644 tests/stack_test.go create mode 100644 tests/table_test.go create mode 100644 tests/wrapper_test.go delete mode 100644 wrapper_bench_test.go delete mode 100644 wrapper_test.go diff --git a/DOCS.md b/DOCS.md index dbc815b..5f07f5d 100644 --- a/DOCS.md +++ b/DOCS.md @@ -2,16 +2,8 @@ ## State Management -### NewSafe() *State -Creates a new Lua state with stack safety enabled. -```go -L := luajit.NewSafe() -defer L.Close() -defer L.Cleanup() -``` - ### New() *State -Creates a new Lua state without stack safety checks. +Creates a new Lua state with all standard libraries loaded. ```go L := luajit.New() defer L.Close() @@ -38,6 +30,18 @@ Returns the index of the top element in the stack. top := L.GetTop() // 0 for empty stack ``` +### SetTop(index int) +Sets the stack top to a specific index. +```go +L.SetTop(2) // Truncate stack to 2 elements +``` + +### PushCopy(index int) +Pushes a copy of the value at the given index onto the stack. +```go +L.PushCopy(-1) // Duplicate the top element +``` + ### Pop(n int) Removes n elements from the stack. ```go @@ -52,6 +56,9 @@ L.Remove(-1) // Remove top element L.Remove(1) // Remove first element ``` +### absIndex(index int) int +Internal function that converts a possibly negative index to its absolute position. + ### checkStack(n int) error Internal function that ensures there's enough space for n new elements. ```go @@ -70,9 +77,12 @@ if L.GetType(-1) == TypeString { } ``` -### IsFunction(index int) bool +### IsNil(index int) bool +### IsBoolean(index int) bool +### IsNumber(index int) bool +### IsString(index int) bool ### IsTable(index int) bool -### IsUserData(index int) bool +### IsFunction(index int) bool Type checking functions for specific Lua types. ```go if L.IsTable(-1) { @@ -118,6 +128,12 @@ if err != nil { } ``` +### GetTableLength(index int) int +Returns the length of a table at the given index. +```go +length := L.GetTableLength(-1) +``` + ## Value Pushing ### PushNil() @@ -148,16 +164,82 @@ data := map[string]interface{}{ err := L.PushTable(data) ``` -## Function Registration +## Table Operations -### RegisterGoFunction(name string, fn GoFunction) error -Registers a Go function that can be called from Lua. +### CreateTable(narr, nrec int) +Creates a new table with pre-allocated space. ```go -adder := func(s *State) int { +L.CreateTable(10, 5) // Space for 10 array elements, 5 records +``` + +### NewTable() +Creates a new empty table and pushes it onto the stack. +```go +L.NewTable() +``` + +### GetTable(index int) +Gets a table field (t[k]) where t is at the given index and k is at the top of the stack. +```go +L.PushString("key") +L.GetTable(-2) // Gets table["key"] +``` + +### SetTable(index int) +Sets a table field (t[k] = v) where t is at the given index, k is at -2, and v is at -1. +```go +L.PushString("key") +L.PushString("value") +L.SetTable(-3) // table["key"] = "value" +``` + +### GetField(index int, key string) +Gets a table field t[k] and pushes it onto the stack. +```go +L.GetField(-1, "name") // gets table.name +``` + +### SetField(index int, key string) +Sets a table field t[k] = v, where v is the value at the top of the stack. +```go +L.PushString("value") +L.SetField(-2, "key") // table.key = "value" +``` + +### Next(index int) bool +Pops a key from the stack and pushes the next key-value pair from the table. +```go +L.PushNil() // Start iteration +for L.Next(-2) { + // Stack now has key at -2 and value at -1 + key := L.ToString(-2) + value := L.ToString(-1) + L.Pop(1) // Remove value, keep key for next iteration +} +``` + +## Function Registration and Calling + +### GoFunction +Type definition for Go functions callable from Lua. +```go +type GoFunction func(*State) int +``` + +### PushGoFunction(fn GoFunction) error +Wraps a Go function and pushes it onto the Lua stack. +```go +adder := func(s *luajit.State) int { sum := s.ToNumber(1) + s.ToNumber(2) s.PushNumber(sum) return 1 } +err := L.PushGoFunction(adder) +``` + +### RegisterGoFunction(name string, fn GoFunction) error +Registers a Go function as a global Lua function. +```go err := L.RegisterGoFunction("add", adder) ``` @@ -167,23 +249,45 @@ Removes a previously registered function. L.UnregisterGoFunction("add") ``` -## Package Management - -### SetPackagePath(path string) error -Sets the Lua package.path variable. +### Call(nargs, nresults int) error +Calls a function with the given number of arguments and results. ```go -err := L.SetPackagePath("./?.lua;/usr/local/share/lua/5.1/?.lua") +L.GetGlobal("myfunction") +L.PushNumber(1) +L.PushNumber(2) +err := L.Call(2, 1) // Call with 2 args, expect 1 result ``` -### AddPackagePath(path string) error -Adds a path to the existing package.path. +## Global Operations + +### GetGlobal(name string) +Gets a global variable and pushes it onto the stack. ```go -err := L.AddPackagePath("./modules/?.lua") +L.GetGlobal("myGlobal") +``` + +### SetGlobal(name string) +Sets a global variable from the value at the top of the stack. +```go +L.PushNumber(42) +L.SetGlobal("answer") // answer = 42 ``` ## Code Execution -### DoString(str string) error +### LoadString(code string) error +Loads a Lua chunk from a string without executing it. +```go +err := L.LoadString("return 42") +``` + +### LoadFile(filename string) error +Loads a Lua chunk from a file without executing it. +```go +err := L.LoadFile("script.lua") +``` + +### DoString(code string) error Executes a string of Lua code. ```go err := L.DoString(` @@ -199,32 +303,76 @@ Executes a Lua file. err := L.DoFile("script.lua") ``` -## Table Operations - -### GetField(index int, key string) -Gets a field from a table at the given index. +### Execute(code string) (int, error) +Executes a Lua string and returns the number of results left on the stack. ```go -L.GetField(-1, "name") // gets table.name +nresults, err := L.Execute("return 1, 2, 3") +// nresults would be 3 ``` -### SetField(index int, key string) -Sets a field in a table at the given index. +### ExecuteWithResult(code string) (interface{}, error) +Executes a Lua string and returns the first result. ```go -L.PushString("value") -L.SetField(-2, "key") // table.key = "value" +result, err := L.ExecuteWithResult("return 'hello'") +// result would be "hello" ``` -### GetGlobal(name string) -Gets a global variable. +## Bytecode Operations + +### CompileBytecode(code string, name string) ([]byte, error) +Compiles a Lua chunk to bytecode without executing it. ```go -L.GetGlobal("myGlobal") +bytecode, err := L.CompileBytecode("return 42", "test") ``` -### SetGlobal(name string) -Sets a global variable from the value at the top of the stack. +### LoadBytecode(bytecode []byte, name string) error +Loads precompiled bytecode without executing it. ```go -L.PushNumber(42) -L.SetGlobal("answer") // answer = 42 +err := L.LoadBytecode(bytecode, "test") +``` + +### RunBytecode() error +Executes previously loaded bytecode with 0 results. +```go +err := L.RunBytecode() +``` + +### RunBytecodeWithResults(nresults int) error +Executes bytecode and keeps nresults on the stack. +```go +err := L.RunBytecodeWithResults(1) +``` + +### LoadAndRunBytecode(bytecode []byte, name string) error +Loads and executes bytecode. +```go +err := L.LoadAndRunBytecode(bytecode, "test") +``` + +### LoadAndRunBytecodeWithResults(bytecode []byte, name string, nresults int) error +Loads and executes bytecode, preserving results. +```go +err := L.LoadAndRunBytecodeWithResults(bytecode, "test", 1) +``` + +### CompileAndRun(code string, name string) error +Compiles and immediately executes Lua code. +```go +err := L.CompileAndRun("answer = 42", "test") +``` + +## Package Path Operations + +### SetPackagePath(path string) error +Sets the Lua package.path. +```go +err := L.SetPackagePath("./?.lua;/usr/local/share/lua/5.1/?.lua") +``` + +### AddPackagePath(path string) error +Adds a path to package.path. +```go +err := L.AddPackagePath("./modules/?.lua") ``` ## Error Handling @@ -238,13 +386,21 @@ type LuaError struct { } ``` -### getStackTrace() string +### GetStackTrace() string Gets the current Lua stack trace. ```go -trace := L.getStackTrace() +trace := L.GetStackTrace() fmt.Println(trace) ``` +### safeCall(f func() C.int) error +Internal function that wraps a potentially dangerous C call with stack checking. +```go +err := s.safeCall(func() C.int { + return C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0) +}) +``` + ## Thread Safety Notes - The function registry is thread-safe @@ -256,14 +412,13 @@ fmt.Println(trace) Always pair state creation with cleanup: ```go -L := luajit.NewSafe() +L := luajit.New() defer L.Close() defer L.Cleanup() ``` -Stack management in unsafe mode requires manual attention: +Stack management requires manual attention: ```go -L := luajit.New() L.PushString("hello") // ... use the string L.Pop(1) // Clean up when done diff --git a/LICENSE b/LICENSE index a1cae91..f462541 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2025 Sky +Copyright (c) 2025 Sharkk, Skylear Johnson Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/README.md b/README.md index a578fb8..2c49de0 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # LuaJIT Go Wrapper -Hey there! This is a Go wrapper for LuaJIT that makes it easy to embed Lua in your Go applications. We've focused on making it both safe and fast, while keeping the API clean and intuitive. +This is a Go wrapper for LuaJIT that makes it easy to embed Lua in your Go applications. We've focused on making it both performant and developer-friendly, with an API that feels natural to use. ## What's This For? @@ -22,51 +22,13 @@ You'll need LuaJIT's development files, but don't worry - we include libraries f Here's the simplest thing you can do: ```go -L := luajit.NewSafe() +L := luajit.New() defer L.Close() defer L.Cleanup() err := L.DoString(`print("Hey from Lua!")`) ``` -## Stack Safety: Choose Your Adventure - -One of the key decisions you'll make is whether to use stack-safe mode. Here's what that means: - -### Stack-Safe Mode (NewSafe()) -```go -L := luajit.NewSafe() -``` -Think of this as driving with guardrails. It's perfect when: -- You're new to Lua or embedding scripting languages -- You're writing a server or long-running application -- You want to handle untrusted Lua code -- You'd rather have slightly slower code than mysterious crashes - -The safe mode will: -- Prevent stack overflows -- Check types more thoroughly -- Clean up after messy Lua code -- Give you better error messages - -### Non-Stack-Safe Mode (New()) -```go -L := luajit.New() -``` -This is like taking off the training wheels. Use it when: -- You know exactly how your Lua code behaves -- You've profiled your application and need more speed -- You're doing lots of rapid, simple Lua calls -- You're writing performance-critical code - -The unsafe mode: -- Skips most safety checks -- Runs noticeably faster -- Gives you direct control over the stack -- Can crash spectacularly if you make a mistake - -Most applications should start with stack-safe mode and only switch to unsafe mode if profiling shows it's necessary. - ## Working with Bytecode Need even more performance? You can compile your Lua code to bytecode and reuse it: @@ -82,28 +44,30 @@ bytecode, err := L.CompileBytecode(` // Execute many times for i := 0; i < 1000; i++ { - err := L.LoadBytecode(bytecode, "calc") + err := L.LoadAndRunBytecode(bytecode, "calc") } // Or do both at once -err := L.CompileAndLoad(`return "hello"`, "greeting") +err := L.CompileAndRun(`return "hello"`, "greeting") ``` ### When to Use Bytecode Bytecode execution is consistently faster than direct execution: -- Simple operations: 20-60% faster -- String operations: Up to 60% speedup -- Loop-heavy code: 10-15% improvement -- Table operations: 10-15% faster -Some benchmark results on a typical system: ``` -Operation Direct Exec Bytecode Exec ----------------------------------------- -Simple Math 1.5M ops/sec 2.4M ops/sec -String Ops 370K ops/sec 600K ops/sec -Table Creation 127K ops/sec 146K ops/sec +Benchmark Ops/sec Comparison +---------------------------------------------------------------------------- +BenchmarkSimpleDoString 2,561,012 Base +BenchmarkSimplePrecompiledBytecode 3,828,841 +49.5% faster +BenchmarkFunctionCallDoString 2,021,098 Base +BenchmarkFunctionCallPrecompiled 3,482,074 +72.3% faster +BenchmarkLoopDoString 188,119 Base +BenchmarkLoopPrecompiled 211,081 +12.2% faster +BenchmarkTableOperationsDoString 84,086 Base +BenchmarkTableOperationsPrecompiled 93,655 +11.4% faster +BenchmarkComplexScript 33,133 Base +BenchmarkComplexScriptPrecompiled 41,044 +23.9% faster ``` Use bytecode when you: @@ -114,7 +78,7 @@ Use bytecode when you: ## Registering Go Functions -Want to call Go code from Lua? Easy: +Want to call Go code from Lua? It's straightforward: ```go // This function adds two numbers and returns the result adder := func(s *luajit.State) int { @@ -151,26 +115,84 @@ result, err := L.ToTable(-1) ## Error Handling -We try to give you useful errors instead of mysterious panics: +We provide useful errors instead of mysterious panics: ```go if err := L.DoString("this isn't valid Lua!"); err != nil { if luaErr, ok := err.(*luajit.LuaError); ok { - fmt.Printf("Oops: %s\n", luaErr.Message) + fmt.Printf("Error: %s\n", luaErr.Message) } } ``` -## A Few Tips +## Memory Management -- Always use those `defer L.Close()` and `defer L.Cleanup()` calls - they prevent memory leaks +The wrapper uses a custom table pooling system to reduce GC pressure when handling many tables: + +```go +// Tables are pooled and reused internally for better performance +for i := 0; i < 1000; i++ { + L.GetGlobal("table") + table, _ := L.ToTable(-1) + // Use table... + L.Pop(1) + // Table is automatically returned to pool +} +``` + +## Best Practices + +- Always use `defer L.Close()` and `defer L.Cleanup()` to prevent memory leaks - Each Lua state should stick to one goroutine -- For concurrent stuff, create multiple states +- For concurrent operations, create multiple states - You can share functions between states safely -- Keep an eye on your stack in unsafe mode - it won't clean up after itself -- Start with stack-safe mode and measure before optimizing +- Keep an eye on your stack management - pop as many items as you push - Use bytecode for frequently executed code paths - Consider compiling critical Lua code to bytecode at startup +## Advanced Features + +### Bytecode Serialization + +You can serialize bytecode for distribution or caching: + +```go +// Compile once +bytecode, _ := L.CompileBytecode(complexScript, "module") + +// Save to file +ioutil.WriteFile("module.luac", bytecode, 0644) + +// Later, load from file +bytecode, _ := ioutil.ReadFile("module.luac") +L.LoadAndRunBytecode(bytecode, "module") +``` + +### Closures and Upvalues + +Bytecode properly preserves closures and upvalues: + +```go +code := ` + local counter = 0 + return function() + counter = counter + 1 + return counter + end +` + +bytecode, _ := L.CompileBytecode(code, "counter") +L.LoadAndRunBytecodeWithResults(bytecode, "counter", 1) +L.SetGlobal("increment") + +// Later... +L.GetGlobal("increment") +L.Call(0, 1) // Returns 1 +L.Pop(1) + +L.GetGlobal("increment") +L.Call(0, 1) // Returns 2 +``` + ## Need Help? Check out the tests in the repository - they're full of examples. If you're stuck, open an issue! We're here to help. diff --git a/bench/bench_test.go b/bench/bench_test.go new file mode 100644 index 0000000..b4d93b6 --- /dev/null +++ b/bench/bench_test.go @@ -0,0 +1,430 @@ +package luajit_bench + +import ( + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// BenchmarkSimpleDoString benchmarks direct execution of a simple expression +func BenchmarkSimpleDoString(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := "local x = 1 + 1" + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.DoString(code); err != nil { + b.Fatalf("DoString failed: %v", err) + } + } +} + +// BenchmarkSimpleCompileAndRun benchmarks compile and run of a simple expression +func BenchmarkSimpleCompileAndRun(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := "local x = 1 + 1" + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.CompileAndRun(code, "simple"); err != nil { + b.Fatalf("CompileAndRun failed: %v", err) + } + } +} + +// BenchmarkSimpleCompileLoadRun benchmarks compile, load, and run of a simple expression +func BenchmarkSimpleCompileLoadRun(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := "local x = 1 + 1" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bytecode, err := state.CompileBytecode(code, "simple") + if err != nil { + b.Fatalf("CompileBytecode failed: %v", err) + } + if err := state.LoadAndRunBytecode(bytecode, "simple"); err != nil { + b.Fatalf("LoadAndRunBytecode failed: %v", err) + } + } +} + +// BenchmarkSimplePrecompiledBytecode benchmarks running precompiled bytecode +func BenchmarkSimplePrecompiledBytecode(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := "local x = 1 + 1" + bytecode, err := state.CompileBytecode(code, "simple") + if err != nil { + b.Fatalf("CompileBytecode failed: %v", err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.LoadAndRunBytecode(bytecode, "simple"); err != nil { + b.Fatalf("LoadAndRunBytecode failed: %v", err) + } + } +} + +// BenchmarkFunctionCallDoString benchmarks direct execution of a function call +func BenchmarkFunctionCallDoString(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Setup function + setupCode := ` + function add(a, b) + return a + b + end + ` + if err := state.DoString(setupCode); err != nil { + b.Fatalf("Failed to set up function: %v", err) + } + + code := "local result = add(10, 20)" + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.DoString(code); err != nil { + b.Fatalf("DoString failed: %v", err) + } + } +} + +// BenchmarkFunctionCallPrecompiled benchmarks precompiled function call +func BenchmarkFunctionCallPrecompiled(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Setup function + setupCode := ` + function add(a, b) + return a + b + end + ` + if err := state.DoString(setupCode); err != nil { + b.Fatalf("Failed to set up function: %v", err) + } + + code := "local result = add(10, 20)" + bytecode, err := state.CompileBytecode(code, "call") + if err != nil { + b.Fatalf("CompileBytecode failed: %v", err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.LoadAndRunBytecode(bytecode, "call"); err != nil { + b.Fatalf("LoadAndRunBytecode failed: %v", err) + } + } +} + +// BenchmarkLoopDoString benchmarks direct execution of a loop +func BenchmarkLoopDoString(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := ` + local sum = 0 + for i = 1, 1000 do + sum = sum + i + end + ` + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.DoString(code); err != nil { + b.Fatalf("DoString failed: %v", err) + } + } +} + +// BenchmarkLoopPrecompiled benchmarks precompiled loop execution +func BenchmarkLoopPrecompiled(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := ` + local sum = 0 + for i = 1, 1000 do + sum = sum + i + end + ` + bytecode, err := state.CompileBytecode(code, "loop") + if err != nil { + b.Fatalf("CompileBytecode failed: %v", err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.LoadAndRunBytecode(bytecode, "loop"); err != nil { + b.Fatalf("LoadAndRunBytecode failed: %v", err) + } + } +} + +// BenchmarkTableOperationsDoString benchmarks direct execution of table operations +func BenchmarkTableOperationsDoString(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := ` + local t = {} + for i = 1, 100 do + t[i] = i * 2 + end + local sum = 0 + for i, v in ipairs(t) do + sum = sum + v + end + ` + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.DoString(code); err != nil { + b.Fatalf("DoString failed: %v", err) + } + } +} + +// BenchmarkTableOperationsPrecompiled benchmarks precompiled table operations +func BenchmarkTableOperationsPrecompiled(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := ` + local t = {} + for i = 1, 100 do + t[i] = i * 2 + end + local sum = 0 + for i, v in ipairs(t) do + sum = sum + v + end + ` + bytecode, err := state.CompileBytecode(code, "table") + if err != nil { + b.Fatalf("CompileBytecode failed: %v", err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.LoadAndRunBytecode(bytecode, "table"); err != nil { + b.Fatalf("LoadAndRunBytecode failed: %v", err) + } + } +} + +// BenchmarkGoFunctionCall benchmarks calling a Go function from Lua +func BenchmarkGoFunctionCall(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Register a simple Go function + add := func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a + b) + return 1 + } + if err := state.RegisterGoFunction("add", add); err != nil { + b.Fatalf("RegisterGoFunction failed: %v", err) + } + + code := "local result = add(10, 20)" + bytecode, err := state.CompileBytecode(code, "gofunc") + if err != nil { + b.Fatalf("CompileBytecode failed: %v", err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.LoadAndRunBytecode(bytecode, "gofunc"); err != nil { + b.Fatalf("LoadAndRunBytecode failed: %v", err) + } + } +} + +// BenchmarkComplexScript benchmarks a more complex script +func BenchmarkComplexScript(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := ` + -- Define a simple class + local Class = {} + Class.__index = Class + + function Class.new(x, y) + local self = setmetatable({}, Class) + self.x = x or 0 + self.y = y or 0 + return self + end + + function Class:move(dx, dy) + self.x = self.x + dx + self.y = self.y + dy + return self + end + + function Class:getPosition() + return self.x, self.y + end + + -- Create instances and operate on them + local instances = {} + for i = 1, 50 do + instances[i] = Class.new(i, i*2) + end + + local result = 0 + for i, obj in ipairs(instances) do + obj:move(i, -i) + local x, y = obj:getPosition() + result = result + x + y + end + + return result + ` + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := state.ExecuteWithResult(code); err != nil { + b.Fatalf("ExecuteWithResult failed: %v", err) + } + } +} + +// BenchmarkComplexScriptPrecompiled benchmarks a precompiled complex script +func BenchmarkComplexScriptPrecompiled(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := ` + -- Define a simple class + local Class = {} + Class.__index = Class + + function Class.new(x, y) + local self = setmetatable({}, Class) + self.x = x or 0 + self.y = y or 0 + return self + end + + function Class:move(dx, dy) + self.x = self.x + dx + self.y = self.y + dy + return self + end + + function Class:getPosition() + return self.x, self.y + end + + -- Create instances and operate on them + local instances = {} + for i = 1, 50 do + instances[i] = Class.new(i, i*2) + end + + local result = 0 + for i, obj in ipairs(instances) do + obj:move(i, -i) + local x, y = obj:getPosition() + result = result + x + y + end + + return result + ` + bytecode, err := state.CompileBytecode(code, "complex") + if err != nil { + b.Fatalf("CompileBytecode failed: %v", err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.LoadBytecode(bytecode, "complex"); err != nil { + b.Fatalf("LoadBytecode failed: %v", err) + } + if err := state.RunBytecodeWithResults(1); err != nil { // Assuming this method exists to get the return value + b.Fatalf("RunBytecodeWithResults failed: %v", err) + } + state.Pop(1) // Pop the result + } +} + +// BenchmarkMultipleExecutions benchmarks executing the same bytecode multiple times +func BenchmarkMultipleExecutions(b *testing.B) { + state := luajit.New() + if state == nil { + b.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Setup a stateful environment + setupCode := ` + counter = 0 + function increment(amount) + counter = counter + (amount or 1) + return counter + end + ` + if err := state.DoString(setupCode); err != nil { + b.Fatalf("Failed to set up environment: %v", err) + } + + // Compile the function call + code := "return increment(5)" + bytecode, err := state.CompileBytecode(code, "increment") + if err != nil { + b.Fatalf("CompileBytecode failed: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := state.LoadBytecode(bytecode, "increment"); err != nil { + b.Fatalf("LoadBytecode failed: %v", err) + } + if err := state.RunBytecodeWithResults(1); err != nil { // Assuming this method exists + b.Fatalf("RunBytecodeWithResults failed: %v", err) + } + state.Pop(1) // Pop the result + } +} diff --git a/bench/ezbench_test.go b/bench/ezbench_test.go new file mode 100644 index 0000000..b35242a --- /dev/null +++ b/bench/ezbench_test.go @@ -0,0 +1,133 @@ +package luajit_bench + +import ( + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +var benchCases = []struct { + name string + code string +}{ + { + name: "SimpleAddition", + code: `return 1 + 1`, + }, + { + name: "LoopSum", + code: ` + local sum = 0 + for i = 1, 1000 do + sum = sum + i + end + return sum + `, + }, + { + name: "FunctionCall", + code: ` + local result = 0 + for i = 1, 100 do + result = result + i + end + return result + `, + }, + { + name: "TableCreation", + code: ` + local t = {} + for i = 1, 100 do + t[i] = i * 2 + end + return t[50] + `, + }, + { + name: "StringOperations", + code: ` + local s = "hello" + for i = 1, 10 do + s = s .. " world" + end + return #s + `, + }, +} + +func BenchmarkLuaDirectExecution(b *testing.B) { + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + L := luajit.New() + if L == nil { + b.Fatal("Failed to create Lua state") + } + defer L.Close() + defer L.Cleanup() + + // First verify we can execute the code + if err := L.DoString(bc.code); err != nil { + b.Fatalf("Failed to execute test code: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Execute string and get results + nresults, err := L.Execute(bc.code) + if err != nil { + b.Fatalf("Failed to execute code: %v", err) + } + L.Pop(nresults) // Clean up any results + } + }) + } +} + +func BenchmarkLuaBytecodeExecution(b *testing.B) { + // First compile all bytecode + bytecodes := make(map[string][]byte) + for _, bc := range benchCases { + L := luajit.New() + if L == nil { + b.Fatal("Failed to create Lua state") + } + defer L.Cleanup() + + bytecode, err := L.CompileBytecode(bc.code, bc.name) + if err != nil { + L.Close() + b.Fatalf("Error compiling bytecode for %s: %v", bc.name, err) + } + bytecodes[bc.name] = bytecode + L.Close() + } + + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + L := luajit.New() + if L == nil { + b.Fatal("Failed to create Lua state") + } + defer L.Close() + defer L.Cleanup() + + bytecode := bytecodes[bc.name] + + // First verify we can execute the bytecode + if err := L.LoadAndRunBytecodeWithResults(bytecode, bc.name, 1); err != nil { + b.Fatalf("Failed to execute test bytecode: %v", err) + } + L.Pop(1) // Clean up the result + + b.ResetTimer() + b.SetBytes(int64(len(bytecode))) // Track bytecode size in benchmarks + + for i := 0; i < b.N; i++ { + if err := L.LoadAndRunBytecode(bytecode, bc.name); err != nil { + b.Fatalf("Error executing bytecode: %v", err) + } + } + }) + } +} diff --git a/bytecode.go b/bytecode.go index 14c579f..de296bb 100644 --- a/bytecode.go +++ b/bytecode.go @@ -21,7 +21,7 @@ static const char *bytecode_reader(lua_State *L, void *ud, size_t *size) { return (const char *)r->buf; } -static int load_bytecode_chunk(lua_State *L, const unsigned char *buf, size_t len, const char *name) { +static int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char *name) { BytecodeReader reader = {buf, len, name}; return lua_load(L, bytecode_reader, &reader, name); } @@ -29,21 +29,37 @@ static int load_bytecode_chunk(lua_State *L, const unsigned char *buf, size_t le typedef struct { unsigned char *buf; size_t len; + size_t capacity; } BytecodeWriter; -int bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) { +static int bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) { BytecodeWriter *w = (BytecodeWriter *)ud; unsigned char *newbuf; (void)L; // unused - newbuf = (unsigned char *)realloc(w->buf, w->len + sz); - if (newbuf == NULL) return 1; + // Check if we need to reallocate + if (w->len + sz > w->capacity) { + size_t new_capacity = w->capacity * 2; + if (new_capacity < w->len + sz) { + new_capacity = w->len + sz; + } - memcpy(newbuf + w->len, p, sz); - w->buf = newbuf; + newbuf = (unsigned char *)realloc(w->buf, new_capacity); + if (newbuf == NULL) return 1; + + w->buf = newbuf; + w->capacity = new_capacity; + } + + memcpy(w->buf + w->len, p, sz); w->len += sz; return 0; } + +// Wrapper function that calls lua_dump with bytecode_writer +static int dump_lua_function(lua_State *L, BytecodeWriter *w) { + return lua_dump(L, bytecode_writer, w); +} */ import "C" import ( @@ -53,89 +69,108 @@ import ( // CompileBytecode compiles a Lua chunk to bytecode without executing it func (s *State) CompileBytecode(code string, name string) ([]byte, error) { - // First load the string but don't execute it - ccode := C.CString(code) - defer C.free(unsafe.Pointer(ccode)) - - cname := C.CString(name) - defer C.free(unsafe.Pointer(cname)) - - if C.luaL_loadstring(s.L, ccode) != 0 { - err := &LuaError{ - Code: int(C.lua_status(s.L)), - Message: s.ToString(-1), - } - s.Pop(1) + if err := s.LoadString(code); err != nil { return nil, fmt.Errorf("failed to load string: %w", err) } - // Set up writer + // Set up writer with initial capacity var writer C.BytecodeWriter writer.buf = nil writer.len = 0 + writer.capacity = 0 + + // Initial allocation with a reasonable size + const initialSize = 4096 + writer.buf = (*C.uchar)(C.malloc(initialSize)) + if writer.buf == nil { + s.Pop(1) // Remove the loaded function + return nil, fmt.Errorf("failed to allocate memory for bytecode") + } + writer.capacity = initialSize // Dump the function to bytecode - if C.lua_dump(s.L, (*[0]byte)(C.bytecode_writer), unsafe.Pointer(&writer)) != 0 { - if writer.buf != nil { - C.free(unsafe.Pointer(writer.buf)) - } - s.Pop(1) - return nil, fmt.Errorf("failed to dump bytecode") - } + err := s.safeCall(func() C.int { + return C.dump_lua_function(s.L, (*C.BytecodeWriter)(unsafe.Pointer(&writer))) + }) - // Copy to Go slice + // Copy bytecode to Go slice regardless of the result bytecode := C.GoBytes(unsafe.Pointer(writer.buf), C.int(writer.len)) // Clean up - if writer.buf != nil { - C.free(unsafe.Pointer(writer.buf)) + C.free(unsafe.Pointer(writer.buf)) + s.Pop(1) // Remove the function from stack + + if err != nil { + return nil, fmt.Errorf("failed to dump bytecode: %w", err) } - s.Pop(1) // Remove the function return bytecode, nil } -// LoadBytecode loads precompiled bytecode and executes it +// LoadBytecode loads precompiled bytecode without executing it func (s *State) LoadBytecode(bytecode []byte, name string) error { + if len(bytecode) == 0 { + return fmt.Errorf("empty bytecode") + } + cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) // Load the bytecode - status := C.load_bytecode_chunk( - s.L, - (*C.uchar)(unsafe.Pointer(&bytecode[0])), - C.size_t(len(bytecode)), - cname, - ) + err := s.safeCall(func() C.int { + return C.load_bytecode( + s.L, + (*C.uchar)(unsafe.Pointer(&bytecode[0])), + C.size_t(len(bytecode)), + cname, + ) + }) - if status != 0 { - err := &LuaError{ - Code: int(status), - Message: s.ToString(-1), - } - s.Pop(1) + if err != nil { return fmt.Errorf("failed to load bytecode: %w", err) } - // Execute the loaded chunk - if err := s.safeCall(func() C.int { - return C.lua_pcall(s.L, 0, 0, 0) - }); err != nil { - return fmt.Errorf("failed to execute bytecode: %w", err) - } - return nil } -// Helper function to compile and immediately load/execute bytecode -func (s *State) CompileAndLoad(code string, name string) error { +// RunBytecode executes previously loaded bytecode with 0 results +func (s *State) RunBytecode() error { + return s.RunBytecodeWithResults(0) +} + +// RunBytecodeWithResults executes bytecode and keeps nresults on the stack +// Use LUA_MULTRET (-1) to keep all results +func (s *State) RunBytecodeWithResults(nresults int) error { + return s.safeCall(func() C.int { + return C.lua_pcall(s.L, 0, C.int(nresults), 0) + }) +} + +// LoadAndRunBytecode loads and executes bytecode +func (s *State) LoadAndRunBytecode(bytecode []byte, name string) error { + if err := s.LoadBytecode(bytecode, name); err != nil { + return err + } + return s.RunBytecode() +} + +// LoadAndRunBytecodeWithResults loads and executes bytecode, preserving results +func (s *State) LoadAndRunBytecodeWithResults(bytecode []byte, name string, nresults int) error { + if err := s.LoadBytecode(bytecode, name); err != nil { + return err + } + return s.RunBytecodeWithResults(nresults) +} + +// CompileAndRun compiles and immediately executes Lua code +func (s *State) CompileAndRun(code string, name string) error { bytecode, err := s.CompileBytecode(code, name) if err != nil { return fmt.Errorf("compile error: %w", err) } - if err := s.LoadBytecode(bytecode, name); err != nil { - return fmt.Errorf("load error: %w", err) + if err := s.LoadAndRunBytecode(bytecode, name); err != nil { + return fmt.Errorf("execution error: %w", err) } return nil diff --git a/bytecode_test.go b/bytecode_test.go deleted file mode 100644 index 4e2822c..0000000 --- a/bytecode_test.go +++ /dev/null @@ -1,162 +0,0 @@ -package luajit - -import ( - "fmt" - "testing" -) - -func TestBytecodeCompilation(t *testing.T) { - tests := []struct { - name string - code string - wantErr bool - }{ - { - name: "simple assignment", - code: "x = 42", - wantErr: false, - }, - { - name: "function definition", - code: "function add(a,b) return a+b end", - wantErr: false, - }, - { - name: "syntax error", - code: "function bad syntax", - wantErr: true, - }, - } - - for _, tt := range tests { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - bytecode, err := L.CompileBytecode(tt.code, "test") - if (err != nil) != tt.wantErr { - t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr { - if len(bytecode) == 0 { - t.Error("CompileBytecode() returned empty bytecode") - } - } - } -} - -func TestBytecodeExecution(t *testing.T) { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - // Compile some test code - code := ` - function add(a, b) - return a + b - end - result = add(40, 2) - ` - - bytecode, err := L.CompileBytecode(code, "test") - if err != nil { - t.Fatalf("CompileBytecode() error = %v", err) - } - - // Load and execute the bytecode - if err := L.LoadBytecode(bytecode, "test"); err != nil { - t.Fatalf("LoadBytecode() error = %v", err) - } - - // Verify the result - L.GetGlobal("result") - if result := L.ToNumber(-1); result != 42 { - t.Errorf("got result = %v, want 42", result) - } -} - -func TestInvalidBytecode(t *testing.T) { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - // Test with invalid bytecode - invalidBytecode := []byte("this is not valid bytecode") - if err := L.LoadBytecode(invalidBytecode, "test"); err == nil { - t.Error("LoadBytecode() expected error with invalid bytecode") - } -} - -func TestBytecodeRoundTrip(t *testing.T) { - tests := []struct { - name string - code string - check func(*State) error - }{ - { - name: "global variable", - code: "x = 42", - check: func(L *State) error { - L.GetGlobal("x") - if x := L.ToNumber(-1); x != 42 { - return fmt.Errorf("got x = %v, want 42", x) - } - return nil - }, - }, - { - name: "function definition", - code: "function test() return 'hello' end", - check: func(L *State) error { - if err := L.DoString("result = test()"); err != nil { - return err - } - L.GetGlobal("result") - if s := L.ToString(-1); s != "hello" { - return fmt.Errorf("got result = %q, want 'hello'", s) - } - return nil - }, - }, - } - - for _, tt := range tests { - // First state for compilation - L1 := New() - if L1 == nil { - t.Fatal("Failed to create first Lua state") - } - defer L1.Close() - - // Compile the code - bytecode, err := L1.CompileBytecode(tt.code, "test") - if err != nil { - t.Fatalf("CompileBytecode() error = %v", err) - } - - // Second state for execution - L2 := New() - if L2 == nil { - t.Fatal("Failed to create second Lua state") - } - defer L2.Close() - - // Load and execute the bytecode - if err := L2.LoadBytecode(bytecode, "test"); err != nil { - t.Fatalf("LoadBytecode() error = %v", err) - } - - // Run the check function - if err := tt.check(L2); err != nil { - t.Errorf("check failed: %v", err) - } - } -} diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..2287deb --- /dev/null +++ b/example/main.go @@ -0,0 +1,70 @@ +package main + +import ( + "fmt" + "log" + "os" + "path/filepath" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +func main() { + if len(os.Args) < 2 { + fmt.Println("Usage: go run main.go script.lua") + os.Exit(1) + } + + scriptPath := os.Args[1] + + // Create a new Lua state + L := luajit.New() + if L == nil { + log.Fatal("Failed to create Lua state") + } + defer L.Close() + + // Register a Go function to be called from Lua + L.RegisterGoFunction("printFromGo", func(s *luajit.State) int { + msg := s.ToString(1) // Get first argument + fmt.Printf("Go received from Lua: %s\n", msg) + + // Return a value to Lua + s.PushString("Hello from Go!") + return 1 // Number of return values + }) + + // Add some values to the Lua environment + L.PushValue(map[string]interface{}{ + "appName": "LuaJIT Example", + "version": 1.0, + "features": []float64{1, 2, 3}, + }) + L.SetGlobal("config") + + // Get the directory of the script to properly handle requires + dir := filepath.Dir(scriptPath) + L.AddPackagePath(filepath.Join(dir, "?.lua")) + + // Execute the script + fmt.Printf("Running Lua script: %s\n", scriptPath) + if err := L.DoFile(scriptPath); err != nil { + log.Fatalf("Error executing script: %v", err) + } + + // Call a Lua function and get its result + L.GetGlobal("getResult") + if L.IsFunction(-1) { + if err := L.Call(0, 1); err != nil { + log.Fatalf("Error calling Lua function: %v", err) + } + + result, err := L.ToValue(-1) + if err != nil { + log.Fatalf("Error converting Lua result: %v", err) + } + + fmt.Printf("Result from Lua: %v\n", result) + L.Pop(1) // Clean up the result + } +} diff --git a/example/script.lua b/example/script.lua new file mode 100644 index 0000000..5117847 --- /dev/null +++ b/example/script.lua @@ -0,0 +1,35 @@ +-- Example Lua script to demonstrate Go-Lua integration + +-- Access the config table passed from Go +print("Script started") +print("App name:", config.appName) +print("Version:", config.version) +print("Features:", table.concat(config.features, ", ")) + +-- Call the Go function +local response = printFromGo("Hello from Lua!") +print("Response from Go:", response) + +-- Function that will be called from Go +function getResult() + local result = { + status = "success", + calculations = { + sum = 10 + 20, + product = 5 * 7 + }, + message = "Calculation completed" + } + return result +end + +-- Load external module (if available) +local success, utils = pcall(require, "utils") +if success then + print("Utils module loaded") + utils.doSomething() +else + print("Utils module not available:", utils) +end + +print("Script completed") \ No newline at end of file diff --git a/example/utils.lua b/example/utils.lua new file mode 100644 index 0000000..5645682 --- /dev/null +++ b/example/utils.lua @@ -0,0 +1,19 @@ +-- Optional utility module + +local utils = {} + +function utils.doSomething() + print("Utils module function called") + return true +end + +function utils.calculate(a, b) + return { + sum = a + b, + difference = a - b, + product = a * b, + quotient = a / b + } +end + +return utils \ No newline at end of file diff --git a/functions.go b/functions.go index 2770174..db5f2e1 100644 --- a/functions.go +++ b/functions.go @@ -7,8 +7,9 @@ package luajit extern int goFunctionWrapper(lua_State* L); +// Helper function to access upvalues static int get_upvalue_index(int i) { - return -10002 - i; // LUA_GLOBALSINDEX - i + return lua_upvalueindex(i); } */ import "C" @@ -18,9 +19,11 @@ import ( "unsafe" ) +// GoFunction defines the signature for Go functions callable from Lua type GoFunction func(*State) int var ( + // functionRegistry stores all registered Go functions functionRegistry = struct { sync.RWMutex funcs map[unsafe.Pointer]GoFunction @@ -33,10 +36,10 @@ var ( func goFunctionWrapper(L *C.lua_State) C.int { state := &State{L: L} - // Get upvalue using standard Lua 5.1 macro + // Get function pointer from the first upvalue ptr := C.lua_touserdata(L, C.get_upvalue_index(1)) if ptr == nil { - state.PushString("error: function not found") + state.PushString("error: function pointer not found") return -1 } @@ -49,49 +52,56 @@ func goFunctionWrapper(L *C.lua_State) C.int { return -1 } - result := fn(state) - return C.int(result) + // Call the Go function + return C.int(fn(state)) } +// PushGoFunction wraps a Go function and pushes it onto the Lua stack func (s *State) PushGoFunction(fn GoFunction) error { + // Allocate a pointer to use as the function key ptr := C.malloc(1) if ptr == nil { return fmt.Errorf("failed to allocate memory for function pointer") } + // Register the function functionRegistry.Lock() functionRegistry.funcs[ptr] = fn functionRegistry.Unlock() + // Push the pointer as lightuserdata (first upvalue) C.lua_pushlightuserdata(s.L, ptr) + + // Create closure with the C wrapper and the upvalue C.lua_pushcclosure(s.L, (*[0]byte)(C.goFunctionWrapper), 1) + return nil } +// RegisterGoFunction registers a Go function as a global Lua function func (s *State) RegisterGoFunction(name string, fn GoFunction) error { if err := s.PushGoFunction(fn); err != nil { return err } - cname := C.CString(name) - defer C.free(unsafe.Pointer(cname)) - C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname) + s.SetGlobal(name) return nil } +// UnregisterGoFunction removes a global function func (s *State) UnregisterGoFunction(name string) { s.PushNil() - cname := C.CString(name) - defer C.free(unsafe.Pointer(cname)) - C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname) + s.SetGlobal(name) } +// Cleanup frees all function pointers and clears the registry func (s *State) Cleanup() { functionRegistry.Lock() defer functionRegistry.Unlock() + // Free all allocated pointers for ptr := range functionRegistry.funcs { C.free(ptr) + delete(functionRegistry.funcs, ptr) } - functionRegistry.funcs = make(map[unsafe.Pointer]GoFunction) } diff --git a/functions_test.go b/functions_test.go deleted file mode 100644 index 9a57245..0000000 --- a/functions_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package luajit - -import ( - "testing" -) - -func TestGoFunctions(t *testing.T) { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - defer L.Cleanup() - - addFunc := func(s *State) int { - s.PushNumber(s.ToNumber(1) + s.ToNumber(2)) - return 1 - } - - if err := L.RegisterGoFunction("add", addFunc); err != nil { - t.Fatalf("Failed to register function: %v", err) - } - - // Test basic function call - if err := L.DoString("result = add(40, 2)"); err != nil { - t.Fatalf("Failed to call function: %v", err) - } - - L.GetGlobal("result") - if result := L.ToNumber(-1); result != 42 { - t.Errorf("got %v, want 42", result) - } - L.Pop(1) - - // Test multiple return values - multiFunc := func(s *State) int { - s.PushString("hello") - s.PushNumber(42) - s.PushBoolean(true) - return 3 - } - - if err := L.RegisterGoFunction("multi", multiFunc); err != nil { - t.Fatalf("Failed to register multi function: %v", err) - } - - code := ` - a, b, c = multi() - result = (a == "hello" and b == 42 and c == true) - ` - - if err := L.DoString(code); err != nil { - t.Fatalf("Failed to call multi function: %v", err) - } - - L.GetGlobal("result") - if !L.ToBoolean(-1) { - t.Error("Multiple return values test failed") - } - L.Pop(1) - - // Test error handling - errFunc := func(s *State) int { - s.PushString("test error") - return -1 - } - - if err := L.RegisterGoFunction("err", errFunc); err != nil { - t.Fatalf("Failed to register error function: %v", err) - } - - if err := L.DoString("err()"); err == nil { - t.Error("Expected error from error function") - } - - // Test unregistering - L.UnregisterGoFunction("add") - if err := L.DoString("add(1, 2)"); err == nil { - t.Error("Expected error calling unregistered function") - } -} - -func TestStackSafety(t *testing.T) { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - defer L.Cleanup() - - // Test stack overflow protection - overflowFunc := func(s *State) int { - for i := 0; i < 100; i++ { - s.PushNumber(float64(i)) - } - s.PushString("done") - return 101 - } - - if err := L.RegisterGoFunction("overflow", overflowFunc); err != nil { - t.Fatal(err) - } - - if err := L.DoString("overflow()"); err != nil { - t.Logf("Got expected error: %v", err) - } -} diff --git a/stack.go b/stack.go index 91fa45e..f0e0173 100644 --- a/stack.go +++ b/stack.go @@ -25,84 +25,23 @@ const ( LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table ) -// checkStack ensures there is enough space on the Lua stack -func (s *State) checkStack(n int) error { - if C.lua_checkstack(s.L, C.int(n)) == 0 { - return fmt.Errorf("stack overflow (cannot allocate %d slots)", n) - } - return nil -} - -// safeCall wraps a potentially dangerous C call with stack checking -func (s *State) safeCall(f func() C.int) error { - // Save current stack size - top := s.GetTop() - - // Ensure we have enough stack space (minimum 20 slots as per Lua standard) - if err := s.checkStack(LUA_MINSTACK); err != nil { - return err +// GetStackTrace returns the current Lua stack trace +func (s *State) GetStackTrace() string { + s.GetGlobal("debug") + if !s.IsTable(-1) { + s.Pop(1) + return "debug table not available" } - // Make the call - status := f() - - // Check for errors - if status != 0 { - err := &LuaError{ - Code: int(status), - Message: s.ToString(-1), - } - s.Pop(1) // Remove error message - return err + s.GetField(-1, "traceback") + if !s.IsFunction(-1) { + s.Pop(2) // Remove debug table and non-function + return "debug.traceback not available" } - // For lua_pcall, the function and arguments are popped before results are pushed - // So we don't consider it an underflow if the new top is less than the original - if status == 0 && s.GetType(-1) == TypeFunction { - // If we still have a function on the stack, restore original size - s.SetTop(top) - } + s.Call(0, 1) + trace := s.ToString(-1) + s.Pop(1) // Remove the trace - return nil -} - -// stackGuard wraps a function with stack checking -func stackGuard[T any](s *State, f func() (T, error)) (T, error) { - // Save current stack size - top := s.GetTop() - defer func() { - // Only restore if stack is larger than original - if s.GetTop() > top { - s.SetTop(top) - } - }() - - // Run the protected function - return f() -} - -// stackGuardValue executes a function with stack protection -func stackGuardValue[T any](s *State, f func() (T, error)) (T, error) { - return stackGuard(s, f) -} - -// stackGuardErr executes a function that only returns an error with stack protection -func stackGuardErr(s *State, f func() error) error { - // Save current stack size - top := s.GetTop() - defer func() { - // Only restore if stack is larger than original - if s.GetTop() > top { - s.SetTop(top) - } - }() - - // Run the protected function - return f() -} - -// getStackTrace returns the current Lua stack trace -func (s *State) getStackTrace() string { - // Same implementation... - return "" + return trace } diff --git a/table.go b/table.go index ebf3caf..7316ddc 100644 --- a/table.go +++ b/table.go @@ -6,29 +6,29 @@ package luajit #include #include -static int get_table_length(lua_State *L, int index) { +static size_t get_table_length(lua_State *L, int index) { return lua_objlen(L, index); } */ import "C" import ( + "fmt" "strconv" "sync" ) +// Use a pool to reduce GC pressure when handling many tables var tablePool = sync.Pool{ New: func() interface{} { return make(map[string]interface{}) }, } -// TableValue represents any value that can be stored in a Lua table -type TableValue interface { - ~string | ~float64 | ~bool | ~int | ~map[string]interface{} | ~[]float64 | ~[]interface{} +// GetTableLength returns the length of a table at the given index +func (s *State) GetTableLength(index int) int { + return int(C.get_table_length(s.L, C.int(index))) } -func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) } - // getTableFromPool gets a map from the pool and ensures it's empty func getTableFromPool() map[string]interface{} { table := tablePool.Get().(map[string]interface{}) @@ -46,46 +46,68 @@ func putTableToPool(table map[string]interface{}) { // PushTable pushes a Go map onto the Lua stack as a table func (s *State) PushTable(table map[string]interface{}) error { + // Create table with appropriate capacity hints s.CreateTable(0, len(table)) + + // Add each key-value pair for k, v := range table { + // Push key + s.PushString(k) + + // Push value if err := s.PushValue(v); err != nil { return err } - s.SetField(-2, k) + + // t[k] = v + s.SetTable(-3) } - // If this is a pooled table, return it - if _, hasEmptyKey := table[""]; len(table) == 1 && hasEmptyKey { + + // Return pooled tables to the pool + if isPooledTable(table) { putTableToPool(table) } + return nil } -// ToTable converts a Lua table to a Go map +// isPooledTable detects if a table came from our pool +func isPooledTable(table map[string]interface{}) bool { + // Check for our special marker - used for array tables in the pool + _, hasEmptyKey := table[""] + return len(table) == 1 && hasEmptyKey +} + +// ToTable converts a Lua table at the given index to a Go map func (s *State) ToTable(index int) (map[string]interface{}, error) { absIdx := s.absIndex(index) - table := getTableFromPool() + if !s.IsTable(absIdx) { + return nil, fmt.Errorf("value at index %d is not a table", index) + } - // Check if it's an array-like table + // Try to detect array-like tables first length := s.GetTableLength(absIdx) if length > 0 { - array := make([]float64, length) + // Check if this is an array-like table isArray := true + array := make([]float64, length) - // Try to convert to array for i := 1; i <= length; i++ { s.PushNumber(float64(i)) s.GetTable(absIdx) - if s.GetType(-1) != TypeNumber { + + if !s.IsNumber(-1) { isArray = false s.Pop(1) break } + array[i-1] = s.ToNumber(-1) s.Pop(1) } if isArray { - putTableToPool(table) // Return unused table to pool + // Return array as a special pooled table with empty key result := getTableFromPool() result[""] = array return result, nil @@ -93,24 +115,36 @@ func (s *State) ToTable(index int) (map[string]interface{}, error) { } // Handle regular table - s.PushNil() - for C.lua_next(s.L, C.int(absIdx)) != 0 { - key := "" - valueType := C.lua_type(s.L, -2) - if valueType == C.LUA_TSTRING { + table := getTableFromPool() + + // Iterate through all key-value pairs + s.PushNil() // Start iteration with nil key + for s.Next(absIdx) { + // Stack now has key at -2 and value at -1 + + // Convert key to string + var key string + keyType := s.GetType(-2) + switch keyType { + case TypeString: key = s.ToString(-2) - } else if valueType == C.LUA_TNUMBER { + case TypeNumber: key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64) + default: + // Skip non-string/non-number keys + s.Pop(1) // Pop value, leave key for next iteration + continue } + // Convert and store the value value, err := s.ToValue(-1) if err != nil { - s.Pop(1) - putTableToPool(table) // Return table to pool on error + s.Pop(2) // Pop both key and value + putTableToPool(table) // Return the table to the pool on error return nil, err } - // Handle nested array case + // Handle nested array tables if m, ok := value.(map[string]interface{}); ok { if arr, ok := m[""]; ok { value = arr @@ -118,27 +152,8 @@ func (s *State) ToTable(index int) (map[string]interface{}, error) { } table[key] = value - s.Pop(1) + s.Pop(1) // Pop value, leave key for next iteration } return table, nil } - -// NewTable creates a new table and pushes it onto the stack -func (s *State) NewTable() { - C.lua_createtable(s.L, 0, 0) -} - -// SetTable sets a table field -func (s *State) SetTable(index int) { - C.lua_settable(s.L, C.int(index)) -} - -// GetTable gets a table field -func (s *State) GetTable(index int) { - C.lua_gettable(s.L, C.int(index)) -} - -func (s *State) CreateTable(narr, nrec int) { - C.lua_createtable(s.L, C.int(narr), C.int(nrec)) -} diff --git a/table_test.go b/table_test.go deleted file mode 100644 index 4d0d9e8..0000000 --- a/table_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package luajit - -import ( - "math" - "testing" -) - -func TestTableOperations(t *testing.T) { - tests := []struct { - name string - data map[string]interface{} - }{ - { - name: "empty", - data: map[string]interface{}{}, - }, - { - name: "primitives", - data: map[string]interface{}{ - "str": "hello", - "num": 42.0, - "bool": true, - "array": []float64{1.1, 2.2, 3.3}, - }, - }, - { - name: "nested", - data: map[string]interface{}{ - "nested": map[string]interface{}{ - "value": 123.0, - "array": []float64{4.4, 5.5}, - }, - }, - }, - } - - for _, tt := range tests { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - if err := L.PushTable(tt.data); err != nil { - t.Fatalf("PushTable() error = %v", err) - } - - got, err := L.ToTable(-1) - if err != nil { - t.Fatalf("ToTable() error = %v", err) - } - - if !tablesEqual(got, tt.data) { - t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data) - } - } -} - -func tablesEqual(a, b map[string]interface{}) bool { - if len(a) != len(b) { - return false - } - - for k, v1 := range a { - v2, ok := b[k] - if !ok { - return false - } - - switch v1 := v1.(type) { - case map[string]interface{}: - v2, ok := v2.(map[string]interface{}) - if !ok || !tablesEqual(v1, v2) { - return false - } - case []float64: - v2, ok := v2.([]float64) - if !ok || len(v1) != len(v2) { - return false - } - for i := range v1 { - if math.Abs(v1[i]-v2[i]) > 1e-10 { - return false - } - } - default: - if v1 != v2 { - return false - } - } - } - return true -} diff --git a/tests/bytecode_test.go b/tests/bytecode_test.go new file mode 100644 index 0000000..2ec8822 --- /dev/null +++ b/tests/bytecode_test.go @@ -0,0 +1,443 @@ +package luajit_test + +import ( + "bytes" + "errors" + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// TestCompileBytecode tests basic bytecode compilation +func TestCompileBytecode(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + code := "return 42" + bytecode, err := state.CompileBytecode(code, "test") + if err != nil { + t.Fatalf("CompileBytecode failed: %v", err) + } + + if len(bytecode) == 0 { + t.Fatal("Expected non-empty bytecode") + } +} + +// TestLoadBytecode tests loading precompiled bytecode +func TestLoadBytecode(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // First compile some bytecode + code := "answer = 42" + bytecode, err := state.CompileBytecode(code, "test") + if err != nil { + t.Fatalf("CompileBytecode failed: %v", err) + } + + // Then load it + err = state.LoadBytecode(bytecode, "test") + if err != nil { + t.Fatalf("LoadBytecode failed: %v", err) + } + + // Verify a function is on the stack + if !state.IsFunction(-1) { + t.Fatal("Expected function at top of stack after LoadBytecode") + } + + // Pop the function + state.Pop(1) +} + +// TestRunBytecode tests running previously loaded bytecode +func TestRunBytecode(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // First compile and load bytecode + code := "answer = 42" + bytecode, err := state.CompileBytecode(code, "test") + if err != nil { + t.Fatalf("CompileBytecode failed: %v", err) + } + + err = state.LoadBytecode(bytecode, "test") + if err != nil { + t.Fatalf("LoadBytecode failed: %v", err) + } + + // Run the bytecode + err = state.RunBytecode() + if err != nil { + t.Fatalf("RunBytecode failed: %v", err) + } + + // Verify the code has executed correctly + state.GetGlobal("answer") + if !state.IsNumber(-1) || state.ToNumber(-1) != 42 { + t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1)) + } + state.Pop(1) +} + +// TestLoadAndRunBytecode tests the combined load and run functionality +func TestLoadAndRunBytecode(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Compile bytecode + code := "answer = 42" + bytecode, err := state.CompileBytecode(code, "test") + if err != nil { + t.Fatalf("CompileBytecode failed: %v", err) + } + + // Load and run in one step + err = state.LoadAndRunBytecode(bytecode, "test") + if err != nil { + t.Fatalf("LoadAndRunBytecode failed: %v", err) + } + + // Verify execution + state.GetGlobal("answer") + if !state.IsNumber(-1) || state.ToNumber(-1) != 42 { + t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1)) + } + state.Pop(1) +} + +// TestCompileAndRun tests compile and run functionality +func TestCompileAndRun(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Compile and run in one step + code := "answer = 42" + err := state.CompileAndRun(code, "test") + if err != nil { + t.Fatalf("CompileAndRun failed: %v", err) + } + + // Verify execution + state.GetGlobal("answer") + if !state.IsNumber(-1) || state.ToNumber(-1) != 42 { + t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1)) + } + state.Pop(1) +} + +// TestEmptyBytecode tests error handling for empty bytecode +func TestEmptyBytecode(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Try to load empty bytecode + err := state.LoadBytecode([]byte{}, "empty") + if err == nil { + t.Fatal("Expected error for empty bytecode, got nil") + } +} + +// TestInvalidBytecode tests error handling for invalid bytecode +func TestInvalidBytecode(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Create some invalid bytecode + invalidBytecode := []byte("not valid bytecode") + + // Try to load invalid bytecode + err := state.LoadBytecode(invalidBytecode, "invalid") + if err == nil { + t.Fatal("Expected error for invalid bytecode, got nil") + } +} + +// TestBytecodeSerialization tests serializing and deserializing bytecode +func TestBytecodeSerialization(t *testing.T) { + // First state to compile + state1 := luajit.New() + if state1 == nil { + t.Fatal("Failed to create first Lua state") + } + defer state1.Close() + + // Compile bytecode + code := ` + function add(a, b) + return a + b + end + result = add(10, 20) + ` + bytecode, err := state1.CompileBytecode(code, "test") + if err != nil { + t.Fatalf("CompileBytecode failed: %v", err) + } + + // Second state to execute + state2 := luajit.New() + if state2 == nil { + t.Fatal("Failed to create second Lua state") + } + defer state2.Close() + + // Load and run the bytecode in the second state + err = state2.LoadAndRunBytecode(bytecode, "test") + if err != nil { + t.Fatalf("LoadAndRunBytecode failed: %v", err) + } + + // Verify execution + state2.GetGlobal("result") + if !state2.IsNumber(-1) || state2.ToNumber(-1) != 30 { + t.Fatalf("Expected result to be 30, got %v", state2.ToNumber(-1)) + } + state2.Pop(1) + + // Call the function to verify it was properly transferred + state2.GetGlobal("add") + if !state2.IsFunction(-1) { + t.Fatal("Expected add to be a function") + } + state2.PushNumber(5) + state2.PushNumber(7) + if err := state2.Call(2, 1); err != nil { + t.Fatalf("Failed to call function: %v", err) + } + if state2.ToNumber(-1) != 12 { + t.Fatalf("Expected add(5, 7) to return 12, got %v", state2.ToNumber(-1)) + } + state2.Pop(1) +} + +// TestCompilationError tests error handling for compilation errors +func TestCompilationError(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Invalid Lua code that should fail to compile + code := "function without end" + + // Try to compile + _, err := state.CompileBytecode(code, "invalid") + if err == nil { + t.Fatal("Expected compilation error, got nil") + } + + // Check error type + var luaErr *luajit.LuaError + if !errors.As(err, &luaErr) { + t.Fatalf("Expected error to wrap *luajit.LuaError, got %T", err) + } +} + +// TestExecutionError tests error handling for runtime errors +func TestExecutionError(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Code that compiles but fails at runtime + code := "error('deliberate error')" + + // Compile bytecode + bytecode, err := state.CompileBytecode(code, "error") + if err != nil { + t.Fatalf("CompileBytecode failed: %v", err) + } + + // Try to execute + err = state.LoadAndRunBytecode(bytecode, "error") + if err == nil { + t.Fatal("Expected execution error, got nil") + } + + // Check error type + if _, ok := err.(*luajit.LuaError); !ok { + t.Fatalf("Expected *luajit.LuaError, got %T", err) + } +} + +// TestBytecodeEquivalence tests that bytecode execution produces the same results as direct execution +func TestBytecodeEquivalence(t *testing.T) { + code := ` + local result = 0 + for i = 1, 10 do + result = result + i + end + return result + ` + + // First, execute directly + state1 := luajit.New() + if state1 == nil { + t.Fatal("Failed to create first Lua state") + } + defer state1.Close() + + directResult, err := state1.ExecuteWithResult(code) + if err != nil { + t.Fatalf("ExecuteWithResult failed: %v", err) + } + + // Then, compile and execute bytecode + state2 := luajit.New() + if state2 == nil { + t.Fatal("Failed to create second Lua state") + } + defer state2.Close() + + bytecode, err := state2.CompileBytecode(code, "test") + if err != nil { + t.Fatalf("CompileBytecode failed: %v", err) + } + + err = state2.LoadBytecode(bytecode, "test") + if err != nil { + t.Fatalf("LoadBytecode failed: %v", err) + } + + err = state2.Call(0, 1) + if err != nil { + t.Fatalf("Call failed: %v", err) + } + + bytecodeResult, err := state2.ToValue(-1) + if err != nil { + t.Fatalf("ToValue failed: %v", err) + } + state2.Pop(1) + + // Compare results + if directResult != bytecodeResult { + t.Fatalf("Results differ: direct=%v, bytecode=%v", directResult, bytecodeResult) + } +} + +// TestBytecodeReuse tests reusing the same bytecode multiple times +func TestBytecodeReuse(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Create a function in bytecode + code := ` + return function(x) + return x * 2 + end + ` + bytecode, err := state.CompileBytecode(code, "func") + if err != nil { + t.Fatalf("CompileBytecode failed: %v", err) + } + + // Execute it several times + for i := 1; i <= 3; i++ { + // Load and run to get the function + err = state.LoadAndRunBytecodeWithResults(bytes.Clone(bytecode), "func", 1) + if err != nil { + t.Fatalf("LoadAndRunBytecodeWithResults failed: %v", err) + } + + // Stack now has the function at the top + if !state.IsFunction(-1) { + t.Fatal("Expected function at top of stack") + } + + // Call with parameter i + state.PushNumber(float64(i)) + if err := state.Call(1, 1); err != nil { + t.Fatalf("Call failed: %v", err) + } + + // Check result + expected := float64(i * 2) + if state.ToNumber(-1) != expected { + t.Fatalf("Expected %v, got %v", expected, state.ToNumber(-1)) + } + + // Pop the result + state.Pop(1) + } +} + +// TestBytecodeClosure tests that bytecode properly handles closures and upvalues +func TestBytecodeClosure(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Create a closure + code := ` + local counter = 0 + return function() + counter = counter + 1 + return counter + end + ` + + // Compile to bytecode + bytecode, err := state.CompileBytecode(code, "closure") + if err != nil { + t.Fatalf("CompileBytecode failed: %v", err) + } + + // Load and run to get the counter function + err = state.LoadAndRunBytecodeWithResults(bytecode, "closure", 1) + if err != nil { + t.Fatalf("LoadAndRunBytecode failed: %v", err) + } + + // Stack now has the function at the top + if !state.IsFunction(-1) { + t.Fatal("Expected function at top of stack") + } + + // Store in a global + state.SetGlobal("counter_func") + + // Call it multiple times and check the results + for i := 1; i <= 3; i++ { + state.GetGlobal("counter_func") + if err := state.Call(0, 1); err != nil { + t.Fatalf("Call failed: %v", err) + } + + if state.ToNumber(-1) != float64(i) { + t.Fatalf("Expected counter to be %d, got %v", i, state.ToNumber(-1)) + } + state.Pop(1) + } +} diff --git a/tests/functions_test.go b/tests/functions_test.go new file mode 100644 index 0000000..bb9f163 --- /dev/null +++ b/tests/functions_test.go @@ -0,0 +1,178 @@ +package luajit_test + +import ( + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +func TestPushGoFunction(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Define a simple function that adds two numbers + add := func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a + b) + return 1 // Return one result + } + + // Push the function onto the stack + if err := state.PushGoFunction(add); err != nil { + t.Fatalf("PushGoFunction failed: %v", err) + } + + // Verify that a function is on the stack + if !state.IsFunction(-1) { + t.Fatalf("Expected function at top of stack") + } + + // Push arguments + state.PushNumber(3) + state.PushNumber(4) + + // Call the function + if err := state.Call(2, 1); err != nil { + t.Fatalf("Failed to call function: %v", err) + } + + // Check the result + if state.ToNumber(-1) != 7 { + t.Fatalf("Function returned %f, expected 7", state.ToNumber(-1)) + } + state.Pop(1) +} + +func TestRegisterGoFunction(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Define a function that squares a number + square := func(s *luajit.State) int { + x := s.ToNumber(1) + s.PushNumber(x * x) + return 1 + } + + // Register the function + if err := state.RegisterGoFunction("square", square); err != nil { + t.Fatalf("RegisterGoFunction failed: %v", err) + } + + // Call the function from Lua + if err := state.DoString("result = square(5)"); err != nil { + t.Fatalf("Failed to call registered function: %v", err) + } + + // Check the result + state.GetGlobal("result") + if state.ToNumber(-1) != 25 { + t.Fatalf("Function returned %f, expected 25", state.ToNumber(-1)) + } + state.Pop(1) + + // Test UnregisterGoFunction + state.UnregisterGoFunction("square") + + // Function should no longer exist + err := state.DoString("result = square(5)") + if err == nil { + t.Fatalf("Expected error after unregistering function, got nil") + } +} + +func TestGoFunctionWithErrorHandling(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Function that returns an error in Lua + errFunc := func(s *luajit.State) int { + s.PushString("error from Go function") + return -1 // Signal error + } + + // Register the function + if err := state.RegisterGoFunction("errorFunc", errFunc); err != nil { + t.Fatalf("RegisterGoFunction failed: %v", err) + } + + // Call the function expecting an error + err := state.DoString("result = errorFunc()") + if err == nil { + t.Fatalf("Expected error from function, got nil") + } + + // Error message should contain our message + luaErr, ok := err.(*luajit.LuaError) + if !ok { + t.Fatalf("Expected LuaError, got %T: %v", err, err) + } + + if luaErr.Message == "" { + t.Fatalf("Expected non-empty error message from Go function") + } +} + +func TestCleanup(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + + // Register several functions + for i := 0; i < 5; i++ { + dummy := func(s *luajit.State) int { return 0 } + if err := state.RegisterGoFunction("dummy", dummy); err != nil { + t.Fatalf("RegisterGoFunction failed: %v", err) + } + } + + // Call Cleanup explicitly + state.Cleanup() + + // Make sure we can still close the state + state.Close() + + // Also test that Close can be called after Cleanup + state = luajit.New() + if state == nil { + t.Fatal("Failed to create second Lua state") + } + + state.Close() // Should call Cleanup internally +} + +func TestGoFunctionErrorPointer(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Create a Lua function that calls a non-existent Go function pointer + // This isn't a direct test of internal implementation, but tries to cover + // error cases in the goFunctionWrapper + code := ` + function test() + -- This is a stub that doesn't actually call the wrapper, + -- but we're testing error handling in our State.DoString + return "test" + end + ` + if err := state.DoString(code); err != nil { + t.Fatalf("Failed to define test function: %v", err) + } + + // The real test is that Cleanup doesn't crash + state.Cleanup() +} diff --git a/tests/stack_test.go b/tests/stack_test.go new file mode 100644 index 0000000..f54f5a6 --- /dev/null +++ b/tests/stack_test.go @@ -0,0 +1,53 @@ +package luajit_test + +import ( + "strings" + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +func TestLuaError(t *testing.T) { + err := &luajit.LuaError{ + Code: 123, + Message: "test error", + } + + expected := "lua error (code=123): test error" + if err.Error() != expected { + t.Errorf("Expected error message %q, got %q", expected, err.Error()) + } +} + +func TestGetStackTrace(t *testing.T) { + s := luajit.New() + defer s.Close() + + // Test with debug library available + trace := s.GetStackTrace() + if !strings.Contains(trace, "stack traceback:") { + t.Errorf("Expected trace to contain 'stack traceback:', got %q", trace) + } + + // Test when debug table is not available + err := s.DoString("debug = nil") + if err != nil { + t.Fatalf("Failed to set debug to nil: %v", err) + } + + trace = s.GetStackTrace() + if trace != "debug table not available" { + t.Errorf("Expected 'debug table not available', got %q", trace) + } + + // Test when debug.traceback is not available + err = s.DoString("debug = {}") + if err != nil { + t.Fatalf("Failed to set debug to empty table: %v", err) + } + + trace = s.GetStackTrace() + if trace != "debug.traceback not available" { + t.Errorf("Expected 'debug.traceback not available', got %q", trace) + } +} diff --git a/tests/table_test.go b/tests/table_test.go new file mode 100644 index 0000000..ca0f9d7 --- /dev/null +++ b/tests/table_test.go @@ -0,0 +1,246 @@ +package luajit_test + +import ( + "reflect" + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +func TestGetTableLength(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Create a table with numeric indices + if err := state.DoString("t = {10, 20, 30, 40, 50}"); err != nil { + t.Fatalf("Failed to create test table: %v", err) + } + + // Get the table + state.GetGlobal("t") + length := state.GetTableLength(-1) + if length != 5 { + t.Fatalf("Expected length 5, got %d", length) + } + state.Pop(1) + + // Create a table with string keys + if err := state.DoString("t2 = {a=1, b=2, c=3}"); err != nil { + t.Fatalf("Failed to create test table: %v", err) + } + + // Get the table + state.GetGlobal("t2") + length = state.GetTableLength(-1) + if length != 0 { + t.Fatalf("Expected length 0 for string-keyed table, got %d", length) + } + state.Pop(1) +} + +func TestPushTable(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Create a test table + testTable := map[string]any{ + "int": 42, + "float": 3.14, + "string": "hello", + "boolean": true, + "nil": nil, + } + + // Push the table onto the stack + if err := state.PushTable(testTable); err != nil { + t.Fatalf("Failed to push table: %v", err) + } + + // Execute Lua code to test the table contents + if err := state.DoString(` + function validate_table(t) + return t.int == 42 and + math.abs(t.float - 3.14) < 0.0001 and + t.string == "hello" and + t.boolean == true and + t["nil"] == nil + end + `); err != nil { + t.Fatalf("Failed to create validation function: %v", err) + } + + // Call the validation function + state.GetGlobal("validate_table") + state.PushCopy(-2) // Copy the table to the top + if err := state.Call(1, 1); err != nil { + t.Fatalf("Failed to call validation function: %v", err) + } + + if !state.ToBoolean(-1) { + t.Fatalf("Table validation failed") + } + state.Pop(2) // Pop the result and the table +} + +func TestToTable(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Test regular table conversion + if err := state.DoString(`t = {a=1, b=2.5, c="test", d=true, e=nil}`); err != nil { + t.Fatalf("Failed to create test table: %v", err) + } + + state.GetGlobal("t") + table, err := state.ToTable(-1) + if err != nil { + t.Fatalf("Failed to convert table: %v", err) + } + state.Pop(1) + + expected := map[string]any{ + "a": float64(1), + "b": 2.5, + "c": "test", + "d": true, + } + + for k, v := range expected { + if table[k] != v { + t.Fatalf("Expected table[%s] = %v, got %v", k, v, table[k]) + } + } + + // Test array-like table conversion + if err := state.DoString(`arr = {10, 20, 30, 40, 50}`); err != nil { + t.Fatalf("Failed to create test array: %v", err) + } + + state.GetGlobal("arr") + table, err = state.ToTable(-1) + if err != nil { + t.Fatalf("Failed to convert array table: %v", err) + } + state.Pop(1) + + // For array tables, we should get a special format with an empty key + // and the array as the value + expectedArray := []float64{10, 20, 30, 40, 50} + if arr, ok := table[""].([]float64); !ok { + t.Fatalf("Expected array table to be converted with empty key, got: %v", table) + } else if !reflect.DeepEqual(arr, expectedArray) { + t.Fatalf("Expected %v, got %v", expectedArray, arr) + } + + // Test invalid table index + _, err = state.ToTable(100) + if err == nil { + t.Fatalf("Expected error for invalid table index, got nil") + } + + // Test non-table value + state.PushNumber(123) + _, err = state.ToTable(-1) + if err == nil { + t.Fatalf("Expected error for non-table value, got nil") + } + state.Pop(1) + + // Test mixed array with non-numeric values + if err := state.DoString(`mixed = {10, 20, key="value", 30}`); err != nil { + t.Fatalf("Failed to create mixed table: %v", err) + } + + state.GetGlobal("mixed") + table, err = state.ToTable(-1) + if err != nil { + t.Fatalf("Failed to convert mixed table: %v", err) + } + + // Let's print the table for debugging + t.Logf("Table contents: %v", table) + + state.Pop(1) + + // Check if the array part is detected and stored with empty key + if arr, ok := table[""]; !ok { + t.Fatalf("Expected array-like part to be detected, got: %v", table) + } else { + // Verify the array contains the expected values + expectedArr := []float64{10, 20, 30} + actualArr := arr.([]float64) + if len(actualArr) != len(expectedArr) { + t.Fatalf("Expected array length %d, got %d", len(expectedArr), len(actualArr)) + } + + for i, v := range expectedArr { + if actualArr[i] != v { + t.Fatalf("Expected array[%d] = %v, got %v", i, v, actualArr[i]) + } + } + } + + // Based on the implementation, we need to create a separate test for string keys + if err := state.DoString(`dict = {foo="bar", baz="qux"}`); err != nil { + t.Fatalf("Failed to create dict table: %v", err) + } + + state.GetGlobal("dict") + dictTable, err := state.ToTable(-1) + if err != nil { + t.Fatalf("Failed to convert dict table: %v", err) + } + state.Pop(1) + + // Check the string keys + if val, ok := dictTable["foo"]; !ok || val != "bar" { + t.Fatalf("Expected dictTable[\"foo\"] = \"bar\", got: %v", val) + } + if val, ok := dictTable["baz"]; !ok || val != "qux" { + t.Fatalf("Expected dictTable[\"baz\"] = \"qux\", got: %v", val) + } +} + +func TestTablePooling(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Create a Lua table and push it onto the stack + if err := state.DoString(`t = {a=1, b=2}`); err != nil { + t.Fatalf("Failed to create test table: %v", err) + } + + state.GetGlobal("t") + + // First conversion - should get a table from the pool + table1, err := state.ToTable(-1) + if err != nil { + t.Fatalf("Failed to convert table (1): %v", err) + } + + // Second conversion - should get another table from the pool + table2, err := state.ToTable(-1) + if err != nil { + t.Fatalf("Failed to convert table (2): %v", err) + } + + // Both tables should have the same content + if !reflect.DeepEqual(table1, table2) { + t.Fatalf("Tables should have the same content: %v vs %v", table1, table2) + } + + // Clean up + state.Pop(1) +} diff --git a/tests/wrapper_test.go b/tests/wrapper_test.go new file mode 100644 index 0000000..8f5fc6e --- /dev/null +++ b/tests/wrapper_test.go @@ -0,0 +1,473 @@ +package luajit_test + +import ( + "os" + "reflect" + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +func TestStateLifecycle(t *testing.T) { + // Test creation + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + + // Test close + state.Close() + + // Test close is idempotent (doesn't crash) + state.Close() +} + +func TestStackManipulation(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Test initial stack size + if state.GetTop() != 0 { + t.Fatalf("Expected empty stack, got %d elements", state.GetTop()) + } + + // Push values + state.PushNil() + state.PushBoolean(true) + state.PushNumber(42) + state.PushString("hello") + + // Check stack size + if state.GetTop() != 4 { + t.Fatalf("Expected 4 elements, got %d", state.GetTop()) + } + + // Test SetTop + state.SetTop(2) + if state.GetTop() != 2 { + t.Fatalf("Expected 2 elements after SetTop, got %d", state.GetTop()) + } + + // Test PushCopy + state.PushCopy(2) // Copy the boolean + if !state.IsBoolean(-1) { + t.Fatalf("Expected boolean at top of stack") + } + + // Test Pop + state.Pop(1) + if state.GetTop() != 2 { + t.Fatalf("Expected 2 elements after Pop, got %d", state.GetTop()) + } + + // Test Remove + state.PushNumber(99) + state.Remove(1) // Remove the first element (nil) + if state.GetTop() != 2 { + t.Fatalf("Expected 2 elements after Remove, got %d", state.GetTop()) + } + + // Verify first element is now boolean + if !state.IsBoolean(1) { + t.Fatalf("Expected boolean at index 1 after Remove") + } +} + +func TestTypeChecking(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Push values of different types + state.PushNil() + state.PushBoolean(true) + state.PushNumber(42) + state.PushString("hello") + state.NewTable() + + // Check types with GetType + if state.GetType(1) != luajit.TypeNil { + t.Fatalf("Expected nil type at index 1, got %s", state.GetType(1)) + } + if state.GetType(2) != luajit.TypeBoolean { + t.Fatalf("Expected boolean type at index 2, got %s", state.GetType(2)) + } + if state.GetType(3) != luajit.TypeNumber { + t.Fatalf("Expected number type at index 3, got %s", state.GetType(3)) + } + if state.GetType(4) != luajit.TypeString { + t.Fatalf("Expected string type at index 4, got %s", state.GetType(4)) + } + if state.GetType(5) != luajit.TypeTable { + t.Fatalf("Expected table type at index 5, got %s", state.GetType(5)) + } + + // Test individual type checking functions + if !state.IsNil(1) { + t.Fatalf("IsNil failed for nil value") + } + if !state.IsBoolean(2) { + t.Fatalf("IsBoolean failed for boolean value") + } + if !state.IsNumber(3) { + t.Fatalf("IsNumber failed for number value") + } + if !state.IsString(4) { + t.Fatalf("IsString failed for string value") + } + if !state.IsTable(5) { + t.Fatalf("IsTable failed for table value") + } + + // Function test + state.DoString("function test() return true end") + state.GetGlobal("test") + if !state.IsFunction(-1) { + t.Fatalf("IsFunction failed for function value") + } +} + +func TestValueConversion(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Push values + state.PushBoolean(true) + state.PushNumber(42.5) + state.PushString("hello") + + // Test conversion + if !state.ToBoolean(1) { + t.Fatalf("ToBoolean failed") + } + if state.ToNumber(2) != 42.5 { + t.Fatalf("ToNumber failed, expected 42.5, got %f", state.ToNumber(2)) + } + if state.ToString(3) != "hello" { + t.Fatalf("ToString failed, expected 'hello', got '%s'", state.ToString(3)) + } +} + +func TestTableOperations(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Test CreateTable + state.CreateTable(0, 3) + + // Add fields using SetField + state.PushNumber(42) + state.SetField(-2, "answer") + + state.PushString("hello") + state.SetField(-2, "greeting") + + state.PushBoolean(true) + state.SetField(-2, "flag") + + // Test GetField + state.GetField(-1, "answer") + if state.ToNumber(-1) != 42 { + t.Fatalf("GetField for 'answer' failed") + } + state.Pop(1) + + state.GetField(-1, "greeting") + if state.ToString(-1) != "hello" { + t.Fatalf("GetField for 'greeting' failed") + } + state.Pop(1) + + // Test Next for iteration + state.PushNil() // Start iteration + count := 0 + for state.Next(-2) { + count++ + state.Pop(1) // Pop value, leave key for next iteration + } + + if count != 3 { + t.Fatalf("Expected 3 table entries, found %d", count) + } + + // Clean up + state.Pop(1) // Pop the table +} + +func TestGlobalOperations(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Set a global value + state.PushNumber(42) + state.SetGlobal("answer") + + // Get the global value + state.GetGlobal("answer") + if state.ToNumber(-1) != 42 { + t.Fatalf("GetGlobal failed, expected 42, got %f", state.ToNumber(-1)) + } + state.Pop(1) + + // Test non-existent global (should be nil) + state.GetGlobal("nonexistent") + if !state.IsNil(-1) { + t.Fatalf("Expected nil for non-existent global") + } + state.Pop(1) +} + +func TestCodeExecution(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Test LoadString + if err := state.LoadString("return 42"); err != nil { + t.Fatalf("LoadString failed: %v", err) + } + + // Test Call + if err := state.Call(0, 1); err != nil { + t.Fatalf("Call failed: %v", err) + } + + if state.ToNumber(-1) != 42 { + t.Fatalf("Call result incorrect, expected 42, got %f", state.ToNumber(-1)) + } + state.Pop(1) + + // Test DoString + if err := state.DoString("answer = 42 + 1"); err != nil { + t.Fatalf("DoString failed: %v", err) + } + + state.GetGlobal("answer") + if state.ToNumber(-1) != 43 { + t.Fatalf("DoString execution incorrect, expected 43, got %f", state.ToNumber(-1)) + } + state.Pop(1) + + // Test Execute + nresults, err := state.Execute("return 5, 10, 15") + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + if nresults != 3 { + t.Fatalf("Execute returned %d results, expected 3", nresults) + } + + if state.ToNumber(-3) != 5 || state.ToNumber(-2) != 10 || state.ToNumber(-1) != 15 { + t.Fatalf("Execute results incorrect") + } + state.Pop(3) + + // Test ExecuteWithResult + result, err := state.ExecuteWithResult("return 'hello'") + if err != nil { + t.Fatalf("ExecuteWithResult failed: %v", err) + } + + if result != "hello" { + t.Fatalf("ExecuteWithResult returned %v, expected 'hello'", result) + } + + // Test error handling + err = state.DoString("this is not valid lua code") + if err == nil { + t.Fatalf("Expected error for invalid code, got nil") + } +} + +func TestDoFile(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Create a temporary Lua file + content := []byte("answer = 42") + tmpfile, err := os.CreateTemp("", "test-*.lua") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write(content); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + if err := tmpfile.Close(); err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } + + // Test LoadFile and DoFile + if err := state.LoadFile(tmpfile.Name()); err != nil { + t.Fatalf("LoadFile failed: %v", err) + } + + if err := state.Call(0, 0); err != nil { + t.Fatalf("Call failed after LoadFile: %v", err) + } + + state.GetGlobal("answer") + if state.ToNumber(-1) != 42 { + t.Fatalf("Incorrect result after LoadFile, expected 42, got %f", state.ToNumber(-1)) + } + state.Pop(1) + + // Reset global + if err := state.DoString("answer = nil"); err != nil { + t.Fatalf("Failed to reset answer: %v", err) + } + + // Test DoFile + if err := state.DoFile(tmpfile.Name()); err != nil { + t.Fatalf("DoFile failed: %v", err) + } + + state.GetGlobal("answer") + if state.ToNumber(-1) != 42 { + t.Fatalf("Incorrect result after DoFile, expected 42, got %f", state.ToNumber(-1)) + } + state.Pop(1) +} + +func TestPackagePath(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Test SetPackagePath + testPath := "/test/path/?.lua" + if err := state.SetPackagePath(testPath); err != nil { + t.Fatalf("SetPackagePath failed: %v", err) + } + + result, err := state.ExecuteWithResult("return package.path") + if err != nil { + t.Fatalf("Failed to get package.path: %v", err) + } + + if result != testPath { + t.Fatalf("Expected package.path to be '%s', got '%s'", testPath, result) + } + + // Test AddPackagePath + addPath := "/another/path/?.lua" + if err := state.AddPackagePath(addPath); err != nil { + t.Fatalf("AddPackagePath failed: %v", err) + } + + result, err = state.ExecuteWithResult("return package.path") + if err != nil { + t.Fatalf("Failed to get package.path: %v", err) + } + + expected := testPath + ";" + addPath + if result != expected { + t.Fatalf("Expected package.path to be '%s', got '%s'", expected, result) + } +} + +func TestPushValueAndToValue(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + testCases := []struct { + value any + }{ + {nil}, + {true}, + {false}, + {42}, + {42.5}, + {"hello"}, + {[]float64{1, 2, 3, 4, 5}}, + {[]any{1, "test", true}}, + {map[string]any{"a": 1, "b": "test", "c": true}}, + } + + for i, tc := range testCases { + // Push value + err := state.PushValue(tc.value) + if err != nil { + t.Fatalf("PushValue failed for testCase %d: %v", i, err) + } + + // Check stack + if state.GetTop() != i+1 { + t.Fatalf("Stack size incorrect after push, expected %d, got %d", i+1, state.GetTop()) + } + } + + // Test conversion back to Go + for i := range testCases { + index := len(testCases) - i + value, err := state.ToValue(index) + if err != nil { + t.Fatalf("ToValue failed for index %d: %v", index, err) + } + + // For tables, we need special handling due to how Go types are stored + switch expected := testCases[index-1].value.(type) { + case []float64: + // Arrays come back as map[string]any with empty key + if m, ok := value.(map[string]any); ok { + if arr, ok := m[""].([]float64); ok { + if !reflect.DeepEqual(arr, expected) { + t.Fatalf("Value mismatch for testCase %d: expected %v, got %v", index-1, expected, arr) + } + } else { + t.Fatalf("Invalid array conversion for testCase %d", index-1) + } + } else { + t.Fatalf("Expected map for array value in testCase %d, got %T", index-1, value) + } + case int: + if num, ok := value.(float64); ok { + if float64(expected) == num { + continue // Values match after type conversion + } + } + case []any: + // Skip detailed comparison for mixed arrays + case map[string]any: + // Skip detailed comparison for maps + default: + if !reflect.DeepEqual(value, testCases[index-1].value) { + t.Fatalf("Value mismatch for testCase %d: expected %v, got %v", + index-1, testCases[index-1].value, value) + } + } + } + + // Test unsupported type + complex := complex(1, 2) + err := state.PushValue(complex) + if err == nil { + t.Fatalf("Expected error for unsupported type") + } +} diff --git a/types.go b/types.go index a778dba..48b4559 100644 --- a/types.go +++ b/types.go @@ -9,7 +9,7 @@ import "C" type LuaType int const ( - // These constants must match lua.h's LUA_T* values + // These constants match lua.h's LUA_T* values TypeNone LuaType = -1 TypeNil LuaType = 0 TypeBoolean LuaType = 1 diff --git a/wrapper.go b/wrapper.go index eb2a3f4..d37a12d 100644 --- a/wrapper.go +++ b/wrapper.go @@ -9,61 +9,30 @@ package luajit #include #include #include +#include -// Simple wrapper around luaL_loadstring -static int load_chunk(lua_State *L, const char *s) { - return luaL_loadstring(L, s); -} - -// Direct wrapper around lua_pcall -static int protected_call(lua_State *L, int nargs, int nresults, int errfunc) { - return lua_pcall(L, nargs, nresults, errfunc); -} - -// Combined load and execute with no results -static int do_string(lua_State *L, const char *s) { - return luaL_dostring(L, s); -} - -// Combined load and execute file -static int do_file(lua_State *L, const char *filename) { - return luaL_dofile(L, filename); -} - -// Execute string with multiple returns -static int execute_string(lua_State *L, const char *s) { - int base = lua_gettop(L); // Save stack position - int status = luaL_loadstring(L, s); - if (status) return -status; // Return negative status for load errors - - status = lua_pcall(L, 0, LUA_MULTRET, 0); - if (status) return -status; // Return negative status for runtime errors - - return lua_gettop(L) - base; // Return number of results -} - -// Get absolute stack index (converts negative indices) +// Helper to simplify some common operations static int get_abs_index(lua_State *L, int idx) { if (idx > 0 || idx <= LUA_REGISTRYINDEX) return idx; return lua_gettop(L) + idx + 1; } -// Stack manipulation helpers -static int check_stack(lua_State *L, int n) { - return lua_checkstack(L, n); +// Combined load and execute with no results +static int do_string(lua_State *L, const char *s) { + int status = luaL_loadstring(L, s); + if (status == 0) { + status = lua_pcall(L, 0, 0, 0); + } + return status; } -static void remove_stack(lua_State *L, int idx) { - lua_remove(L, idx); -} - -static int get_field_helper(lua_State *L, int idx, const char *k) { - lua_getfield(L, idx, k); - return lua_type(L, -1); -} - -static void set_field_helper(lua_State *L, int idx, const char *k) { - lua_setfield(L, idx, k); +// Combined load and execute file +static int do_file(lua_State *L, const char *filename) { + int status = luaL_loadfile(L, filename); + if (status == 0) { + status = lua_pcall(L, 0, 0, 0); + } + return status; } */ import "C" @@ -78,7 +47,7 @@ type State struct { L *C.lua_State } -// New creates a new Lua state +// New creates a new Lua state with all standard libraries loaded func New() *State { L := C.luaL_newstate() if L == nil { @@ -88,51 +57,195 @@ func New() *State { return &State{L: L} } -// Close closes the Lua state +// Close closes the Lua state and frees resources func (s *State) Close() { if s.L != nil { + s.Cleanup() // Clean up Go function registry C.lua_close(s.L) s.L = nil } } -// DoString executes a Lua string. -func (s *State) DoString(str string) error { - // Save initial stack size - top := s.GetTop() +// Stack manipulation methods - // Load the string - if err := s.LoadString(str); err != nil { - return err - } - - // Execute and check for errors - if err := s.Call(0, 0); err != nil { - return err - } - - // Restore stack to initial size to clean up any leftovers - s.SetTop(top) - return nil +// GetTop returns the index of the top element in the stack +func (s *State) GetTop() int { + return int(C.lua_gettop(s.L)) } -// PushValue pushes a Go value onto the stack +// SetTop sets the stack top to a specific index +func (s *State) SetTop(index int) { + C.lua_settop(s.L, C.int(index)) +} + +// PushValue pushes a copy of the value at the given index onto the stack +func (s *State) PushCopy(index int) { + C.lua_pushvalue(s.L, C.int(index)) +} + +// Pop pops n elements from the stack +func (s *State) Pop(n int) { + C.lua_settop(s.L, C.int(-n-1)) +} + +// Remove removes the element at the given valid index +func (s *State) Remove(index int) { + C.lua_remove(s.L, C.int(index)) +} + +// absIndex converts a possibly negative index to its absolute position +func (s *State) absIndex(index int) int { + if index > 0 || index <= LUA_REGISTRYINDEX { + return index + } + return s.GetTop() + index + 1 +} + +// Type checking methods + +// GetType returns the type of the value at the given index +func (s *State) GetType(index int) LuaType { + return LuaType(C.lua_type(s.L, C.int(index))) +} + +// IsNil checks if the value at the given index is nil +func (s *State) IsNil(index int) bool { + return s.GetType(index) == TypeNil +} + +// IsBoolean checks if the value at the given index is a boolean +func (s *State) IsBoolean(index int) bool { + return s.GetType(index) == TypeBoolean +} + +// IsNumber checks if the value at the given index is a number +func (s *State) IsNumber(index int) bool { + return C.lua_isnumber(s.L, C.int(index)) != 0 +} + +// IsString checks if the value at the given index is a string +func (s *State) IsString(index int) bool { + return C.lua_isstring(s.L, C.int(index)) != 0 +} + +// IsTable checks if the value at the given index is a table +func (s *State) IsTable(index int) bool { + return s.GetType(index) == TypeTable +} + +// IsFunction checks if the value at the given index is a function +func (s *State) IsFunction(index int) bool { + return s.GetType(index) == TypeFunction +} + +// Value conversion methods + +// ToBoolean returns the value at the given index as a boolean +func (s *State) ToBoolean(index int) bool { + return C.lua_toboolean(s.L, C.int(index)) != 0 +} + +// ToNumber returns the value at the given index as a number +func (s *State) ToNumber(index int) float64 { + return float64(C.lua_tonumber(s.L, C.int(index))) +} + +// ToString returns the value at the given index as a string +func (s *State) ToString(index int) string { + var length C.size_t + cstr := C.lua_tolstring(s.L, C.int(index), &length) + if cstr == nil { + return "" + } + return C.GoStringN(cstr, C.int(length)) +} + +// Push methods + +// PushNil pushes a nil value onto the stack +func (s *State) PushNil() { + C.lua_pushnil(s.L) +} + +// PushBoolean pushes a boolean value onto the stack +func (s *State) PushBoolean(b bool) { + var value C.int + if b { + value = 1 + } + C.lua_pushboolean(s.L, value) +} + +// PushNumber pushes a number value onto the stack +func (s *State) PushNumber(n float64) { + C.lua_pushnumber(s.L, C.lua_Number(n)) +} + +// PushString pushes a string value onto the stack +func (s *State) PushString(str string) { + cstr := C.CString(str) + defer C.free(unsafe.Pointer(cstr)) + C.lua_pushlstring(s.L, cstr, C.size_t(len(str))) +} + +// Table operations + +// CreateTable creates a new table and pushes it onto the stack +func (s *State) CreateTable(narr, nrec int) { + C.lua_createtable(s.L, C.int(narr), C.int(nrec)) +} + +// NewTable creates a new empty table and pushes it onto the stack +func (s *State) NewTable() { + C.lua_createtable(s.L, 0, 0) +} + +// GetTable gets a table field (t[k]) where t is at the given index and k is at the top of the stack +func (s *State) GetTable(index int) { + C.lua_gettable(s.L, C.int(index)) +} + +// SetTable sets a table field (t[k] = v) where t is at the given index, k is at -2, and v is at -1 +func (s *State) SetTable(index int) { + C.lua_settable(s.L, C.int(index)) +} + +// GetField gets a table field t[k] and pushes it onto the stack +func (s *State) GetField(index int, key string) { + ckey := C.CString(key) + defer C.free(unsafe.Pointer(ckey)) + C.lua_getfield(s.L, C.int(index), ckey) +} + +// SetField sets a table field t[k] = v, where v is the value at the top of the stack +func (s *State) SetField(index int, key string) { + ckey := C.CString(key) + defer C.free(unsafe.Pointer(ckey)) + C.lua_setfield(s.L, C.int(index), ckey) +} + +// Next pops a key from the stack and pushes the next key-value pair from the table at the given index +func (s *State) Next(index int) bool { + return C.lua_next(s.L, C.int(index)) != 0 +} + +// PushValue pushes a Go value onto the stack with proper type conversion func (s *State) PushValue(v interface{}) error { switch v := v.(type) { case nil: s.PushNil() case bool: s.PushBoolean(v) - case float64: - s.PushNumber(v) case int: s.PushNumber(float64(v)) + case float64: + s.PushNumber(v) case string: s.PushString(v) case map[string]interface{}: // Special case: handle array stored in map if arr, ok := v[""].([]float64); ok { - s.NewTable() + s.CreateTable(len(arr), 0) for i, elem := range arr { s.PushNumber(float64(i + 1)) s.PushNumber(elem) @@ -142,14 +255,14 @@ func (s *State) PushValue(v interface{}) error { } return s.PushTable(v) case []float64: - s.NewTable() + s.CreateTable(len(v), 0) for i, elem := range v { s.PushNumber(float64(i + 1)) s.PushNumber(elem) s.SetTable(-3) } case []interface{}: - s.NewTable() + s.CreateTable(len(v), 0) for i, elem := range v { s.PushNumber(float64(i + 1)) if err := s.PushValue(elem); err != nil { @@ -163,9 +276,10 @@ func (s *State) PushValue(v interface{}) error { return nil } -// ToValue converts a Lua value to a Go value +// ToValue converts a Lua value at the given index to a Go value func (s *State) ToValue(index int) (interface{}, error) { - switch s.GetType(index) { + luaType := s.GetType(index) + switch luaType { case TypeNil: return nil, nil case TypeBoolean: @@ -175,231 +289,161 @@ func (s *State) ToValue(index int) (interface{}, error) { case TypeString: return s.ToString(index), nil case TypeTable: - if !s.IsTable(index) { - return nil, fmt.Errorf("not a table at index %d", index) - } return s.ToTable(index) default: - return nil, fmt.Errorf("unsupported type: %s", s.GetType(index)) + return nil, fmt.Errorf("unsupported type: %s", luaType) } } -// Simple operations remain unchanged as they don't need stack protection +// Global operations -func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) } -func (s *State) IsString(index int) bool { return s.GetType(index) == TypeString } -func (s *State) IsNumber(index int) bool { return s.GetType(index) == TypeNumber } -func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction } -func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable } -func (s *State) IsNil(index int) bool { return s.GetType(index) == TypeNil } -func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 } -func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) } -func (s *State) ToString(index int) string { - return C.GoString(C.lua_tolstring(s.L, C.int(index), nil)) -} -func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) } -func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) } -func (s *State) SetTop(index int) { C.lua_settop(s.L, C.int(index)) } - -// Push operations - -func (s *State) PushNil() { C.lua_pushnil(s.L) } -func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, C.int(bool2int(b))) } -func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.double(n)) } -func (s *State) PushString(str string) { - cstr := C.CString(str) - defer C.free(unsafe.Pointer(cstr)) - C.lua_pushstring(s.L, cstr) -} - -func (s *State) Next(index int) bool { - return C.lua_next(s.L, C.int(index)) != 0 -} - -// Helper functions -func bool2int(b bool) int { - if b { - return 1 - } - return 0 -} - -func (s *State) absIndex(index int) int { - if index > 0 || index <= LUA_REGISTRYINDEX { - return index - } - return s.GetTop() + index + 1 -} - -// SetField sets a field in a table at the given index -func (s *State) SetField(index int, key string) { - cstr := C.CString(key) - defer C.free(unsafe.Pointer(cstr)) - C.lua_setfield(s.L, C.int(index), cstr) -} - -// GetField gets a field from a table -func (s *State) GetField(index int, key string) { - cstr := C.CString(key) - defer C.free(unsafe.Pointer(cstr)) - C.lua_getfield(s.L, C.int(index), cstr) -} - -// GetGlobal gets a global variable and pushes it onto the stack +// GetGlobal pushes the global variable with the given name onto the stack func (s *State) GetGlobal(name string) { - cname := C.CString(name) - defer C.free(unsafe.Pointer(cname)) - C.get_field_helper(s.L, C.LUA_GLOBALSINDEX, cname) + s.GetField(LUA_GLOBALSINDEX, name) } -// SetGlobal sets a global variable from the value at the top of the stack +// SetGlobal sets the global variable with the given name to the value at the top of the stack func (s *State) SetGlobal(name string) { - cstr := C.CString(name) - defer C.free(unsafe.Pointer(cstr)) - C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr) + s.SetField(LUA_GLOBALSINDEX, name) } -// Remove removes element with cached absolute index -func (s *State) Remove(index int) { - absIdx := index - C.lua_remove(s.L, C.int(absIdx)) +// Code execution methods + +// LoadString loads a Lua chunk from a string without executing it +func (s *State) LoadString(code string) error { + ccode := C.CString(code) + defer C.free(unsafe.Pointer(ccode)) + + return s.safeCall(func() C.int { + return C.luaL_loadstring(s.L, ccode) + }) } -// DoFile executes a Lua file with appropriate stack management +// LoadFile loads a Lua chunk from a file without executing it +func (s *State) LoadFile(filename string) error { + cfilename := C.CString(filename) + defer C.free(unsafe.Pointer(cfilename)) + + return s.safeCall(func() C.int { + return C.luaL_loadfile(s.L, cfilename) + }) +} + +// Call calls a function with the given number of arguments and results +func (s *State) Call(nargs, nresults int) error { + return s.safeCall(func() C.int { + return C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0) + }) +} + +// DoString executes a Lua string and cleans up the stack +func (s *State) DoString(code string) error { + ccode := C.CString(code) + defer C.free(unsafe.Pointer(ccode)) + + return s.safeCall(func() C.int { + return C.do_string(s.L, ccode) + }) +} + +// DoFile executes a Lua file and cleans up the stack func (s *State) DoFile(filename string) error { cfilename := C.CString(filename) defer C.free(unsafe.Pointer(cfilename)) - status := C.do_file(s.L, cfilename) - if status != 0 { - return &LuaError{ - Code: int(status), - Message: s.ToString(-1), + return s.safeCall(func() C.int { + return C.do_file(s.L, cfilename) + }) +} + +// Execute executes a Lua string and returns the number of results left on the stack +func (s *State) Execute(code string) (int, error) { + baseTop := s.GetTop() + + ccode := C.CString(code) + defer C.free(unsafe.Pointer(ccode)) + + var nresults int + err := s.safeCall(func() C.int { + status := C.luaL_loadstring(s.L, ccode) + if status != 0 { + return status } - } - return nil -} -// SetPackagePath sets the Lua package.path -func (s *State) SetPackagePath(path string) error { - path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths - cmd := fmt.Sprintf(`package.path = %q`, path) - return s.DoString(cmd) -} - -// AddPackagePath adds a path to package.path -func (s *State) AddPackagePath(path string) error { - path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths - cmd := fmt.Sprintf(`package.path = package.path .. ";%s"`, path) - return s.DoString(cmd) -} - -// Call executes a function on the stack with the given number of arguments and results. -// The function and arguments should already be on the stack in the correct order -// (function first, then args from left to right). -func (s *State) Call(nargs, nresults int) error { - if !s.IsFunction(-nargs - 1) { - return fmt.Errorf("attempt to call a non-function") - } - status := C.protected_call(s.L, C.int(nargs), C.int(nresults), 0) - if status != 0 { - err := &LuaError{ - Code: int(status), - Message: s.ToString(-1), + status = C.lua_pcall(s.L, 0, C.LUA_MULTRET, 0) + if status == 0 { + nresults = s.GetTop() - baseTop } - s.Pop(1) - return err - } - return nil -} + return status + }) -// LoadString loads but does not execute a string of Lua code. -// The compiled code chunk is left on the stack. -func (s *State) LoadString(str string) error { - cstr := C.CString(str) - defer C.free(unsafe.Pointer(cstr)) - - status := C.load_chunk(s.L, cstr) - if status != 0 { - err := &LuaError{ - Code: int(status), - Message: s.ToString(-1), - } - s.Pop(1) - return err - } - - if !s.IsFunction(-1) { - s.Pop(1) - return fmt.Errorf("failed to load function") - } - return nil -} - -// ExecuteString executes a string of Lua code and returns the number of results. -// The results are left on the stack. -func (s *State) ExecuteString(str string) (int, error) { - base := s.GetTop() - - // First load the string - if err := s.LoadString(str); err != nil { - return 0, err - } - - // Now execute it - if err := s.Call(0, C.LUA_MULTRET); err != nil { - return 0, err - } - - return s.GetTop() - base, nil -} - -// ExecuteStringResult executes a Lua string and returns its first result as a Go value. -// It's a convenience wrapper around ExecuteString for the common case of wanting -// a single return value. The stack is restored to its original state after execution. -func (s *State) ExecuteStringResult(code string) (interface{}, error) { - top := s.GetTop() - defer s.SetTop(top) // Restore stack when we're done - - nresults, err := s.ExecuteString(code) if err != nil { - return nil, fmt.Errorf("execution error: %w", err) + return 0, err + } + + return nresults, nil +} + +// ExecuteWithResult executes a Lua string and returns the first result +func (s *State) ExecuteWithResult(code string) (interface{}, error) { + top := s.GetTop() + defer s.SetTop(top) // Restore stack when done + + nresults, err := s.Execute(code) + if err != nil { + return nil, err } if nresults == 0 { return nil, nil } - // Get the result - result, err := s.ToValue(-nresults) // Get first result - if err != nil { - return nil, fmt.Errorf("error converting result: %w", err) - } - - return result, nil + return s.ToValue(-nresults) } -// DoStringResult executes a Lua string and expects a single return value. -// Unlike ExecuteStringResult, this function specifically expects exactly one -// return value and will return an error if the code returns 0 or multiple values. -func (s *State) DoStringResult(code string) (interface{}, error) { - top := s.GetTop() - defer s.SetTop(top) // Restore stack when we're done +// Package path operations - nresults, err := s.ExecuteString(code) - if err != nil { - return nil, fmt.Errorf("execution error: %w", err) - } - - if nresults != 1 { - return nil, fmt.Errorf("expected 1 return value, got %d", nresults) - } - - // Get the result - result, err := s.ToValue(-1) - if err != nil { - return nil, fmt.Errorf("error converting result: %w", err) - } - - return result, nil +// SetPackagePath sets the Lua package.path +func (s *State) SetPackagePath(path string) error { + path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths + return s.DoString(fmt.Sprintf(`package.path = %q`, path)) +} + +// AddPackagePath adds a path to package.path +func (s *State) AddPackagePath(path string) error { + path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths + return s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)) +} + +// Helper functions + +// checkStack ensures there is enough space on the Lua stack +func (s *State) checkStack(n int) error { + if C.lua_checkstack(s.L, C.int(n)) == 0 { + return fmt.Errorf("stack overflow (cannot allocate %d slots)", n) + } + return nil +} + +// safeCall wraps a potentially dangerous C call with stack checking +func (s *State) safeCall(f func() C.int) error { + // Ensure we have enough stack space + if err := s.checkStack(LUA_MINSTACK); err != nil { + return err + } + + // Make the call + status := f() + + // Check for errors + if status != 0 { + err := &LuaError{ + Code: int(status), + Message: s.ToString(-1), + } + s.Pop(1) // Remove error message + return err + } + + return nil } diff --git a/wrapper_bench_test.go b/wrapper_bench_test.go deleted file mode 100644 index 6316ae6..0000000 --- a/wrapper_bench_test.go +++ /dev/null @@ -1,237 +0,0 @@ -package luajit - -import ( - "testing" -) - -var benchCases = []struct { - name string - code string -}{ - { - name: "SimpleAddition", - code: `return 1 + 1`, - }, - { - name: "LoopSum", - code: ` - local sum = 0 - for i = 1, 1000 do - sum = sum + i - end - return sum - `, - }, - { - name: "FunctionCall", - code: ` - local result = 0 - for i = 1, 100 do - result = result + i - end - return result - `, - }, - { - name: "TableCreation", - code: ` - local t = {} - for i = 1, 100 do - t[i] = i * 2 - end - return t[50] - `, - }, - { - name: "StringOperations", - code: ` - local s = "hello" - for i = 1, 10 do - s = s .. " world" - end - return #s - `, - }, -} - -func BenchmarkLuaDirectExecution(b *testing.B) { - for _, bc := range benchCases { - b.Run(bc.name, func(b *testing.B) { - L := New() - if L == nil { - b.Fatal("Failed to create Lua state") - } - defer L.Close() - - // First verify we can execute the code - if err := L.DoString(bc.code); err != nil { - b.Fatalf("Failed to execute test code: %v", err) - } - L.Pop(1) // Clean up the result - - b.ResetTimer() - for i := 0; i < b.N; i++ { - // Execute string and get result - nresults, err := L.ExecuteString(bc.code) - if err != nil { - b.Fatalf("Failed to execute code: %v", err) - } - L.Pop(nresults) // Clean up any results - } - }) - } -} - -func BenchmarkLuaBytecodeExecution(b *testing.B) { - // First compile all bytecode - bytecodes := make(map[string][]byte) - for _, bc := range benchCases { - L := New() - if L == nil { - b.Fatal("Failed to create Lua state") - } - bytecode, err := L.CompileBytecode(bc.code, bc.name) - if err != nil { - L.Close() - b.Fatalf("Error compiling bytecode for %s: %v", bc.name, err) - } - bytecodes[bc.name] = bytecode - L.Close() - } - - for _, bc := range benchCases { - b.Run(bc.name, func(b *testing.B) { - L := New() - if L == nil { - b.Fatal("Failed to create Lua state") - } - defer L.Close() - - bytecode := bytecodes[bc.name] - - // First verify we can execute the bytecode - if err := L.LoadBytecode(bytecode, bc.name); err != nil { - b.Fatalf("Failed to execute test bytecode: %v", err) - } - - b.ResetTimer() - b.SetBytes(int64(len(bytecode))) // Track bytecode size in benchmarks - - for i := 0; i < b.N; i++ { - if err := L.LoadBytecode(bytecode, bc.name); err != nil { - b.Fatalf("Error executing bytecode: %v", err) - } - } - }) - } -} - -func BenchmarkTableOperations(b *testing.B) { - testData := map[string]interface{}{ - "number": 42.0, - "string": "hello", - "bool": true, - "nested": map[string]interface{}{ - "value": 123.0, - "array": []float64{1.1, 2.2, 3.3}, - }, - } - - b.Run("PushTable", func(b *testing.B) { - L := New() - if L == nil { - b.Fatal("Failed to create Lua state") - } - defer L.Close() - - // First verify we can push the table - if err := L.PushTable(testData); err != nil { - b.Fatalf("Failed to push initial table: %v", err) - } - L.Pop(1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := L.PushTable(testData); err != nil { - b.Fatalf("Failed to push table: %v", err) - } - L.Pop(1) - } - }) - - b.Run("ToTable", func(b *testing.B) { - L := New() - if L == nil { - b.Fatal("Failed to create Lua state") - } - defer L.Close() - - // Keep a table on the stack for repeated conversions - if err := L.PushTable(testData); err != nil { - b.Fatalf("Failed to push initial table: %v", err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - if _, err := L.ToTable(-1); err != nil { - b.Fatalf("Failed to convert table: %v", err) - } - } - }) -} - -func BenchmarkValueConversion(b *testing.B) { - testValues := []struct { - name string - value interface{} - }{ - {"Number", 42.0}, - {"String", "hello world"}, - {"Boolean", true}, - {"Nil", nil}, - } - - for _, tv := range testValues { - b.Run("Push"+tv.name, func(b *testing.B) { - L := New() - if L == nil { - b.Fatal("Failed to create Lua state") - } - defer L.Close() - - // First verify we can push the value - if err := L.PushValue(tv.value); err != nil { - b.Fatalf("Failed to push initial value: %v", err) - } - L.Pop(1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := L.PushValue(tv.value); err != nil { - b.Fatalf("Failed to push value: %v", err) - } - L.Pop(1) - } - }) - - b.Run("To"+tv.name, func(b *testing.B) { - L := New() - if L == nil { - b.Fatal("Failed to create Lua state") - } - defer L.Close() - - // Keep a value on the stack for repeated conversions - if err := L.PushValue(tv.value); err != nil { - b.Fatalf("Failed to push initial value: %v", err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - if _, err := L.ToValue(-1); err != nil { - b.Fatalf("Failed to convert value: %v", err) - } - } - }) - } -} diff --git a/wrapper_test.go b/wrapper_test.go deleted file mode 100644 index d7ce9e2..0000000 --- a/wrapper_test.go +++ /dev/null @@ -1,693 +0,0 @@ -package luajit - -import ( - "fmt" - "os" - "path/filepath" - "testing" -) - -func TestNew(t *testing.T) { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() -} - -func TestLoadString(t *testing.T) { - tests := []struct { - name string - code string - wantErr bool - }{ - { - name: "valid function", - code: "function add(a, b) return a + b end", - wantErr: false, - }, - { - name: "valid expression", - code: "return 1 + 1", - wantErr: false, - }, - { - name: "syntax error", - code: "function bad syntax", - wantErr: true, - }, - } - - for _, tt := range tests { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - err := L.LoadString(tt.code) - if (err != nil) != tt.wantErr { - t.Errorf("LoadString() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr { - // Verify the function is on the stack - if L.GetTop() != 1 { - t.Error("LoadString() did not leave exactly one value on stack") - } - if !L.IsFunction(-1) { - t.Error("LoadString() did not leave a function on the stack") - } - } - } -} - -func TestExecuteString(t *testing.T) { - tests := []struct { - name string - code string - wantResults int - checkResults func(*State) error - wantErr bool - wantStackSize int - }{ - { - name: "no results", - code: "local x = 1", - wantResults: 0, - wantErr: false, - }, - { - name: "single result", - code: "return 42", - wantResults: 1, - checkResults: func(L *State) error { - if n := L.ToNumber(-1); n != 42 { - return fmt.Errorf("got %v, want 42", n) - } - return nil - }, - wantErr: false, - }, - { - name: "multiple results", - code: "return 1, 'test', true", - wantResults: 3, - checkResults: func(L *State) error { - if n := L.ToNumber(-3); n != 1 { - return fmt.Errorf("first result: got %v, want 1", n) - } - if s := L.ToString(-2); s != "test" { - return fmt.Errorf("second result: got %v, want 'test'", s) - } - if b := L.ToBoolean(-1); !b { - return fmt.Errorf("third result: got %v, want true", b) - } - return nil - }, - wantErr: false, - }, - { - name: "syntax error", - code: "this is not valid lua", - wantErr: true, - }, - { - name: "runtime error", - code: "error('test error')", - wantErr: true, - }, - } - - for _, tt := range tests { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - // Record initial stack size - initialStack := L.GetTop() - - results, err := L.ExecuteString(tt.code) - if (err != nil) != tt.wantErr { - t.Errorf("ExecuteString() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if err == nil { - if results != tt.wantResults { - t.Errorf("ExecuteString() returned %d results, want %d", results, tt.wantResults) - } - - if tt.checkResults != nil { - if err := tt.checkResults(L); err != nil { - t.Errorf("Result check failed: %v", err) - } - } - - // Verify stack size matches expected results - if got := L.GetTop() - initialStack; got != tt.wantResults { - t.Errorf("Stack size grew by %d, want %d", got, tt.wantResults) - } - } - } -} - -func TestDoString(t *testing.T) { - tests := []struct { - name string - code string - wantErr bool - }{ - {"simple addition", "return 1 + 1", false}, - {"set global", "test = 42", false}, - {"syntax error", "this is not valid lua", true}, - {"runtime error", "error('test error')", true}, - } - - for _, tt := range tests { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - initialStack := L.GetTop() - err := L.DoString(tt.code) - if (err != nil) != tt.wantErr { - t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr) - } - - // Verify stack is unchanged - if finalStack := L.GetTop(); finalStack != initialStack { - t.Errorf("Stack size changed from %d to %d", initialStack, finalStack) - } - } -} - -func TestPushAndGetValues(t *testing.T) { - values := []struct { - name string - push func(*State) - check func(*State) error - }{ - { - name: "string", - push: func(L *State) { L.PushString("hello") }, - check: func(L *State) error { - if got := L.ToString(-1); got != "hello" { - return fmt.Errorf("got %q, want %q", got, "hello") - } - return nil - }, - }, - { - name: "number", - push: func(L *State) { L.PushNumber(42.5) }, - check: func(L *State) error { - if got := L.ToNumber(-1); got != 42.5 { - return fmt.Errorf("got %f, want %f", got, 42.5) - } - return nil - }, - }, - { - name: "boolean", - push: func(L *State) { L.PushBoolean(true) }, - check: func(L *State) error { - if got := L.ToBoolean(-1); !got { - return fmt.Errorf("got %v, want true", got) - } - return nil - }, - }, - { - name: "nil", - push: func(L *State) { L.PushNil() }, - check: func(L *State) error { - if typ := L.GetType(-1); typ != TypeNil { - return fmt.Errorf("got type %v, want TypeNil", typ) - } - return nil - }, - }, - } - - for _, v := range values { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - v.push(L) - if err := v.check(L); err != nil { - t.Error(err) - } - } -} - -func TestStackManipulation(t *testing.T) { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - // Push values - values := []string{"first", "second", "third"} - for _, v := range values { - L.PushString(v) - } - - // Check size - if top := L.GetTop(); top != len(values) { - t.Errorf("stack size = %d, want %d", top, len(values)) - } - - // Pop one value - L.Pop(1) - - // Check new top - if str := L.ToString(-1); str != "second" { - t.Errorf("top value = %q, want 'second'", str) - } - - // Check new size - if top := L.GetTop(); top != len(values)-1 { - t.Errorf("stack size after pop = %d, want %d", top, len(values)-1) - } -} - -func TestGlobals(t *testing.T) { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - // Test via Lua - if err := L.DoString(`globalVar = "test"`); err != nil { - t.Fatalf("DoString error: %v", err) - } - - // Get the global - L.GetGlobal("globalVar") - if str := L.ToString(-1); str != "test" { - t.Errorf("global value = %q, want 'test'", str) - } - L.Pop(1) - - // Set and get via API - L.PushNumber(42) - L.SetGlobal("testNum") - - L.GetGlobal("testNum") - if num := L.ToNumber(-1); num != 42 { - t.Errorf("global number = %f, want 42", num) - } -} - -func TestCall(t *testing.T) { - tests := []struct { - funcName string // Add explicit function name field - setup string - args []interface{} - nresults int - checkStack func(*State) error - wantErr bool - }{ - { - funcName: "add", - setup: "function add(a, b) return a + b end", - args: []interface{}{float64(40), float64(2)}, - nresults: 1, - checkStack: func(L *State) error { - if n := L.ToNumber(-1); n != 42 { - return fmt.Errorf("got %v, want 42", n) - } - return nil - }, - }, - { - funcName: "multi", - setup: "function multi() return 1, 'test', true end", - args: []interface{}{}, - nresults: 3, - checkStack: func(L *State) error { - if n := L.ToNumber(-3); n != 1 { - return fmt.Errorf("first result: got %v, want 1", n) - } - if s := L.ToString(-2); s != "test" { - return fmt.Errorf("second result: got %v, want 'test'", s) - } - if b := L.ToBoolean(-1); !b { - return fmt.Errorf("third result: got %v, want true", b) - } - return nil - }, - }, - { - funcName: "err", - setup: "function err() error('test error') end", - args: []interface{}{}, - nresults: 0, - wantErr: true, - }, - } - - for _, tt := range tests { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - // Setup function - if err := L.DoString(tt.setup); err != nil { - t.Fatalf("Setup failed: %v", err) - } - - // Get function - L.GetGlobal(tt.funcName) - if !L.IsFunction(-1) { - t.Fatal("Failed to get function") - } - - // Push arguments - for _, arg := range tt.args { - if err := L.PushValue(arg); err != nil { - t.Fatalf("Failed to push argument: %v", err) - } - } - - // Call function - err := L.Call(len(tt.args), tt.nresults) - if (err != nil) != tt.wantErr { - t.Errorf("Call() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if err == nil && tt.checkStack != nil { - if err := tt.checkStack(L); err != nil { - t.Errorf("Stack check failed: %v", err) - } - } - } -} - -func TestDoFile(t *testing.T) { - L := New() - defer L.Close() - - // Create test file - content := []byte(` - function add(a, b) - return a + b - end - result = add(40, 2) - `) - - tmpDir := t.TempDir() - filename := filepath.Join(tmpDir, "test.lua") - if err := os.WriteFile(filename, content, 0644); err != nil { - t.Fatalf("Failed to create test file: %v", err) - } - - if err := L.DoFile(filename); err != nil { - t.Fatalf("DoFile failed: %v", err) - } - - L.GetGlobal("result") - if result := L.ToNumber(-1); result != 42 { - t.Errorf("Expected result=42, got %v", result) - } -} - -func TestRequireAndPackagePath(t *testing.T) { - L := New() - defer L.Close() - - tmpDir := t.TempDir() - - // Create module file - moduleContent := []byte(` - local M = {} - function M.multiply(a, b) - return a * b - end - return M - `) - - if err := os.WriteFile(filepath.Join(tmpDir, "mathmod.lua"), moduleContent, 0644); err != nil { - t.Fatalf("Failed to create module file: %v", err) - } - - // Add module path and test require - if err := L.AddPackagePath(filepath.Join(tmpDir, "?.lua")); err != nil { - t.Fatalf("AddPackagePath failed: %v", err) - } - - if err := L.DoString(` - local math = require("mathmod") - result = math.multiply(6, 7) - `); err != nil { - t.Fatalf("Failed to require module: %v", err) - } - - L.GetGlobal("result") - if result := L.ToNumber(-1); result != 42 { - t.Errorf("Expected result=42, got %v", result) - } -} - -func TestSetPackagePath(t *testing.T) { - L := New() - defer L.Close() - - customPath := "./custom/?.lua" - if err := L.SetPackagePath(customPath); err != nil { - t.Fatalf("SetPackagePath failed: %v", err) - } - - L.GetGlobal("package") - L.GetField(-1, "path") - if path := L.ToString(-1); path != customPath { - t.Errorf("Expected package.path=%q, got %q", customPath, path) - } - - // Test that the old path is completely replaced - initialPath := L.ToString(-1) - anotherPath := "./another/?.lua" - if err := L.SetPackagePath(anotherPath); err != nil { - t.Fatalf("Second SetPackagePath failed: %v", err) - } - - L.GetGlobal("package") - L.GetField(-1, "path") - if path := L.ToString(-1); path != anotherPath { - t.Errorf("Expected package.path=%q, got %q", anotherPath, path) - } - if path := L.ToString(-1); path == initialPath { - t.Error("SetPackagePath did not replace the old path") - } -} - -func TestStackDebug(t *testing.T) { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - - t.Log("Testing LoadString:") - initialTop := L.GetTop() - t.Logf("Initial stack size: %d", initialTop) - - err := L.LoadString("return 42") - if err != nil { - t.Errorf("LoadString failed: %v", err) - } - - afterLoad := L.GetTop() - t.Logf("Stack size after load: %d", afterLoad) - t.Logf("Type of top element: %s", L.GetType(-1)) - - if L.IsFunction(-1) { - t.Log("Top element is a function") - } else { - t.Log("Top element is NOT a function") - } - - // Clean up after LoadString test - L.SetTop(0) - - t.Log("\nTesting ExecuteString:") - if err := L.DoString("function test() return 1, 'hello', true end"); err != nil { - t.Errorf("DoString failed: %v", err) - } - - beforeExec := L.GetTop() - t.Logf("Stack size before execute: %d", beforeExec) - - nresults, err := L.ExecuteString("return test()") - if err != nil { - t.Errorf("ExecuteString failed: %v", err) - } - - afterExec := L.GetTop() - t.Logf("Stack size after execute: %d", afterExec) - t.Logf("Reported number of results: %d", nresults) - - // Print each stack element - for i := 1; i <= afterExec; i++ { - t.Logf("Stack[-%d] type: %s", i, L.GetType(-i)) - } - - if afterExec != nresults { - t.Errorf("Stack size (%d) doesn't match number of results (%d)", afterExec, nresults) - } -} - -func TestTemplateRendering(t *testing.T) { - L := New() - if L == nil { - t.Fatal("Failed to create Lua state") - } - defer L.Close() - defer L.Cleanup() - - // Create a simple render.template function - renderFunc := func(s *State) int { - // Template will be at index 1, data at index 2 - data, err := s.ToTable(2) - if err != nil { - s.PushString(fmt.Sprintf("failed to get data table: %v", err)) - return -1 - } - - // Push data back as global for template access - if err := s.PushTable(data); err != nil { - s.PushString(fmt.Sprintf("failed to push data table: %v", err)) - return -1 - } - s.SetGlobal("data") - - // Template processing code - luaCode := ` - local result = {} - if data.user.logged_in then - table.insert(result, '
') - table.insert(result, string.format('

Welcome, %s!

', tostring(data.user.name))) - table.insert(result, ' ') - if data.user.is_admin then - table.insert(result, '
') - table.insert(result, '

Admin Controls

') - table.insert(result, ' ') - table.insert(result, '
') - end - table.insert(result, '
') - else - table.insert(result, '
') - table.insert(result, '

Please log in to view your profile

') - table.insert(result, '
') - end - return table.concat(result, '\n')` - - result, err := s.DoStringResult(luaCode) - if err != nil { - s.PushString(fmt.Sprintf("template execution failed: %v", err)) - return -1 - } - - // Push the string result - if str, ok := result.(string); ok { - s.PushString(str) - return 1 - } - - s.PushString(fmt.Sprintf("expected string result, got %T", result)) - return -1 - } - - // Create render table and add template function - L.NewTable() - if err := L.PushGoFunction(renderFunc); err != nil { - t.Fatalf("Failed to create render function: %v", err) - } - L.SetField(-2, "template") - L.SetGlobal("render") - - // Test with logged in admin user - testCode := ` - local data = { - user = { - logged_in = true, - name = "John Doe", - email = "john@example.com", - joined_date = "2024-02-09", - is_admin = true - } - } - return render.template("test.html", data) - ` - - result, err := L.DoStringResult(testCode) - if err != nil { - t.Fatalf("Failed to execute test: %v", err) - } - - str, ok := result.(string) - if !ok { - t.Fatalf("Expected string result, got %T", result) - } - - expectedResult := `
-

Welcome, John Doe!

- -
-

Admin Controls

- -
-
` - - if str != expectedResult { - t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str) - } - - // Test with logged out user - testCode = ` - local data = { - user = { - logged_in = false - } - } - return render.template("test.html", data) - ` - - result, err = L.DoStringResult(testCode) - if err != nil { - t.Fatalf("Failed to execute logged out test: %v", err) - } - - str, ok = result.(string) - if !ok { - t.Fatalf("Expected string result, got %T", result) - } - - expectedResult = `
-

Please log in to view your profile

-
` - - if str != expectedResult { - t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str) - } -}