Compare commits

..

26 Commits

Author SHA1 Message Date
6b9e2a0e20 op 1 2025-04-04 21:32:17 -05:00
e58f9a6028 add bench profiling 2025-04-04 21:32:17 -05:00
a2b4b1c927 add type utilities 2025-03-29 09:08:02 -05:00
44337fffe3 add flag for std libs 2025-03-29 08:32:45 -05:00
0756cabcaa remove sandbox docs 2025-03-28 20:31:29 -05:00
656ac1a703 remove sandbox 2025-03-28 20:27:51 -05:00
5774808064 update docs 2025-03-27 22:05:09 -05:00
875abee366 optimize sandbox 2025-03-27 21:58:56 -05:00
4ad87f81f3 move sandbox benches 2025-03-27 21:51:31 -05:00
9e5092acdb optimize wrapper 2025-03-27 21:46:23 -05:00
b83f77d7a6 optimize functions 2025-03-27 21:45:39 -05:00
29679349ef optimize table 2025-03-27 21:42:58 -05:00
fed0c2ad34 optimize bytecode 2025-03-27 21:41:06 -05:00
faab0a2d08 sandbox 2 2025-03-27 21:31:41 -05:00
f106dfd9ea sandbox 1 2025-03-27 18:25:09 -05:00
936e4ccdc2 remove unregistration change 2025-03-27 14:04:17 -05:00
075b45768f more robust unregistergofunction 2025-03-27 13:56:02 -05:00
13686b3e66 Use pkg-config on *nix 2025-03-12 11:50:37 -05:00
98ca857d73 interface{} to any 2025-03-07 07:25:34 -06:00
143b9333c6 Wrapper rewrite 2025-02-26 07:00:01 -06:00
865ac8859f added table pooling and some micro-ops 2025-02-13 07:11:13 -06:00
4dc266201f BIG changes; no "safe" mode, function updates, etc 2025-02-12 19:17:11 -06:00
7c79616cac Small update to DoString for stack discipline 2025-02-08 09:57:59 -06:00
146b0a51db Add Call and IsNil 2025-02-03 19:09:30 -06:00
c74ad4bbc9 Add IsNumber 2025-02-03 18:55:34 -06:00
229884ba97 Add IsString and Next helpers 2025-02-03 18:52:26 -06:00
27 changed files with 3420 additions and 1400 deletions

2
.gitignore vendored
View File

@ -21,3 +21,5 @@
go.work
.idea
bench/profile_results

267
DOCS.md
View File

@ -1,19 +1,11 @@
# LuaJIT Go Wrapper API Documentation
# API Documentation
## State Management
### NewSafe() *State
Creates a new Lua state with stack safety enabled.
```go
L := luajit.NewSafe()
defer L.Close()
defer L.Cleanup()
```
### New() *State
Creates a new Lua state without stack safety checks.
New creates a new Lua state with optional standard libraries; true if not specified
```go
L := luajit.New()
L := luajit.New() // or luajit.New(false)
defer L.Close()
defer L.Cleanup()
```
@ -38,6 +30,18 @@ Returns the index of the top element in the stack.
top := L.GetTop() // 0 for empty stack
```
### SetTop(index int)
Sets the stack top to a specific index.
```go
L.SetTop(2) // Truncate stack to 2 elements
```
### PushCopy(index int)
Pushes a copy of the value at the given index onto the stack.
```go
L.PushCopy(-1) // Duplicate the top element
```
### Pop(n int)
Removes n elements from the stack.
```go
@ -52,6 +56,9 @@ L.Remove(-1) // Remove top element
L.Remove(1) // Remove first element
```
### absIndex(index int) int
Internal function that converts a possibly negative index to its absolute position.
### checkStack(n int) error
Internal function that ensures there's enough space for n new elements.
```go
@ -70,9 +77,12 @@ if L.GetType(-1) == TypeString {
}
```
### IsFunction(index int) bool
### IsNil(index int) bool
### IsBoolean(index int) bool
### IsNumber(index int) bool
### IsString(index int) bool
### IsTable(index int) bool
### IsUserData(index int) bool
### IsFunction(index int) bool
Type checking functions for specific Lua types.
```go
if L.IsTable(-1) {
@ -100,7 +110,7 @@ Converts the value to a boolean.
bool := L.ToBoolean(-1)
```
### ToValue(index int) (interface{}, error)
### ToValue(index int) (any, error)
Converts any Lua value to its Go equivalent.
```go
val, err := L.ToValue(-1)
@ -109,7 +119,7 @@ if err != nil {
}
```
### ToTable(index int) (map[string]interface{}, error)
### ToTable(index int) (map[string]any, error)
Converts a Lua table to a Go map.
```go
table, err := L.ToTable(-1)
@ -118,6 +128,12 @@ if err != nil {
}
```
### GetTableLength(index int) int
Returns the length of a table at the given index.
```go
length := L.GetTableLength(-1)
```
## Value Pushing
### PushNil()
@ -132,32 +148,98 @@ L.PushBoolean(true)
L.PushNil()
```
### PushValue(v interface{}) error
### PushValue(v any) error
Pushes any Go value onto the stack.
```go
err := L.PushValue(myValue)
```
### PushTable(table map[string]interface{}) error
### PushTable(table map[string]any) error
Pushes a Go map as a Lua table.
```go
data := map[string]interface{}{
data := map[string]any{
"key": "value",
"numbers": []float64{1, 2, 3},
}
err := L.PushTable(data)
```
## Function Registration
## Table Operations
### RegisterGoFunction(name string, fn GoFunction) error
Registers a Go function that can be called from Lua.
### CreateTable(narr, nrec int)
Creates a new table with pre-allocated space.
```go
adder := func(s *State) int {
L.CreateTable(10, 5) // Space for 10 array elements, 5 records
```
### NewTable()
Creates a new empty table and pushes it onto the stack.
```go
L.NewTable()
```
### GetTable(index int)
Gets a table field (t[k]) where t is at the given index and k is at the top of the stack.
```go
L.PushString("key")
L.GetTable(-2) // Gets table["key"]
```
### SetTable(index int)
Sets a table field (t[k] = v) where t is at the given index, k is at -2, and v is at -1.
```go
L.PushString("key")
L.PushString("value")
L.SetTable(-3) // table["key"] = "value"
```
### GetField(index int, key string)
Gets a table field t[k] and pushes it onto the stack.
```go
L.GetField(-1, "name") // gets table.name
```
### SetField(index int, key string)
Sets a table field t[k] = v, where v is the value at the top of the stack.
```go
L.PushString("value")
L.SetField(-2, "key") // table.key = "value"
```
### Next(index int) bool
Pops a key from the stack and pushes the next key-value pair from the table.
```go
L.PushNil() // Start iteration
for L.Next(-2) {
// Stack now has key at -2 and value at -1
key := L.ToString(-2)
value := L.ToString(-1)
L.Pop(1) // Remove value, keep key for next iteration
}
```
## Function Registration and Calling
### GoFunction
Type definition for Go functions callable from Lua.
```go
type GoFunction func(*State) int
```
### PushGoFunction(fn GoFunction) error
Wraps a Go function and pushes it onto the Lua stack.
```go
adder := func(s *luajit.State) int {
sum := s.ToNumber(1) + s.ToNumber(2)
s.PushNumber(sum)
return 1
}
err := L.PushGoFunction(adder)
```
### RegisterGoFunction(name string, fn GoFunction) error
Registers a Go function as a global Lua function.
```go
err := L.RegisterGoFunction("add", adder)
```
@ -167,23 +249,45 @@ Removes a previously registered function.
L.UnregisterGoFunction("add")
```
## Package Management
### SetPackagePath(path string) error
Sets the Lua package.path variable.
### Call(nargs, nresults int) error
Calls a function with the given number of arguments and results.
```go
err := L.SetPackagePath("./?.lua;/usr/local/share/lua/5.1/?.lua")
L.GetGlobal("myfunction")
L.PushNumber(1)
L.PushNumber(2)
err := L.Call(2, 1) // Call with 2 args, expect 1 result
```
### AddPackagePath(path string) error
Adds a path to the existing package.path.
## Global Operations
### GetGlobal(name string)
Gets a global variable and pushes it onto the stack.
```go
err := L.AddPackagePath("./modules/?.lua")
L.GetGlobal("myGlobal")
```
### SetGlobal(name string)
Sets a global variable from the value at the top of the stack.
```go
L.PushNumber(42)
L.SetGlobal("answer") // answer = 42
```
## Code Execution
### DoString(str string) error
### LoadString(code string) error
Loads a Lua chunk from a string without executing it.
```go
err := L.LoadString("return 42")
```
### LoadFile(filename string) error
Loads a Lua chunk from a file without executing it.
```go
err := L.LoadFile("script.lua")
```
### DoString(code string) error
Executes a string of Lua code.
```go
err := L.DoString(`
@ -199,32 +303,76 @@ Executes a Lua file.
err := L.DoFile("script.lua")
```
## Table Operations
### GetField(index int, key string)
Gets a field from a table at the given index.
### Execute(code string) (int, error)
Executes a Lua string and returns the number of results left on the stack.
```go
L.GetField(-1, "name") // gets table.name
nresults, err := L.Execute("return 1, 2, 3")
// nresults would be 3
```
### SetField(index int, key string)
Sets a field in a table at the given index.
### ExecuteWithResult(code string) (any, error)
Executes a Lua string and returns the first result.
```go
L.PushString("value")
L.SetField(-2, "key") // table.key = "value"
result, err := L.ExecuteWithResult("return 'hello'")
// result would be "hello"
```
### GetGlobal(name string)
Gets a global variable.
## Bytecode Operations
### CompileBytecode(code string, name string) ([]byte, error)
Compiles a Lua chunk to bytecode without executing it.
```go
L.GetGlobal("myGlobal")
bytecode, err := L.CompileBytecode("return 42", "test")
```
### SetGlobal(name string)
Sets a global variable from the value at the top of the stack.
### LoadBytecode(bytecode []byte, name string) error
Loads precompiled bytecode without executing it.
```go
L.PushNumber(42)
L.SetGlobal("answer") // answer = 42
err := L.LoadBytecode(bytecode, "test")
```
### RunBytecode() error
Executes previously loaded bytecode with 0 results.
```go
err := L.RunBytecode()
```
### RunBytecodeWithResults(nresults int) error
Executes bytecode and keeps nresults on the stack.
```go
err := L.RunBytecodeWithResults(1)
```
### LoadAndRunBytecode(bytecode []byte, name string) error
Loads and executes bytecode.
```go
err := L.LoadAndRunBytecode(bytecode, "test")
```
### LoadAndRunBytecodeWithResults(bytecode []byte, name string, nresults int) error
Loads and executes bytecode, preserving results.
```go
err := L.LoadAndRunBytecodeWithResults(bytecode, "test", 1)
```
### CompileAndRun(code string, name string) error
Compiles and immediately executes Lua code.
```go
err := L.CompileAndRun("answer = 42", "test")
```
## Package Path Operations
### SetPackagePath(path string) error
Sets the Lua package.path.
```go
err := L.SetPackagePath("./?.lua;/usr/local/share/lua/5.1/?.lua")
```
### AddPackagePath(path string) error
Adds a path to package.path.
```go
err := L.AddPackagePath("./modules/?.lua")
```
## Error Handling
@ -238,13 +386,21 @@ type LuaError struct {
}
```
### getStackTrace() string
### GetStackTrace() string
Gets the current Lua stack trace.
```go
trace := L.getStackTrace()
trace := L.GetStackTrace()
fmt.Println(trace)
```
### safeCall(f func() C.int) error
Internal function that wraps a potentially dangerous C call with stack checking.
```go
err := s.safeCall(func() C.int {
return C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
})
```
## Thread Safety Notes
- The function registry is thread-safe
@ -256,15 +412,20 @@ fmt.Println(trace)
Always pair state creation with cleanup:
```go
L := luajit.NewSafe()
L := luajit.New()
defer L.Close()
defer L.Cleanup()
```
Stack management in unsafe mode requires manual attention:
Stack management requires manual attention:
```go
L := luajit.New()
L.PushString("hello")
// ... use the string
L.Pop(1) // Clean up when done
```
```
Sandbox management:
```go
sandbox := luajit.NewSandbox()
defer sandbox.Close()
```

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2025 Sky
Copyright (c) 2025 Sharkk, Skylear Johnson
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

159
README.md
View File

@ -1,6 +1,6 @@
# LuaJIT Go Wrapper
Hey there! This is a Go wrapper for LuaJIT that makes it easy to embed Lua in your Go applications. We've focused on making it both safe and fast, while keeping the API clean and intuitive.
This is a Go wrapper for LuaJIT that makes it easy to embed Lua in your Go applications. We've focused on making it both performant and developer-friendly, with an API that feels natural to use.
## What's This For?
@ -22,51 +22,13 @@ You'll need LuaJIT's development files, but don't worry - we include libraries f
Here's the simplest thing you can do:
```go
L := luajit.NewSafe()
L := luajit.New() // pass false to not load standard libs
defer L.Close()
defer L.Cleanup()
err := L.DoString(`print("Hey from Lua!")`)
```
## Stack Safety: Choose Your Adventure
One of the key decisions you'll make is whether to use stack-safe mode. Here's what that means:
### Stack-Safe Mode (NewSafe())
```go
L := luajit.NewSafe()
```
Think of this as driving with guardrails. It's perfect when:
- You're new to Lua or embedding scripting languages
- You're writing a server or long-running application
- You want to handle untrusted Lua code
- You'd rather have slightly slower code than mysterious crashes
The safe mode will:
- Prevent stack overflows
- Check types more thoroughly
- Clean up after messy Lua code
- Give you better error messages
### Non-Stack-Safe Mode (New())
```go
L := luajit.New()
```
This is like taking off the training wheels. Use it when:
- You know exactly how your Lua code behaves
- You've profiled your application and need more speed
- You're doing lots of rapid, simple Lua calls
- You're writing performance-critical code
The unsafe mode:
- Skips most safety checks
- Runs noticeably faster
- Gives you direct control over the stack
- Can crash spectacularly if you make a mistake
Most applications should start with stack-safe mode and only switch to unsafe mode if profiling shows it's necessary.
## Working with Bytecode
Need even more performance? You can compile your Lua code to bytecode and reuse it:
@ -82,28 +44,30 @@ bytecode, err := L.CompileBytecode(`
// Execute many times
for i := 0; i < 1000; i++ {
err := L.LoadBytecode(bytecode, "calc")
err := L.LoadAndRunBytecode(bytecode, "calc")
}
// Or do both at once
err := L.CompileAndLoad(`return "hello"`, "greeting")
err := L.CompileAndRun(`return "hello"`, "greeting")
```
### When to Use Bytecode
Bytecode execution is consistently faster than direct execution:
- Simple operations: 20-60% faster
- String operations: Up to 60% speedup
- Loop-heavy code: 10-15% improvement
- Table operations: 10-15% faster
Some benchmark results on a typical system:
```
Operation Direct Exec Bytecode Exec
----------------------------------------
Simple Math 1.5M ops/sec 2.4M ops/sec
String Ops 370K ops/sec 600K ops/sec
Table Creation 127K ops/sec 146K ops/sec
Benchmark Ops/sec Comparison
----------------------------------------------------------------------------
BenchmarkSimpleDoString 2,561,012 Base
BenchmarkSimplePrecompiledBytecode 3,828,841 +49.5% faster
BenchmarkFunctionCallDoString 2,021,098 Base
BenchmarkFunctionCallPrecompiled 3,482,074 +72.3% faster
BenchmarkLoopDoString 188,119 Base
BenchmarkLoopPrecompiled 211,081 +12.2% faster
BenchmarkTableOperationsDoString 84,086 Base
BenchmarkTableOperationsPrecompiled 93,655 +11.4% faster
BenchmarkComplexScript 33,133 Base
BenchmarkComplexScriptPrecompiled 41,044 +23.9% faster
```
Use bytecode when you:
@ -114,7 +78,7 @@ Use bytecode when you:
## Registering Go Functions
Want to call Go code from Lua? Easy:
Want to call Go code from Lua? It's straightforward:
```go
// This function adds two numbers and returns the result
adder := func(s *luajit.State) int {
@ -137,7 +101,7 @@ Lua tables are pretty powerful - they're like a mix of Go's maps and slices. We
```go
// Go → Lua
stuff := map[string]interface{}{
stuff := map[string]any{
"name": "Arthur Dent",
"age": 30,
"items": []float64{1, 2, 3},
@ -151,25 +115,96 @@ result, err := L.ToTable(-1)
## Error Handling
We try to give you useful errors instead of mysterious panics:
We provide useful errors instead of mysterious panics:
```go
if err := L.DoString("this isn't valid Lua!"); err != nil {
if luaErr, ok := err.(*luajit.LuaError); ok {
fmt.Printf("Oops: %s\n", luaErr.Message)
fmt.Printf("Error: %s\n", luaErr.Message)
}
}
```
## A Few Tips
## Memory Management
- Always use those `defer L.Close()` and `defer L.Cleanup()` calls - they prevent memory leaks
The wrapper uses a custom table pooling system to reduce GC pressure when handling many tables:
```go
// Tables are pooled and reused internally for better performance
for i := 0; i < 1000; i++ {
L.GetGlobal("table")
table, _ := L.ToTable(-1)
// Use table...
L.Pop(1)
// Table is automatically returned to pool
}
```
The sandbox also manages its environment efficiently:
```go
// Environment objects are pooled and reused
for i := 0; i < 1000; i++ {
result, _ := sandbox.Run("return i + 1")
}
```
## Best Practices
### State Management
- Always use `defer L.Close()` and `defer L.Cleanup()` to prevent memory leaks
- Each Lua state should stick to one goroutine
- For concurrent stuff, create multiple states
- For concurrent operations, create multiple states
- You can share functions between states safely
- Keep an eye on your stack in unsafe mode - it won't clean up after itself
- Start with stack-safe mode and measure before optimizing
- Keep an eye on your stack management - pop as many items as you push
### Bytecode Optimization
- Use bytecode for frequently executed code paths
- Consider compiling critical Lua code to bytecode at startup
- For small scripts (< 1024 bytes), direct execution might be faster
## Advanced Features
### Bytecode Serialization
You can serialize bytecode for distribution or caching:
```go
// Compile once
bytecode, _ := L.CompileBytecode(complexScript, "module")
// Save to file
ioutil.WriteFile("module.luac", bytecode, 0644)
// Later, load from file
bytecode, _ := ioutil.ReadFile("module.luac")
L.LoadAndRunBytecode(bytecode, "module")
```
### Closures and Upvalues
Bytecode properly preserves closures and upvalues:
```go
code := `
local counter = 0
return function()
counter = counter + 1
return counter
end
`
bytecode, _ := L.CompileBytecode(code, "counter")
L.LoadAndRunBytecodeWithResults(bytecode, "counter", 1)
L.SetGlobal("increment")
// Later...
L.GetGlobal("increment")
L.Call(0, 1) // Returns 1
L.Pop(1)
L.GetGlobal("increment")
L.Call(0, 1) // Returns 2
```
## Need Help?
@ -177,4 +212,4 @@ Check out the tests in the repository - they're full of examples. If you're stuc
## License
MIT Licensed - do whatever you want with it!
MIT Licensed - do whatever you want with it!

171
bench/bench_profile.go Normal file
View File

@ -0,0 +1,171 @@
package luajit_bench
import (
"flag"
"fmt"
"os"
"runtime"
"runtime/pprof"
"testing"
)
// Profiling flags
var (
cpuProfile = flag.String("cpuprofile", "", "write cpu profile to `file`")
memProfile = flag.String("memprofile", "", "write memory profile to `file`")
memProfileGC = flag.Bool("memprofilegc", false, "force GC before writing memory profile")
blockProfile = flag.String("blockprofile", "", "write block profile to `file`")
mutexProfile = flag.String("mutexprofile", "", "write mutex profile to `file`")
)
// setupTestMain configures profiling for benchmarks
func setupTestMain() {
// Make sure the flags are parsed
if !flag.Parsed() {
flag.Parse()
}
// CPU profiling
if *cpuProfile != "" {
f, err := os.Create(*cpuProfile)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create CPU profile: %v\n", err)
os.Exit(1)
}
if err := pprof.StartCPUProfile(f); err != nil {
fmt.Fprintf(os.Stderr, "Failed to start CPU profile: %v\n", err)
os.Exit(1)
}
fmt.Println("CPU profiling enabled")
}
// Block profiling (goroutine blocking)
if *blockProfile != "" {
runtime.SetBlockProfileRate(1)
fmt.Println("Block profiling enabled")
}
// Mutex profiling (lock contention)
if *mutexProfile != "" {
runtime.SetMutexProfileFraction(1)
fmt.Println("Mutex profiling enabled")
}
}
// teardownTestMain completes profiling and writes output files
func teardownTestMain() {
// Stop CPU profile
if *cpuProfile != "" {
pprof.StopCPUProfile()
fmt.Println("CPU profile written to", *cpuProfile)
}
// Write memory profile
if *memProfile != "" {
f, err := os.Create(*memProfile)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create memory profile: %v\n", err)
os.Exit(1)
}
defer f.Close()
// Force garbage collection before writing memory profile if requested
if *memProfileGC {
runtime.GC()
}
if err := pprof.WriteHeapProfile(f); err != nil {
fmt.Fprintf(os.Stderr, "Failed to write memory profile: %v\n", err)
os.Exit(1)
}
fmt.Println("Memory profile written to", *memProfile)
}
// Write block profile
if *blockProfile != "" {
f, err := os.Create(*blockProfile)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create block profile: %v\n", err)
os.Exit(1)
}
defer f.Close()
if err := pprof.Lookup("block").WriteTo(f, 0); err != nil {
fmt.Fprintf(os.Stderr, "Failed to write block profile: %v\n", err)
os.Exit(1)
}
fmt.Println("Block profile written to", *blockProfile)
}
// Write mutex profile
if *mutexProfile != "" {
f, err := os.Create(*mutexProfile)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create mutex profile: %v\n", err)
os.Exit(1)
}
defer f.Close()
if err := pprof.Lookup("mutex").WriteTo(f, 0); err != nil {
fmt.Fprintf(os.Stderr, "Failed to write mutex profile: %v\n", err)
os.Exit(1)
}
fmt.Println("Mutex profile written to", *mutexProfile)
}
}
// TestMain is the entry point for all tests in this package
func TestMain(m *testing.M) {
setupTestMain()
code := m.Run()
teardownTestMain()
os.Exit(code)
}
// MemStats captures a snapshot of memory statistics
type MemStats struct {
Alloc uint64
TotalAlloc uint64
Sys uint64
Mallocs uint64
Frees uint64
HeapAlloc uint64
}
// CaptureMemStats returns current memory statistics
func CaptureMemStats() MemStats {
var m runtime.MemStats
runtime.ReadMemStats(&m)
return MemStats{
Alloc: m.Alloc,
TotalAlloc: m.TotalAlloc,
Sys: m.Sys,
Mallocs: m.Mallocs,
Frees: m.Frees,
HeapAlloc: m.HeapAlloc,
}
}
// TrackMemoryUsage runs fn and reports memory usage before and after
func TrackMemoryUsage(b *testing.B, name string, fn func()) {
b.Helper()
// Force GC before measurement
runtime.GC()
// Capture memory stats before
before := CaptureMemStats()
// Run the function
fn()
// Force GC after measurement to get accurate stats
runtime.GC()
// Capture memory stats after
after := CaptureMemStats()
// Report stats
b.ReportMetric(float64(after.Mallocs-before.Mallocs), name+"-mallocs")
b.ReportMetric(float64(after.TotalAlloc-before.TotalAlloc)/float64(b.N), name+"-bytes/op")
}

472
bench/bench_test.go Normal file
View File

@ -0,0 +1,472 @@
package luajit_bench
import (
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// BenchmarkSimpleDoString benchmarks direct execution of a simple expression
func BenchmarkSimpleDoString(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := "local x = 1 + 1"
b.ResetTimer()
TrackMemoryUsage(b, "dostring", func() {
for i := 0; i < b.N; i++ {
if err := state.DoString(code); err != nil {
b.Fatalf("DoString failed: %v", err)
}
}
})
}
// BenchmarkSimpleCompileAndRun benchmarks compile and run of a simple expression
func BenchmarkSimpleCompileAndRun(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := "local x = 1 + 1"
b.ResetTimer()
TrackMemoryUsage(b, "compile-run", func() {
for i := 0; i < b.N; i++ {
if err := state.CompileAndRun(code, "simple"); err != nil {
b.Fatalf("CompileAndRun failed: %v", err)
}
}
})
}
// BenchmarkSimpleCompileLoadRun benchmarks compile, load, and run of a simple expression
func BenchmarkSimpleCompileLoadRun(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := "local x = 1 + 1"
b.ResetTimer()
TrackMemoryUsage(b, "compile-load-run", func() {
for i := 0; i < b.N; i++ {
bytecode, err := state.CompileBytecode(code, "simple")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
if err := state.LoadAndRunBytecode(bytecode, "simple"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
})
}
// BenchmarkSimplePrecompiledBytecode benchmarks running precompiled bytecode
func BenchmarkSimplePrecompiledBytecode(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := "local x = 1 + 1"
bytecode, err := state.CompileBytecode(code, "simple")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
TrackMemoryUsage(b, "precompiled", func() {
for i := 0; i < b.N; i++ {
if err := state.LoadAndRunBytecode(bytecode, "simple"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
})
}
// BenchmarkFunctionCallDoString benchmarks direct execution of a function call
func BenchmarkFunctionCallDoString(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
// Setup function
setupCode := `
function add(a, b)
return a + b
end
`
if err := state.DoString(setupCode); err != nil {
b.Fatalf("Failed to set up function: %v", err)
}
code := "local result = add(10, 20)"
b.ResetTimer()
TrackMemoryUsage(b, "func-dostring", func() {
for i := 0; i < b.N; i++ {
if err := state.DoString(code); err != nil {
b.Fatalf("DoString failed: %v", err)
}
}
})
}
// BenchmarkFunctionCallPrecompiled benchmarks precompiled function call
func BenchmarkFunctionCallPrecompiled(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
// Setup function
setupCode := `
function add(a, b)
return a + b
end
`
if err := state.DoString(setupCode); err != nil {
b.Fatalf("Failed to set up function: %v", err)
}
code := "local result = add(10, 20)"
bytecode, err := state.CompileBytecode(code, "call")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
TrackMemoryUsage(b, "func-precompiled", func() {
for i := 0; i < b.N; i++ {
if err := state.LoadAndRunBytecode(bytecode, "call"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
})
}
// BenchmarkLoopDoString benchmarks direct execution of a loop
func BenchmarkLoopDoString(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
local sum = 0
for i = 1, 1000 do
sum = sum + i
end
`
b.ResetTimer()
TrackMemoryUsage(b, "loop-dostring", func() {
for i := 0; i < b.N; i++ {
if err := state.DoString(code); err != nil {
b.Fatalf("DoString failed: %v", err)
}
}
})
}
// BenchmarkLoopPrecompiled benchmarks precompiled loop execution
func BenchmarkLoopPrecompiled(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
local sum = 0
for i = 1, 1000 do
sum = sum + i
end
`
bytecode, err := state.CompileBytecode(code, "loop")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
TrackMemoryUsage(b, "loop-precompiled", func() {
for i := 0; i < b.N; i++ {
if err := state.LoadAndRunBytecode(bytecode, "loop"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
})
}
// BenchmarkTableOperationsDoString benchmarks direct execution of table operations
func BenchmarkTableOperationsDoString(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
local t = {}
for i = 1, 100 do
t[i] = i * 2
end
local sum = 0
for i, v in ipairs(t) do
sum = sum + v
end
`
b.ResetTimer()
TrackMemoryUsage(b, "table-dostring", func() {
for i := 0; i < b.N; i++ {
if err := state.DoString(code); err != nil {
b.Fatalf("DoString failed: %v", err)
}
}
})
}
// BenchmarkTableOperationsPrecompiled benchmarks precompiled table operations
func BenchmarkTableOperationsPrecompiled(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
local t = {}
for i = 1, 100 do
t[i] = i * 2
end
local sum = 0
for i, v in ipairs(t) do
sum = sum + v
end
`
bytecode, err := state.CompileBytecode(code, "table")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
TrackMemoryUsage(b, "table-precompiled", func() {
for i := 0; i < b.N; i++ {
if err := state.LoadAndRunBytecode(bytecode, "table"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
})
}
// BenchmarkGoFunctionCall benchmarks calling a Go function from Lua
func BenchmarkGoFunctionCall(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
// Register a simple Go function
add := func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1
}
if err := state.RegisterGoFunction("add", add); err != nil {
b.Fatalf("RegisterGoFunction failed: %v", err)
}
code := "local result = add(10, 20)"
bytecode, err := state.CompileBytecode(code, "gofunc")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
TrackMemoryUsage(b, "go-func-call", func() {
for i := 0; i < b.N; i++ {
if err := state.LoadAndRunBytecode(bytecode, "gofunc"); err != nil {
b.Fatalf("LoadAndRunBytecode failed: %v", err)
}
}
})
}
// BenchmarkComplexScript benchmarks a more complex script
func BenchmarkComplexScript(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
-- Define a simple class
local Class = {}
Class.__index = Class
function Class.new(x, y)
local self = setmetatable({}, Class)
self.x = x or 0
self.y = y or 0
return self
end
function Class:move(dx, dy)
self.x = self.x + dx
self.y = self.y + dy
return self
end
function Class:getPosition()
return self.x, self.y
end
-- Create instances and operate on them
local instances = {}
for i = 1, 50 do
instances[i] = Class.new(i, i*2)
end
local result = 0
for i, obj in ipairs(instances) do
obj:move(i, -i)
local x, y = obj:getPosition()
result = result + x + y
end
return result
`
b.ResetTimer()
TrackMemoryUsage(b, "complex-script", func() {
for i := 0; i < b.N; i++ {
if _, err := state.ExecuteWithResult(code); err != nil {
b.Fatalf("ExecuteWithResult failed: %v", err)
}
}
})
}
// BenchmarkComplexScriptPrecompiled benchmarks a precompiled complex script
func BenchmarkComplexScriptPrecompiled(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
code := `
-- Define a simple class
local Class = {}
Class.__index = Class
function Class.new(x, y)
local self = setmetatable({}, Class)
self.x = x or 0
self.y = y or 0
return self
end
function Class:move(dx, dy)
self.x = self.x + dx
self.y = self.y + dy
return self
end
function Class:getPosition()
return self.x, self.y
end
-- Create instances and operate on them
local instances = {}
for i = 1, 50 do
instances[i] = Class.new(i, i*2)
end
local result = 0
for i, obj in ipairs(instances) do
obj:move(i, -i)
local x, y = obj:getPosition()
result = result + x + y
end
return result
`
bytecode, err := state.CompileBytecode(code, "complex")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
TrackMemoryUsage(b, "complex-precompiled", func() {
for i := 0; i < b.N; i++ {
if err := state.LoadBytecode(bytecode, "complex"); err != nil {
b.Fatalf("LoadBytecode failed: %v", err)
}
if err := state.RunBytecodeWithResults(1); err != nil {
b.Fatalf("RunBytecodeWithResults failed: %v", err)
}
state.Pop(1) // Pop the result
}
})
}
// BenchmarkMultipleExecutions benchmarks executing the same bytecode multiple times
func BenchmarkMultipleExecutions(b *testing.B) {
state := luajit.New()
if state == nil {
b.Fatal("Failed to create Lua state")
}
defer state.Close()
// Setup a stateful environment
setupCode := `
counter = 0
function increment(amount)
counter = counter + (amount or 1)
return counter
end
`
if err := state.DoString(setupCode); err != nil {
b.Fatalf("Failed to set up environment: %v", err)
}
// Compile the function call
code := "return increment(5)"
bytecode, err := state.CompileBytecode(code, "increment")
if err != nil {
b.Fatalf("CompileBytecode failed: %v", err)
}
b.ResetTimer()
TrackMemoryUsage(b, "multiple-executions", func() {
for i := 0; i < b.N; i++ {
if err := state.LoadBytecode(bytecode, "increment"); err != nil {
b.Fatalf("LoadBytecode failed: %v", err)
}
if err := state.RunBytecodeWithResults(1); err != nil {
b.Fatalf("RunBytecodeWithResults failed: %v", err)
}
state.Pop(1) // Pop the result
}
})
}

138
bench/ezbench_test.go Normal file
View File

@ -0,0 +1,138 @@
package luajit_bench
import (
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
var benchCases = []struct {
name string
code string
}{
{
name: "SimpleAddition",
code: `return 1 + 1`,
},
{
name: "LoopSum",
code: `
local sum = 0
for i = 1, 1000 do
sum = sum + i
end
return sum
`,
},
{
name: "FunctionCall",
code: `
local result = 0
for i = 1, 100 do
result = result + i
end
return result
`,
},
{
name: "TableCreation",
code: `
local t = {}
for i = 1, 100 do
t[i] = i * 2
end
return t[50]
`,
},
{
name: "StringOperations",
code: `
local s = "hello"
for i = 1, 10 do
s = s .. " world"
end
return #s
`,
},
}
func BenchmarkLuaDirectExecution(b *testing.B) {
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
L := luajit.New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
// First verify we can execute the code
if err := L.DoString(bc.code); err != nil {
b.Fatalf("Failed to execute test code: %v", err)
}
b.ResetTimer()
TrackMemoryUsage(b, "direct-"+bc.name, func() {
for i := 0; i < b.N; i++ {
// Execute string and get results
nresults, err := L.Execute(bc.code)
if err != nil {
b.Fatalf("Failed to execute code: %v", err)
}
L.Pop(nresults) // Clean up any results
}
})
})
}
}
func BenchmarkLuaBytecodeExecution(b *testing.B) {
// First compile all bytecode
bytecodes := make(map[string][]byte)
for _, bc := range benchCases {
L := luajit.New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Cleanup()
bytecode, err := L.CompileBytecode(bc.code, bc.name)
if err != nil {
L.Close()
b.Fatalf("Error compiling bytecode for %s: %v", bc.name, err)
}
bytecodes[bc.name] = bytecode
L.Close()
}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
L := luajit.New()
if L == nil {
b.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
bytecode := bytecodes[bc.name]
// First verify we can execute the bytecode
if err := L.LoadAndRunBytecodeWithResults(bytecode, bc.name, 1); err != nil {
b.Fatalf("Failed to execute test bytecode: %v", err)
}
L.Pop(1) // Clean up the result
b.ResetTimer()
b.SetBytes(int64(len(bytecode))) // Track bytecode size in benchmarks
TrackMemoryUsage(b, "bytecode-"+bc.name, func() {
for i := 0; i < b.N; i++ {
if err := L.LoadAndRunBytecode(bytecode, bc.name); err != nil {
b.Fatalf("Error executing bytecode: %v", err)
}
}
})
})
}
}

78
bench/profile.sh Executable file
View File

@ -0,0 +1,78 @@
#!/bin/bash
# Easy script to run benchmarks with profiling enabled
# Usage: ./profile_benchmarks.sh [benchmark_pattern]
set -e
# Default values
BENCHMARK=${1:-"."}
OUTPUT_DIR="./profile_results"
CPU_PROFILE="$OUTPUT_DIR/cpu.prof"
MEM_PROFILE="$OUTPUT_DIR/mem.prof"
BLOCK_PROFILE="$OUTPUT_DIR/block.prof"
MUTEX_PROFILE="$OUTPUT_DIR/mutex.prof"
TRACE_FILE="$OUTPUT_DIR/trace.out"
HTML_OUTPUT="$OUTPUT_DIR/profile_report.html"
# Create output directory
mkdir -p "$OUTPUT_DIR"
echo "Running benchmarks with profiling enabled..."
# Run benchmarks with profiling flags
go test -bench="$BENCHMARK" -benchmem -cpuprofile="$CPU_PROFILE" -memprofile="$MEM_PROFILE" -blockprofile="$BLOCK_PROFILE" -mutexprofile="$MUTEX_PROFILE" -count=5 -timeout=30m
echo "Generating CPU profile analysis..."
go tool pprof -http=":1880" -output="$OUTPUT_DIR/cpu_graph.svg" "$CPU_PROFILE"
echo "Generating memory profile analysis..."
go tool pprof -http=":1880" -output="$OUTPUT_DIR/mem_graph.svg" "$MEM_PROFILE"
# Generate a simple HTML report
cat > "$HTML_OUTPUT" << EOF
<!DOCTYPE html>
<html>
<head>
<title>LuaJIT Benchmark Profiling Results</title>
<style>
body { font-family: Arial, sans-serif; margin: 20px; }
h1, h2 { color: #333; }
.profile { margin-bottom: 30px; }
img { max-width: 100%; border: 1px solid #ddd; }
</style>
</head>
<body>
<h1>LuaJIT Benchmark Profiling Results</h1>
<p>Generated on: $(date)</p>
<div class="profile">
<h2>CPU Profile</h2>
<img src="cpu_graph.svg" alt="CPU Profile Graph">
<p>Command to explore: <code>go tool pprof $CPU_PROFILE</code></p>
</div>
<div class="profile">
<h2>Memory Profile</h2>
<img src="mem_graph.svg" alt="Memory Profile Graph">
<p>Command to explore: <code>go tool pprof $MEM_PROFILE</code></p>
</div>
<div class="profile">
<h2>Tips for Profile Analysis</h2>
<ul>
<li>Use <code>go tool pprof -http=:8080 $CPU_PROFILE</code> for interactive web UI</li>
<li>Use <code>top10</code> in pprof to see the top 10 functions by CPU/memory usage</li>
<li>Use <code>list FunctionName</code> to see line-by-line stats for a specific function</li>
</ul>
</div>
</body>
</html>
EOF
echo "Profiling complete! Results available in $OUTPUT_DIR"
echo "View the HTML report at $HTML_OUTPUT"
echo ""
echo "For detailed interactive analysis, run:"
echo " go tool pprof -http=:1880 $CPU_PROFILE # For CPU profile"
echo " go tool pprof -http=:1880 $MEM_PROFILE # For memory profile"

View File

@ -1,148 +0,0 @@
package main
import (
"fmt"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
type benchCase struct {
name string
code string
}
var cases = []benchCase{
{
name: "Simple Addition",
code: `return 1 + 1`,
},
{
name: "Loop Sum",
code: `
local sum = 0
for i = 1, 1000 do
sum = sum + i
end
return sum
`,
},
{
name: "Function Call",
code: `
local result = 0
for i = 1, 100 do
result = result + i
end
return result
`,
},
{
name: "Table Creation",
code: `
local t = {}
for i = 1, 100 do
t[i] = i * 2
end
return t[50]
`,
},
{
name: "String Operations",
code: `
local s = "hello"
for i = 1, 10 do
s = s .. " world"
end
return #s
`,
},
}
func runBenchmark(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) {
start := time.Now()
deadline := start.Add(duration)
var ops int64
for time.Now().Before(deadline) {
if err := L.DoString(code); err != nil {
fmt.Printf("Error executing code: %v\n", err)
return 0, 0
}
L.Pop(1)
ops++
}
return time.Since(start), ops
}
func runBytecodeTest(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) {
// First compile the bytecode
bytecode, err := L.CompileBytecode(code, "bench")
if err != nil {
fmt.Printf("Error compiling bytecode: %v\n", err)
return 0, 0
}
start := time.Now()
deadline := start.Add(duration)
var ops int64
for time.Now().Before(deadline) {
if err := L.LoadBytecode(bytecode, "bench"); err != nil {
fmt.Printf("Error executing bytecode: %v\n", err)
return 0, 0
}
ops++
}
return time.Since(start), ops
}
func benchmarkCase(newState func() *luajit.State, bc benchCase) {
fmt.Printf("\n%s:\n", bc.name)
// Direct execution benchmark
L := newState()
if L == nil {
fmt.Printf(" Failed to create Lua state\n")
return
}
execTime, ops := runBenchmark(L, bc.code, 2*time.Second)
L.Close()
if ops > 0 {
opsPerSec := float64(ops) / execTime.Seconds()
fmt.Printf(" Direct: %.0f ops/sec\n", opsPerSec)
}
// Bytecode execution benchmark
L = newState()
if L == nil {
fmt.Printf(" Failed to create Lua state\n")
return
}
execTime, ops = runBytecodeTest(L, bc.code, 2*time.Second)
L.Close()
if ops > 0 {
opsPerSec := float64(ops) / execTime.Seconds()
fmt.Printf(" Bytecode: %.0f ops/sec\n", opsPerSec)
}
}
func main() {
modes := []struct {
name string
newState func() *luajit.State
}{
{"Safe", luajit.NewSafe},
{"Unsafe", luajit.New},
}
for _, mode := range modes {
fmt.Printf("\n=== %s Mode ===\n", mode.name)
for _, c := range cases {
benchmarkCase(mode.newState, c)
}
}
}

View File

@ -12,7 +12,7 @@ typedef struct {
const char *name;
} BytecodeReader;
static const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
BytecodeReader *r = (BytecodeReader *)ud;
(void)L; // unused
if (r->size == 0) return NULL;
@ -21,29 +21,33 @@ static const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
return (const char *)r->buf;
}
static int load_bytecode_chunk(lua_State *L, const unsigned char *buf, size_t len, const char *name) {
int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char *name) {
BytecodeReader reader = {buf, len, name};
return lua_load(L, bytecode_reader, &reader, name);
}
typedef struct {
unsigned char *buf;
size_t len;
} BytecodeWriter;
int bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
BytecodeWriter *w = (BytecodeWriter *)ud;
unsigned char *newbuf;
(void)L; // unused
newbuf = (unsigned char *)realloc(w->buf, w->len + sz);
// Direct bytecode dumping without intermediate buffer - more efficient
int direct_bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
void **data = (void **)ud;
size_t current_size = (size_t)data[1];
void *newbuf = realloc(data[0], current_size + sz);
if (newbuf == NULL) return 1;
memcpy(newbuf + w->len, p, sz);
w->buf = newbuf;
w->len += sz;
memcpy((unsigned char*)newbuf + current_size, p, sz);
data[0] = newbuf;
data[1] = (void*)(current_size + sz);
return 0;
}
// Combined load and run bytecode in a single call
int load_and_run_bytecode(lua_State *L, const unsigned char *buf, size_t len,
const char *name, int nresults) {
BytecodeReader reader = {buf, len, name};
int status = lua_load(L, bytecode_reader, &reader, name);
if (status != 0) return status;
return lua_pcall(L, 0, nresults, 0);
}
*/
import "C"
import (
@ -51,55 +55,46 @@ import (
"unsafe"
)
func (s *State) compileBytecodeUnsafe(code string, name string) ([]byte, error) {
// First load the string but don't execute it
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
if C.luaL_loadstring(s.L, ccode) != 0 {
err := &LuaError{
Code: int(C.lua_status(s.L)),
Message: s.ToString(-1),
}
s.Pop(1)
// CompileBytecode compiles a Lua chunk to bytecode without executing it
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
if err := s.LoadString(code); err != nil {
return nil, fmt.Errorf("failed to load string: %w", err)
}
// Set up writer
var writer C.BytecodeWriter
writer.buf = nil
writer.len = 0
// Use a simpler direct writer with just two pointers
data := [2]unsafe.Pointer{nil, nil}
// Dump the function to bytecode
if C.lua_dump(s.L, (*[0]byte)(C.bytecode_writer), unsafe.Pointer(&writer)) != 0 {
if writer.buf != nil {
C.free(unsafe.Pointer(writer.buf))
}
s.Pop(1)
return nil, fmt.Errorf("failed to dump bytecode")
status := C.lua_dump(s.L, (*[0]byte)(unsafe.Pointer(C.direct_bytecode_writer)), unsafe.Pointer(&data))
if status != 0 {
return nil, fmt.Errorf("failed to dump bytecode: status %d", status)
}
// Copy to Go slice
bytecode := C.GoBytes(unsafe.Pointer(writer.buf), C.int(writer.len))
// Clean up
if writer.buf != nil {
C.free(unsafe.Pointer(writer.buf))
// Get result
var bytecode []byte
if data[0] != nil {
// Create Go slice that references the C memory
length := uintptr(data[1])
bytecode = C.GoBytes(data[0], C.int(length))
C.free(data[0])
}
s.Pop(1) // Remove the function
s.Pop(1) // Remove the function from stack
return bytecode, nil
}
func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error {
// LoadBytecode loads precompiled bytecode without executing it
func (s *State) LoadBytecode(bytecode []byte, name string) error {
if len(bytecode) == 0 {
return fmt.Errorf("empty bytecode")
}
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
// Load the bytecode
status := C.load_bytecode_chunk(
status := C.load_bytecode(
s.L,
(*C.uchar)(unsafe.Pointer(&bytecode[0])),
C.size_t(len(bytecode)),
@ -111,49 +106,107 @@ func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error {
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1)
return fmt.Errorf("failed to load bytecode: %w", err)
}
// Execute the loaded chunk
if err := s.safeCall(func() C.int {
return C.lua_pcall(s.L, 0, 0, 0)
}); err != nil {
return fmt.Errorf("failed to execute bytecode: %w", err)
s.Pop(1) // Remove error message
return err
}
return nil
}
// CompileBytecode compiles a Lua chunk to bytecode without executing it
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
if s.safeStack {
return stackGuardValue[[]byte](s, func() ([]byte, error) {
return s.compileBytecodeUnsafe(code, name)
})
}
return s.compileBytecodeUnsafe(code, name)
// RunBytecode executes previously loaded bytecode with 0 results
func (s *State) RunBytecode() error {
return s.RunBytecodeWithResults(0)
}
// LoadBytecode loads precompiled bytecode and executes it
func (s *State) LoadBytecode(bytecode []byte, name string) error {
if s.safeStack {
return stackGuardErr(s, func() error {
return s.loadBytecodeUnsafe(bytecode, name)
})
// RunBytecodeWithResults executes bytecode and keeps nresults on the stack
// Use LUA_MULTRET (-1) to keep all results
func (s *State) RunBytecodeWithResults(nresults int) error {
status := C.lua_pcall(s.L, 0, C.int(nresults), 0)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
}
return s.loadBytecodeUnsafe(bytecode, name)
return nil
}
// Helper function to compile and immediately load/execute bytecode
func (s *State) CompileAndLoad(code string, name string) error {
// LoadAndRunBytecode loads and executes bytecode in a single CGO transition
func (s *State) LoadAndRunBytecode(bytecode []byte, name string) error {
if len(bytecode) == 0 {
return fmt.Errorf("empty bytecode")
}
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
// Use combined load and run function
status := C.load_and_run_bytecode(
s.L,
(*C.uchar)(unsafe.Pointer(&bytecode[0])),
C.size_t(len(bytecode)),
cname,
0, // No results
)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
}
return nil
}
// LoadAndRunBytecodeWithResults loads and executes bytecode, preserving results
func (s *State) LoadAndRunBytecodeWithResults(bytecode []byte, name string, nresults int) error {
if len(bytecode) == 0 {
return fmt.Errorf("empty bytecode")
}
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
// Use combined load and run function
status := C.load_and_run_bytecode(
s.L,
(*C.uchar)(unsafe.Pointer(&bytecode[0])),
C.size_t(len(bytecode)),
cname,
C.int(nresults),
)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
}
return nil
}
// CompileAndRun compiles and immediately executes Lua code
func (s *State) CompileAndRun(code string, name string) error {
// Skip bytecode step for small scripts - direct execution is faster
if len(code) < 1024 {
return s.DoString(code)
}
bytecode, err := s.CompileBytecode(code, name)
if err != nil {
return fmt.Errorf("compile error: %w", err)
}
if err := s.LoadBytecode(bytecode, name); err != nil {
return fmt.Errorf("load error: %w", err)
if err := s.LoadAndRunBytecode(bytecode, name); err != nil {
return fmt.Errorf("execution error: %w", err)
}
return nil

View File

@ -1,178 +0,0 @@
package luajit
import (
"fmt"
"testing"
)
func TestBytecodeCompilation(t *testing.T) {
tests := []struct {
name string
code string
wantErr bool
}{
{
name: "simple assignment",
code: "x = 42",
wantErr: false,
},
{
name: "function definition",
code: "function add(a,b) return a+b end",
wantErr: false,
},
{
name: "syntax error",
code: "function bad syntax",
wantErr: true,
},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
bytecode, err := L.CompileBytecode(tt.code, "test")
if (err != nil) != tt.wantErr {
t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if len(bytecode) == 0 {
t.Error("CompileBytecode() returned empty bytecode")
}
}
})
}
}
}
func TestBytecodeExecution(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Compile some test code
code := `
function add(a, b)
return a + b
end
result = add(40, 2)
`
bytecode, err := L.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode() error = %v", err)
}
// Load and execute the bytecode
if err := L.LoadBytecode(bytecode, "test"); err != nil {
t.Fatalf("LoadBytecode() error = %v", err)
}
// Verify the result
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("got result = %v, want 42", result)
}
})
}
}
func TestInvalidBytecode(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Test with invalid bytecode
invalidBytecode := []byte("this is not valid bytecode")
if err := L.LoadBytecode(invalidBytecode, "test"); err == nil {
t.Error("LoadBytecode() expected error with invalid bytecode")
}
})
}
}
func TestBytecodeRoundTrip(t *testing.T) {
tests := []struct {
name string
code string
check func(*State) error
}{
{
name: "global variable",
code: "x = 42",
check: func(L *State) error {
L.GetGlobal("x")
if x := L.ToNumber(-1); x != 42 {
return fmt.Errorf("got x = %v, want 42", x)
}
return nil
},
},
{
name: "function definition",
code: "function test() return 'hello' end",
check: func(L *State) error {
if err := L.DoString("result = test()"); err != nil {
return err
}
L.GetGlobal("result")
if s := L.ToString(-1); s != "hello" {
return fmt.Errorf("got result = %q, want 'hello'", s)
}
return nil
},
},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
// First state for compilation
L1 := f.new()
if L1 == nil {
t.Fatal("Failed to create first Lua state")
}
defer L1.Close()
// Compile the code
bytecode, err := L1.CompileBytecode(tt.code, "test")
if err != nil {
t.Fatalf("CompileBytecode() error = %v", err)
}
// Second state for execution
L2 := f.new()
if L2 == nil {
t.Fatal("Failed to create second Lua state")
}
defer L2.Close()
// Load and execute the bytecode
if err := L2.LoadBytecode(bytecode, "test"); err != nil {
t.Fatalf("LoadBytecode() error = %v", err)
}
// Run the check function
if err := tt.check(L2); err != nil {
t.Errorf("check failed: %v", err)
}
})
}
}
}

70
example/main.go Normal file
View File

@ -0,0 +1,70 @@
package main
import (
"fmt"
"log"
"os"
"path/filepath"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func main() {
if len(os.Args) < 2 {
fmt.Println("Usage: go run main.go script.lua")
os.Exit(1)
}
scriptPath := os.Args[1]
// Create a new Lua state
L := luajit.New()
if L == nil {
log.Fatal("Failed to create Lua state")
}
defer L.Close()
// Register a Go function to be called from Lua
L.RegisterGoFunction("printFromGo", func(s *luajit.State) int {
msg := s.ToString(1) // Get first argument
fmt.Printf("Go received from Lua: %s\n", msg)
// Return a value to Lua
s.PushString("Hello from Go!")
return 1 // Number of return values
})
// Add some values to the Lua environment
L.PushValue(map[string]any{
"appName": "LuaJIT Example",
"version": 1.0,
"features": []float64{1, 2, 3},
})
L.SetGlobal("config")
// Get the directory of the script to properly handle requires
dir := filepath.Dir(scriptPath)
L.AddPackagePath(filepath.Join(dir, "?.lua"))
// Execute the script
fmt.Printf("Running Lua script: %s\n", scriptPath)
if err := L.DoFile(scriptPath); err != nil {
log.Fatalf("Error executing script: %v", err)
}
// Call a Lua function and get its result
L.GetGlobal("getResult")
if L.IsFunction(-1) {
if err := L.Call(0, 1); err != nil {
log.Fatalf("Error calling Lua function: %v", err)
}
result, err := L.ToValue(-1)
if err != nil {
log.Fatalf("Error converting Lua result: %v", err)
}
fmt.Printf("Result from Lua: %v\n", result)
L.Pop(1) // Clean up the result
}
}

35
example/script.lua Normal file
View File

@ -0,0 +1,35 @@
-- Example Lua script to demonstrate Go-Lua integration
-- Access the config table passed from Go
print("Script started")
print("App name:", config.appName)
print("Version:", config.version)
print("Features:", table.concat(config.features, ", "))
-- Call the Go function
local response = printFromGo("Hello from Lua!")
print("Response from Go:", response)
-- Function that will be called from Go
function getResult()
local result = {
status = "success",
calculations = {
sum = 10 + 20,
product = 5 * 7
},
message = "Calculation completed"
}
return result
end
-- Load external module (if available)
local success, utils = pcall(require, "utils")
if success then
print("Utils module loaded")
utils.doSomething()
else
print("Utils module not available:", utils)
end
print("Script completed")

19
example/utils.lua Normal file
View File

@ -0,0 +1,19 @@
-- Optional utility module
local utils = {}
function utils.doSomething()
print("Utils module function called")
return true
end
function utils.calculate(a, b)
return {
sum = a + b,
difference = a - b,
product = a * b,
quotient = a / b
}
end
return utils

View File

@ -7,8 +7,9 @@ package luajit
extern int goFunctionWrapper(lua_State* L);
// Helper function to access upvalues
static int get_upvalue_index(int i) {
return -10002 - i; // LUA_GLOBALSINDEX - i
return lua_upvalueindex(i);
}
*/
import "C"
@ -18,28 +19,35 @@ import (
"unsafe"
)
// GoFunction defines the signature for Go functions callable from Lua
type GoFunction func(*State) int
// Static registry size reduces resizing operations
const initialRegistrySize = 64
var (
// functionRegistry stores all registered Go functions
functionRegistry = struct {
sync.RWMutex
funcs map[unsafe.Pointer]GoFunction
funcs map[unsafe.Pointer]GoFunction
initOnce sync.Once
}{
funcs: make(map[unsafe.Pointer]GoFunction),
funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize),
}
)
//export goFunctionWrapper
func goFunctionWrapper(L *C.lua_State) C.int {
state := &State{L: L, safeStack: true}
state := &State{L: L}
// Get upvalue using standard Lua 5.1 macro
// Get function pointer from the first upvalue
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
if ptr == nil {
state.PushString("error: function not found")
state.PushString("error: function pointer not found")
return -1
}
// Use read-lock for better concurrency
functionRegistry.RLock()
fn, ok := functionRegistry.funcs[ptr]
functionRegistry.RUnlock()
@ -49,50 +57,56 @@ func goFunctionWrapper(L *C.lua_State) C.int {
return -1
}
result := fn(state)
return C.int(result)
// Call the Go function
return C.int(fn(state))
}
// PushGoFunction wraps a Go function and pushes it onto the Lua stack
func (s *State) PushGoFunction(fn GoFunction) error {
// Push lightuserdata as upvalue and create closure
// Allocate a pointer to use as the function key
ptr := C.malloc(1)
if ptr == nil {
return fmt.Errorf("failed to allocate memory for function pointer")
}
// Register the function
functionRegistry.Lock()
functionRegistry.funcs[ptr] = fn
functionRegistry.Unlock()
// Push the pointer as lightuserdata (first upvalue)
C.lua_pushlightuserdata(s.L, ptr)
// Create closure with the C wrapper and the upvalue
C.lua_pushcclosure(s.L, (*[0]byte)(C.goFunctionWrapper), 1)
return nil
}
// RegisterGoFunction registers a Go function as a global Lua function
func (s *State) RegisterGoFunction(name string, fn GoFunction) error {
if err := s.PushGoFunction(fn); err != nil {
return err
}
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
s.SetGlobal(name)
return nil
}
// UnregisterGoFunction removes a global function
func (s *State) UnregisterGoFunction(name string) {
s.PushNil()
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
s.SetGlobal(name)
}
// Cleanup frees all function pointers and clears the registry
func (s *State) Cleanup() {
functionRegistry.Lock()
defer functionRegistry.Unlock()
// Free all allocated pointers
for ptr := range functionRegistry.funcs {
C.free(ptr)
delete(functionRegistry.funcs, ptr)
}
functionRegistry.funcs = make(map[unsafe.Pointer]GoFunction)
}

View File

@ -1,109 +0,0 @@
package luajit
import "testing"
func TestGoFunctions(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
addFunc := func(s *State) int {
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
return 1
}
if err := L.RegisterGoFunction("add", addFunc); err != nil {
t.Fatalf("Failed to register function: %v", err)
}
// Test basic function call
if err := L.DoString("result = add(40, 2)"); err != nil {
t.Fatalf("Failed to call function: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("got %v, want 42", result)
}
L.Pop(1)
// Test multiple return values
multiFunc := func(s *State) int {
s.PushString("hello")
s.PushNumber(42)
s.PushBoolean(true)
return 3
}
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
t.Fatalf("Failed to register multi function: %v", err)
}
code := `
a, b, c = multi()
result = (a == "hello" and b == 42 and c == true)
`
if err := L.DoString(code); err != nil {
t.Fatalf("Failed to call multi function: %v", err)
}
L.GetGlobal("result")
if !L.ToBoolean(-1) {
t.Error("Multiple return values test failed")
}
L.Pop(1)
// Test error handling
errFunc := func(s *State) int {
s.PushString("test error")
return -1
}
if err := L.RegisterGoFunction("err", errFunc); err != nil {
t.Fatalf("Failed to register error function: %v", err)
}
if err := L.DoString("err()"); err == nil {
t.Error("Expected error from error function")
}
// Test unregistering
L.UnregisterGoFunction("add")
if err := L.DoString("add(1, 2)"); err == nil {
t.Error("Expected error calling unregistered function")
}
})
}
}
func TestStackSafety(t *testing.T) {
L := NewSafe()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
// Test stack overflow protection
overflowFunc := func(s *State) int {
for i := 0; i < 100; i++ {
s.PushNumber(float64(i))
}
s.PushString("done")
return 101
}
if err := L.RegisterGoFunction("overflow", overflowFunc); err != nil {
t.Fatal(err)
}
if err := L.DoString("overflow()"); err != nil {
t.Logf("Got expected error: %v", err)
}
}

View File

@ -25,84 +25,23 @@ const (
LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table
)
// checkStack ensures there is enough space on the Lua stack
func (s *State) checkStack(n int) error {
if C.lua_checkstack(s.L, C.int(n)) == 0 {
return fmt.Errorf("stack overflow (cannot allocate %d slots)", n)
}
return nil
}
// safeCall wraps a potentially dangerous C call with stack checking
func (s *State) safeCall(f func() C.int) error {
// Save current stack size
top := s.GetTop()
// Ensure we have enough stack space (minimum 20 slots as per Lua standard)
if err := s.checkStack(LUA_MINSTACK); err != nil {
return err
// GetStackTrace returns the current Lua stack trace
func (s *State) GetStackTrace() string {
s.GetGlobal("debug")
if !s.IsTable(-1) {
s.Pop(1)
return "debug table not available"
}
// Make the call
status := f()
// Check for errors
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
s.GetField(-1, "traceback")
if !s.IsFunction(-1) {
s.Pop(2) // Remove debug table and non-function
return "debug.traceback not available"
}
// For lua_pcall, the function and arguments are popped before results are pushed
// So we don't consider it an underflow if the new top is less than the original
if status == 0 && s.GetType(-1) == TypeFunction {
// If we still have a function on the stack, restore original size
s.SetTop(top)
}
s.Call(0, 1)
trace := s.ToString(-1)
s.Pop(1) // Remove the trace
return nil
}
// stackGuard wraps a function with stack checking
func stackGuard[T any](s *State, f func() (T, error)) (T, error) {
// Save current stack size
top := s.GetTop()
defer func() {
// Only restore if stack is larger than original
if s.GetTop() > top {
s.SetTop(top)
}
}()
// Run the protected function
return f()
}
// stackGuardValue executes a function with stack protection
func stackGuardValue[T any](s *State, f func() (T, error)) (T, error) {
return stackGuard(s, f)
}
// stackGuardErr executes a function that only returns an error with stack protection
func stackGuardErr(s *State, f func() error) error {
// Save current stack size
top := s.GetTop()
defer func() {
// Only restore if stack is larger than original
if s.GetTop() > top {
s.SetTop(top)
}
}()
// Run the protected function
return f()
}
// getStackTrace returns the current Lua stack trace
func (s *State) getStackTrace() string {
// Same implementation...
return ""
return trace
}

207
table.go
View File

@ -6,172 +6,159 @@ package luajit
#include <lauxlib.h>
#include <stdlib.h>
static int get_table_length(lua_State *L, int index) {
// Simple direct length check
size_t get_table_length(lua_State *L, int index) {
return lua_objlen(L, index);
}
*/
import "C"
import (
"fmt"
"strconv"
)
// TableValue represents any value that can be stored in a Lua table
type TableValue interface {
~string | ~float64 | ~bool | ~int | ~map[string]interface{} | ~[]float64 | ~[]interface{}
// GetTableLength returns the length of a table at the given index
func (s *State) GetTableLength(index int) int {
return int(C.get_table_length(s.L, C.int(index)))
}
func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) }
// ToTable converts a Lua table to a Go map
func (s *State) ToTable(index int) (map[string]interface{}, error) {
if s.safeStack {
return stackGuardValue[map[string]interface{}](s, func() (map[string]interface{}, error) {
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
// PushTable pushes a Go map onto the Lua stack as a table
func (s *State) PushTable(table map[string]any) error {
// Fast path for array tables
if arr, ok := table[""]; ok {
if floatArr, ok := arr.([]float64); ok {
s.CreateTable(len(floatArr), 0)
for i, v := range floatArr {
s.PushNumber(float64(i + 1))
s.PushNumber(v)
s.SetTable(-3)
}
return s.toTableUnsafe(index)
})
}
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.toTableUnsafe(index)
}
func (s *State) pushTableSafe(table map[string]interface{}) error {
size := 2
if err := s.checkStack(size); err != nil {
return fmt.Errorf("insufficient stack space: %w", err)
return nil
} else if anyArr, ok := arr.([]any); ok {
s.CreateTable(len(anyArr), 0)
for i, v := range anyArr {
s.PushNumber(float64(i + 1))
if err := s.PushValue(v); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
}
s.NewTable()
// Regular table case - optimize capacity hint
s.CreateTable(0, len(table))
// Add each key-value pair directly
for k, v := range table {
if err := s.pushValueSafe(v); err != nil {
s.PushString(k)
if err := s.PushValue(v); err != nil {
return err
}
s.SetField(-2, k)
s.SetTable(-3)
}
return nil
}
func (s *State) pushTableUnsafe(table map[string]interface{}) error {
s.NewTable()
for k, v := range table {
if err := s.pushValueUnsafe(v); err != nil {
return err
}
s.SetField(-2, k)
}
return nil
}
func (s *State) toTableSafe(index int) (map[string]interface{}, error) {
if err := s.checkStack(2); err != nil {
return nil, err
}
return s.toTableUnsafe(index)
}
func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
// ToTable converts a Lua table at the given index to a Go map
func (s *State) ToTable(index int) (map[string]any, error) {
absIdx := s.absIndex(index)
table := make(map[string]interface{})
if !s.IsTable(absIdx) {
return nil, fmt.Errorf("value at index %d is not a table", index)
}
// Check if it's an array-like table
// Try to detect array-like tables first
length := s.GetTableLength(absIdx)
if length > 0 {
array := make([]float64, length)
isArray := true
// Fast path for common array case
allNumbers := true
// Try to convert to array
for i := 1; i <= length; i++ {
// Sample first few values to check if it's likely an array of numbers
for i := 1; i <= min(length, 5); i++ {
s.PushNumber(float64(i))
s.GetTable(absIdx)
if s.GetType(-1) != TypeNumber {
isArray = false
if !s.IsNumber(-1) {
allNumbers = false
s.Pop(1)
break
}
array[i-1] = s.ToNumber(-1)
s.Pop(1)
}
if isArray {
return map[string]interface{}{"": array}, nil
if allNumbers {
// Efficiently extract array values
array := make([]float64, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(absIdx)
array[i-1] = s.ToNumber(-1)
s.Pop(1)
}
// Return array as a special table with empty key
result := make(map[string]any, 1)
result[""] = array
return result, nil
}
}
// Handle regular table
s.PushNil()
for C.lua_next(s.L, C.int(absIdx)) != 0 {
key := ""
valueType := C.lua_type(s.L, -2)
if valueType == C.LUA_TSTRING {
// Handle regular table with pre-allocated capacity
table := make(map[string]any, max(length, 8))
// Iterate through all key-value pairs
s.PushNil() // Start iteration with nil key
for s.Next(absIdx) {
// Stack now has key at -2 and value at -1
// Convert key to string
var key string
keyType := s.GetType(-2)
switch keyType {
case TypeString:
key = s.ToString(-2)
} else if valueType == C.LUA_TNUMBER {
key = fmt.Sprintf("%g", s.ToNumber(-2))
case TypeNumber:
key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64)
default:
// Skip non-string/non-number keys
s.Pop(1) // Pop value, leave key for next iteration
continue
}
value, err := s.toValueUnsafe(-1)
// Convert and store the value
value, err := s.ToValue(-1)
if err != nil {
s.Pop(1)
s.Pop(2) // Pop both key and value
return nil, err
}
// Handle nested array case
if m, ok := value.(map[string]interface{}); ok {
// Unwrap nested array tables
if m, ok := value.(map[string]any); ok {
if arr, ok := m[""]; ok {
value = arr
}
}
table[key] = value
s.Pop(1)
s.Pop(1) // Pop value, leave key for next iteration
}
return table, nil
}
// NewTable creates a new table and pushes it onto the stack
func (s *State) NewTable() {
if s.safeStack {
if err := s.checkStack(1); err != nil {
// Since we can't return an error, we'll push nil instead
s.PushNil()
return
}
// Helper functions for min/max operations
func min(a, b int) int {
if a < b {
return a
}
C.lua_createtable(s.L, 0, 0)
return b
}
// SetTable sets a table field with cached absolute index
func (s *State) SetTable(index int) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
func max(a, b int) int {
if a > b {
return a
}
C.lua_settable(s.L, C.int(absIdx))
}
// GetTable gets a table field with cached absolute index
func (s *State) GetTable(index int) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
C.lua_gettable(s.L, C.int(absIdx))
}
// PushTable pushes a Go map onto the Lua stack as a table with stack checking
func (s *State) PushTable(table map[string]interface{}) error {
if s.safeStack {
return s.pushTableSafe(table)
}
return s.pushTableUnsafe(table)
return b
}

View File

@ -1,97 +0,0 @@
package luajit
import (
"math"
"testing"
)
func TestTableOperations(t *testing.T) {
tests := []struct {
name string
data map[string]interface{}
}{
{
name: "empty",
data: map[string]interface{}{},
},
{
name: "primitives",
data: map[string]interface{}{
"str": "hello",
"num": 42.0,
"bool": true,
"array": []float64{1.1, 2.2, 3.3},
},
},
{
name: "nested",
data: map[string]interface{}{
"nested": map[string]interface{}{
"value": 123.0,
"array": []float64{4.4, 5.5},
},
},
},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
if err := L.PushTable(tt.data); err != nil {
t.Fatalf("PushTable() error = %v", err)
}
got, err := L.ToTable(-1)
if err != nil {
t.Fatalf("ToTable() error = %v", err)
}
if !tablesEqual(got, tt.data) {
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
}
})
}
}
}
func tablesEqual(a, b map[string]interface{}) bool {
if len(a) != len(b) {
return false
}
for k, v1 := range a {
v2, ok := b[k]
if !ok {
return false
}
switch v1 := v1.(type) {
case map[string]interface{}:
v2, ok := v2.(map[string]interface{})
if !ok || !tablesEqual(v1, v2) {
return false
}
case []float64:
v2, ok := v2.([]float64)
if !ok || len(v1) != len(v2) {
return false
}
for i := range v1 {
if math.Abs(v1[i]-v2[i]) > 1e-10 {
return false
}
}
default:
if v1 != v2 {
return false
}
}
}
return true
}

443
tests/bytecode_test.go Normal file
View File

@ -0,0 +1,443 @@
package luajit_test
import (
"bytes"
"errors"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// TestCompileBytecode tests basic bytecode compilation
func TestCompileBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
code := "return 42"
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
if len(bytecode) == 0 {
t.Fatal("Expected non-empty bytecode")
}
}
// TestLoadBytecode tests loading precompiled bytecode
func TestLoadBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// First compile some bytecode
code := "answer = 42"
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Then load it
err = state.LoadBytecode(bytecode, "test")
if err != nil {
t.Fatalf("LoadBytecode failed: %v", err)
}
// Verify a function is on the stack
if !state.IsFunction(-1) {
t.Fatal("Expected function at top of stack after LoadBytecode")
}
// Pop the function
state.Pop(1)
}
// TestRunBytecode tests running previously loaded bytecode
func TestRunBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// First compile and load bytecode
code := "answer = 42"
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
err = state.LoadBytecode(bytecode, "test")
if err != nil {
t.Fatalf("LoadBytecode failed: %v", err)
}
// Run the bytecode
err = state.RunBytecode()
if err != nil {
t.Fatalf("RunBytecode failed: %v", err)
}
// Verify the code has executed correctly
state.GetGlobal("answer")
if !state.IsNumber(-1) || state.ToNumber(-1) != 42 {
t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1))
}
state.Pop(1)
}
// TestLoadAndRunBytecode tests the combined load and run functionality
func TestLoadAndRunBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Compile bytecode
code := "answer = 42"
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Load and run in one step
err = state.LoadAndRunBytecode(bytecode, "test")
if err != nil {
t.Fatalf("LoadAndRunBytecode failed: %v", err)
}
// Verify execution
state.GetGlobal("answer")
if !state.IsNumber(-1) || state.ToNumber(-1) != 42 {
t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1))
}
state.Pop(1)
}
// TestCompileAndRun tests compile and run functionality
func TestCompileAndRun(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Compile and run in one step
code := "answer = 42"
err := state.CompileAndRun(code, "test")
if err != nil {
t.Fatalf("CompileAndRun failed: %v", err)
}
// Verify execution
state.GetGlobal("answer")
if !state.IsNumber(-1) || state.ToNumber(-1) != 42 {
t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1))
}
state.Pop(1)
}
// TestEmptyBytecode tests error handling for empty bytecode
func TestEmptyBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Try to load empty bytecode
err := state.LoadBytecode([]byte{}, "empty")
if err == nil {
t.Fatal("Expected error for empty bytecode, got nil")
}
}
// TestInvalidBytecode tests error handling for invalid bytecode
func TestInvalidBytecode(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create some invalid bytecode
invalidBytecode := []byte("not valid bytecode")
// Try to load invalid bytecode
err := state.LoadBytecode(invalidBytecode, "invalid")
if err == nil {
t.Fatal("Expected error for invalid bytecode, got nil")
}
}
// TestBytecodeSerialization tests serializing and deserializing bytecode
func TestBytecodeSerialization(t *testing.T) {
// First state to compile
state1 := luajit.New()
if state1 == nil {
t.Fatal("Failed to create first Lua state")
}
defer state1.Close()
// Compile bytecode
code := `
function add(a, b)
return a + b
end
result = add(10, 20)
`
bytecode, err := state1.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Second state to execute
state2 := luajit.New()
if state2 == nil {
t.Fatal("Failed to create second Lua state")
}
defer state2.Close()
// Load and run the bytecode in the second state
err = state2.LoadAndRunBytecode(bytecode, "test")
if err != nil {
t.Fatalf("LoadAndRunBytecode failed: %v", err)
}
// Verify execution
state2.GetGlobal("result")
if !state2.IsNumber(-1) || state2.ToNumber(-1) != 30 {
t.Fatalf("Expected result to be 30, got %v", state2.ToNumber(-1))
}
state2.Pop(1)
// Call the function to verify it was properly transferred
state2.GetGlobal("add")
if !state2.IsFunction(-1) {
t.Fatal("Expected add to be a function")
}
state2.PushNumber(5)
state2.PushNumber(7)
if err := state2.Call(2, 1); err != nil {
t.Fatalf("Failed to call function: %v", err)
}
if state2.ToNumber(-1) != 12 {
t.Fatalf("Expected add(5, 7) to return 12, got %v", state2.ToNumber(-1))
}
state2.Pop(1)
}
// TestCompilationError tests error handling for compilation errors
func TestCompilationError(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Invalid Lua code that should fail to compile
code := "function without end"
// Try to compile
_, err := state.CompileBytecode(code, "invalid")
if err == nil {
t.Fatal("Expected compilation error, got nil")
}
// Check error type
var luaErr *luajit.LuaError
if !errors.As(err, &luaErr) {
t.Fatalf("Expected error to wrap *luajit.LuaError, got %T", err)
}
}
// TestExecutionError tests error handling for runtime errors
func TestExecutionError(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Code that compiles but fails at runtime
code := "error('deliberate error')"
// Compile bytecode
bytecode, err := state.CompileBytecode(code, "error")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Try to execute
err = state.LoadAndRunBytecode(bytecode, "error")
if err == nil {
t.Fatal("Expected execution error, got nil")
}
// Check error type
if _, ok := err.(*luajit.LuaError); !ok {
t.Fatalf("Expected *luajit.LuaError, got %T", err)
}
}
// TestBytecodeEquivalence tests that bytecode execution produces the same results as direct execution
func TestBytecodeEquivalence(t *testing.T) {
code := `
local result = 0
for i = 1, 10 do
result = result + i
end
return result
`
// First, execute directly
state1 := luajit.New()
if state1 == nil {
t.Fatal("Failed to create first Lua state")
}
defer state1.Close()
directResult, err := state1.ExecuteWithResult(code)
if err != nil {
t.Fatalf("ExecuteWithResult failed: %v", err)
}
// Then, compile and execute bytecode
state2 := luajit.New()
if state2 == nil {
t.Fatal("Failed to create second Lua state")
}
defer state2.Close()
bytecode, err := state2.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
err = state2.LoadBytecode(bytecode, "test")
if err != nil {
t.Fatalf("LoadBytecode failed: %v", err)
}
err = state2.Call(0, 1)
if err != nil {
t.Fatalf("Call failed: %v", err)
}
bytecodeResult, err := state2.ToValue(-1)
if err != nil {
t.Fatalf("ToValue failed: %v", err)
}
state2.Pop(1)
// Compare results
if directResult != bytecodeResult {
t.Fatalf("Results differ: direct=%v, bytecode=%v", directResult, bytecodeResult)
}
}
// TestBytecodeReuse tests reusing the same bytecode multiple times
func TestBytecodeReuse(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a function in bytecode
code := `
return function(x)
return x * 2
end
`
bytecode, err := state.CompileBytecode(code, "func")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Execute it several times
for i := 1; i <= 3; i++ {
// Load and run to get the function
err = state.LoadAndRunBytecodeWithResults(bytes.Clone(bytecode), "func", 1)
if err != nil {
t.Fatalf("LoadAndRunBytecodeWithResults failed: %v", err)
}
// Stack now has the function at the top
if !state.IsFunction(-1) {
t.Fatal("Expected function at top of stack")
}
// Call with parameter i
state.PushNumber(float64(i))
if err := state.Call(1, 1); err != nil {
t.Fatalf("Call failed: %v", err)
}
// Check result
expected := float64(i * 2)
if state.ToNumber(-1) != expected {
t.Fatalf("Expected %v, got %v", expected, state.ToNumber(-1))
}
// Pop the result
state.Pop(1)
}
}
// TestBytecodeClosure tests that bytecode properly handles closures and upvalues
func TestBytecodeClosure(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a closure
code := `
local counter = 0
return function()
counter = counter + 1
return counter
end
`
// Compile to bytecode
bytecode, err := state.CompileBytecode(code, "closure")
if err != nil {
t.Fatalf("CompileBytecode failed: %v", err)
}
// Load and run to get the counter function
err = state.LoadAndRunBytecodeWithResults(bytecode, "closure", 1)
if err != nil {
t.Fatalf("LoadAndRunBytecode failed: %v", err)
}
// Stack now has the function at the top
if !state.IsFunction(-1) {
t.Fatal("Expected function at top of stack")
}
// Store in a global
state.SetGlobal("counter_func")
// Call it multiple times and check the results
for i := 1; i <= 3; i++ {
state.GetGlobal("counter_func")
if err := state.Call(0, 1); err != nil {
t.Fatalf("Call failed: %v", err)
}
if state.ToNumber(-1) != float64(i) {
t.Fatalf("Expected counter to be %d, got %v", i, state.ToNumber(-1))
}
state.Pop(1)
}
}

178
tests/functions_test.go Normal file
View File

@ -0,0 +1,178 @@
package luajit_test
import (
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func TestPushGoFunction(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Define a simple function that adds two numbers
add := func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1 // Return one result
}
// Push the function onto the stack
if err := state.PushGoFunction(add); err != nil {
t.Fatalf("PushGoFunction failed: %v", err)
}
// Verify that a function is on the stack
if !state.IsFunction(-1) {
t.Fatalf("Expected function at top of stack")
}
// Push arguments
state.PushNumber(3)
state.PushNumber(4)
// Call the function
if err := state.Call(2, 1); err != nil {
t.Fatalf("Failed to call function: %v", err)
}
// Check the result
if state.ToNumber(-1) != 7 {
t.Fatalf("Function returned %f, expected 7", state.ToNumber(-1))
}
state.Pop(1)
}
func TestRegisterGoFunction(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Define a function that squares a number
square := func(s *luajit.State) int {
x := s.ToNumber(1)
s.PushNumber(x * x)
return 1
}
// Register the function
if err := state.RegisterGoFunction("square", square); err != nil {
t.Fatalf("RegisterGoFunction failed: %v", err)
}
// Call the function from Lua
if err := state.DoString("result = square(5)"); err != nil {
t.Fatalf("Failed to call registered function: %v", err)
}
// Check the result
state.GetGlobal("result")
if state.ToNumber(-1) != 25 {
t.Fatalf("Function returned %f, expected 25", state.ToNumber(-1))
}
state.Pop(1)
// Test UnregisterGoFunction
state.UnregisterGoFunction("square")
// Function should no longer exist
err := state.DoString("result = square(5)")
if err == nil {
t.Fatalf("Expected error after unregistering function, got nil")
}
}
func TestGoFunctionWithErrorHandling(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Function that returns an error in Lua
errFunc := func(s *luajit.State) int {
s.PushString("error from Go function")
return -1 // Signal error
}
// Register the function
if err := state.RegisterGoFunction("errorFunc", errFunc); err != nil {
t.Fatalf("RegisterGoFunction failed: %v", err)
}
// Call the function expecting an error
err := state.DoString("result = errorFunc()")
if err == nil {
t.Fatalf("Expected error from function, got nil")
}
// Error message should contain our message
luaErr, ok := err.(*luajit.LuaError)
if !ok {
t.Fatalf("Expected LuaError, got %T: %v", err, err)
}
if luaErr.Message == "" {
t.Fatalf("Expected non-empty error message from Go function")
}
}
func TestCleanup(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
// Register several functions
for i := 0; i < 5; i++ {
dummy := func(s *luajit.State) int { return 0 }
if err := state.RegisterGoFunction("dummy", dummy); err != nil {
t.Fatalf("RegisterGoFunction failed: %v", err)
}
}
// Call Cleanup explicitly
state.Cleanup()
// Make sure we can still close the state
state.Close()
// Also test that Close can be called after Cleanup
state = luajit.New()
if state == nil {
t.Fatal("Failed to create second Lua state")
}
state.Close() // Should call Cleanup internally
}
func TestGoFunctionErrorPointer(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a Lua function that calls a non-existent Go function pointer
// This isn't a direct test of internal implementation, but tries to cover
// error cases in the goFunctionWrapper
code := `
function test()
-- This is a stub that doesn't actually call the wrapper,
-- but we're testing error handling in our State.DoString
return "test"
end
`
if err := state.DoString(code); err != nil {
t.Fatalf("Failed to define test function: %v", err)
}
// The real test is that Cleanup doesn't crash
state.Cleanup()
}

53
tests/stack_test.go Normal file
View File

@ -0,0 +1,53 @@
package luajit_test
import (
"strings"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func TestLuaError(t *testing.T) {
err := &luajit.LuaError{
Code: 123,
Message: "test error",
}
expected := "lua error (code=123): test error"
if err.Error() != expected {
t.Errorf("Expected error message %q, got %q", expected, err.Error())
}
}
func TestGetStackTrace(t *testing.T) {
s := luajit.New()
defer s.Close()
// Test with debug library available
trace := s.GetStackTrace()
if !strings.Contains(trace, "stack traceback:") {
t.Errorf("Expected trace to contain 'stack traceback:', got %q", trace)
}
// Test when debug table is not available
err := s.DoString("debug = nil")
if err != nil {
t.Fatalf("Failed to set debug to nil: %v", err)
}
trace = s.GetStackTrace()
if trace != "debug table not available" {
t.Errorf("Expected 'debug table not available', got %q", trace)
}
// Test when debug.traceback is not available
err = s.DoString("debug = {}")
if err != nil {
t.Fatalf("Failed to set debug to empty table: %v", err)
}
trace = s.GetStackTrace()
if trace != "debug.traceback not available" {
t.Errorf("Expected 'debug.traceback not available', got %q", trace)
}
}

246
tests/table_test.go Normal file
View File

@ -0,0 +1,246 @@
package luajit_test
import (
"reflect"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func TestGetTableLength(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a table with numeric indices
if err := state.DoString("t = {10, 20, 30, 40, 50}"); err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
// Get the table
state.GetGlobal("t")
length := state.GetTableLength(-1)
if length != 5 {
t.Fatalf("Expected length 5, got %d", length)
}
state.Pop(1)
// Create a table with string keys
if err := state.DoString("t2 = {a=1, b=2, c=3}"); err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
// Get the table
state.GetGlobal("t2")
length = state.GetTableLength(-1)
if length != 0 {
t.Fatalf("Expected length 0 for string-keyed table, got %d", length)
}
state.Pop(1)
}
func TestPushTable(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a test table
testTable := map[string]any{
"int": 42,
"float": 3.14,
"string": "hello",
"boolean": true,
"nil": nil,
}
// Push the table onto the stack
if err := state.PushTable(testTable); err != nil {
t.Fatalf("Failed to push table: %v", err)
}
// Execute Lua code to test the table contents
if err := state.DoString(`
function validate_table(t)
return t.int == 42 and
math.abs(t.float - 3.14) < 0.0001 and
t.string == "hello" and
t.boolean == true and
t["nil"] == nil
end
`); err != nil {
t.Fatalf("Failed to create validation function: %v", err)
}
// Call the validation function
state.GetGlobal("validate_table")
state.PushCopy(-2) // Copy the table to the top
if err := state.Call(1, 1); err != nil {
t.Fatalf("Failed to call validation function: %v", err)
}
if !state.ToBoolean(-1) {
t.Fatalf("Table validation failed")
}
state.Pop(2) // Pop the result and the table
}
func TestToTable(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test regular table conversion
if err := state.DoString(`t = {a=1, b=2.5, c="test", d=true, e=nil}`); err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
state.GetGlobal("t")
table, err := state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert table: %v", err)
}
state.Pop(1)
expected := map[string]any{
"a": float64(1),
"b": 2.5,
"c": "test",
"d": true,
}
for k, v := range expected {
if table[k] != v {
t.Fatalf("Expected table[%s] = %v, got %v", k, v, table[k])
}
}
// Test array-like table conversion
if err := state.DoString(`arr = {10, 20, 30, 40, 50}`); err != nil {
t.Fatalf("Failed to create test array: %v", err)
}
state.GetGlobal("arr")
table, err = state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert array table: %v", err)
}
state.Pop(1)
// For array tables, we should get a special format with an empty key
// and the array as the value
expectedArray := []float64{10, 20, 30, 40, 50}
if arr, ok := table[""].([]float64); !ok {
t.Fatalf("Expected array table to be converted with empty key, got: %v", table)
} else if !reflect.DeepEqual(arr, expectedArray) {
t.Fatalf("Expected %v, got %v", expectedArray, arr)
}
// Test invalid table index
_, err = state.ToTable(100)
if err == nil {
t.Fatalf("Expected error for invalid table index, got nil")
}
// Test non-table value
state.PushNumber(123)
_, err = state.ToTable(-1)
if err == nil {
t.Fatalf("Expected error for non-table value, got nil")
}
state.Pop(1)
// Test mixed array with non-numeric values
if err := state.DoString(`mixed = {10, 20, key="value", 30}`); err != nil {
t.Fatalf("Failed to create mixed table: %v", err)
}
state.GetGlobal("mixed")
table, err = state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert mixed table: %v", err)
}
// Let's print the table for debugging
t.Logf("Table contents: %v", table)
state.Pop(1)
// Check if the array part is detected and stored with empty key
if arr, ok := table[""]; !ok {
t.Fatalf("Expected array-like part to be detected, got: %v", table)
} else {
// Verify the array contains the expected values
expectedArr := []float64{10, 20, 30}
actualArr := arr.([]float64)
if len(actualArr) != len(expectedArr) {
t.Fatalf("Expected array length %d, got %d", len(expectedArr), len(actualArr))
}
for i, v := range expectedArr {
if actualArr[i] != v {
t.Fatalf("Expected array[%d] = %v, got %v", i, v, actualArr[i])
}
}
}
// Based on the implementation, we need to create a separate test for string keys
if err := state.DoString(`dict = {foo="bar", baz="qux"}`); err != nil {
t.Fatalf("Failed to create dict table: %v", err)
}
state.GetGlobal("dict")
dictTable, err := state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert dict table: %v", err)
}
state.Pop(1)
// Check the string keys
if val, ok := dictTable["foo"]; !ok || val != "bar" {
t.Fatalf("Expected dictTable[\"foo\"] = \"bar\", got: %v", val)
}
if val, ok := dictTable["baz"]; !ok || val != "qux" {
t.Fatalf("Expected dictTable[\"baz\"] = \"qux\", got: %v", val)
}
}
func TestTablePooling(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a Lua table and push it onto the stack
if err := state.DoString(`t = {a=1, b=2}`); err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
state.GetGlobal("t")
// First conversion - should get a table from the pool
table1, err := state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert table (1): %v", err)
}
// Second conversion - should get another table from the pool
table2, err := state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert table (2): %v", err)
}
// Both tables should have the same content
if !reflect.DeepEqual(table1, table2) {
t.Fatalf("Tables should have the same content: %v vs %v", table1, table2)
}
// Clean up
state.Pop(1)
}

473
tests/wrapper_test.go Normal file
View File

@ -0,0 +1,473 @@
package luajit_test
import (
"os"
"reflect"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func TestStateLifecycle(t *testing.T) {
// Test creation
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
// Test close
state.Close()
// Test close is idempotent (doesn't crash)
state.Close()
}
func TestStackManipulation(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test initial stack size
if state.GetTop() != 0 {
t.Fatalf("Expected empty stack, got %d elements", state.GetTop())
}
// Push values
state.PushNil()
state.PushBoolean(true)
state.PushNumber(42)
state.PushString("hello")
// Check stack size
if state.GetTop() != 4 {
t.Fatalf("Expected 4 elements, got %d", state.GetTop())
}
// Test SetTop
state.SetTop(2)
if state.GetTop() != 2 {
t.Fatalf("Expected 2 elements after SetTop, got %d", state.GetTop())
}
// Test PushCopy
state.PushCopy(2) // Copy the boolean
if !state.IsBoolean(-1) {
t.Fatalf("Expected boolean at top of stack")
}
// Test Pop
state.Pop(1)
if state.GetTop() != 2 {
t.Fatalf("Expected 2 elements after Pop, got %d", state.GetTop())
}
// Test Remove
state.PushNumber(99)
state.Remove(1) // Remove the first element (nil)
if state.GetTop() != 2 {
t.Fatalf("Expected 2 elements after Remove, got %d", state.GetTop())
}
// Verify first element is now boolean
if !state.IsBoolean(1) {
t.Fatalf("Expected boolean at index 1 after Remove")
}
}
func TestTypeChecking(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Push values of different types
state.PushNil()
state.PushBoolean(true)
state.PushNumber(42)
state.PushString("hello")
state.NewTable()
// Check types with GetType
if state.GetType(1) != luajit.TypeNil {
t.Fatalf("Expected nil type at index 1, got %s", state.GetType(1))
}
if state.GetType(2) != luajit.TypeBoolean {
t.Fatalf("Expected boolean type at index 2, got %s", state.GetType(2))
}
if state.GetType(3) != luajit.TypeNumber {
t.Fatalf("Expected number type at index 3, got %s", state.GetType(3))
}
if state.GetType(4) != luajit.TypeString {
t.Fatalf("Expected string type at index 4, got %s", state.GetType(4))
}
if state.GetType(5) != luajit.TypeTable {
t.Fatalf("Expected table type at index 5, got %s", state.GetType(5))
}
// Test individual type checking functions
if !state.IsNil(1) {
t.Fatalf("IsNil failed for nil value")
}
if !state.IsBoolean(2) {
t.Fatalf("IsBoolean failed for boolean value")
}
if !state.IsNumber(3) {
t.Fatalf("IsNumber failed for number value")
}
if !state.IsString(4) {
t.Fatalf("IsString failed for string value")
}
if !state.IsTable(5) {
t.Fatalf("IsTable failed for table value")
}
// Function test
state.DoString("function test() return true end")
state.GetGlobal("test")
if !state.IsFunction(-1) {
t.Fatalf("IsFunction failed for function value")
}
}
func TestValueConversion(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Push values
state.PushBoolean(true)
state.PushNumber(42.5)
state.PushString("hello")
// Test conversion
if !state.ToBoolean(1) {
t.Fatalf("ToBoolean failed")
}
if state.ToNumber(2) != 42.5 {
t.Fatalf("ToNumber failed, expected 42.5, got %f", state.ToNumber(2))
}
if state.ToString(3) != "hello" {
t.Fatalf("ToString failed, expected 'hello', got '%s'", state.ToString(3))
}
}
func TestTableOperations(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test CreateTable
state.CreateTable(0, 3)
// Add fields using SetField
state.PushNumber(42)
state.SetField(-2, "answer")
state.PushString("hello")
state.SetField(-2, "greeting")
state.PushBoolean(true)
state.SetField(-2, "flag")
// Test GetField
state.GetField(-1, "answer")
if state.ToNumber(-1) != 42 {
t.Fatalf("GetField for 'answer' failed")
}
state.Pop(1)
state.GetField(-1, "greeting")
if state.ToString(-1) != "hello" {
t.Fatalf("GetField for 'greeting' failed")
}
state.Pop(1)
// Test Next for iteration
state.PushNil() // Start iteration
count := 0
for state.Next(-2) {
count++
state.Pop(1) // Pop value, leave key for next iteration
}
if count != 3 {
t.Fatalf("Expected 3 table entries, found %d", count)
}
// Clean up
state.Pop(1) // Pop the table
}
func TestGlobalOperations(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Set a global value
state.PushNumber(42)
state.SetGlobal("answer")
// Get the global value
state.GetGlobal("answer")
if state.ToNumber(-1) != 42 {
t.Fatalf("GetGlobal failed, expected 42, got %f", state.ToNumber(-1))
}
state.Pop(1)
// Test non-existent global (should be nil)
state.GetGlobal("nonexistent")
if !state.IsNil(-1) {
t.Fatalf("Expected nil for non-existent global")
}
state.Pop(1)
}
func TestCodeExecution(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test LoadString
if err := state.LoadString("return 42"); err != nil {
t.Fatalf("LoadString failed: %v", err)
}
// Test Call
if err := state.Call(0, 1); err != nil {
t.Fatalf("Call failed: %v", err)
}
if state.ToNumber(-1) != 42 {
t.Fatalf("Call result incorrect, expected 42, got %f", state.ToNumber(-1))
}
state.Pop(1)
// Test DoString
if err := state.DoString("answer = 42 + 1"); err != nil {
t.Fatalf("DoString failed: %v", err)
}
state.GetGlobal("answer")
if state.ToNumber(-1) != 43 {
t.Fatalf("DoString execution incorrect, expected 43, got %f", state.ToNumber(-1))
}
state.Pop(1)
// Test Execute
nresults, err := state.Execute("return 5, 10, 15")
if err != nil {
t.Fatalf("Execute failed: %v", err)
}
if nresults != 3 {
t.Fatalf("Execute returned %d results, expected 3", nresults)
}
if state.ToNumber(-3) != 5 || state.ToNumber(-2) != 10 || state.ToNumber(-1) != 15 {
t.Fatalf("Execute results incorrect")
}
state.Pop(3)
// Test ExecuteWithResult
result, err := state.ExecuteWithResult("return 'hello'")
if err != nil {
t.Fatalf("ExecuteWithResult failed: %v", err)
}
if result != "hello" {
t.Fatalf("ExecuteWithResult returned %v, expected 'hello'", result)
}
// Test error handling
err = state.DoString("this is not valid lua code")
if err == nil {
t.Fatalf("Expected error for invalid code, got nil")
}
}
func TestDoFile(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Create a temporary Lua file
content := []byte("answer = 42")
tmpfile, err := os.CreateTemp("", "test-*.lua")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
if err := tmpfile.Close(); err != nil {
t.Fatalf("Failed to close temp file: %v", err)
}
// Test LoadFile and DoFile
if err := state.LoadFile(tmpfile.Name()); err != nil {
t.Fatalf("LoadFile failed: %v", err)
}
if err := state.Call(0, 0); err != nil {
t.Fatalf("Call failed after LoadFile: %v", err)
}
state.GetGlobal("answer")
if state.ToNumber(-1) != 42 {
t.Fatalf("Incorrect result after LoadFile, expected 42, got %f", state.ToNumber(-1))
}
state.Pop(1)
// Reset global
if err := state.DoString("answer = nil"); err != nil {
t.Fatalf("Failed to reset answer: %v", err)
}
// Test DoFile
if err := state.DoFile(tmpfile.Name()); err != nil {
t.Fatalf("DoFile failed: %v", err)
}
state.GetGlobal("answer")
if state.ToNumber(-1) != 42 {
t.Fatalf("Incorrect result after DoFile, expected 42, got %f", state.ToNumber(-1))
}
state.Pop(1)
}
func TestPackagePath(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test SetPackagePath
testPath := "/test/path/?.lua"
if err := state.SetPackagePath(testPath); err != nil {
t.Fatalf("SetPackagePath failed: %v", err)
}
result, err := state.ExecuteWithResult("return package.path")
if err != nil {
t.Fatalf("Failed to get package.path: %v", err)
}
if result != testPath {
t.Fatalf("Expected package.path to be '%s', got '%s'", testPath, result)
}
// Test AddPackagePath
addPath := "/another/path/?.lua"
if err := state.AddPackagePath(addPath); err != nil {
t.Fatalf("AddPackagePath failed: %v", err)
}
result, err = state.ExecuteWithResult("return package.path")
if err != nil {
t.Fatalf("Failed to get package.path: %v", err)
}
expected := testPath + ";" + addPath
if result != expected {
t.Fatalf("Expected package.path to be '%s', got '%s'", expected, result)
}
}
func TestPushValueAndToValue(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
testCases := []struct {
value any
}{
{nil},
{true},
{false},
{42},
{42.5},
{"hello"},
{[]float64{1, 2, 3, 4, 5}},
{[]any{1, "test", true}},
{map[string]any{"a": 1, "b": "test", "c": true}},
}
for i, tc := range testCases {
// Push value
err := state.PushValue(tc.value)
if err != nil {
t.Fatalf("PushValue failed for testCase %d: %v", i, err)
}
// Check stack
if state.GetTop() != i+1 {
t.Fatalf("Stack size incorrect after push, expected %d, got %d", i+1, state.GetTop())
}
}
// Test conversion back to Go
for i := range testCases {
index := len(testCases) - i
value, err := state.ToValue(index)
if err != nil {
t.Fatalf("ToValue failed for index %d: %v", index, err)
}
// For tables, we need special handling due to how Go types are stored
switch expected := testCases[index-1].value.(type) {
case []float64:
// Arrays come back as map[string]any with empty key
if m, ok := value.(map[string]any); ok {
if arr, ok := m[""].([]float64); ok {
if !reflect.DeepEqual(arr, expected) {
t.Fatalf("Value mismatch for testCase %d: expected %v, got %v", index-1, expected, arr)
}
} else {
t.Fatalf("Invalid array conversion for testCase %d", index-1)
}
} else {
t.Fatalf("Expected map for array value in testCase %d, got %T", index-1, value)
}
case int:
if num, ok := value.(float64); ok {
if float64(expected) == num {
continue // Values match after type conversion
}
}
case []any:
// Skip detailed comparison for mixed arrays
case map[string]any:
// Skip detailed comparison for maps
default:
if !reflect.DeepEqual(value, testCases[index-1].value) {
t.Fatalf("Value mismatch for testCase %d: expected %v, got %v",
index-1, testCases[index-1].value, value)
}
}
}
// Test unsupported type
complex := complex(1, 2)
err := state.PushValue(complex)
if err == nil {
t.Fatalf("Expected error for unsupported type")
}
}

103
types.go
View File

@ -4,12 +4,16 @@ package luajit
#include <lua.h>
*/
import "C"
import (
"fmt"
"strconv"
)
// LuaType represents Lua value types
type LuaType int
const (
// These constants must match lua.h's LUA_T* values
// These constants match lua.h's LUA_T* values
TypeNone LuaType = -1
TypeNil LuaType = 0
TypeBoolean LuaType = 1
@ -49,3 +53,100 @@ func (t LuaType) String() string {
return "unknown"
}
}
// ConvertValue converts a value to the requested type with proper type conversion
func ConvertValue[T any](value any) (T, bool) {
var zero T
// Handle nil case
if value == nil {
return zero, false
}
// Try direct type assertion first
if result, ok := value.(T); ok {
return result, true
}
// Type-specific conversions
switch any(zero).(type) {
case string:
switch v := value.(type) {
case float64:
return any(fmt.Sprintf("%g", v)).(T), true
case int:
return any(strconv.Itoa(v)).(T), true
case bool:
if v {
return any("true").(T), true
}
return any("false").(T), true
}
case int:
switch v := value.(type) {
case float64:
return any(int(v)).(T), true
case string:
if i, err := strconv.Atoi(v); err == nil {
return any(i).(T), true
}
case bool:
if v {
return any(1).(T), true
}
return any(0).(T), true
}
case float64:
switch v := value.(type) {
case int:
return any(float64(v)).(T), true
case string:
if f, err := strconv.ParseFloat(v, 64); err == nil {
return any(f).(T), true
}
case bool:
if v {
return any(1.0).(T), true
}
return any(0.0).(T), true
}
case bool:
switch v := value.(type) {
case string:
switch v {
case "true", "yes", "1":
return any(true).(T), true
case "false", "no", "0":
return any(false).(T), true
}
case int:
return any(v != 0).(T), true
case float64:
return any(v != 0).(T), true
}
}
return zero, false
}
// GetTypedValue gets a value from the state with type conversion
func GetTypedValue[T any](s *State, index int) (T, bool) {
var zero T
// Get the value as any type
value, err := s.ToValue(index)
if err != nil {
return zero, false
}
// Convert it to the requested type
return ConvertValue[T](value)
}
// GetGlobalTyped gets a global variable with type conversion
func GetGlobalTyped[T any](s *State, name string) (T, bool) {
s.GetGlobal(name)
defer s.Pop(1)
return GetTypedValue[T](s, -1)
}

View File

@ -1,127 +1,275 @@
package luajit
/*
#cgo CFLAGS: -I${SRCDIR}/vendor/luajit/include
#cgo !windows pkg-config: --static luajit
#cgo windows CFLAGS: -I${SRCDIR}/vendor/luajit/include
#cgo windows LDFLAGS: -L${SRCDIR}/vendor/luajit/windows -lluajit -static
#cgo !windows LDFLAGS: -L${SRCDIR}/vendor/luajit/linux -lluajit -static
#include <lua.h>
#include <lualib.h>
#include <lauxlib.h>
#include <stdlib.h>
#include <string.h>
// Direct execution helpers to minimize CGO transitions
static int do_string(lua_State *L, const char *s) {
int status = luaL_loadstring(L, s);
if (status) return status;
return lua_pcall(L, 0, LUA_MULTRET, 0);
if (status == 0) {
status = lua_pcall(L, 0, 0, 0);
}
return status;
}
static int do_file(lua_State *L, const char *filename) {
int status = luaL_loadfile(L, filename);
if (status) return status;
return lua_pcall(L, 0, LUA_MULTRET, 0);
if (status == 0) {
status = lua_pcall(L, 0, 0, 0);
}
return status;
}
static int execute_with_results(lua_State *L, const char *code, int store_results) {
int status = luaL_loadstring(L, code);
if (status != 0) return status;
return lua_pcall(L, 0, store_results ? LUA_MULTRET : 0, 0);
}
*/
import "C"
import (
"fmt"
"path/filepath"
"strings"
"sync"
"unsafe"
)
// State represents a Lua state with configurable stack safety
// Type pool for common objects to reduce GC pressure
var stringBufferPool = sync.Pool{
New: func() any {
return new(strings.Builder)
},
}
// State represents a Lua state
type State struct {
L *C.lua_State
safeStack bool
L *C.lua_State
}
// NewSafe creates a new Lua state with full stack safety guarantees
func NewSafe() *State {
// New creates a new Lua state with optional standard libraries; true if not specified
func New(openLibs ...bool) *State {
L := C.luaL_newstate()
if L == nil {
return nil
}
C.luaL_openlibs(L)
return &State{L: L, safeStack: true}
}
// New creates a new Lua state with minimal stack checking
func New() *State {
L := C.luaL_newstate()
if L == nil {
return nil
if len(openLibs) == 0 || openLibs[0] {
C.luaL_openlibs(L)
}
C.luaL_openlibs(L)
return &State{L: L, safeStack: false}
return &State{L: L}
}
// Close closes the Lua state
// Close closes the Lua state and frees resources
func (s *State) Close() {
if s.L != nil {
s.Cleanup() // Clean up Go function registry
C.lua_close(s.L)
s.L = nil
}
}
// DoString executes a Lua string with appropriate stack management
func (s *State) DoString(str string) error {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
// Stack manipulation methods
if s.safeStack {
return stackGuardErr(s, func() error {
return s.safeCall(func() C.int {
return C.do_string(s.L, cstr)
})
})
}
status := C.do_string(s.L, cstr)
if status != 0 {
return &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
}
return nil
// GetTop returns the index of the top element in the stack
func (s *State) GetTop() int {
return int(C.lua_gettop(s.L))
}
// PushValue pushes a Go value onto the stack
func (s *State) PushValue(v interface{}) error {
if s.safeStack {
return stackGuardErr(s, func() error {
if err := s.checkStack(1); err != nil {
return fmt.Errorf("pushing value: %w", err)
}
return s.pushValueUnsafe(v)
})
}
return s.pushValueUnsafe(v)
// SetTop sets the stack top to a specific index
func (s *State) SetTop(index int) {
C.lua_settop(s.L, C.int(index))
}
func (s *State) pushValueSafe(v interface{}) error {
if err := s.checkStack(1); err != nil {
return fmt.Errorf("pushing value: %w", err)
}
return s.pushValueUnsafe(v)
// PushCopy pushes a copy of the value at the given index onto the stack
func (s *State) PushCopy(index int) {
C.lua_pushvalue(s.L, C.int(index))
}
func (s *State) pushValueUnsafe(v interface{}) error {
// Pop pops n elements from the stack
func (s *State) Pop(n int) {
C.lua_settop(s.L, C.int(-n-1))
}
// Remove removes the element at the given valid index
func (s *State) Remove(index int) {
C.lua_remove(s.L, C.int(index))
}
// absIndex converts a possibly negative index to its absolute position
func (s *State) absIndex(index int) int {
if index > 0 || index <= LUA_REGISTRYINDEX {
return index
}
return s.GetTop() + index + 1
}
// Type checking methods
// GetType returns the type of the value at the given index
func (s *State) GetType(index int) LuaType {
return LuaType(C.lua_type(s.L, C.int(index)))
}
// IsNil checks if the value at the given index is nil
func (s *State) IsNil(index int) bool {
return s.GetType(index) == TypeNil
}
// IsBoolean checks if the value at the given index is a boolean
func (s *State) IsBoolean(index int) bool {
return s.GetType(index) == TypeBoolean
}
// IsNumber checks if the value at the given index is a number
func (s *State) IsNumber(index int) bool {
return C.lua_isnumber(s.L, C.int(index)) != 0
}
// IsString checks if the value at the given index is a string
func (s *State) IsString(index int) bool {
return C.lua_isstring(s.L, C.int(index)) != 0
}
// IsTable checks if the value at the given index is a table
func (s *State) IsTable(index int) bool {
return s.GetType(index) == TypeTable
}
// IsFunction checks if the value at the given index is a function
func (s *State) IsFunction(index int) bool {
return s.GetType(index) == TypeFunction
}
// Value conversion methods
// ToBoolean returns the value at the given index as a boolean
func (s *State) ToBoolean(index int) bool {
return C.lua_toboolean(s.L, C.int(index)) != 0
}
// ToNumber returns the value at the given index as a number
func (s *State) ToNumber(index int) float64 {
return float64(C.lua_tonumber(s.L, C.int(index)))
}
// ToString returns the value at the given index as a string
func (s *State) ToString(index int) string {
var length C.size_t
cstr := C.lua_tolstring(s.L, C.int(index), &length)
if cstr == nil {
return ""
}
return C.GoStringN(cstr, C.int(length))
}
// Push methods
// PushNil pushes a nil value onto the stack
func (s *State) PushNil() {
C.lua_pushnil(s.L)
}
// PushBoolean pushes a boolean value onto the stack
func (s *State) PushBoolean(b bool) {
var value C.int
if b {
value = 1
}
C.lua_pushboolean(s.L, value)
}
// PushNumber pushes a number value onto the stack
func (s *State) PushNumber(n float64) {
C.lua_pushnumber(s.L, C.lua_Number(n))
}
// PushString pushes a string value onto the stack
func (s *State) PushString(str string) {
// Use direct C string for short strings (avoid allocations)
if len(str) < 128 {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
C.lua_pushlstring(s.L, cstr, C.size_t(len(str)))
return
}
// For longer strings, avoid double copy by using unsafe pointer
header := (*struct {
p unsafe.Pointer
len int
cap int
})(unsafe.Pointer(&str))
C.lua_pushlstring(s.L, (*C.char)(header.p), C.size_t(len(str)))
}
// Table operations
// CreateTable creates a new table and pushes it onto the stack
func (s *State) CreateTable(narr, nrec int) {
C.lua_createtable(s.L, C.int(narr), C.int(nrec))
}
// NewTable creates a new empty table and pushes it onto the stack
func (s *State) NewTable() {
C.lua_createtable(s.L, 0, 0)
}
// GetTable gets a table field (t[k]) where t is at the given index and k is at the top of the stack
func (s *State) GetTable(index int) {
C.lua_gettable(s.L, C.int(index))
}
// SetTable sets a table field (t[k] = v) where t is at the given index, k is at -2, and v is at -1
func (s *State) SetTable(index int) {
C.lua_settable(s.L, C.int(index))
}
// GetField gets a table field t[k] and pushes it onto the stack
func (s *State) GetField(index int, key string) {
ckey := C.CString(key)
defer C.free(unsafe.Pointer(ckey))
C.lua_getfield(s.L, C.int(index), ckey)
}
// SetField sets a table field t[k] = v, where v is the value at the top of the stack
func (s *State) SetField(index int, key string) {
ckey := C.CString(key)
defer C.free(unsafe.Pointer(ckey))
C.lua_setfield(s.L, C.int(index), ckey)
}
// Next pops a key from the stack and pushes the next key-value pair from the table at the given index
func (s *State) Next(index int) bool {
return C.lua_next(s.L, C.int(index)) != 0
}
// PushValue pushes a Go value onto the stack with proper type conversion
func (s *State) PushValue(v any) error {
switch v := v.(type) {
case nil:
s.PushNil()
case bool:
s.PushBoolean(v)
case float64:
s.PushNumber(v)
case int:
s.PushNumber(float64(v))
case float64:
s.PushNumber(v)
case string:
s.PushString(v)
case map[string]interface{}:
case map[string]any:
// Special case: handle array stored in map
if arr, ok := v[""].([]float64); ok {
s.NewTable()
s.CreateTable(len(arr), 0)
for i, elem := range arr {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
@ -129,19 +277,19 @@ func (s *State) pushValueUnsafe(v interface{}) error {
}
return nil
}
return s.pushTableUnsafe(v)
return s.PushTable(v)
case []float64:
s.NewTable()
s.CreateTable(len(v), 0)
for i, elem := range v {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
s.SetTable(-3)
}
case []interface{}:
s.NewTable()
case []any:
s.CreateTable(len(v), 0)
for i, elem := range v {
s.PushNumber(float64(i + 1))
if err := s.pushValueUnsafe(elem); err != nil {
if err := s.PushValue(elem); err != nil {
return err
}
s.SetTable(-3)
@ -152,18 +300,10 @@ func (s *State) pushValueUnsafe(v interface{}) error {
return nil
}
// ToValue converts a Lua value to a Go value
func (s *State) ToValue(index int) (interface{}, error) {
if s.safeStack {
return stackGuardValue[interface{}](s, func() (interface{}, error) {
return s.toValueUnsafe(index)
})
}
return s.toValueUnsafe(index)
}
func (s *State) toValueUnsafe(index int) (interface{}, error) {
switch s.GetType(index) {
// ToValue converts a Lua value at the given index to a Go value
func (s *State) ToValue(index int) (any, error) {
luaType := s.GetType(index)
switch luaType {
case TypeNil:
return nil, nil
case TypeBoolean:
@ -173,151 +313,171 @@ func (s *State) toValueUnsafe(index int) (interface{}, error) {
case TypeString:
return s.ToString(index), nil
case TypeTable:
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.toTableUnsafe(index)
return s.ToTable(index)
default:
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
return nil, fmt.Errorf("unsupported type: %s", luaType)
}
}
// Simple operations remain unchanged as they don't need stack protection
// Global operations
func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) }
func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction }
func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable }
func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 }
func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) }
func (s *State) ToString(index int) string {
return C.GoString(C.lua_tolstring(s.L, C.int(index), nil))
}
func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) }
func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) }
func (s *State) SetTop(index int) { C.lua_settop(s.L, C.int(index)) }
// Push operations
func (s *State) PushNil() { C.lua_pushnil(s.L) }
func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, C.int(bool2int(b))) }
func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.double(n)) }
func (s *State) PushString(str string) {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
C.lua_pushstring(s.L, cstr)
}
// Helper functions
func bool2int(b bool) int {
if b {
return 1
}
return 0
}
func (s *State) absIndex(index int) int {
if index > 0 || index <= LUA_REGISTRYINDEX {
return index
}
return s.GetTop() + index + 1
}
// SetField sets a field in a table at the given index with cached absolute index
func (s *State) SetField(index int, key string) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.int(absIdx), cstr)
}
// GetField gets a field from a table with cached absolute index
func (s *State) GetField(index int, key string) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_getfield(s.L, C.int(absIdx), cstr)
}
// GetGlobal gets a global variable and pushes it onto the stack
// GetGlobal pushes the global variable with the given name onto the stack
func (s *State) GetGlobal(name string) {
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr)
s.GetField(LUA_GLOBALSINDEX, name)
}
// SetGlobal sets a global variable from the value at the top of the stack
// SetGlobal sets the global variable with the given name to the value at the top of the stack
func (s *State) SetGlobal(name string) {
// SetGlobal doesn't need stack space checking as it pops the value
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
s.SetField(LUA_GLOBALSINDEX, name)
}
// Remove removes element with cached absolute index
func (s *State) Remove(index int) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
// Code execution methods
// LoadString loads a Lua chunk from a string without executing it
func (s *State) LoadString(code string) error {
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
status := C.luaL_loadstring(s.L, ccode)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
}
C.lua_remove(s.L, C.int(absIdx))
return nil
}
// DoFile executes a Lua file with appropriate stack management
// LoadFile loads a Lua chunk from a file without executing it
func (s *State) LoadFile(filename string) error {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
status := C.luaL_loadfile(s.L, cfilename)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
}
return nil
}
// Call calls a function with the given number of arguments and results
func (s *State) Call(nargs, nresults int) error {
status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
}
return nil
}
// DoString executes a Lua string and cleans up the stack
func (s *State) DoString(code string) error {
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
status := C.do_string(s.L, ccode)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
}
return nil
}
// DoFile executes a Lua file and cleans up the stack
func (s *State) DoFile(filename string) error {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
if s.safeStack {
return stackGuardErr(s, func() error {
return s.safeCall(func() C.int {
return C.do_file(s.L, cfilename)
})
})
}
status := C.do_file(s.L, cfilename)
if status != 0 {
return &LuaError{
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
}
return nil
}
// Execute executes a Lua string and returns the number of results left on the stack
func (s *State) Execute(code string) (int, error) {
baseTop := s.GetTop()
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
status := C.execute_with_results(s.L, ccode, 1) // store_results=true
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return 0, err
}
return s.GetTop() - baseTop, nil
}
// ExecuteWithResult executes a Lua string and returns the first result
func (s *State) ExecuteWithResult(code string) (any, error) {
top := s.GetTop()
defer s.SetTop(top) // Restore stack when done
nresults, err := s.Execute(code)
if err != nil {
return nil, err
}
if nresults == 0 {
return nil, nil
}
return s.ToValue(-nresults)
}
// BatchExecute executes multiple statements with a single CGO transition
func (s *State) BatchExecute(statements []string) error {
// Join statements with semicolons
combinedCode := ""
for i, stmt := range statements {
combinedCode += stmt
if i < len(statements)-1 {
combinedCode += "; "
}
}
return s.DoString(combinedCode)
}
// Package path operations
// SetPackagePath sets the Lua package.path
func (s *State) SetPackagePath(path string) error {
path = filepath.ToSlash(path)
if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil {
return fmt.Errorf("setting package.path: %w", err)
}
return nil
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
code := fmt.Sprintf(`package.path = %q`, path)
return s.DoString(code)
}
// AddPackagePath adds a path to package.path
func (s *State) AddPackagePath(path string) error {
path = filepath.ToSlash(path)
if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil {
return fmt.Errorf("adding to package.path: %w", err)
}
return nil
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
code := fmt.Sprintf(`package.path = package.path .. ";%s"`, path)
return s.DoString(code)
}

View File

@ -1,276 +0,0 @@
package luajit
import (
"fmt"
"os"
"path/filepath"
"testing"
)
type stateFactory struct {
name string
new func() *State
}
var factories = []stateFactory{
{"unsafe", New},
{"safe", NewSafe},
}
func TestNew(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
})
}
}
func TestDoString(t *testing.T) {
tests := []struct {
name string
code string
wantErr bool
}{
{"simple addition", "return 1 + 1", false},
{"set global", "test = 42", false},
{"syntax error", "this is not valid lua", true},
{"runtime error", "error('test error')", true},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
err := L.DoString(tt.code)
if (err != nil) != tt.wantErr {
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
}
func TestPushAndGetValues(t *testing.T) {
values := []struct {
name string
push func(*State)
check func(*State) error
}{
{
name: "string",
push: func(L *State) { L.PushString("hello") },
check: func(L *State) error {
if got := L.ToString(-1); got != "hello" {
return fmt.Errorf("got %q, want %q", got, "hello")
}
return nil
},
},
{
name: "number",
push: func(L *State) { L.PushNumber(42.5) },
check: func(L *State) error {
if got := L.ToNumber(-1); got != 42.5 {
return fmt.Errorf("got %f, want %f", got, 42.5)
}
return nil
},
},
{
name: "boolean",
push: func(L *State) { L.PushBoolean(true) },
check: func(L *State) error {
if got := L.ToBoolean(-1); !got {
return fmt.Errorf("got %v, want true", got)
}
return nil
},
},
{
name: "nil",
push: func(L *State) { L.PushNil() },
check: func(L *State) error {
if typ := L.GetType(-1); typ != TypeNil {
return fmt.Errorf("got type %v, want TypeNil", typ)
}
return nil
},
},
}
for _, f := range factories {
for _, v := range values {
t.Run(f.name+"/"+v.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
v.push(L)
if err := v.check(L); err != nil {
t.Error(err)
}
})
}
}
}
func TestStackManipulation(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Push values
values := []string{"first", "second", "third"}
for _, v := range values {
L.PushString(v)
}
// Check size
if top := L.GetTop(); top != len(values) {
t.Errorf("stack size = %d, want %d", top, len(values))
}
// Pop one value
L.Pop(1)
// Check new top
if str := L.ToString(-1); str != "second" {
t.Errorf("top value = %q, want 'second'", str)
}
// Check new size
if top := L.GetTop(); top != len(values)-1 {
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
}
})
}
}
func TestGlobals(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Test via Lua
if err := L.DoString(`globalVar = "test"`); err != nil {
t.Fatalf("DoString error: %v", err)
}
// Get the global
L.GetGlobal("globalVar")
if str := L.ToString(-1); str != "test" {
t.Errorf("global value = %q, want 'test'", str)
}
L.Pop(1)
// Set and get via API
L.PushNumber(42)
L.SetGlobal("testNum")
L.GetGlobal("testNum")
if num := L.ToNumber(-1); num != 42 {
t.Errorf("global number = %f, want 42", num)
}
})
}
}
func TestDoFile(t *testing.T) {
L := NewSafe()
defer L.Close()
// Create test file
content := []byte(`
function add(a, b)
return a + b
end
result = add(40, 2)
`)
tmpDir := t.TempDir()
filename := filepath.Join(tmpDir, "test.lua")
if err := os.WriteFile(filename, content, 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
if err := L.DoFile(filename); err != nil {
t.Fatalf("DoFile failed: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("Expected result=42, got %v", result)
}
}
func TestRequireAndPackagePath(t *testing.T) {
L := NewSafe()
defer L.Close()
tmpDir := t.TempDir()
// Create module file
moduleContent := []byte(`
local M = {}
function M.multiply(a, b)
return a * b
end
return M
`)
if err := os.WriteFile(filepath.Join(tmpDir, "mathmod.lua"), moduleContent, 0644); err != nil {
t.Fatalf("Failed to create module file: %v", err)
}
// Add module path and test require
if err := L.AddPackagePath(filepath.Join(tmpDir, "?.lua")); err != nil {
t.Fatalf("AddPackagePath failed: %v", err)
}
if err := L.DoString(`
local math = require("mathmod")
result = math.multiply(6, 7)
`); err != nil {
t.Fatalf("Failed to require module: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("Expected result=42, got %v", result)
}
}
func TestSetPackagePath(t *testing.T) {
L := NewSafe()
defer L.Close()
customPath := "./custom/?.lua"
if err := L.SetPackagePath(customPath); err != nil {
t.Fatalf("SetPackagePath failed: %v", err)
}
L.GetGlobal("package")
L.GetField(-1, "path")
if path := L.ToString(-1); path != customPath {
t.Errorf("Expected package.path=%q, got %q", customPath, path)
}
}