Wrapper rewrite

This commit is contained in:
Sky Johnson 2025-02-26 07:00:01 -06:00
parent 865ac8859f
commit 143b9333c6
24 changed files with 2863 additions and 1855 deletions

245
DOCS.md
View File

@ -2,16 +2,8 @@
## State Management
### NewSafe() *State
Creates a new Lua state with stack safety enabled.
```go
L := luajit.NewSafe()
defer L.Close()
defer L.Cleanup()
```
### New() *State
Creates a new Lua state without stack safety checks.
Creates a new Lua state with all standard libraries loaded.
```go
L := luajit.New()
defer L.Close()
@ -38,6 +30,18 @@ Returns the index of the top element in the stack.
top := L.GetTop() // 0 for empty stack
```
### SetTop(index int)
Sets the stack top to a specific index.
```go
L.SetTop(2) // Truncate stack to 2 elements
```
### PushCopy(index int)
Pushes a copy of the value at the given index onto the stack.
```go
L.PushCopy(-1) // Duplicate the top element
```
### Pop(n int)
Removes n elements from the stack.
```go
@ -52,6 +56,9 @@ L.Remove(-1) // Remove top element
L.Remove(1) // Remove first element
```
### absIndex(index int) int
Internal function that converts a possibly negative index to its absolute position.
### checkStack(n int) error
Internal function that ensures there's enough space for n new elements.
```go
@ -70,9 +77,12 @@ if L.GetType(-1) == TypeString {
}
```
### IsFunction(index int) bool
### IsNil(index int) bool
### IsBoolean(index int) bool
### IsNumber(index int) bool
### IsString(index int) bool
### IsTable(index int) bool
### IsUserData(index int) bool
### IsFunction(index int) bool
Type checking functions for specific Lua types.
```go
if L.IsTable(-1) {
@ -118,6 +128,12 @@ if err != nil {
}
```
### GetTableLength(index int) int
Returns the length of a table at the given index.
```go
length := L.GetTableLength(-1)
```
## Value Pushing
### PushNil()
@ -148,16 +164,82 @@ data := map[string]interface{}{
err := L.PushTable(data)
```
## Function Registration
## Table Operations
### RegisterGoFunction(name string, fn GoFunction) error
Registers a Go function that can be called from Lua.
### CreateTable(narr, nrec int)
Creates a new table with pre-allocated space.
```go
adder := func(s *State) int {
L.CreateTable(10, 5) // Space for 10 array elements, 5 records
```
### NewTable()
Creates a new empty table and pushes it onto the stack.
```go
L.NewTable()
```
### GetTable(index int)
Gets a table field (t[k]) where t is at the given index and k is at the top of the stack.
```go
L.PushString("key")
L.GetTable(-2) // Gets table["key"]
```
### SetTable(index int)
Sets a table field (t[k] = v) where t is at the given index, k is at -2, and v is at -1.
```go
L.PushString("key")
L.PushString("value")
L.SetTable(-3) // table["key"] = "value"
```
### GetField(index int, key string)
Gets a table field t[k] and pushes it onto the stack.
```go
L.GetField(-1, "name") // gets table.name
```
### SetField(index int, key string)
Sets a table field t[k] = v, where v is the value at the top of the stack.
```go
L.PushString("value")
L.SetField(-2, "key") // table.key = "value"
```
### Next(index int) bool
Pops a key from the stack and pushes the next key-value pair from the table.
```go
L.PushNil() // Start iteration
for L.Next(-2) {
// Stack now has key at -2 and value at -1
key := L.ToString(-2)
value := L.ToString(-1)
L.Pop(1) // Remove value, keep key for next iteration
}
```
## Function Registration and Calling
### GoFunction
Type definition for Go functions callable from Lua.
```go
type GoFunction func(*State) int
```
### PushGoFunction(fn GoFunction) error
Wraps a Go function and pushes it onto the Lua stack.
```go
adder := func(s *luajit.State) int {
sum := s.ToNumber(1) + s.ToNumber(2)
s.PushNumber(sum)
return 1
}
err := L.PushGoFunction(adder)
```
### RegisterGoFunction(name string, fn GoFunction) error
Registers a Go function as a global Lua function.
```go
err := L.RegisterGoFunction("add", adder)
```
@ -167,23 +249,45 @@ Removes a previously registered function.
L.UnregisterGoFunction("add")
```
## Package Management
### SetPackagePath(path string) error
Sets the Lua package.path variable.
### Call(nargs, nresults int) error
Calls a function with the given number of arguments and results.
```go
err := L.SetPackagePath("./?.lua;/usr/local/share/lua/5.1/?.lua")
L.GetGlobal("myfunction")
L.PushNumber(1)
L.PushNumber(2)
err := L.Call(2, 1) // Call with 2 args, expect 1 result
```
### AddPackagePath(path string) error
Adds a path to the existing package.path.
## Global Operations
### GetGlobal(name string)
Gets a global variable and pushes it onto the stack.
```go
err := L.AddPackagePath("./modules/?.lua")
L.GetGlobal("myGlobal")
```
### SetGlobal(name string)
Sets a global variable from the value at the top of the stack.
```go
L.PushNumber(42)
L.SetGlobal("answer") // answer = 42
```
## Code Execution
### DoString(str string) error
### LoadString(code string) error
Loads a Lua chunk from a string without executing it.
```go
err := L.LoadString("return 42")
```
### LoadFile(filename string) error
Loads a Lua chunk from a file without executing it.
```go
err := L.LoadFile("script.lua")
```
### DoString(code string) error
Executes a string of Lua code.
```go
err := L.DoString(`
@ -199,32 +303,76 @@ Executes a Lua file.
err := L.DoFile("script.lua")
```
## Table Operations
### GetField(index int, key string)
Gets a field from a table at the given index.
### Execute(code string) (int, error)
Executes a Lua string and returns the number of results left on the stack.
```go
L.GetField(-1, "name") // gets table.name
nresults, err := L.Execute("return 1, 2, 3")
// nresults would be 3
```
### SetField(index int, key string)
Sets a field in a table at the given index.
### ExecuteWithResult(code string) (interface{}, error)
Executes a Lua string and returns the first result.
```go
L.PushString("value")
L.SetField(-2, "key") // table.key = "value"
result, err := L.ExecuteWithResult("return 'hello'")
// result would be "hello"
```
### GetGlobal(name string)
Gets a global variable.
## Bytecode Operations
### CompileBytecode(code string, name string) ([]byte, error)
Compiles a Lua chunk to bytecode without executing it.
```go
L.GetGlobal("myGlobal")
bytecode, err := L.CompileBytecode("return 42", "test")
```
### SetGlobal(name string)
Sets a global variable from the value at the top of the stack.
### LoadBytecode(bytecode []byte, name string) error
Loads precompiled bytecode without executing it.
```go
L.PushNumber(42)
L.SetGlobal("answer") // answer = 42
err := L.LoadBytecode(bytecode, "test")
```
### RunBytecode() error
Executes previously loaded bytecode with 0 results.
```go
err := L.RunBytecode()
```
### RunBytecodeWithResults(nresults int) error
Executes bytecode and keeps nresults on the stack.
```go
err := L.RunBytecodeWithResults(1)
```
### LoadAndRunBytecode(bytecode []byte, name string) error
Loads and executes bytecode.
```go
err := L.LoadAndRunBytecode(bytecode, "test")
```
### LoadAndRunBytecodeWithResults(bytecode []byte, name string, nresults int) error
Loads and executes bytecode, preserving results.
```go
err := L.LoadAndRunBytecodeWithResults(bytecode, "test", 1)
```
### CompileAndRun(code string, name string) error
Compiles and immediately executes Lua code.
```go
err := L.CompileAndRun("answer = 42", "test")
```
## Package Path Operations
### SetPackagePath(path string) error
Sets the Lua package.path.
```go
err := L.SetPackagePath("./?.lua;/usr/local/share/lua/5.1/?.lua")
```
### AddPackagePath(path string) error
Adds a path to package.path.
```go
err := L.AddPackagePath("./modules/?.lua")
```
## Error Handling
@ -238,13 +386,21 @@ type LuaError struct {
}
```
### getStackTrace() string
### GetStackTrace() string
Gets the current Lua stack trace.
```go
trace := L.getStackTrace()
trace := L.GetStackTrace()
fmt.Println(trace)
```
### safeCall(f func() C.int) error
Internal function that wraps a potentially dangerous C call with stack checking.
```go
err := s.safeCall(func() C.int {
return C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
})
```
## Thread Safety Notes
- The function registry is thread-safe
@ -256,14 +412,13 @@ fmt.Println(trace)
Always pair state creation with cleanup:
```go
L := luajit.NewSafe()
L := luajit.New()
defer L.Close()
defer L.Cleanup()
```
Stack management in unsafe mode requires manual attention:
Stack management requires manual attention:
```go
L := luajit.New()
L.PushString("hello")
// ... use the string
L.Pop(1) // Clean up when done

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2025 Sky
Copyright (c) 2025 Sharkk, Skylear Johnson
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

142
README.md
View File

@ -1,6 +1,6 @@
# LuaJIT Go Wrapper
Hey there! This is a Go wrapper for LuaJIT that makes it easy to embed Lua in your Go applications. We've focused on making it both safe and fast, while keeping the API clean and intuitive.
This is a Go wrapper for LuaJIT that makes it easy to embed Lua in your Go applications. We've focused on making it both performant and developer-friendly, with an API that feels natural to use.
## What's This For?
@ -22,51 +22,13 @@ You'll need LuaJIT's development files, but don't worry - we include libraries f
Here's the simplest thing you can do:
```go
L := luajit.NewSafe()
L := luajit.New()
defer L.Close()
defer L.Cleanup()
err := L.DoString(`print("Hey from Lua!")`)
```
## Stack Safety: Choose Your Adventure
One of the key decisions you'll make is whether to use stack-safe mode. Here's what that means:
### Stack-Safe Mode (NewSafe())
```go
L := luajit.NewSafe()
```
Think of this as driving with guardrails. It's perfect when:
- You're new to Lua or embedding scripting languages
- You're writing a server or long-running application
- You want to handle untrusted Lua code
- You'd rather have slightly slower code than mysterious crashes
The safe mode will:
- Prevent stack overflows
- Check types more thoroughly
- Clean up after messy Lua code
- Give you better error messages
### Non-Stack-Safe Mode (New())
```go
L := luajit.New()
```
This is like taking off the training wheels. Use it when:
- You know exactly how your Lua code behaves
- You've profiled your application and need more speed
- You're doing lots of rapid, simple Lua calls
- You're writing performance-critical code
The unsafe mode:
- Skips most safety checks
- Runs noticeably faster
- Gives you direct control over the stack
- Can crash spectacularly if you make a mistake
Most applications should start with stack-safe mode and only switch to unsafe mode if profiling shows it's necessary.
## Working with Bytecode
Need even more performance? You can compile your Lua code to bytecode and reuse it:
@ -82,28 +44,30 @@ bytecode, err := L.CompileBytecode(`
// Execute many times
for i := 0; i < 1000; i++ {
err := L.LoadBytecode(bytecode, "calc")
err := L.LoadAndRunBytecode(bytecode, "calc")
}
// Or do both at once
err := L.CompileAndLoad(`return "hello"`, "greeting")
err := L.CompileAndRun(`return "hello"`, "greeting")
```
### When to Use Bytecode
Bytecode execution is consistently faster than direct execution:
- Simple operations: 20-60% faster
- String operations: Up to 60% speedup
- Loop-heavy code: 10-15% improvement
- Table operations: 10-15% faster
Some benchmark results on a typical system:
```
Operation Direct Exec Bytecode Exec
----------------------------------------
Simple Math 1.5M ops/sec 2.4M ops/sec
String Ops 370K ops/sec 600K ops/sec
Table Creation 127K ops/sec 146K ops/sec
Benchmark Ops/sec Comparison
----------------------------------------------------------------------------
BenchmarkSimpleDoString 2,561,012 Base
BenchmarkSimplePrecompiledBytecode 3,828,841 +49.5% faster
BenchmarkFunctionCallDoString 2,021,098 Base
BenchmarkFunctionCallPrecompiled 3,482,074 +72.3% faster
BenchmarkLoopDoString 188,119 Base
BenchmarkLoopPrecompiled 211,081 +12.2% faster
BenchmarkTableOperationsDoString 84,086 Base
BenchmarkTableOperationsPrecompiled 93,655 +11.4% faster
BenchmarkComplexScript 33,133 Base
BenchmarkComplexScriptPrecompiled 41,044 +23.9% faster
```
Use bytecode when you:
@ -114,7 +78,7 @@ Use bytecode when you:
## Registering Go Functions
Want to call Go code from Lua? Easy:
Want to call Go code from Lua? It's straightforward:
```go
// This function adds two numbers and returns the result
adder := func(s *luajit.State) int {
@ -151,26 +115,84 @@ result, err := L.ToTable(-1)
## Error Handling
We try to give you useful errors instead of mysterious panics:
We provide useful errors instead of mysterious panics:
```go
if err := L.DoString("this isn't valid Lua!"); err != nil {
if luaErr, ok := err.(*luajit.LuaError); ok {
fmt.Printf("Oops: %s\n", luaErr.Message)
fmt.Printf("Error: %s\n", luaErr.Message)
}
}
```
## A Few Tips
## Memory Management
- Always use those `defer L.Close()` and `defer L.Cleanup()` calls - they prevent memory leaks
The wrapper uses a custom table pooling system to reduce GC pressure when handling many tables:
```go
// Tables are pooled and reused internally for better performance
for i := 0; i < 1000; i++ {
L.GetGlobal("table")
table, _ := L.ToTable(-1)
// Use table...
L.Pop(1)
// Table is automatically returned to pool
}
```
## Best Practices
- Always use `defer L.Close()` and `defer L.Cleanup()` to prevent memory leaks
- Each Lua state should stick to one goroutine
- For concurrent stuff, create multiple states
- For concurrent operations, create multiple states
- You can share functions between states safely
- Keep an eye on your stack in unsafe mode - it won't clean up after itself
- Start with stack-safe mode and measure before optimizing
- Keep an eye on your stack management - pop as many items as you push
- Use bytecode for frequently executed code paths
- Consider compiling critical Lua code to bytecode at startup
## Advanced Features
### Bytecode Serialization
You can serialize bytecode for distribution or caching:
```go
// Compile once
bytecode, _ := L.CompileBytecode(complexScript, "module")
// Save to file
ioutil.WriteFile("module.luac", bytecode, 0644)
// Later, load from file
bytecode, _ := ioutil.ReadFile("module.luac")
L.LoadAndRunBytecode(bytecode, "module")
```
### Closures and Upvalues
Bytecode properly preserves closures and upvalues:
```go
code := `
local counter = 0
return function()
counter = counter + 1
return counter
end
`
bytecode, _ := L.CompileBytecode(code, "counter")
L.LoadAndRunBytecodeWithResults(bytecode, "counter", 1)
L.SetGlobal("increment")
// Later...
L.GetGlobal("increment")
L.Call(0, 1) // Returns 1
L.Pop(1)
L.GetGlobal("increment")
L.Call(0, 1) // Returns 2
```
## Need Help?
Check out the tests in the repository - they're full of examples. If you're stuck, open an issue! We're here to help.

430
bench/bench_test.go Normal file
View File

@ -0,0 +1,430 @@
package luajit_bench
import (
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// BenchmarkSimpleDoString benchmarks direct execution of a simple expression
func BenchmarkSimpleDoString(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := "local x = 1 + 1"
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.DoString(code); err != nil {
b.Fatalf("DoString failed: %v", err)
}
}
}
// BenchmarkSimpleCompileAndRun benchmarks compile and run of a simple expression
func BenchmarkSimpleCompileAndRun(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := "local x = 1 + 1"
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.CompileAndRun(code, "simple"); err != nil {
b.Fatalf("CompileAndRun failed: %v", err)
}
}
}
// BenchmarkSimpleCompileLoadRun benchmarks compile, load, and run of a simple expression
func BenchmarkSimpleCompileLoadRun(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := "local x = 1 + 1"
b.ResetTimer()
for i := 0; i < b.N; i++ {
bytecode, err := state.CompileBytecode(code, "simple")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
if err := state.LoadAndRunBytecode(bytecode, "simple"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
}
// BenchmarkSimplePrecompiledBytecode benchmarks running precompiled bytecode
func BenchmarkSimplePrecompiledBytecode(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := "local x = 1 + 1"
bytecode, err := state.CompileBytecode(code, "simple")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.LoadAndRunBytecode(bytecode, "simple"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
}
// BenchmarkFunctionCallDoString benchmarks direct execution of a function call
func BenchmarkFunctionCallDoString(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
// Setup function
setupCode := `
function add(a, b)
return a + b
end
`
if err := state.DoString(setupCode); err != nil {
b.Fatalf("Failed to set up function: %v", err)
}
code := "local result = add(10, 20)"
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.DoString(code); err != nil {
b.Fatalf("DoString failed: %v", err)
}
}
}
// BenchmarkFunctionCallPrecompiled benchmarks precompiled function call
func BenchmarkFunctionCallPrecompiled(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
// Setup function
setupCode := `
function add(a, b)
return a + b
end
`
if err := state.DoString(setupCode); err != nil {
b.Fatalf("Failed to set up function: %v", err)
}
code := "local result = add(10, 20)"
bytecode, err := state.CompileBytecode(code, "call")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.LoadAndRunBytecode(bytecode, "call"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
}
// BenchmarkLoopDoString benchmarks direct execution of a loop
func BenchmarkLoopDoString(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
local sum = 0
for i = 1, 1000 do
sum = sum + i
end
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.DoString(code); err != nil {
b.Fatalf("DoString failed: %v", err)
}
}
}
// BenchmarkLoopPrecompiled benchmarks precompiled loop execution
func BenchmarkLoopPrecompiled(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
local sum = 0
for i = 1, 1000 do
sum = sum + i
end
`
bytecode, err := state.CompileBytecode(code, "loop")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.LoadAndRunBytecode(bytecode, "loop"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
}
// BenchmarkTableOperationsDoString benchmarks direct execution of table operations
func BenchmarkTableOperationsDoString(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
local t = {}
for i = 1, 100 do
t[i] = i * 2
end
local sum = 0
for i, v in ipairs(t) do
sum = sum + v
end
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.DoString(code); err != nil {
b.Fatalf("DoString failed: %v", err)
}
}
}
// BenchmarkTableOperationsPrecompiled benchmarks precompiled table operations
func BenchmarkTableOperationsPrecompiled(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
local t = {}
for i = 1, 100 do
t[i] = i * 2
end
local sum = 0
for i, v in ipairs(t) do
sum = sum + v
end
`
bytecode, err := state.CompileBytecode(code, "table")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.LoadAndRunBytecode(bytecode, "table"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
}
// BenchmarkGoFunctionCall benchmarks calling a Go function from Lua
func BenchmarkGoFunctionCall(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
// Register a simple Go function
add := func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1
}
if err := state.RegisterGoFunction("add", add); err != nil {
b.Fatalf("RegisterGoFunction failed: %v", err)
}
code := "local result = add(10, 20)"
bytecode, err := state.CompileBytecode(code, "gofunc")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.LoadAndRunBytecode(bytecode, "gofunc"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
}
// BenchmarkComplexScript benchmarks a more complex script
func BenchmarkComplexScript(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
-- Define a simple class
local Class = {}
Class.__index = Class
function Class.new(x, y)
local self = setmetatable({}, Class)
self.x = x or 0
self.y = y or 0
return self
end
function Class:move(dx, dy)
self.x = self.x + dx
self.y = self.y + dy
return self
end
function Class:getPosition()
return self.x, self.y
end
-- Create instances and operate on them
local instances = {}
for i = 1, 50 do
instances[i] = Class.new(i, i*2)
end
local result = 0
for i, obj in ipairs(instances) do
obj:move(i, -i)
local x, y = obj:getPosition()
result = result + x + y
end
return result
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := state.ExecuteWithResult(code); err != nil {
b.Fatalf("ExecuteWithResult failed: %v", err)
}
}
}
// BenchmarkComplexScriptPrecompiled benchmarks a precompiled complex script
func BenchmarkComplexScriptPrecompiled(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
-- Define a simple class
local Class = {}
Class.__index = Class
function Class.new(x, y)
local self = setmetatable({}, Class)
self.x = x or 0
self.y = y or 0
return self
end
function Class:move(dx, dy)
self.x = self.x + dx
self.y = self.y + dy
return self
end
function Class:getPosition()
return self.x, self.y
end
-- Create instances and operate on them
local instances = {}
for i = 1, 50 do
instances[i] = Class.new(i, i*2)
end
local result = 0
for i, obj in ipairs(instances) do
obj:move(i, -i)
local x, y = obj:getPosition()
result = result + x + y
end
return result
`
bytecode, err := state.CompileBytecode(code, "complex")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.LoadBytecode(bytecode, "complex"); err != nil {
b.Fatalf("LoadBytecode failed: %v", err)
}
if err := state.RunBytecodeWithResults(1); err != nil { // Assuming this method exists to get the return value
b.Fatalf("RunBytecodeWithResults failed: %v", err)
}
state.Pop(1) // Pop the result
}
}
// BenchmarkMultipleExecutions benchmarks executing the same bytecode multiple times
func BenchmarkMultipleExecutions(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
// Setup a stateful environment
setupCode := `
counter = 0
function increment(amount)
counter = counter + (amount or 1)
return counter
end
`
if err := state.DoString(setupCode); err != nil {
b.Fatalf("Failed to set up environment: %v", err)
}
// Compile the function call
code := "return increment(5)"
bytecode, err := state.CompileBytecode(code, "increment")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := state.LoadBytecode(bytecode, "increment"); err != nil {
b.Fatalf("LoadBytecode failed: %v", err)
}
if err := state.RunBytecodeWithResults(1); err != nil { // Assuming this method exists
b.Fatalf("RunBytecodeWithResults failed: %v", err)
}
state.Pop(1) // Pop the result
}
}

133
bench/ezbench_test.go Normal file
View File

@ -0,0 +1,133 @@
package luajit_bench
import (
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
var benchCases = []struct {
name string
code string
}{
{
name: "SimpleAddition",
code: `return 1 + 1`,
},
{
name: "LoopSum",
code: `
local sum = 0
for i = 1, 1000 do
sum = sum + i
end
return sum
`,
},
{
name: "FunctionCall",
code: `
local result = 0
for i = 1, 100 do
result = result + i
end
return result
`,
},
{
name: "TableCreation",
code: `
local t = {}
for i = 1, 100 do
t[i] = i * 2
end
return t[50]
`,
},
{
name: "StringOperations",
code: `
local s = "hello"
for i = 1, 10 do
s = s .. " world"
end
return #s
`,
},
}
func BenchmarkLuaDirectExecution(b *testing.B) {
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
L := luajit.New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
// First verify we can execute the code
if err := L.DoString(bc.code); err != nil {
b.Fatalf("Failed to execute test code: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Execute string and get results
nresults, err := L.Execute(bc.code)
if err != nil {
b.Fatalf("Failed to execute code: %v", err)
}
L.Pop(nresults) // Clean up any results
}
})
}
}
func BenchmarkLuaBytecodeExecution(b *testing.B) {
// First compile all bytecode
bytecodes := make(map[string][]byte)
for _, bc := range benchCases {
L := luajit.New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Cleanup()
bytecode, err := L.CompileBytecode(bc.code, bc.name)
if err != nil {
L.Close()
b.Fatalf("Error compiling bytecode for %s: %v", bc.name, err)
}
bytecodes[bc.name] = bytecode
L.Close()
}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
L := luajit.New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
bytecode := bytecodes[bc.name]
// First verify we can execute the bytecode
if err := L.LoadAndRunBytecodeWithResults(bytecode, bc.name, 1); err != nil {
b.Fatalf("Failed to execute test bytecode: %v", err)
}
L.Pop(1) // Clean up the result
b.ResetTimer()
b.SetBytes(int64(len(bytecode))) // Track bytecode size in benchmarks
for i := 0; i < b.N; i++ {
if err := L.LoadAndRunBytecode(bytecode, bc.name); err != nil {
b.Fatalf("Error executing bytecode: %v", err)
}
}
})
}
}

View File

@ -21,7 +21,7 @@ static const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
return (const char *)r->buf;
}
static int load_bytecode_chunk(lua_State *L, const unsigned char *buf, size_t len, const char *name) {
static int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char *name) {
BytecodeReader reader = {buf, len, name};
return lua_load(L, bytecode_reader, &reader, name);
}
@ -29,21 +29,37 @@ static int load_bytecode_chunk(lua_State *L, const unsigned char *buf, size_t le
typedef struct {
unsigned char *buf;
size_t len;
size_t capacity;
} BytecodeWriter;
int bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
static int bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
BytecodeWriter *w = (BytecodeWriter *)ud;
unsigned char *newbuf;
(void)L; // unused
newbuf = (unsigned char *)realloc(w->buf, w->len + sz);
if (newbuf == NULL) return 1;
// Check if we need to reallocate
if (w->len + sz > w->capacity) {
size_t new_capacity = w->capacity * 2;
if (new_capacity < w->len + sz) {
new_capacity = w->len + sz;
}
memcpy(newbuf + w->len, p, sz);
w->buf = newbuf;
newbuf = (unsigned char *)realloc(w->buf, new_capacity);
if (newbuf == NULL) return 1;
w->buf = newbuf;
w->capacity = new_capacity;
}
memcpy(w->buf + w->len, p, sz);
w->len += sz;
return 0;
}
// Wrapper function that calls lua_dump with bytecode_writer
static int dump_lua_function(lua_State *L, BytecodeWriter *w) {
return lua_dump(L, bytecode_writer, w);
}
*/
import "C"
import (
@ -53,89 +69,108 @@ import (
// CompileBytecode compiles a Lua chunk to bytecode without executing it
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
// First load the string but don't execute it
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
if C.luaL_loadstring(s.L, ccode) != 0 {
err := &LuaError{
Code: int(C.lua_status(s.L)),
Message: s.ToString(-1),
}
s.Pop(1)
if err := s.LoadString(code); err != nil {
return nil, fmt.Errorf("failed to load string: %w", err)
}
// Set up writer
// Set up writer with initial capacity
var writer C.BytecodeWriter
writer.buf = nil
writer.len = 0
writer.capacity = 0
// Initial allocation with a reasonable size
const initialSize = 4096
writer.buf = (*C.uchar)(C.malloc(initialSize))
if writer.buf == nil {
s.Pop(1) // Remove the loaded function
return nil, fmt.Errorf("failed to allocate memory for bytecode")
}
writer.capacity = initialSize
// Dump the function to bytecode
if C.lua_dump(s.L, (*[0]byte)(C.bytecode_writer), unsafe.Pointer(&writer)) != 0 {
if writer.buf != nil {
C.free(unsafe.Pointer(writer.buf))
}
s.Pop(1)
return nil, fmt.Errorf("failed to dump bytecode")
}
err := s.safeCall(func() C.int {
return C.dump_lua_function(s.L, (*C.BytecodeWriter)(unsafe.Pointer(&writer)))
})
// Copy to Go slice
// Copy bytecode to Go slice regardless of the result
bytecode := C.GoBytes(unsafe.Pointer(writer.buf), C.int(writer.len))
// Clean up
if writer.buf != nil {
C.free(unsafe.Pointer(writer.buf))
C.free(unsafe.Pointer(writer.buf))
s.Pop(1) // Remove the function from stack
if err != nil {
return nil, fmt.Errorf("failed to dump bytecode: %w", err)
}
s.Pop(1) // Remove the function
return bytecode, nil
}
// LoadBytecode loads precompiled bytecode and executes it
// LoadBytecode loads precompiled bytecode without executing it
func (s *State) LoadBytecode(bytecode []byte, name string) error {
if len(bytecode) == 0 {
return fmt.Errorf("empty bytecode")
}
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
// Load the bytecode
status := C.load_bytecode_chunk(
s.L,
(*C.uchar)(unsafe.Pointer(&bytecode[0])),
C.size_t(len(bytecode)),
cname,
)
err := s.safeCall(func() C.int {
return C.load_bytecode(
s.L,
(*C.uchar)(unsafe.Pointer(&bytecode[0])),
C.size_t(len(bytecode)),
cname,
)
})
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1)
if err != nil {
return fmt.Errorf("failed to load bytecode: %w", err)
}
// Execute the loaded chunk
if err := s.safeCall(func() C.int {
return C.lua_pcall(s.L, 0, 0, 0)
}); err != nil {
return fmt.Errorf("failed to execute bytecode: %w", err)
}
return nil
}
// Helper function to compile and immediately load/execute bytecode
func (s *State) CompileAndLoad(code string, name string) error {
// RunBytecode executes previously loaded bytecode with 0 results
func (s *State) RunBytecode() error {
return s.RunBytecodeWithResults(0)
}
// RunBytecodeWithResults executes bytecode and keeps nresults on the stack
// Use LUA_MULTRET (-1) to keep all results
func (s *State) RunBytecodeWithResults(nresults int) error {
return s.safeCall(func() C.int {
return C.lua_pcall(s.L, 0, C.int(nresults), 0)
})
}
// LoadAndRunBytecode loads and executes bytecode
func (s *State) LoadAndRunBytecode(bytecode []byte, name string) error {
if err := s.LoadBytecode(bytecode, name); err != nil {
return err
}
return s.RunBytecode()
}
// LoadAndRunBytecodeWithResults loads and executes bytecode, preserving results
func (s *State) LoadAndRunBytecodeWithResults(bytecode []byte, name string, nresults int) error {
if err := s.LoadBytecode(bytecode, name); err != nil {
return err
}
return s.RunBytecodeWithResults(nresults)
}
// CompileAndRun compiles and immediately executes Lua code
func (s *State) CompileAndRun(code string, name string) error {
bytecode, err := s.CompileBytecode(code, name)
if err != nil {
return fmt.Errorf("compile error: %w", err)
}
if err := s.LoadBytecode(bytecode, name); err != nil {
return fmt.Errorf("load error: %w", err)
if err := s.LoadAndRunBytecode(bytecode, name); err != nil {
return fmt.Errorf("execution error: %w", err)
}
return nil

View File

@ -1,162 +0,0 @@
package luajit
import (
"fmt"
"testing"
)
func TestBytecodeCompilation(t *testing.T) {
tests := []struct {
name string
code string
wantErr bool
}{
{
name: "simple assignment",
code: "x = 42",
wantErr: false,
},
{
name: "function definition",
code: "function add(a,b) return a+b end",
wantErr: false,
},
{
name: "syntax error",
code: "function bad syntax",
wantErr: true,
},
}
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
bytecode, err := L.CompileBytecode(tt.code, "test")
if (err != nil) != tt.wantErr {
t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if len(bytecode) == 0 {
t.Error("CompileBytecode() returned empty bytecode")
}
}
}
}
func TestBytecodeExecution(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Compile some test code
code := `
function add(a, b)
return a + b
end
result = add(40, 2)
`
bytecode, err := L.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode() error = %v", err)
}
// Load and execute the bytecode
if err := L.LoadBytecode(bytecode, "test"); err != nil {
t.Fatalf("LoadBytecode() error = %v", err)
}
// Verify the result
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("got result = %v, want 42", result)
}
}
func TestInvalidBytecode(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Test with invalid bytecode
invalidBytecode := []byte("this is not valid bytecode")
if err := L.LoadBytecode(invalidBytecode, "test"); err == nil {
t.Error("LoadBytecode() expected error with invalid bytecode")
}
}
func TestBytecodeRoundTrip(t *testing.T) {
tests := []struct {
name string
code string
check func(*State) error
}{
{
name: "global variable",
code: "x = 42",
check: func(L *State) error {
L.GetGlobal("x")
if x := L.ToNumber(-1); x != 42 {
return fmt.Errorf("got x = %v, want 42", x)
}
return nil
},
},
{
name: "function definition",
code: "function test() return 'hello' end",
check: func(L *State) error {
if err := L.DoString("result = test()"); err != nil {
return err
}
L.GetGlobal("result")
if s := L.ToString(-1); s != "hello" {
return fmt.Errorf("got result = %q, want 'hello'", s)
}
return nil
},
},
}
for _, tt := range tests {
// First state for compilation
L1 := New()
if L1 == nil {
t.Fatal("Failed to create first Lua state")
}
defer L1.Close()
// Compile the code
bytecode, err := L1.CompileBytecode(tt.code, "test")
if err != nil {
t.Fatalf("CompileBytecode() error = %v", err)
}
// Second state for execution
L2 := New()
if L2 == nil {
t.Fatal("Failed to create second Lua state")
}
defer L2.Close()
// Load and execute the bytecode
if err := L2.LoadBytecode(bytecode, "test"); err != nil {
t.Fatalf("LoadBytecode() error = %v", err)
}
// Run the check function
if err := tt.check(L2); err != nil {
t.Errorf("check failed: %v", err)
}
}
}

70
example/main.go Normal file
View File

@ -0,0 +1,70 @@
package main
import (
"fmt"
"log"
"os"
"path/filepath"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func main() {
if len(os.Args) < 2 {
fmt.Println("Usage: go run main.go script.lua")
os.Exit(1)
}
scriptPath := os.Args[1]
// Create a new Lua state
L := luajit.New()
if L == nil {
log.Fatal("Failed to create Lua state")
}
defer L.Close()
// Register a Go function to be called from Lua
L.RegisterGoFunction("printFromGo", func(s *luajit.State) int {
msg := s.ToString(1) // Get first argument
fmt.Printf("Go received from Lua: %s\n", msg)
// Return a value to Lua
s.PushString("Hello from Go!")
return 1 // Number of return values
})
// Add some values to the Lua environment
L.PushValue(map[string]interface{}{
"appName": "LuaJIT Example",
"version": 1.0,
"features": []float64{1, 2, 3},
})
L.SetGlobal("config")
// Get the directory of the script to properly handle requires
dir := filepath.Dir(scriptPath)
L.AddPackagePath(filepath.Join(dir, "?.lua"))
// Execute the script
fmt.Printf("Running Lua script: %s\n", scriptPath)
if err := L.DoFile(scriptPath); err != nil {
log.Fatalf("Error executing script: %v", err)
}
// Call a Lua function and get its result
L.GetGlobal("getResult")
if L.IsFunction(-1) {
if err := L.Call(0, 1); err != nil {
log.Fatalf("Error calling Lua function: %v", err)
}
result, err := L.ToValue(-1)
if err != nil {
log.Fatalf("Error converting Lua result: %v", err)
}
fmt.Printf("Result from Lua: %v\n", result)
L.Pop(1) // Clean up the result
}
}

35
example/script.lua Normal file
View File

@ -0,0 +1,35 @@
-- Example Lua script to demonstrate Go-Lua integration
-- Access the config table passed from Go
print("Script started")
print("App name:", config.appName)
print("Version:", config.version)
print("Features:", table.concat(config.features, ", "))
-- Call the Go function
local response = printFromGo("Hello from Lua!")
print("Response from Go:", response)
-- Function that will be called from Go
function getResult()
local result = {
status = "success",
calculations = {
sum = 10 + 20,
product = 5 * 7
},
message = "Calculation completed"
}
return result
end
-- Load external module (if available)
local success, utils = pcall(require, "utils")
if success then
print("Utils module loaded")
utils.doSomething()
else
print("Utils module not available:", utils)
end
print("Script completed")

19
example/utils.lua Normal file
View File

@ -0,0 +1,19 @@
-- Optional utility module
local utils = {}
function utils.doSomething()
print("Utils module function called")
return true
end
function utils.calculate(a, b)
return {
sum = a + b,
difference = a - b,
product = a * b,
quotient = a / b
}
end
return utils

View File

@ -7,8 +7,9 @@ package luajit
extern int goFunctionWrapper(lua_State* L);
// Helper function to access upvalues
static int get_upvalue_index(int i) {
return -10002 - i; // LUA_GLOBALSINDEX - i
return lua_upvalueindex(i);
}
*/
import "C"
@ -18,9 +19,11 @@ import (
"unsafe"
)
// GoFunction defines the signature for Go functions callable from Lua
type GoFunction func(*State) int
var (
// functionRegistry stores all registered Go functions
functionRegistry = struct {
sync.RWMutex
funcs map[unsafe.Pointer]GoFunction
@ -33,10 +36,10 @@ var (
func goFunctionWrapper(L *C.lua_State) C.int {
state := &State{L: L}
// Get upvalue using standard Lua 5.1 macro
// Get function pointer from the first upvalue
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
if ptr == nil {
state.PushString("error: function not found")
state.PushString("error: function pointer not found")
return -1
}
@ -49,49 +52,56 @@ func goFunctionWrapper(L *C.lua_State) C.int {
return -1
}
result := fn(state)
return C.int(result)
// Call the Go function
return C.int(fn(state))
}
// PushGoFunction wraps a Go function and pushes it onto the Lua stack
func (s *State) PushGoFunction(fn GoFunction) error {
// Allocate a pointer to use as the function key
ptr := C.malloc(1)
if ptr == nil {
return fmt.Errorf("failed to allocate memory for function pointer")
}
// Register the function
functionRegistry.Lock()
functionRegistry.funcs[ptr] = fn
functionRegistry.Unlock()
// Push the pointer as lightuserdata (first upvalue)
C.lua_pushlightuserdata(s.L, ptr)
// Create closure with the C wrapper and the upvalue
C.lua_pushcclosure(s.L, (*[0]byte)(C.goFunctionWrapper), 1)
return nil
}
// RegisterGoFunction registers a Go function as a global Lua function
func (s *State) RegisterGoFunction(name string, fn GoFunction) error {
if err := s.PushGoFunction(fn); err != nil {
return err
}
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
s.SetGlobal(name)
return nil
}
// UnregisterGoFunction removes a global function
func (s *State) UnregisterGoFunction(name string) {
s.PushNil()
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
s.SetGlobal(name)
}
// Cleanup frees all function pointers and clears the registry
func (s *State) Cleanup() {
functionRegistry.Lock()
defer functionRegistry.Unlock()
// Free all allocated pointers
for ptr := range functionRegistry.funcs {
C.free(ptr)
delete(functionRegistry.funcs, ptr)
}
functionRegistry.funcs = make(map[unsafe.Pointer]GoFunction)
}

View File

@ -1,107 +0,0 @@
package luajit
import (
"testing"
)
func TestGoFunctions(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
addFunc := func(s *State) int {
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
return 1
}
if err := L.RegisterGoFunction("add", addFunc); err != nil {
t.Fatalf("Failed to register function: %v", err)
}
// Test basic function call
if err := L.DoString("result = add(40, 2)"); err != nil {
t.Fatalf("Failed to call function: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("got %v, want 42", result)
}
L.Pop(1)
// Test multiple return values
multiFunc := func(s *State) int {
s.PushString("hello")
s.PushNumber(42)
s.PushBoolean(true)
return 3
}
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
t.Fatalf("Failed to register multi function: %v", err)
}
code := `
a, b, c = multi()
result = (a == "hello" and b == 42 and c == true)
`
if err := L.DoString(code); err != nil {
t.Fatalf("Failed to call multi function: %v", err)
}
L.GetGlobal("result")
if !L.ToBoolean(-1) {
t.Error("Multiple return values test failed")
}
L.Pop(1)
// Test error handling
errFunc := func(s *State) int {
s.PushString("test error")
return -1
}
if err := L.RegisterGoFunction("err", errFunc); err != nil {
t.Fatalf("Failed to register error function: %v", err)
}
if err := L.DoString("err()"); err == nil {
t.Error("Expected error from error function")
}
// Test unregistering
L.UnregisterGoFunction("add")
if err := L.DoString("add(1, 2)"); err == nil {
t.Error("Expected error calling unregistered function")
}
}
func TestStackSafety(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
// Test stack overflow protection
overflowFunc := func(s *State) int {
for i := 0; i < 100; i++ {
s.PushNumber(float64(i))
}
s.PushString("done")
return 101
}
if err := L.RegisterGoFunction("overflow", overflowFunc); err != nil {
t.Fatal(err)
}
if err := L.DoString("overflow()"); err != nil {
t.Logf("Got expected error: %v", err)
}
}

View File

@ -25,84 +25,23 @@ const (
LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table
)
// checkStack ensures there is enough space on the Lua stack
func (s *State) checkStack(n int) error {
if C.lua_checkstack(s.L, C.int(n)) == 0 {
return fmt.Errorf("stack overflow (cannot allocate %d slots)", n)
}
return nil
}
// safeCall wraps a potentially dangerous C call with stack checking
func (s *State) safeCall(f func() C.int) error {
// Save current stack size
top := s.GetTop()
// Ensure we have enough stack space (minimum 20 slots as per Lua standard)
if err := s.checkStack(LUA_MINSTACK); err != nil {
return err
// GetStackTrace returns the current Lua stack trace
func (s *State) GetStackTrace() string {
s.GetGlobal("debug")
if !s.IsTable(-1) {
s.Pop(1)
return "debug table not available"
}
// Make the call
status := f()
// Check for errors
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
s.GetField(-1, "traceback")
if !s.IsFunction(-1) {
s.Pop(2) // Remove debug table and non-function
return "debug.traceback not available"
}
// For lua_pcall, the function and arguments are popped before results are pushed
// So we don't consider it an underflow if the new top is less than the original
if status == 0 && s.GetType(-1) == TypeFunction {
// If we still have a function on the stack, restore original size
s.SetTop(top)
}
s.Call(0, 1)
trace := s.ToString(-1)
s.Pop(1) // Remove the trace
return nil
}
// stackGuard wraps a function with stack checking
func stackGuard[T any](s *State, f func() (T, error)) (T, error) {
// Save current stack size
top := s.GetTop()
defer func() {
// Only restore if stack is larger than original
if s.GetTop() > top {
s.SetTop(top)
}
}()
// Run the protected function
return f()
}
// stackGuardValue executes a function with stack protection
func stackGuardValue[T any](s *State, f func() (T, error)) (T, error) {
return stackGuard(s, f)
}
// stackGuardErr executes a function that only returns an error with stack protection
func stackGuardErr(s *State, f func() error) error {
// Save current stack size
top := s.GetTop()
defer func() {
// Only restore if stack is larger than original
if s.GetTop() > top {
s.SetTop(top)
}
}()
// Run the protected function
return f()
}
// getStackTrace returns the current Lua stack trace
func (s *State) getStackTrace() string {
// Same implementation...
return ""
return trace
}

105
table.go
View File

@ -6,29 +6,29 @@ package luajit
#include <lauxlib.h>
#include <stdlib.h>
static int get_table_length(lua_State *L, int index) {
static size_t get_table_length(lua_State *L, int index) {
return lua_objlen(L, index);
}
*/
import "C"
import (
"fmt"
"strconv"
"sync"
)
// Use a pool to reduce GC pressure when handling many tables
var tablePool = sync.Pool{
New: func() interface{} {
return make(map[string]interface{})
},
}
// TableValue represents any value that can be stored in a Lua table
type TableValue interface {
~string | ~float64 | ~bool | ~int | ~map[string]interface{} | ~[]float64 | ~[]interface{}
// GetTableLength returns the length of a table at the given index
func (s *State) GetTableLength(index int) int {
return int(C.get_table_length(s.L, C.int(index)))
}
func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) }
// getTableFromPool gets a map from the pool and ensures it's empty
func getTableFromPool() map[string]interface{} {
table := tablePool.Get().(map[string]interface{})
@ -46,46 +46,68 @@ func putTableToPool(table map[string]interface{}) {
// PushTable pushes a Go map onto the Lua stack as a table
func (s *State) PushTable(table map[string]interface{}) error {
// Create table with appropriate capacity hints
s.CreateTable(0, len(table))
// Add each key-value pair
for k, v := range table {
// Push key
s.PushString(k)
// Push value
if err := s.PushValue(v); err != nil {
return err
}
s.SetField(-2, k)
// t[k] = v
s.SetTable(-3)
}
// If this is a pooled table, return it
if _, hasEmptyKey := table[""]; len(table) == 1 && hasEmptyKey {
// Return pooled tables to the pool
if isPooledTable(table) {
putTableToPool(table)
}
return nil
}
// ToTable converts a Lua table to a Go map
// isPooledTable detects if a table came from our pool
func isPooledTable(table map[string]interface{}) bool {
// Check for our special marker - used for array tables in the pool
_, hasEmptyKey := table[""]
return len(table) == 1 && hasEmptyKey
}
// ToTable converts a Lua table at the given index to a Go map
func (s *State) ToTable(index int) (map[string]interface{}, error) {
absIdx := s.absIndex(index)
table := getTableFromPool()
if !s.IsTable(absIdx) {
return nil, fmt.Errorf("value at index %d is not a table", index)
}
// Check if it's an array-like table
// Try to detect array-like tables first
length := s.GetTableLength(absIdx)
if length > 0 {
array := make([]float64, length)
// Check if this is an array-like table
isArray := true
array := make([]float64, length)
// Try to convert to array
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(absIdx)
if s.GetType(-1) != TypeNumber {
if !s.IsNumber(-1) {
isArray = false
s.Pop(1)
break
}
array[i-1] = s.ToNumber(-1)
s.Pop(1)
}
if isArray {
putTableToPool(table) // Return unused table to pool
// Return array as a special pooled table with empty key
result := getTableFromPool()
result[""] = array
return result, nil
@ -93,24 +115,36 @@ func (s *State) ToTable(index int) (map[string]interface{}, error) {
}
// Handle regular table
s.PushNil()
for C.lua_next(s.L, C.int(absIdx)) != 0 {
key := ""
valueType := C.lua_type(s.L, -2)
if valueType == C.LUA_TSTRING {
table := getTableFromPool()
// Iterate through all key-value pairs
s.PushNil() // Start iteration with nil key
for s.Next(absIdx) {
// Stack now has key at -2 and value at -1
// Convert key to string
var key string
keyType := s.GetType(-2)
switch keyType {
case TypeString:
key = s.ToString(-2)
} else if valueType == C.LUA_TNUMBER {
case TypeNumber:
key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64)
default:
// Skip non-string/non-number keys
s.Pop(1) // Pop value, leave key for next iteration
continue
}
// Convert and store the value
value, err := s.ToValue(-1)
if err != nil {
s.Pop(1)
putTableToPool(table) // Return table to pool on error
s.Pop(2) // Pop both key and value
putTableToPool(table) // Return the table to the pool on error
return nil, err
}
// Handle nested array case
// Handle nested array tables
if m, ok := value.(map[string]interface{}); ok {
if arr, ok := m[""]; ok {
value = arr
@ -118,27 +152,8 @@ func (s *State) ToTable(index int) (map[string]interface{}, error) {
}
table[key] = value
s.Pop(1)
s.Pop(1) // Pop value, leave key for next iteration
}
return table, nil
}
// NewTable creates a new table and pushes it onto the stack
func (s *State) NewTable() {
C.lua_createtable(s.L, 0, 0)
}
// SetTable sets a table field
func (s *State) SetTable(index int) {
C.lua_settable(s.L, C.int(index))
}
// GetTable gets a table field
func (s *State) GetTable(index int) {
C.lua_gettable(s.L, C.int(index))
}
func (s *State) CreateTable(narr, nrec int) {
C.lua_createtable(s.L, C.int(narr), C.int(nrec))
}

View File

@ -1,93 +0,0 @@
package luajit
import (
"math"
"testing"
)
func TestTableOperations(t *testing.T) {
tests := []struct {
name string
data map[string]interface{}
}{
{
name: "empty",
data: map[string]interface{}{},
},
{
name: "primitives",
data: map[string]interface{}{
"str": "hello",
"num": 42.0,
"bool": true,
"array": []float64{1.1, 2.2, 3.3},
},
},
{
name: "nested",
data: map[string]interface{}{
"nested": map[string]interface{}{
"value": 123.0,
"array": []float64{4.4, 5.5},
},
},
},
}
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
if err := L.PushTable(tt.data); err != nil {
t.Fatalf("PushTable() error = %v", err)
}
got, err := L.ToTable(-1)
if err != nil {
t.Fatalf("ToTable() error = %v", err)
}
if !tablesEqual(got, tt.data) {
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
}
}
}
func tablesEqual(a, b map[string]interface{}) bool {
if len(a) != len(b) {
return false
}
for k, v1 := range a {
v2, ok := b[k]
if !ok {
return false
}
switch v1 := v1.(type) {
case map[string]interface{}:
v2, ok := v2.(map[string]interface{})
if !ok || !tablesEqual(v1, v2) {
return false
}
case []float64:
v2, ok := v2.([]float64)
if !ok || len(v1) != len(v2) {
return false
}
for i := range v1 {
if math.Abs(v1[i]-v2[i]) > 1e-10 {
return false
}
}
default:
if v1 != v2 {
return false
}
}
}
return true
}

443
tests/bytecode_test.go Normal file
View File

@ -0,0 +1,443 @@
package luajit_test
import (
"bytes"
"errors"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// TestCompileBytecode tests basic bytecode compilation
func TestCompileBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
code := "return 42"
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
if len(bytecode) == 0 {
t.Fatal("Expected non-empty bytecode")
}
}
// TestLoadBytecode tests loading precompiled bytecode
func TestLoadBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// First compile some bytecode
code := "answer = 42"
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Then load it
err = state.LoadBytecode(bytecode, "test")
if err != nil {
t.Fatalf("LoadBytecode failed: %v", err)
}
// Verify a function is on the stack
if !state.IsFunction(-1) {
t.Fatal("Expected function at top of stack after LoadBytecode")
}
// Pop the function
state.Pop(1)
}
// TestRunBytecode tests running previously loaded bytecode
func TestRunBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// First compile and load bytecode
code := "answer = 42"
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
err = state.LoadBytecode(bytecode, "test")
if err != nil {
t.Fatalf("LoadBytecode failed: %v", err)
}
// Run the bytecode
err = state.RunBytecode()
if err != nil {
t.Fatalf("RunBytecode failed: %v", err)
}
// Verify the code has executed correctly
state.GetGlobal("answer")
if !state.IsNumber(-1) || state.ToNumber(-1) != 42 {
t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1))
}
state.Pop(1)
}
// TestLoadAndRunBytecode tests the combined load and run functionality
func TestLoadAndRunBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Compile bytecode
code := "answer = 42"
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Load and run in one step
err = state.LoadAndRunBytecode(bytecode, "test")
if err != nil {
t.Fatalf("LoadAndRunBytecode failed: %v", err)
}
// Verify execution
state.GetGlobal("answer")
if !state.IsNumber(-1) || state.ToNumber(-1) != 42 {
t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1))
}
state.Pop(1)
}
// TestCompileAndRun tests compile and run functionality
func TestCompileAndRun(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Compile and run in one step
code := "answer = 42"
err := state.CompileAndRun(code, "test")
if err != nil {
t.Fatalf("CompileAndRun failed: %v", err)
}
// Verify execution
state.GetGlobal("answer")
if !state.IsNumber(-1) || state.ToNumber(-1) != 42 {
t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1))
}
state.Pop(1)
}
// TestEmptyBytecode tests error handling for empty bytecode
func TestEmptyBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Try to load empty bytecode
err := state.LoadBytecode([]byte{}, "empty")
if err == nil {
t.Fatal("Expected error for empty bytecode, got nil")
}
}
// TestInvalidBytecode tests error handling for invalid bytecode
func TestInvalidBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create some invalid bytecode
invalidBytecode := []byte("not valid bytecode")
// Try to load invalid bytecode
err := state.LoadBytecode(invalidBytecode, "invalid")
if err == nil {
t.Fatal("Expected error for invalid bytecode, got nil")
}
}
// TestBytecodeSerialization tests serializing and deserializing bytecode
func TestBytecodeSerialization(t *testing.T) {
// First state to compile
state1 := luajit.New()
if state1 == nil {
t.Fatal("Failed to create first Lua state")
}
defer state1.Close()
// Compile bytecode
code := `
function add(a, b)
return a + b
end
result = add(10, 20)
`
bytecode, err := state1.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Second state to execute
state2 := luajit.New()
if state2 == nil {
t.Fatal("Failed to create second Lua state")
}
defer state2.Close()
// Load and run the bytecode in the second state
err = state2.LoadAndRunBytecode(bytecode, "test")
if err != nil {
t.Fatalf("LoadAndRunBytecode failed: %v", err)
}
// Verify execution
state2.GetGlobal("result")
if !state2.IsNumber(-1) || state2.ToNumber(-1) != 30 {
t.Fatalf("Expected result to be 30, got %v", state2.ToNumber(-1))
}
state2.Pop(1)
// Call the function to verify it was properly transferred
state2.GetGlobal("add")
if !state2.IsFunction(-1) {
t.Fatal("Expected add to be a function")
}
state2.PushNumber(5)
state2.PushNumber(7)
if err := state2.Call(2, 1); err != nil {
t.Fatalf("Failed to call function: %v", err)
}
if state2.ToNumber(-1) != 12 {
t.Fatalf("Expected add(5, 7) to return 12, got %v", state2.ToNumber(-1))
}
state2.Pop(1)
}
// TestCompilationError tests error handling for compilation errors
func TestCompilationError(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Invalid Lua code that should fail to compile
code := "function without end"
// Try to compile
_, err := state.CompileBytecode(code, "invalid")
if err == nil {
t.Fatal("Expected compilation error, got nil")
}
// Check error type
var luaErr *luajit.LuaError
if !errors.As(err, &luaErr) {
t.Fatalf("Expected error to wrap *luajit.LuaError, got %T", err)
}
}
// TestExecutionError tests error handling for runtime errors
func TestExecutionError(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Code that compiles but fails at runtime
code := "error('deliberate error')"
// Compile bytecode
bytecode, err := state.CompileBytecode(code, "error")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Try to execute
err = state.LoadAndRunBytecode(bytecode, "error")
if err == nil {
t.Fatal("Expected execution error, got nil")
}
// Check error type
if _, ok := err.(*luajit.LuaError); !ok {
t.Fatalf("Expected *luajit.LuaError, got %T", err)
}
}
// TestBytecodeEquivalence tests that bytecode execution produces the same results as direct execution
func TestBytecodeEquivalence(t *testing.T) {
code := `
local result = 0
for i = 1, 10 do
result = result + i
end
return result
`
// First, execute directly
state1 := luajit.New()
if state1 == nil {
t.Fatal("Failed to create first Lua state")
}
defer state1.Close()
directResult, err := state1.ExecuteWithResult(code)
if err != nil {
t.Fatalf("ExecuteWithResult failed: %v", err)
}
// Then, compile and execute bytecode
state2 := luajit.New()
if state2 == nil {
t.Fatal("Failed to create second Lua state")
}
defer state2.Close()
bytecode, err := state2.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
err = state2.LoadBytecode(bytecode, "test")
if err != nil {
t.Fatalf("LoadBytecode failed: %v", err)
}
err = state2.Call(0, 1)
if err != nil {
t.Fatalf("Call failed: %v", err)
}
bytecodeResult, err := state2.ToValue(-1)
if err != nil {
t.Fatalf("ToValue failed: %v", err)
}
state2.Pop(1)
// Compare results
if directResult != bytecodeResult {
t.Fatalf("Results differ: direct=%v, bytecode=%v", directResult, bytecodeResult)
}
}
// TestBytecodeReuse tests reusing the same bytecode multiple times
func TestBytecodeReuse(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a function in bytecode
code := `
return function(x)
return x * 2
end
`
bytecode, err := state.CompileBytecode(code, "func")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Execute it several times
for i := 1; i <= 3; i++ {
// Load and run to get the function
err = state.LoadAndRunBytecodeWithResults(bytes.Clone(bytecode), "func", 1)
if err != nil {
t.Fatalf("LoadAndRunBytecodeWithResults failed: %v", err)
}
// Stack now has the function at the top
if !state.IsFunction(-1) {
t.Fatal("Expected function at top of stack")
}
// Call with parameter i
state.PushNumber(float64(i))
if err := state.Call(1, 1); err != nil {
t.Fatalf("Call failed: %v", err)
}
// Check result
expected := float64(i * 2)
if state.ToNumber(-1) != expected {
t.Fatalf("Expected %v, got %v", expected, state.ToNumber(-1))
}
// Pop the result
state.Pop(1)
}
}
// TestBytecodeClosure tests that bytecode properly handles closures and upvalues
func TestBytecodeClosure(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a closure
code := `
local counter = 0
return function()
counter = counter + 1
return counter
end
`
// Compile to bytecode
bytecode, err := state.CompileBytecode(code, "closure")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Load and run to get the counter function
err = state.LoadAndRunBytecodeWithResults(bytecode, "closure", 1)
if err != nil {
t.Fatalf("LoadAndRunBytecode failed: %v", err)
}
// Stack now has the function at the top
if !state.IsFunction(-1) {
t.Fatal("Expected function at top of stack")
}
// Store in a global
state.SetGlobal("counter_func")
// Call it multiple times and check the results
for i := 1; i <= 3; i++ {
state.GetGlobal("counter_func")
if err := state.Call(0, 1); err != nil {
t.Fatalf("Call failed: %v", err)
}
if state.ToNumber(-1) != float64(i) {
t.Fatalf("Expected counter to be %d, got %v", i, state.ToNumber(-1))
}
state.Pop(1)
}
}

178
tests/functions_test.go Normal file
View File

@ -0,0 +1,178 @@
package luajit_test
import (
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func TestPushGoFunction(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Define a simple function that adds two numbers
add := func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1 // Return one result
}
// Push the function onto the stack
if err := state.PushGoFunction(add); err != nil {
t.Fatalf("PushGoFunction failed: %v", err)
}
// Verify that a function is on the stack
if !state.IsFunction(-1) {
t.Fatalf("Expected function at top of stack")
}
// Push arguments
state.PushNumber(3)
state.PushNumber(4)
// Call the function
if err := state.Call(2, 1); err != nil {
t.Fatalf("Failed to call function: %v", err)
}
// Check the result
if state.ToNumber(-1) != 7 {
t.Fatalf("Function returned %f, expected 7", state.ToNumber(-1))
}
state.Pop(1)
}
func TestRegisterGoFunction(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Define a function that squares a number
square := func(s *luajit.State) int {
x := s.ToNumber(1)
s.PushNumber(x * x)
return 1
}
// Register the function
if err := state.RegisterGoFunction("square", square); err != nil {
t.Fatalf("RegisterGoFunction failed: %v", err)
}
// Call the function from Lua
if err := state.DoString("result = square(5)"); err != nil {
t.Fatalf("Failed to call registered function: %v", err)
}
// Check the result
state.GetGlobal("result")
if state.ToNumber(-1) != 25 {
t.Fatalf("Function returned %f, expected 25", state.ToNumber(-1))
}
state.Pop(1)
// Test UnregisterGoFunction
state.UnregisterGoFunction("square")
// Function should no longer exist
err := state.DoString("result = square(5)")
if err == nil {
t.Fatalf("Expected error after unregistering function, got nil")
}
}
func TestGoFunctionWithErrorHandling(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Function that returns an error in Lua
errFunc := func(s *luajit.State) int {
s.PushString("error from Go function")
return -1 // Signal error
}
// Register the function
if err := state.RegisterGoFunction("errorFunc", errFunc); err != nil {
t.Fatalf("RegisterGoFunction failed: %v", err)
}
// Call the function expecting an error
err := state.DoString("result = errorFunc()")
if err == nil {
t.Fatalf("Expected error from function, got nil")
}
// Error message should contain our message
luaErr, ok := err.(*luajit.LuaError)
if !ok {
t.Fatalf("Expected LuaError, got %T: %v", err, err)
}
if luaErr.Message == "" {
t.Fatalf("Expected non-empty error message from Go function")
}
}
func TestCleanup(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
// Register several functions
for i := 0; i < 5; i++ {
dummy := func(s *luajit.State) int { return 0 }
if err := state.RegisterGoFunction("dummy", dummy); err != nil {
t.Fatalf("RegisterGoFunction failed: %v", err)
}
}
// Call Cleanup explicitly
state.Cleanup()
// Make sure we can still close the state
state.Close()
// Also test that Close can be called after Cleanup
state = luajit.New()
if state == nil {
t.Fatal("Failed to create second Lua state")
}
state.Close() // Should call Cleanup internally
}
func TestGoFunctionErrorPointer(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a Lua function that calls a non-existent Go function pointer
// This isn't a direct test of internal implementation, but tries to cover
// error cases in the goFunctionWrapper
code := `
function test()
-- This is a stub that doesn't actually call the wrapper,
-- but we're testing error handling in our State.DoString
return "test"
end
`
if err := state.DoString(code); err != nil {
t.Fatalf("Failed to define test function: %v", err)
}
// The real test is that Cleanup doesn't crash
state.Cleanup()
}

53
tests/stack_test.go Normal file
View File

@ -0,0 +1,53 @@
package luajit_test
import (
"strings"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func TestLuaError(t *testing.T) {
err := &luajit.LuaError{
Code: 123,
Message: "test error",
}
expected := "lua error (code=123): test error"
if err.Error() != expected {
t.Errorf("Expected error message %q, got %q", expected, err.Error())
}
}
func TestGetStackTrace(t *testing.T) {
s := luajit.New()
defer s.Close()
// Test with debug library available
trace := s.GetStackTrace()
if !strings.Contains(trace, "stack traceback:") {
t.Errorf("Expected trace to contain 'stack traceback:', got %q", trace)
}
// Test when debug table is not available
err := s.DoString("debug = nil")
if err != nil {
t.Fatalf("Failed to set debug to nil: %v", err)
}
trace = s.GetStackTrace()
if trace != "debug table not available" {
t.Errorf("Expected 'debug table not available', got %q", trace)
}
// Test when debug.traceback is not available
err = s.DoString("debug = {}")
if err != nil {
t.Fatalf("Failed to set debug to empty table: %v", err)
}
trace = s.GetStackTrace()
if trace != "debug.traceback not available" {
t.Errorf("Expected 'debug.traceback not available', got %q", trace)
}
}

246
tests/table_test.go Normal file
View File

@ -0,0 +1,246 @@
package luajit_test
import (
"reflect"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func TestGetTableLength(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a table with numeric indices
if err := state.DoString("t = {10, 20, 30, 40, 50}"); err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
// Get the table
state.GetGlobal("t")
length := state.GetTableLength(-1)
if length != 5 {
t.Fatalf("Expected length 5, got %d", length)
}
state.Pop(1)
// Create a table with string keys
if err := state.DoString("t2 = {a=1, b=2, c=3}"); err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
// Get the table
state.GetGlobal("t2")
length = state.GetTableLength(-1)
if length != 0 {
t.Fatalf("Expected length 0 for string-keyed table, got %d", length)
}
state.Pop(1)
}
func TestPushTable(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a test table
testTable := map[string]any{
"int": 42,
"float": 3.14,
"string": "hello",
"boolean": true,
"nil": nil,
}
// Push the table onto the stack
if err := state.PushTable(testTable); err != nil {
t.Fatalf("Failed to push table: %v", err)
}
// Execute Lua code to test the table contents
if err := state.DoString(`
function validate_table(t)
return t.int == 42 and
math.abs(t.float - 3.14) < 0.0001 and
t.string == "hello" and
t.boolean == true and
t["nil"] == nil
end
`); err != nil {
t.Fatalf("Failed to create validation function: %v", err)
}
// Call the validation function
state.GetGlobal("validate_table")
state.PushCopy(-2) // Copy the table to the top
if err := state.Call(1, 1); err != nil {
t.Fatalf("Failed to call validation function: %v", err)
}
if !state.ToBoolean(-1) {
t.Fatalf("Table validation failed")
}
state.Pop(2) // Pop the result and the table
}
func TestToTable(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test regular table conversion
if err := state.DoString(`t = {a=1, b=2.5, c="test", d=true, e=nil}`); err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
state.GetGlobal("t")
table, err := state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert table: %v", err)
}
state.Pop(1)
expected := map[string]any{
"a": float64(1),
"b": 2.5,
"c": "test",
"d": true,
}
for k, v := range expected {
if table[k] != v {
t.Fatalf("Expected table[%s] = %v, got %v", k, v, table[k])
}
}
// Test array-like table conversion
if err := state.DoString(`arr = {10, 20, 30, 40, 50}`); err != nil {
t.Fatalf("Failed to create test array: %v", err)
}
state.GetGlobal("arr")
table, err = state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert array table: %v", err)
}
state.Pop(1)
// For array tables, we should get a special format with an empty key
// and the array as the value
expectedArray := []float64{10, 20, 30, 40, 50}
if arr, ok := table[""].([]float64); !ok {
t.Fatalf("Expected array table to be converted with empty key, got: %v", table)
} else if !reflect.DeepEqual(arr, expectedArray) {
t.Fatalf("Expected %v, got %v", expectedArray, arr)
}
// Test invalid table index
_, err = state.ToTable(100)
if err == nil {
t.Fatalf("Expected error for invalid table index, got nil")
}
// Test non-table value
state.PushNumber(123)
_, err = state.ToTable(-1)
if err == nil {
t.Fatalf("Expected error for non-table value, got nil")
}
state.Pop(1)
// Test mixed array with non-numeric values
if err := state.DoString(`mixed = {10, 20, key="value", 30}`); err != nil {
t.Fatalf("Failed to create mixed table: %v", err)
}
state.GetGlobal("mixed")
table, err = state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert mixed table: %v", err)
}
// Let's print the table for debugging
t.Logf("Table contents: %v", table)
state.Pop(1)
// Check if the array part is detected and stored with empty key
if arr, ok := table[""]; !ok {
t.Fatalf("Expected array-like part to be detected, got: %v", table)
} else {
// Verify the array contains the expected values
expectedArr := []float64{10, 20, 30}
actualArr := arr.([]float64)
if len(actualArr) != len(expectedArr) {
t.Fatalf("Expected array length %d, got %d", len(expectedArr), len(actualArr))
}
for i, v := range expectedArr {
if actualArr[i] != v {
t.Fatalf("Expected array[%d] = %v, got %v", i, v, actualArr[i])
}
}
}
// Based on the implementation, we need to create a separate test for string keys
if err := state.DoString(`dict = {foo="bar", baz="qux"}`); err != nil {
t.Fatalf("Failed to create dict table: %v", err)
}
state.GetGlobal("dict")
dictTable, err := state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert dict table: %v", err)
}
state.Pop(1)
// Check the string keys
if val, ok := dictTable["foo"]; !ok || val != "bar" {
t.Fatalf("Expected dictTable[\"foo\"] = \"bar\", got: %v", val)
}
if val, ok := dictTable["baz"]; !ok || val != "qux" {
t.Fatalf("Expected dictTable[\"baz\"] = \"qux\", got: %v", val)
}
}
func TestTablePooling(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a Lua table and push it onto the stack
if err := state.DoString(`t = {a=1, b=2}`); err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
state.GetGlobal("t")
// First conversion - should get a table from the pool
table1, err := state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert table (1): %v", err)
}
// Second conversion - should get another table from the pool
table2, err := state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert table (2): %v", err)
}
// Both tables should have the same content
if !reflect.DeepEqual(table1, table2) {
t.Fatalf("Tables should have the same content: %v vs %v", table1, table2)
}
// Clean up
state.Pop(1)
}

473
tests/wrapper_test.go Normal file
View File

@ -0,0 +1,473 @@
package luajit_test
import (
"os"
"reflect"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func TestStateLifecycle(t *testing.T) {
// Test creation
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
// Test close
state.Close()
// Test close is idempotent (doesn't crash)
state.Close()
}
func TestStackManipulation(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test initial stack size
if state.GetTop() != 0 {
t.Fatalf("Expected empty stack, got %d elements", state.GetTop())
}
// Push values
state.PushNil()
state.PushBoolean(true)
state.PushNumber(42)
state.PushString("hello")
// Check stack size
if state.GetTop() != 4 {
t.Fatalf("Expected 4 elements, got %d", state.GetTop())
}
// Test SetTop
state.SetTop(2)
if state.GetTop() != 2 {
t.Fatalf("Expected 2 elements after SetTop, got %d", state.GetTop())
}
// Test PushCopy
state.PushCopy(2) // Copy the boolean
if !state.IsBoolean(-1) {
t.Fatalf("Expected boolean at top of stack")
}
// Test Pop
state.Pop(1)
if state.GetTop() != 2 {
t.Fatalf("Expected 2 elements after Pop, got %d", state.GetTop())
}
// Test Remove
state.PushNumber(99)
state.Remove(1) // Remove the first element (nil)
if state.GetTop() != 2 {
t.Fatalf("Expected 2 elements after Remove, got %d", state.GetTop())
}
// Verify first element is now boolean
if !state.IsBoolean(1) {
t.Fatalf("Expected boolean at index 1 after Remove")
}
}
func TestTypeChecking(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Push values of different types
state.PushNil()
state.PushBoolean(true)
state.PushNumber(42)
state.PushString("hello")
state.NewTable()
// Check types with GetType
if state.GetType(1) != luajit.TypeNil {
t.Fatalf("Expected nil type at index 1, got %s", state.GetType(1))
}
if state.GetType(2) != luajit.TypeBoolean {
t.Fatalf("Expected boolean type at index 2, got %s", state.GetType(2))
}
if state.GetType(3) != luajit.TypeNumber {
t.Fatalf("Expected number type at index 3, got %s", state.GetType(3))
}
if state.GetType(4) != luajit.TypeString {
t.Fatalf("Expected string type at index 4, got %s", state.GetType(4))
}
if state.GetType(5) != luajit.TypeTable {
t.Fatalf("Expected table type at index 5, got %s", state.GetType(5))
}
// Test individual type checking functions
if !state.IsNil(1) {
t.Fatalf("IsNil failed for nil value")
}
if !state.IsBoolean(2) {
t.Fatalf("IsBoolean failed for boolean value")
}
if !state.IsNumber(3) {
t.Fatalf("IsNumber failed for number value")
}
if !state.IsString(4) {
t.Fatalf("IsString failed for string value")
}
if !state.IsTable(5) {
t.Fatalf("IsTable failed for table value")
}
// Function test
state.DoString("function test() return true end")
state.GetGlobal("test")
if !state.IsFunction(-1) {
t.Fatalf("IsFunction failed for function value")
}
}
func TestValueConversion(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Push values
state.PushBoolean(true)
state.PushNumber(42.5)
state.PushString("hello")
// Test conversion
if !state.ToBoolean(1) {
t.Fatalf("ToBoolean failed")
}
if state.ToNumber(2) != 42.5 {
t.Fatalf("ToNumber failed, expected 42.5, got %f", state.ToNumber(2))
}
if state.ToString(3) != "hello" {
t.Fatalf("ToString failed, expected 'hello', got '%s'", state.ToString(3))
}
}
func TestTableOperations(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test CreateTable
state.CreateTable(0, 3)
// Add fields using SetField
state.PushNumber(42)
state.SetField(-2, "answer")
state.PushString("hello")
state.SetField(-2, "greeting")
state.PushBoolean(true)
state.SetField(-2, "flag")
// Test GetField
state.GetField(-1, "answer")
if state.ToNumber(-1) != 42 {
t.Fatalf("GetField for 'answer' failed")
}
state.Pop(1)
state.GetField(-1, "greeting")
if state.ToString(-1) != "hello" {
t.Fatalf("GetField for 'greeting' failed")
}
state.Pop(1)
// Test Next for iteration
state.PushNil() // Start iteration
count := 0
for state.Next(-2) {
count++
state.Pop(1) // Pop value, leave key for next iteration
}
if count != 3 {
t.Fatalf("Expected 3 table entries, found %d", count)
}
// Clean up
state.Pop(1) // Pop the table
}
func TestGlobalOperations(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Set a global value
state.PushNumber(42)
state.SetGlobal("answer")
// Get the global value
state.GetGlobal("answer")
if state.ToNumber(-1) != 42 {
t.Fatalf("GetGlobal failed, expected 42, got %f", state.ToNumber(-1))
}
state.Pop(1)
// Test non-existent global (should be nil)
state.GetGlobal("nonexistent")
if !state.IsNil(-1) {
t.Fatalf("Expected nil for non-existent global")
}
state.Pop(1)
}
func TestCodeExecution(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test LoadString
if err := state.LoadString("return 42"); err != nil {
t.Fatalf("LoadString failed: %v", err)
}
// Test Call
if err := state.Call(0, 1); err != nil {
t.Fatalf("Call failed: %v", err)
}
if state.ToNumber(-1) != 42 {
t.Fatalf("Call result incorrect, expected 42, got %f", state.ToNumber(-1))
}
state.Pop(1)
// Test DoString
if err := state.DoString("answer = 42 + 1"); err != nil {
t.Fatalf("DoString failed: %v", err)
}
state.GetGlobal("answer")
if state.ToNumber(-1) != 43 {
t.Fatalf("DoString execution incorrect, expected 43, got %f", state.ToNumber(-1))
}
state.Pop(1)
// Test Execute
nresults, err := state.Execute("return 5, 10, 15")
if err != nil {
t.Fatalf("Execute failed: %v", err)
}
if nresults != 3 {
t.Fatalf("Execute returned %d results, expected 3", nresults)
}
if state.ToNumber(-3) != 5 || state.ToNumber(-2) != 10 || state.ToNumber(-1) != 15 {
t.Fatalf("Execute results incorrect")
}
state.Pop(3)
// Test ExecuteWithResult
result, err := state.ExecuteWithResult("return 'hello'")
if err != nil {
t.Fatalf("ExecuteWithResult failed: %v", err)
}
if result != "hello" {
t.Fatalf("ExecuteWithResult returned %v, expected 'hello'", result)
}
// Test error handling
err = state.DoString("this is not valid lua code")
if err == nil {
t.Fatalf("Expected error for invalid code, got nil")
}
}
func TestDoFile(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a temporary Lua file
content := []byte("answer = 42")
tmpfile, err := os.CreateTemp("", "test-*.lua")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
if err := tmpfile.Close(); err != nil {
t.Fatalf("Failed to close temp file: %v", err)
}
// Test LoadFile and DoFile
if err := state.LoadFile(tmpfile.Name()); err != nil {
t.Fatalf("LoadFile failed: %v", err)
}
if err := state.Call(0, 0); err != nil {
t.Fatalf("Call failed after LoadFile: %v", err)
}
state.GetGlobal("answer")
if state.ToNumber(-1) != 42 {
t.Fatalf("Incorrect result after LoadFile, expected 42, got %f", state.ToNumber(-1))
}
state.Pop(1)
// Reset global
if err := state.DoString("answer = nil"); err != nil {
t.Fatalf("Failed to reset answer: %v", err)
}
// Test DoFile
if err := state.DoFile(tmpfile.Name()); err != nil {
t.Fatalf("DoFile failed: %v", err)
}
state.GetGlobal("answer")
if state.ToNumber(-1) != 42 {
t.Fatalf("Incorrect result after DoFile, expected 42, got %f", state.ToNumber(-1))
}
state.Pop(1)
}
func TestPackagePath(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test SetPackagePath
testPath := "/test/path/?.lua"
if err := state.SetPackagePath(testPath); err != nil {
t.Fatalf("SetPackagePath failed: %v", err)
}
result, err := state.ExecuteWithResult("return package.path")
if err != nil {
t.Fatalf("Failed to get package.path: %v", err)
}
if result != testPath {
t.Fatalf("Expected package.path to be '%s', got '%s'", testPath, result)
}
// Test AddPackagePath
addPath := "/another/path/?.lua"
if err := state.AddPackagePath(addPath); err != nil {
t.Fatalf("AddPackagePath failed: %v", err)
}
result, err = state.ExecuteWithResult("return package.path")
if err != nil {
t.Fatalf("Failed to get package.path: %v", err)
}
expected := testPath + ";" + addPath
if result != expected {
t.Fatalf("Expected package.path to be '%s', got '%s'", expected, result)
}
}
func TestPushValueAndToValue(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
testCases := []struct {
value any
}{
{nil},
{true},
{false},
{42},
{42.5},
{"hello"},
{[]float64{1, 2, 3, 4, 5}},
{[]any{1, "test", true}},
{map[string]any{"a": 1, "b": "test", "c": true}},
}
for i, tc := range testCases {
// Push value
err := state.PushValue(tc.value)
if err != nil {
t.Fatalf("PushValue failed for testCase %d: %v", i, err)
}
// Check stack
if state.GetTop() != i+1 {
t.Fatalf("Stack size incorrect after push, expected %d, got %d", i+1, state.GetTop())
}
}
// Test conversion back to Go
for i := range testCases {
index := len(testCases) - i
value, err := state.ToValue(index)
if err != nil {
t.Fatalf("ToValue failed for index %d: %v", index, err)
}
// For tables, we need special handling due to how Go types are stored
switch expected := testCases[index-1].value.(type) {
case []float64:
// Arrays come back as map[string]any with empty key
if m, ok := value.(map[string]any); ok {
if arr, ok := m[""].([]float64); ok {
if !reflect.DeepEqual(arr, expected) {
t.Fatalf("Value mismatch for testCase %d: expected %v, got %v", index-1, expected, arr)
}
} else {
t.Fatalf("Invalid array conversion for testCase %d", index-1)
}
} else {
t.Fatalf("Expected map for array value in testCase %d, got %T", index-1, value)
}
case int:
if num, ok := value.(float64); ok {
if float64(expected) == num {
continue // Values match after type conversion
}
}
case []any:
// Skip detailed comparison for mixed arrays
case map[string]any:
// Skip detailed comparison for maps
default:
if !reflect.DeepEqual(value, testCases[index-1].value) {
t.Fatalf("Value mismatch for testCase %d: expected %v, got %v",
index-1, testCases[index-1].value, value)
}
}
}
// Test unsupported type
complex := complex(1, 2)
err := state.PushValue(complex)
if err == nil {
t.Fatalf("Expected error for unsupported type")
}
}

View File

@ -9,7 +9,7 @@ import "C"
type LuaType int
const (
// These constants must match lua.h's LUA_T* values
// These constants match lua.h's LUA_T* values
TypeNone LuaType = -1
TypeNil LuaType = 0
TypeBoolean LuaType = 1

View File

@ -9,61 +9,30 @@ package luajit
#include <lualib.h>
#include <lauxlib.h>
#include <stdlib.h>
#include <string.h>
// Simple wrapper around luaL_loadstring
static int load_chunk(lua_State *L, const char *s) {
return luaL_loadstring(L, s);
}
// Direct wrapper around lua_pcall
static int protected_call(lua_State *L, int nargs, int nresults, int errfunc) {
return lua_pcall(L, nargs, nresults, errfunc);
}
// Combined load and execute with no results
static int do_string(lua_State *L, const char *s) {
return luaL_dostring(L, s);
}
// Combined load and execute file
static int do_file(lua_State *L, const char *filename) {
return luaL_dofile(L, filename);
}
// Execute string with multiple returns
static int execute_string(lua_State *L, const char *s) {
int base = lua_gettop(L); // Save stack position
int status = luaL_loadstring(L, s);
if (status) return -status; // Return negative status for load errors
status = lua_pcall(L, 0, LUA_MULTRET, 0);
if (status) return -status; // Return negative status for runtime errors
return lua_gettop(L) - base; // Return number of results
}
// Get absolute stack index (converts negative indices)
// Helper to simplify some common operations
static int get_abs_index(lua_State *L, int idx) {
if (idx > 0 || idx <= LUA_REGISTRYINDEX) return idx;
return lua_gettop(L) + idx + 1;
}
// Stack manipulation helpers
static int check_stack(lua_State *L, int n) {
return lua_checkstack(L, n);
// Combined load and execute with no results
static int do_string(lua_State *L, const char *s) {
int status = luaL_loadstring(L, s);
if (status == 0) {
status = lua_pcall(L, 0, 0, 0);
}
return status;
}
static void remove_stack(lua_State *L, int idx) {
lua_remove(L, idx);
}
static int get_field_helper(lua_State *L, int idx, const char *k) {
lua_getfield(L, idx, k);
return lua_type(L, -1);
}
static void set_field_helper(lua_State *L, int idx, const char *k) {
lua_setfield(L, idx, k);
// Combined load and execute file
static int do_file(lua_State *L, const char *filename) {
int status = luaL_loadfile(L, filename);
if (status == 0) {
status = lua_pcall(L, 0, 0, 0);
}
return status;
}
*/
import "C"
@ -78,7 +47,7 @@ type State struct {
L *C.lua_State
}
// New creates a new Lua state
// New creates a new Lua state with all standard libraries loaded
func New() *State {
L := C.luaL_newstate()
if L == nil {
@ -88,51 +57,195 @@ func New() *State {
return &State{L: L}
}
// Close closes the Lua state
// Close closes the Lua state and frees resources
func (s *State) Close() {
if s.L != nil {
s.Cleanup() // Clean up Go function registry
C.lua_close(s.L)
s.L = nil
}
}
// DoString executes a Lua string.
func (s *State) DoString(str string) error {
// Save initial stack size
top := s.GetTop()
// Stack manipulation methods
// Load the string
if err := s.LoadString(str); err != nil {
return err
}
// Execute and check for errors
if err := s.Call(0, 0); err != nil {
return err
}
// Restore stack to initial size to clean up any leftovers
s.SetTop(top)
return nil
// GetTop returns the index of the top element in the stack
func (s *State) GetTop() int {
return int(C.lua_gettop(s.L))
}
// PushValue pushes a Go value onto the stack
// SetTop sets the stack top to a specific index
func (s *State) SetTop(index int) {
C.lua_settop(s.L, C.int(index))
}
// PushValue pushes a copy of the value at the given index onto the stack
func (s *State) PushCopy(index int) {
C.lua_pushvalue(s.L, C.int(index))
}
// Pop pops n elements from the stack
func (s *State) Pop(n int) {
C.lua_settop(s.L, C.int(-n-1))
}
// Remove removes the element at the given valid index
func (s *State) Remove(index int) {
C.lua_remove(s.L, C.int(index))
}
// absIndex converts a possibly negative index to its absolute position
func (s *State) absIndex(index int) int {
if index > 0 || index <= LUA_REGISTRYINDEX {
return index
}
return s.GetTop() + index + 1
}
// Type checking methods
// GetType returns the type of the value at the given index
func (s *State) GetType(index int) LuaType {
return LuaType(C.lua_type(s.L, C.int(index)))
}
// IsNil checks if the value at the given index is nil
func (s *State) IsNil(index int) bool {
return s.GetType(index) == TypeNil
}
// IsBoolean checks if the value at the given index is a boolean
func (s *State) IsBoolean(index int) bool {
return s.GetType(index) == TypeBoolean
}
// IsNumber checks if the value at the given index is a number
func (s *State) IsNumber(index int) bool {
return C.lua_isnumber(s.L, C.int(index)) != 0
}
// IsString checks if the value at the given index is a string
func (s *State) IsString(index int) bool {
return C.lua_isstring(s.L, C.int(index)) != 0
}
// IsTable checks if the value at the given index is a table
func (s *State) IsTable(index int) bool {
return s.GetType(index) == TypeTable
}
// IsFunction checks if the value at the given index is a function
func (s *State) IsFunction(index int) bool {
return s.GetType(index) == TypeFunction
}
// Value conversion methods
// ToBoolean returns the value at the given index as a boolean
func (s *State) ToBoolean(index int) bool {
return C.lua_toboolean(s.L, C.int(index)) != 0
}
// ToNumber returns the value at the given index as a number
func (s *State) ToNumber(index int) float64 {
return float64(C.lua_tonumber(s.L, C.int(index)))
}
// ToString returns the value at the given index as a string
func (s *State) ToString(index int) string {
var length C.size_t
cstr := C.lua_tolstring(s.L, C.int(index), &length)
if cstr == nil {
return ""
}
return C.GoStringN(cstr, C.int(length))
}
// Push methods
// PushNil pushes a nil value onto the stack
func (s *State) PushNil() {
C.lua_pushnil(s.L)
}
// PushBoolean pushes a boolean value onto the stack
func (s *State) PushBoolean(b bool) {
var value C.int
if b {
value = 1
}
C.lua_pushboolean(s.L, value)
}
// PushNumber pushes a number value onto the stack
func (s *State) PushNumber(n float64) {
C.lua_pushnumber(s.L, C.lua_Number(n))
}
// PushString pushes a string value onto the stack
func (s *State) PushString(str string) {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
C.lua_pushlstring(s.L, cstr, C.size_t(len(str)))
}
// Table operations
// CreateTable creates a new table and pushes it onto the stack
func (s *State) CreateTable(narr, nrec int) {
C.lua_createtable(s.L, C.int(narr), C.int(nrec))
}
// NewTable creates a new empty table and pushes it onto the stack
func (s *State) NewTable() {
C.lua_createtable(s.L, 0, 0)
}
// GetTable gets a table field (t[k]) where t is at the given index and k is at the top of the stack
func (s *State) GetTable(index int) {
C.lua_gettable(s.L, C.int(index))
}
// SetTable sets a table field (t[k] = v) where t is at the given index, k is at -2, and v is at -1
func (s *State) SetTable(index int) {
C.lua_settable(s.L, C.int(index))
}
// GetField gets a table field t[k] and pushes it onto the stack
func (s *State) GetField(index int, key string) {
ckey := C.CString(key)
defer C.free(unsafe.Pointer(ckey))
C.lua_getfield(s.L, C.int(index), ckey)
}
// SetField sets a table field t[k] = v, where v is the value at the top of the stack
func (s *State) SetField(index int, key string) {
ckey := C.CString(key)
defer C.free(unsafe.Pointer(ckey))
C.lua_setfield(s.L, C.int(index), ckey)
}
// Next pops a key from the stack and pushes the next key-value pair from the table at the given index
func (s *State) Next(index int) bool {
return C.lua_next(s.L, C.int(index)) != 0
}
// PushValue pushes a Go value onto the stack with proper type conversion
func (s *State) PushValue(v interface{}) error {
switch v := v.(type) {
case nil:
s.PushNil()
case bool:
s.PushBoolean(v)
case float64:
s.PushNumber(v)
case int:
s.PushNumber(float64(v))
case float64:
s.PushNumber(v)
case string:
s.PushString(v)
case map[string]interface{}:
// Special case: handle array stored in map
if arr, ok := v[""].([]float64); ok {
s.NewTable()
s.CreateTable(len(arr), 0)
for i, elem := range arr {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
@ -142,14 +255,14 @@ func (s *State) PushValue(v interface{}) error {
}
return s.PushTable(v)
case []float64:
s.NewTable()
s.CreateTable(len(v), 0)
for i, elem := range v {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
s.SetTable(-3)
}
case []interface{}:
s.NewTable()
s.CreateTable(len(v), 0)
for i, elem := range v {
s.PushNumber(float64(i + 1))
if err := s.PushValue(elem); err != nil {
@ -163,9 +276,10 @@ func (s *State) PushValue(v interface{}) error {
return nil
}
// ToValue converts a Lua value to a Go value
// ToValue converts a Lua value at the given index to a Go value
func (s *State) ToValue(index int) (interface{}, error) {
switch s.GetType(index) {
luaType := s.GetType(index)
switch luaType {
case TypeNil:
return nil, nil
case TypeBoolean:
@ -175,231 +289,161 @@ func (s *State) ToValue(index int) (interface{}, error) {
case TypeString:
return s.ToString(index), nil
case TypeTable:
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.ToTable(index)
default:
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
return nil, fmt.Errorf("unsupported type: %s", luaType)
}
}
// Simple operations remain unchanged as they don't need stack protection
// Global operations
func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) }
func (s *State) IsString(index int) bool { return s.GetType(index) == TypeString }
func (s *State) IsNumber(index int) bool { return s.GetType(index) == TypeNumber }
func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction }
func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable }
func (s *State) IsNil(index int) bool { return s.GetType(index) == TypeNil }
func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 }
func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) }
func (s *State) ToString(index int) string {
return C.GoString(C.lua_tolstring(s.L, C.int(index), nil))
}
func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) }
func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) }
func (s *State) SetTop(index int) { C.lua_settop(s.L, C.int(index)) }
// Push operations
func (s *State) PushNil() { C.lua_pushnil(s.L) }
func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, C.int(bool2int(b))) }
func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.double(n)) }
func (s *State) PushString(str string) {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
C.lua_pushstring(s.L, cstr)
}
func (s *State) Next(index int) bool {
return C.lua_next(s.L, C.int(index)) != 0
}
// Helper functions
func bool2int(b bool) int {
if b {
return 1
}
return 0
}
func (s *State) absIndex(index int) int {
if index > 0 || index <= LUA_REGISTRYINDEX {
return index
}
return s.GetTop() + index + 1
}
// SetField sets a field in a table at the given index
func (s *State) SetField(index int, key string) {
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.int(index), cstr)
}
// GetField gets a field from a table
func (s *State) GetField(index int, key string) {
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_getfield(s.L, C.int(index), cstr)
}
// GetGlobal gets a global variable and pushes it onto the stack
// GetGlobal pushes the global variable with the given name onto the stack
func (s *State) GetGlobal(name string) {
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.get_field_helper(s.L, C.LUA_GLOBALSINDEX, cname)
s.GetField(LUA_GLOBALSINDEX, name)
}
// SetGlobal sets a global variable from the value at the top of the stack
// SetGlobal sets the global variable with the given name to the value at the top of the stack
func (s *State) SetGlobal(name string) {
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
s.SetField(LUA_GLOBALSINDEX, name)
}
// Remove removes element with cached absolute index
func (s *State) Remove(index int) {
absIdx := index
C.lua_remove(s.L, C.int(absIdx))
// Code execution methods
// LoadString loads a Lua chunk from a string without executing it
func (s *State) LoadString(code string) error {
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
return s.safeCall(func() C.int {
return C.luaL_loadstring(s.L, ccode)
})
}
// DoFile executes a Lua file with appropriate stack management
// LoadFile loads a Lua chunk from a file without executing it
func (s *State) LoadFile(filename string) error {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
return s.safeCall(func() C.int {
return C.luaL_loadfile(s.L, cfilename)
})
}
// Call calls a function with the given number of arguments and results
func (s *State) Call(nargs, nresults int) error {
return s.safeCall(func() C.int {
return C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
})
}
// DoString executes a Lua string and cleans up the stack
func (s *State) DoString(code string) error {
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
return s.safeCall(func() C.int {
return C.do_string(s.L, ccode)
})
}
// DoFile executes a Lua file and cleans up the stack
func (s *State) DoFile(filename string) error {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
status := C.do_file(s.L, cfilename)
if status != 0 {
return &LuaError{
Code: int(status),
Message: s.ToString(-1),
return s.safeCall(func() C.int {
return C.do_file(s.L, cfilename)
})
}
// Execute executes a Lua string and returns the number of results left on the stack
func (s *State) Execute(code string) (int, error) {
baseTop := s.GetTop()
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
var nresults int
err := s.safeCall(func() C.int {
status := C.luaL_loadstring(s.L, ccode)
if status != 0 {
return status
}
}
return nil
}
// SetPackagePath sets the Lua package.path
func (s *State) SetPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
cmd := fmt.Sprintf(`package.path = %q`, path)
return s.DoString(cmd)
}
// AddPackagePath adds a path to package.path
func (s *State) AddPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
cmd := fmt.Sprintf(`package.path = package.path .. ";%s"`, path)
return s.DoString(cmd)
}
// Call executes a function on the stack with the given number of arguments and results.
// The function and arguments should already be on the stack in the correct order
// (function first, then args from left to right).
func (s *State) Call(nargs, nresults int) error {
if !s.IsFunction(-nargs - 1) {
return fmt.Errorf("attempt to call a non-function")
}
status := C.protected_call(s.L, C.int(nargs), C.int(nresults), 0)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
status = C.lua_pcall(s.L, 0, C.LUA_MULTRET, 0)
if status == 0 {
nresults = s.GetTop() - baseTop
}
s.Pop(1)
return err
}
return nil
}
return status
})
// LoadString loads but does not execute a string of Lua code.
// The compiled code chunk is left on the stack.
func (s *State) LoadString(str string) error {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
status := C.load_chunk(s.L, cstr)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1)
return err
}
if !s.IsFunction(-1) {
s.Pop(1)
return fmt.Errorf("failed to load function")
}
return nil
}
// ExecuteString executes a string of Lua code and returns the number of results.
// The results are left on the stack.
func (s *State) ExecuteString(str string) (int, error) {
base := s.GetTop()
// First load the string
if err := s.LoadString(str); err != nil {
return 0, err
}
// Now execute it
if err := s.Call(0, C.LUA_MULTRET); err != nil {
return 0, err
}
return s.GetTop() - base, nil
}
// ExecuteStringResult executes a Lua string and returns its first result as a Go value.
// It's a convenience wrapper around ExecuteString for the common case of wanting
// a single return value. The stack is restored to its original state after execution.
func (s *State) ExecuteStringResult(code string) (interface{}, error) {
top := s.GetTop()
defer s.SetTop(top) // Restore stack when we're done
nresults, err := s.ExecuteString(code)
if err != nil {
return nil, fmt.Errorf("execution error: %w", err)
return 0, err
}
return nresults, nil
}
// ExecuteWithResult executes a Lua string and returns the first result
func (s *State) ExecuteWithResult(code string) (interface{}, error) {
top := s.GetTop()
defer s.SetTop(top) // Restore stack when done
nresults, err := s.Execute(code)
if err != nil {
return nil, err
}
if nresults == 0 {
return nil, nil
}
// Get the result
result, err := s.ToValue(-nresults) // Get first result
if err != nil {
return nil, fmt.Errorf("error converting result: %w", err)
}
return result, nil
return s.ToValue(-nresults)
}
// DoStringResult executes a Lua string and expects a single return value.
// Unlike ExecuteStringResult, this function specifically expects exactly one
// return value and will return an error if the code returns 0 or multiple values.
func (s *State) DoStringResult(code string) (interface{}, error) {
top := s.GetTop()
defer s.SetTop(top) // Restore stack when we're done
// Package path operations
nresults, err := s.ExecuteString(code)
if err != nil {
return nil, fmt.Errorf("execution error: %w", err)
}
if nresults != 1 {
return nil, fmt.Errorf("expected 1 return value, got %d", nresults)
}
// Get the result
result, err := s.ToValue(-1)
if err != nil {
return nil, fmt.Errorf("error converting result: %w", err)
}
return result, nil
// SetPackagePath sets the Lua package.path
func (s *State) SetPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
return s.DoString(fmt.Sprintf(`package.path = %q`, path))
}
// AddPackagePath adds a path to package.path
func (s *State) AddPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
return s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path))
}
// Helper functions
// checkStack ensures there is enough space on the Lua stack
func (s *State) checkStack(n int) error {
if C.lua_checkstack(s.L, C.int(n)) == 0 {
return fmt.Errorf("stack overflow (cannot allocate %d slots)", n)
}
return nil
}
// safeCall wraps a potentially dangerous C call with stack checking
func (s *State) safeCall(f func() C.int) error {
// Ensure we have enough stack space
if err := s.checkStack(LUA_MINSTACK); err != nil {
return err
}
// Make the call
status := f()
// Check for errors
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
}
return nil
}

View File

@ -1,237 +0,0 @@
package luajit
import (
"testing"
)
var benchCases = []struct {
name string
code string
}{
{
name: "SimpleAddition",
code: `return 1 + 1`,
},
{
name: "LoopSum",
code: `
local sum = 0
for i = 1, 1000 do
sum = sum + i
end
return sum
`,
},
{
name: "FunctionCall",
code: `
local result = 0
for i = 1, 100 do
result = result + i
end
return result
`,
},
{
name: "TableCreation",
code: `
local t = {}
for i = 1, 100 do
t[i] = i * 2
end
return t[50]
`,
},
{
name: "StringOperations",
code: `
local s = "hello"
for i = 1, 10 do
s = s .. " world"
end
return #s
`,
},
}
func BenchmarkLuaDirectExecution(b *testing.B) {
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
// First verify we can execute the code
if err := L.DoString(bc.code); err != nil {
b.Fatalf("Failed to execute test code: %v", err)
}
L.Pop(1) // Clean up the result
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Execute string and get result
nresults, err := L.ExecuteString(bc.code)
if err != nil {
b.Fatalf("Failed to execute code: %v", err)
}
L.Pop(nresults) // Clean up any results
}
})
}
}
func BenchmarkLuaBytecodeExecution(b *testing.B) {
// First compile all bytecode
bytecodes := make(map[string][]byte)
for _, bc := range benchCases {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
bytecode, err := L.CompileBytecode(bc.code, bc.name)
if err != nil {
L.Close()
b.Fatalf("Error compiling bytecode for %s: %v", bc.name, err)
}
bytecodes[bc.name] = bytecode
L.Close()
}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
bytecode := bytecodes[bc.name]
// First verify we can execute the bytecode
if err := L.LoadBytecode(bytecode, bc.name); err != nil {
b.Fatalf("Failed to execute test bytecode: %v", err)
}
b.ResetTimer()
b.SetBytes(int64(len(bytecode))) // Track bytecode size in benchmarks
for i := 0; i < b.N; i++ {
if err := L.LoadBytecode(bytecode, bc.name); err != nil {
b.Fatalf("Error executing bytecode: %v", err)
}
}
})
}
}
func BenchmarkTableOperations(b *testing.B) {
testData := map[string]interface{}{
"number": 42.0,
"string": "hello",
"bool": true,
"nested": map[string]interface{}{
"value": 123.0,
"array": []float64{1.1, 2.2, 3.3},
},
}
b.Run("PushTable", func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
// First verify we can push the table
if err := L.PushTable(testData); err != nil {
b.Fatalf("Failed to push initial table: %v", err)
}
L.Pop(1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := L.PushTable(testData); err != nil {
b.Fatalf("Failed to push table: %v", err)
}
L.Pop(1)
}
})
b.Run("ToTable", func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
// Keep a table on the stack for repeated conversions
if err := L.PushTable(testData); err != nil {
b.Fatalf("Failed to push initial table: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := L.ToTable(-1); err != nil {
b.Fatalf("Failed to convert table: %v", err)
}
}
})
}
func BenchmarkValueConversion(b *testing.B) {
testValues := []struct {
name string
value interface{}
}{
{"Number", 42.0},
{"String", "hello world"},
{"Boolean", true},
{"Nil", nil},
}
for _, tv := range testValues {
b.Run("Push"+tv.name, func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
// First verify we can push the value
if err := L.PushValue(tv.value); err != nil {
b.Fatalf("Failed to push initial value: %v", err)
}
L.Pop(1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := L.PushValue(tv.value); err != nil {
b.Fatalf("Failed to push value: %v", err)
}
L.Pop(1)
}
})
b.Run("To"+tv.name, func(b *testing.B) {
L := New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
// Keep a value on the stack for repeated conversions
if err := L.PushValue(tv.value); err != nil {
b.Fatalf("Failed to push initial value: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := L.ToValue(-1); err != nil {
b.Fatalf("Failed to convert value: %v", err)
}
}
})
}
}

View File

@ -1,693 +0,0 @@
package luajit
import (
"fmt"
"os"
"path/filepath"
"testing"
)
func TestNew(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
}
func TestLoadString(t *testing.T) {
tests := []struct {
name string
code string
wantErr bool
}{
{
name: "valid function",
code: "function add(a, b) return a + b end",
wantErr: false,
},
{
name: "valid expression",
code: "return 1 + 1",
wantErr: false,
},
{
name: "syntax error",
code: "function bad syntax",
wantErr: true,
},
}
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
err := L.LoadString(tt.code)
if (err != nil) != tt.wantErr {
t.Errorf("LoadString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
// Verify the function is on the stack
if L.GetTop() != 1 {
t.Error("LoadString() did not leave exactly one value on stack")
}
if !L.IsFunction(-1) {
t.Error("LoadString() did not leave a function on the stack")
}
}
}
}
func TestExecuteString(t *testing.T) {
tests := []struct {
name string
code string
wantResults int
checkResults func(*State) error
wantErr bool
wantStackSize int
}{
{
name: "no results",
code: "local x = 1",
wantResults: 0,
wantErr: false,
},
{
name: "single result",
code: "return 42",
wantResults: 1,
checkResults: func(L *State) error {
if n := L.ToNumber(-1); n != 42 {
return fmt.Errorf("got %v, want 42", n)
}
return nil
},
wantErr: false,
},
{
name: "multiple results",
code: "return 1, 'test', true",
wantResults: 3,
checkResults: func(L *State) error {
if n := L.ToNumber(-3); n != 1 {
return fmt.Errorf("first result: got %v, want 1", n)
}
if s := L.ToString(-2); s != "test" {
return fmt.Errorf("second result: got %v, want 'test'", s)
}
if b := L.ToBoolean(-1); !b {
return fmt.Errorf("third result: got %v, want true", b)
}
return nil
},
wantErr: false,
},
{
name: "syntax error",
code: "this is not valid lua",
wantErr: true,
},
{
name: "runtime error",
code: "error('test error')",
wantErr: true,
},
}
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Record initial stack size
initialStack := L.GetTop()
results, err := L.ExecuteString(tt.code)
if (err != nil) != tt.wantErr {
t.Errorf("ExecuteString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
if results != tt.wantResults {
t.Errorf("ExecuteString() returned %d results, want %d", results, tt.wantResults)
}
if tt.checkResults != nil {
if err := tt.checkResults(L); err != nil {
t.Errorf("Result check failed: %v", err)
}
}
// Verify stack size matches expected results
if got := L.GetTop() - initialStack; got != tt.wantResults {
t.Errorf("Stack size grew by %d, want %d", got, tt.wantResults)
}
}
}
}
func TestDoString(t *testing.T) {
tests := []struct {
name string
code string
wantErr bool
}{
{"simple addition", "return 1 + 1", false},
{"set global", "test = 42", false},
{"syntax error", "this is not valid lua", true},
{"runtime error", "error('test error')", true},
}
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
initialStack := L.GetTop()
err := L.DoString(tt.code)
if (err != nil) != tt.wantErr {
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
}
// Verify stack is unchanged
if finalStack := L.GetTop(); finalStack != initialStack {
t.Errorf("Stack size changed from %d to %d", initialStack, finalStack)
}
}
}
func TestPushAndGetValues(t *testing.T) {
values := []struct {
name string
push func(*State)
check func(*State) error
}{
{
name: "string",
push: func(L *State) { L.PushString("hello") },
check: func(L *State) error {
if got := L.ToString(-1); got != "hello" {
return fmt.Errorf("got %q, want %q", got, "hello")
}
return nil
},
},
{
name: "number",
push: func(L *State) { L.PushNumber(42.5) },
check: func(L *State) error {
if got := L.ToNumber(-1); got != 42.5 {
return fmt.Errorf("got %f, want %f", got, 42.5)
}
return nil
},
},
{
name: "boolean",
push: func(L *State) { L.PushBoolean(true) },
check: func(L *State) error {
if got := L.ToBoolean(-1); !got {
return fmt.Errorf("got %v, want true", got)
}
return nil
},
},
{
name: "nil",
push: func(L *State) { L.PushNil() },
check: func(L *State) error {
if typ := L.GetType(-1); typ != TypeNil {
return fmt.Errorf("got type %v, want TypeNil", typ)
}
return nil
},
},
}
for _, v := range values {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
v.push(L)
if err := v.check(L); err != nil {
t.Error(err)
}
}
}
func TestStackManipulation(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Push values
values := []string{"first", "second", "third"}
for _, v := range values {
L.PushString(v)
}
// Check size
if top := L.GetTop(); top != len(values) {
t.Errorf("stack size = %d, want %d", top, len(values))
}
// Pop one value
L.Pop(1)
// Check new top
if str := L.ToString(-1); str != "second" {
t.Errorf("top value = %q, want 'second'", str)
}
// Check new size
if top := L.GetTop(); top != len(values)-1 {
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
}
}
func TestGlobals(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Test via Lua
if err := L.DoString(`globalVar = "test"`); err != nil {
t.Fatalf("DoString error: %v", err)
}
// Get the global
L.GetGlobal("globalVar")
if str := L.ToString(-1); str != "test" {
t.Errorf("global value = %q, want 'test'", str)
}
L.Pop(1)
// Set and get via API
L.PushNumber(42)
L.SetGlobal("testNum")
L.GetGlobal("testNum")
if num := L.ToNumber(-1); num != 42 {
t.Errorf("global number = %f, want 42", num)
}
}
func TestCall(t *testing.T) {
tests := []struct {
funcName string // Add explicit function name field
setup string
args []interface{}
nresults int
checkStack func(*State) error
wantErr bool
}{
{
funcName: "add",
setup: "function add(a, b) return a + b end",
args: []interface{}{float64(40), float64(2)},
nresults: 1,
checkStack: func(L *State) error {
if n := L.ToNumber(-1); n != 42 {
return fmt.Errorf("got %v, want 42", n)
}
return nil
},
},
{
funcName: "multi",
setup: "function multi() return 1, 'test', true end",
args: []interface{}{},
nresults: 3,
checkStack: func(L *State) error {
if n := L.ToNumber(-3); n != 1 {
return fmt.Errorf("first result: got %v, want 1", n)
}
if s := L.ToString(-2); s != "test" {
return fmt.Errorf("second result: got %v, want 'test'", s)
}
if b := L.ToBoolean(-1); !b {
return fmt.Errorf("third result: got %v, want true", b)
}
return nil
},
},
{
funcName: "err",
setup: "function err() error('test error') end",
args: []interface{}{},
nresults: 0,
wantErr: true,
},
}
for _, tt := range tests {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Setup function
if err := L.DoString(tt.setup); err != nil {
t.Fatalf("Setup failed: %v", err)
}
// Get function
L.GetGlobal(tt.funcName)
if !L.IsFunction(-1) {
t.Fatal("Failed to get function")
}
// Push arguments
for _, arg := range tt.args {
if err := L.PushValue(arg); err != nil {
t.Fatalf("Failed to push argument: %v", err)
}
}
// Call function
err := L.Call(len(tt.args), tt.nresults)
if (err != nil) != tt.wantErr {
t.Errorf("Call() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil && tt.checkStack != nil {
if err := tt.checkStack(L); err != nil {
t.Errorf("Stack check failed: %v", err)
}
}
}
}
func TestDoFile(t *testing.T) {
L := New()
defer L.Close()
// Create test file
content := []byte(`
function add(a, b)
return a + b
end
result = add(40, 2)
`)
tmpDir := t.TempDir()
filename := filepath.Join(tmpDir, "test.lua")
if err := os.WriteFile(filename, content, 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
if err := L.DoFile(filename); err != nil {
t.Fatalf("DoFile failed: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("Expected result=42, got %v", result)
}
}
func TestRequireAndPackagePath(t *testing.T) {
L := New()
defer L.Close()
tmpDir := t.TempDir()
// Create module file
moduleContent := []byte(`
local M = {}
function M.multiply(a, b)
return a * b
end
return M
`)
if err := os.WriteFile(filepath.Join(tmpDir, "mathmod.lua"), moduleContent, 0644); err != nil {
t.Fatalf("Failed to create module file: %v", err)
}
// Add module path and test require
if err := L.AddPackagePath(filepath.Join(tmpDir, "?.lua")); err != nil {
t.Fatalf("AddPackagePath failed: %v", err)
}
if err := L.DoString(`
local math = require("mathmod")
result = math.multiply(6, 7)
`); err != nil {
t.Fatalf("Failed to require module: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("Expected result=42, got %v", result)
}
}
func TestSetPackagePath(t *testing.T) {
L := New()
defer L.Close()
customPath := "./custom/?.lua"
if err := L.SetPackagePath(customPath); err != nil {
t.Fatalf("SetPackagePath failed: %v", err)
}
L.GetGlobal("package")
L.GetField(-1, "path")
if path := L.ToString(-1); path != customPath {
t.Errorf("Expected package.path=%q, got %q", customPath, path)
}
// Test that the old path is completely replaced
initialPath := L.ToString(-1)
anotherPath := "./another/?.lua"
if err := L.SetPackagePath(anotherPath); err != nil {
t.Fatalf("Second SetPackagePath failed: %v", err)
}
L.GetGlobal("package")
L.GetField(-1, "path")
if path := L.ToString(-1); path != anotherPath {
t.Errorf("Expected package.path=%q, got %q", anotherPath, path)
}
if path := L.ToString(-1); path == initialPath {
t.Error("SetPackagePath did not replace the old path")
}
}
func TestStackDebug(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
t.Log("Testing LoadString:")
initialTop := L.GetTop()
t.Logf("Initial stack size: %d", initialTop)
err := L.LoadString("return 42")
if err != nil {
t.Errorf("LoadString failed: %v", err)
}
afterLoad := L.GetTop()
t.Logf("Stack size after load: %d", afterLoad)
t.Logf("Type of top element: %s", L.GetType(-1))
if L.IsFunction(-1) {
t.Log("Top element is a function")
} else {
t.Log("Top element is NOT a function")
}
// Clean up after LoadString test
L.SetTop(0)
t.Log("\nTesting ExecuteString:")
if err := L.DoString("function test() return 1, 'hello', true end"); err != nil {
t.Errorf("DoString failed: %v", err)
}
beforeExec := L.GetTop()
t.Logf("Stack size before execute: %d", beforeExec)
nresults, err := L.ExecuteString("return test()")
if err != nil {
t.Errorf("ExecuteString failed: %v", err)
}
afterExec := L.GetTop()
t.Logf("Stack size after execute: %d", afterExec)
t.Logf("Reported number of results: %d", nresults)
// Print each stack element
for i := 1; i <= afterExec; i++ {
t.Logf("Stack[-%d] type: %s", i, L.GetType(-i))
}
if afterExec != nresults {
t.Errorf("Stack size (%d) doesn't match number of results (%d)", afterExec, nresults)
}
}
func TestTemplateRendering(t *testing.T) {
L := New()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
// Create a simple render.template function
renderFunc := func(s *State) int {
// Template will be at index 1, data at index 2
data, err := s.ToTable(2)
if err != nil {
s.PushString(fmt.Sprintf("failed to get data table: %v", err))
return -1
}
// Push data back as global for template access
if err := s.PushTable(data); err != nil {
s.PushString(fmt.Sprintf("failed to push data table: %v", err))
return -1
}
s.SetGlobal("data")
// Template processing code
luaCode := `
local result = {}
if data.user.logged_in then
table.insert(result, '<div class="profile">')
table.insert(result, string.format(' <h1>Welcome, %s!</h1>', tostring(data.user.name)))
table.insert(result, ' <div class="user-info">')
table.insert(result, string.format(' <p>Email: %s</p>', tostring(data.user.email)))
table.insert(result, string.format(' <p>Member since: %s</p>', tostring(data.user.joined_date)))
table.insert(result, ' </div>')
if data.user.is_admin then
table.insert(result, ' <div class="admin-panel">')
table.insert(result, ' <h2>Admin Controls</h2>')
table.insert(result, ' <!-- admin content -->')
table.insert(result, ' </div>')
end
table.insert(result, '</div>')
else
table.insert(result, '<div class="profile">')
table.insert(result, ' <h1>Please log in to view your profile</h1>')
table.insert(result, '</div>')
end
return table.concat(result, '\n')`
result, err := s.DoStringResult(luaCode)
if err != nil {
s.PushString(fmt.Sprintf("template execution failed: %v", err))
return -1
}
// Push the string result
if str, ok := result.(string); ok {
s.PushString(str)
return 1
}
s.PushString(fmt.Sprintf("expected string result, got %T", result))
return -1
}
// Create render table and add template function
L.NewTable()
if err := L.PushGoFunction(renderFunc); err != nil {
t.Fatalf("Failed to create render function: %v", err)
}
L.SetField(-2, "template")
L.SetGlobal("render")
// Test with logged in admin user
testCode := `
local data = {
user = {
logged_in = true,
name = "John Doe",
email = "john@example.com",
joined_date = "2024-02-09",
is_admin = true
}
}
return render.template("test.html", data)
`
result, err := L.DoStringResult(testCode)
if err != nil {
t.Fatalf("Failed to execute test: %v", err)
}
str, ok := result.(string)
if !ok {
t.Fatalf("Expected string result, got %T", result)
}
expectedResult := `<div class="profile">
<h1>Welcome, John Doe!</h1>
<div class="user-info">
<p>Email: john@example.com</p>
<p>Member since: 2024-02-09</p>
</div>
<div class="admin-panel">
<h2>Admin Controls</h2>
<!-- admin content -->
</div>
</div>`
if str != expectedResult {
t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str)
}
// Test with logged out user
testCode = `
local data = {
user = {
logged_in = false
}
}
return render.template("test.html", data)
`
result, err = L.DoStringResult(testCode)
if err != nil {
t.Fatalf("Failed to execute logged out test: %v", err)
}
str, ok = result.(string)
if !ok {
t.Fatalf("Expected string result, got %T", result)
}
expectedResult = `<div class="profile">
<h1>Please log in to view your profile</h1>
</div>`
if str != expectedResult {
t.Errorf("\nExpected:\n%s\n\nGot:\n%s", expectedResult, str)
}
}