Compare commits
26 Commits
Author | SHA1 | Date | |
---|---|---|---|
6b9e2a0e20 | |||
e58f9a6028 | |||
a2b4b1c927 | |||
44337fffe3 | |||
0756cabcaa | |||
656ac1a703 | |||
5774808064 | |||
875abee366 | |||
4ad87f81f3 | |||
9e5092acdb | |||
b83f77d7a6 | |||
29679349ef | |||
fed0c2ad34 | |||
faab0a2d08 | |||
f106dfd9ea | |||
936e4ccdc2 | |||
075b45768f | |||
13686b3e66 | |||
98ca857d73 | |||
143b9333c6 | |||
865ac8859f | |||
4dc266201f | |||
7c79616cac | |||
146b0a51db | |||
c74ad4bbc9 | |||
229884ba97 |
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -21,3 +21,5 @@
|
|||
go.work
|
||||
|
||||
.idea
|
||||
|
||||
bench/profile_results
|
||||
|
|
267
DOCS.md
267
DOCS.md
|
@ -1,19 +1,11 @@
|
|||
# LuaJIT Go Wrapper API Documentation
|
||||
# API Documentation
|
||||
|
||||
## State Management
|
||||
|
||||
### NewSafe() *State
|
||||
Creates a new Lua state with stack safety enabled.
|
||||
```go
|
||||
L := luajit.NewSafe()
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
```
|
||||
|
||||
### New() *State
|
||||
Creates a new Lua state without stack safety checks.
|
||||
New creates a new Lua state with optional standard libraries; true if not specified
|
||||
```go
|
||||
L := luajit.New()
|
||||
L := luajit.New() // or luajit.New(false)
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
```
|
||||
|
@ -38,6 +30,18 @@ Returns the index of the top element in the stack.
|
|||
top := L.GetTop() // 0 for empty stack
|
||||
```
|
||||
|
||||
### SetTop(index int)
|
||||
Sets the stack top to a specific index.
|
||||
```go
|
||||
L.SetTop(2) // Truncate stack to 2 elements
|
||||
```
|
||||
|
||||
### PushCopy(index int)
|
||||
Pushes a copy of the value at the given index onto the stack.
|
||||
```go
|
||||
L.PushCopy(-1) // Duplicate the top element
|
||||
```
|
||||
|
||||
### Pop(n int)
|
||||
Removes n elements from the stack.
|
||||
```go
|
||||
|
@ -52,6 +56,9 @@ L.Remove(-1) // Remove top element
|
|||
L.Remove(1) // Remove first element
|
||||
```
|
||||
|
||||
### absIndex(index int) int
|
||||
Internal function that converts a possibly negative index to its absolute position.
|
||||
|
||||
### checkStack(n int) error
|
||||
Internal function that ensures there's enough space for n new elements.
|
||||
```go
|
||||
|
@ -70,9 +77,12 @@ if L.GetType(-1) == TypeString {
|
|||
}
|
||||
```
|
||||
|
||||
### IsFunction(index int) bool
|
||||
### IsNil(index int) bool
|
||||
### IsBoolean(index int) bool
|
||||
### IsNumber(index int) bool
|
||||
### IsString(index int) bool
|
||||
### IsTable(index int) bool
|
||||
### IsUserData(index int) bool
|
||||
### IsFunction(index int) bool
|
||||
Type checking functions for specific Lua types.
|
||||
```go
|
||||
if L.IsTable(-1) {
|
||||
|
@ -100,7 +110,7 @@ Converts the value to a boolean.
|
|||
bool := L.ToBoolean(-1)
|
||||
```
|
||||
|
||||
### ToValue(index int) (interface{}, error)
|
||||
### ToValue(index int) (any, error)
|
||||
Converts any Lua value to its Go equivalent.
|
||||
```go
|
||||
val, err := L.ToValue(-1)
|
||||
|
@ -109,7 +119,7 @@ if err != nil {
|
|||
}
|
||||
```
|
||||
|
||||
### ToTable(index int) (map[string]interface{}, error)
|
||||
### ToTable(index int) (map[string]any, error)
|
||||
Converts a Lua table to a Go map.
|
||||
```go
|
||||
table, err := L.ToTable(-1)
|
||||
|
@ -118,6 +128,12 @@ if err != nil {
|
|||
}
|
||||
```
|
||||
|
||||
### GetTableLength(index int) int
|
||||
Returns the length of a table at the given index.
|
||||
```go
|
||||
length := L.GetTableLength(-1)
|
||||
```
|
||||
|
||||
## Value Pushing
|
||||
|
||||
### PushNil()
|
||||
|
@ -132,32 +148,98 @@ L.PushBoolean(true)
|
|||
L.PushNil()
|
||||
```
|
||||
|
||||
### PushValue(v interface{}) error
|
||||
### PushValue(v any) error
|
||||
Pushes any Go value onto the stack.
|
||||
```go
|
||||
err := L.PushValue(myValue)
|
||||
```
|
||||
|
||||
### PushTable(table map[string]interface{}) error
|
||||
### PushTable(table map[string]any) error
|
||||
Pushes a Go map as a Lua table.
|
||||
```go
|
||||
data := map[string]interface{}{
|
||||
data := map[string]any{
|
||||
"key": "value",
|
||||
"numbers": []float64{1, 2, 3},
|
||||
}
|
||||
err := L.PushTable(data)
|
||||
```
|
||||
|
||||
## Function Registration
|
||||
## Table Operations
|
||||
|
||||
### RegisterGoFunction(name string, fn GoFunction) error
|
||||
Registers a Go function that can be called from Lua.
|
||||
### CreateTable(narr, nrec int)
|
||||
Creates a new table with pre-allocated space.
|
||||
```go
|
||||
adder := func(s *State) int {
|
||||
L.CreateTable(10, 5) // Space for 10 array elements, 5 records
|
||||
```
|
||||
|
||||
### NewTable()
|
||||
Creates a new empty table and pushes it onto the stack.
|
||||
```go
|
||||
L.NewTable()
|
||||
```
|
||||
|
||||
### GetTable(index int)
|
||||
Gets a table field (t[k]) where t is at the given index and k is at the top of the stack.
|
||||
```go
|
||||
L.PushString("key")
|
||||
L.GetTable(-2) // Gets table["key"]
|
||||
```
|
||||
|
||||
### SetTable(index int)
|
||||
Sets a table field (t[k] = v) where t is at the given index, k is at -2, and v is at -1.
|
||||
```go
|
||||
L.PushString("key")
|
||||
L.PushString("value")
|
||||
L.SetTable(-3) // table["key"] = "value"
|
||||
```
|
||||
|
||||
### GetField(index int, key string)
|
||||
Gets a table field t[k] and pushes it onto the stack.
|
||||
```go
|
||||
L.GetField(-1, "name") // gets table.name
|
||||
```
|
||||
|
||||
### SetField(index int, key string)
|
||||
Sets a table field t[k] = v, where v is the value at the top of the stack.
|
||||
```go
|
||||
L.PushString("value")
|
||||
L.SetField(-2, "key") // table.key = "value"
|
||||
```
|
||||
|
||||
### Next(index int) bool
|
||||
Pops a key from the stack and pushes the next key-value pair from the table.
|
||||
```go
|
||||
L.PushNil() // Start iteration
|
||||
for L.Next(-2) {
|
||||
// Stack now has key at -2 and value at -1
|
||||
key := L.ToString(-2)
|
||||
value := L.ToString(-1)
|
||||
L.Pop(1) // Remove value, keep key for next iteration
|
||||
}
|
||||
```
|
||||
|
||||
## Function Registration and Calling
|
||||
|
||||
### GoFunction
|
||||
Type definition for Go functions callable from Lua.
|
||||
```go
|
||||
type GoFunction func(*State) int
|
||||
```
|
||||
|
||||
### PushGoFunction(fn GoFunction) error
|
||||
Wraps a Go function and pushes it onto the Lua stack.
|
||||
```go
|
||||
adder := func(s *luajit.State) int {
|
||||
sum := s.ToNumber(1) + s.ToNumber(2)
|
||||
s.PushNumber(sum)
|
||||
return 1
|
||||
}
|
||||
err := L.PushGoFunction(adder)
|
||||
```
|
||||
|
||||
### RegisterGoFunction(name string, fn GoFunction) error
|
||||
Registers a Go function as a global Lua function.
|
||||
```go
|
||||
err := L.RegisterGoFunction("add", adder)
|
||||
```
|
||||
|
||||
|
@ -167,23 +249,45 @@ Removes a previously registered function.
|
|||
L.UnregisterGoFunction("add")
|
||||
```
|
||||
|
||||
## Package Management
|
||||
|
||||
### SetPackagePath(path string) error
|
||||
Sets the Lua package.path variable.
|
||||
### Call(nargs, nresults int) error
|
||||
Calls a function with the given number of arguments and results.
|
||||
```go
|
||||
err := L.SetPackagePath("./?.lua;/usr/local/share/lua/5.1/?.lua")
|
||||
L.GetGlobal("myfunction")
|
||||
L.PushNumber(1)
|
||||
L.PushNumber(2)
|
||||
err := L.Call(2, 1) // Call with 2 args, expect 1 result
|
||||
```
|
||||
|
||||
### AddPackagePath(path string) error
|
||||
Adds a path to the existing package.path.
|
||||
## Global Operations
|
||||
|
||||
### GetGlobal(name string)
|
||||
Gets a global variable and pushes it onto the stack.
|
||||
```go
|
||||
err := L.AddPackagePath("./modules/?.lua")
|
||||
L.GetGlobal("myGlobal")
|
||||
```
|
||||
|
||||
### SetGlobal(name string)
|
||||
Sets a global variable from the value at the top of the stack.
|
||||
```go
|
||||
L.PushNumber(42)
|
||||
L.SetGlobal("answer") // answer = 42
|
||||
```
|
||||
|
||||
## Code Execution
|
||||
|
||||
### DoString(str string) error
|
||||
### LoadString(code string) error
|
||||
Loads a Lua chunk from a string without executing it.
|
||||
```go
|
||||
err := L.LoadString("return 42")
|
||||
```
|
||||
|
||||
### LoadFile(filename string) error
|
||||
Loads a Lua chunk from a file without executing it.
|
||||
```go
|
||||
err := L.LoadFile("script.lua")
|
||||
```
|
||||
|
||||
### DoString(code string) error
|
||||
Executes a string of Lua code.
|
||||
```go
|
||||
err := L.DoString(`
|
||||
|
@ -199,32 +303,76 @@ Executes a Lua file.
|
|||
err := L.DoFile("script.lua")
|
||||
```
|
||||
|
||||
## Table Operations
|
||||
|
||||
### GetField(index int, key string)
|
||||
Gets a field from a table at the given index.
|
||||
### Execute(code string) (int, error)
|
||||
Executes a Lua string and returns the number of results left on the stack.
|
||||
```go
|
||||
L.GetField(-1, "name") // gets table.name
|
||||
nresults, err := L.Execute("return 1, 2, 3")
|
||||
// nresults would be 3
|
||||
```
|
||||
|
||||
### SetField(index int, key string)
|
||||
Sets a field in a table at the given index.
|
||||
### ExecuteWithResult(code string) (any, error)
|
||||
Executes a Lua string and returns the first result.
|
||||
```go
|
||||
L.PushString("value")
|
||||
L.SetField(-2, "key") // table.key = "value"
|
||||
result, err := L.ExecuteWithResult("return 'hello'")
|
||||
// result would be "hello"
|
||||
```
|
||||
|
||||
### GetGlobal(name string)
|
||||
Gets a global variable.
|
||||
## Bytecode Operations
|
||||
|
||||
### CompileBytecode(code string, name string) ([]byte, error)
|
||||
Compiles a Lua chunk to bytecode without executing it.
|
||||
```go
|
||||
L.GetGlobal("myGlobal")
|
||||
bytecode, err := L.CompileBytecode("return 42", "test")
|
||||
```
|
||||
|
||||
### SetGlobal(name string)
|
||||
Sets a global variable from the value at the top of the stack.
|
||||
### LoadBytecode(bytecode []byte, name string) error
|
||||
Loads precompiled bytecode without executing it.
|
||||
```go
|
||||
L.PushNumber(42)
|
||||
L.SetGlobal("answer") // answer = 42
|
||||
err := L.LoadBytecode(bytecode, "test")
|
||||
```
|
||||
|
||||
### RunBytecode() error
|
||||
Executes previously loaded bytecode with 0 results.
|
||||
```go
|
||||
err := L.RunBytecode()
|
||||
```
|
||||
|
||||
### RunBytecodeWithResults(nresults int) error
|
||||
Executes bytecode and keeps nresults on the stack.
|
||||
```go
|
||||
err := L.RunBytecodeWithResults(1)
|
||||
```
|
||||
|
||||
### LoadAndRunBytecode(bytecode []byte, name string) error
|
||||
Loads and executes bytecode.
|
||||
```go
|
||||
err := L.LoadAndRunBytecode(bytecode, "test")
|
||||
```
|
||||
|
||||
### LoadAndRunBytecodeWithResults(bytecode []byte, name string, nresults int) error
|
||||
Loads and executes bytecode, preserving results.
|
||||
```go
|
||||
err := L.LoadAndRunBytecodeWithResults(bytecode, "test", 1)
|
||||
```
|
||||
|
||||
### CompileAndRun(code string, name string) error
|
||||
Compiles and immediately executes Lua code.
|
||||
```go
|
||||
err := L.CompileAndRun("answer = 42", "test")
|
||||
```
|
||||
|
||||
## Package Path Operations
|
||||
|
||||
### SetPackagePath(path string) error
|
||||
Sets the Lua package.path.
|
||||
```go
|
||||
err := L.SetPackagePath("./?.lua;/usr/local/share/lua/5.1/?.lua")
|
||||
```
|
||||
|
||||
### AddPackagePath(path string) error
|
||||
Adds a path to package.path.
|
||||
```go
|
||||
err := L.AddPackagePath("./modules/?.lua")
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
@ -238,13 +386,21 @@ type LuaError struct {
|
|||
}
|
||||
```
|
||||
|
||||
### getStackTrace() string
|
||||
### GetStackTrace() string
|
||||
Gets the current Lua stack trace.
|
||||
```go
|
||||
trace := L.getStackTrace()
|
||||
trace := L.GetStackTrace()
|
||||
fmt.Println(trace)
|
||||
```
|
||||
|
||||
### safeCall(f func() C.int) error
|
||||
Internal function that wraps a potentially dangerous C call with stack checking.
|
||||
```go
|
||||
err := s.safeCall(func() C.int {
|
||||
return C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
|
||||
})
|
||||
```
|
||||
|
||||
## Thread Safety Notes
|
||||
|
||||
- The function registry is thread-safe
|
||||
|
@ -256,15 +412,20 @@ fmt.Println(trace)
|
|||
|
||||
Always pair state creation with cleanup:
|
||||
```go
|
||||
L := luajit.NewSafe()
|
||||
L := luajit.New()
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
```
|
||||
|
||||
Stack management in unsafe mode requires manual attention:
|
||||
Stack management requires manual attention:
|
||||
```go
|
||||
L := luajit.New()
|
||||
L.PushString("hello")
|
||||
// ... use the string
|
||||
L.Pop(1) // Clean up when done
|
||||
```
|
||||
```
|
||||
|
||||
Sandbox management:
|
||||
```go
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
```
|
||||
|
|
2
LICENSE
2
LICENSE
|
@ -1,6 +1,6 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025 Sky
|
||||
Copyright (c) 2025 Sharkk, Skylear Johnson
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
|
|
159
README.md
159
README.md
|
@ -1,6 +1,6 @@
|
|||
# LuaJIT Go Wrapper
|
||||
|
||||
Hey there! This is a Go wrapper for LuaJIT that makes it easy to embed Lua in your Go applications. We've focused on making it both safe and fast, while keeping the API clean and intuitive.
|
||||
This is a Go wrapper for LuaJIT that makes it easy to embed Lua in your Go applications. We've focused on making it both performant and developer-friendly, with an API that feels natural to use.
|
||||
|
||||
## What's This For?
|
||||
|
||||
|
@ -22,51 +22,13 @@ You'll need LuaJIT's development files, but don't worry - we include libraries f
|
|||
|
||||
Here's the simplest thing you can do:
|
||||
```go
|
||||
L := luajit.NewSafe()
|
||||
L := luajit.New() // pass false to not load standard libs
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
|
||||
err := L.DoString(`print("Hey from Lua!")`)
|
||||
```
|
||||
|
||||
## Stack Safety: Choose Your Adventure
|
||||
|
||||
One of the key decisions you'll make is whether to use stack-safe mode. Here's what that means:
|
||||
|
||||
### Stack-Safe Mode (NewSafe())
|
||||
```go
|
||||
L := luajit.NewSafe()
|
||||
```
|
||||
Think of this as driving with guardrails. It's perfect when:
|
||||
- You're new to Lua or embedding scripting languages
|
||||
- You're writing a server or long-running application
|
||||
- You want to handle untrusted Lua code
|
||||
- You'd rather have slightly slower code than mysterious crashes
|
||||
|
||||
The safe mode will:
|
||||
- Prevent stack overflows
|
||||
- Check types more thoroughly
|
||||
- Clean up after messy Lua code
|
||||
- Give you better error messages
|
||||
|
||||
### Non-Stack-Safe Mode (New())
|
||||
```go
|
||||
L := luajit.New()
|
||||
```
|
||||
This is like taking off the training wheels. Use it when:
|
||||
- You know exactly how your Lua code behaves
|
||||
- You've profiled your application and need more speed
|
||||
- You're doing lots of rapid, simple Lua calls
|
||||
- You're writing performance-critical code
|
||||
|
||||
The unsafe mode:
|
||||
- Skips most safety checks
|
||||
- Runs noticeably faster
|
||||
- Gives you direct control over the stack
|
||||
- Can crash spectacularly if you make a mistake
|
||||
|
||||
Most applications should start with stack-safe mode and only switch to unsafe mode if profiling shows it's necessary.
|
||||
|
||||
## Working with Bytecode
|
||||
|
||||
Need even more performance? You can compile your Lua code to bytecode and reuse it:
|
||||
|
@ -82,28 +44,30 @@ bytecode, err := L.CompileBytecode(`
|
|||
|
||||
// Execute many times
|
||||
for i := 0; i < 1000; i++ {
|
||||
err := L.LoadBytecode(bytecode, "calc")
|
||||
err := L.LoadAndRunBytecode(bytecode, "calc")
|
||||
}
|
||||
|
||||
// Or do both at once
|
||||
err := L.CompileAndLoad(`return "hello"`, "greeting")
|
||||
err := L.CompileAndRun(`return "hello"`, "greeting")
|
||||
```
|
||||
|
||||
### When to Use Bytecode
|
||||
|
||||
Bytecode execution is consistently faster than direct execution:
|
||||
- Simple operations: 20-60% faster
|
||||
- String operations: Up to 60% speedup
|
||||
- Loop-heavy code: 10-15% improvement
|
||||
- Table operations: 10-15% faster
|
||||
|
||||
Some benchmark results on a typical system:
|
||||
```
|
||||
Operation Direct Exec Bytecode Exec
|
||||
----------------------------------------
|
||||
Simple Math 1.5M ops/sec 2.4M ops/sec
|
||||
String Ops 370K ops/sec 600K ops/sec
|
||||
Table Creation 127K ops/sec 146K ops/sec
|
||||
Benchmark Ops/sec Comparison
|
||||
----------------------------------------------------------------------------
|
||||
BenchmarkSimpleDoString 2,561,012 Base
|
||||
BenchmarkSimplePrecompiledBytecode 3,828,841 +49.5% faster
|
||||
BenchmarkFunctionCallDoString 2,021,098 Base
|
||||
BenchmarkFunctionCallPrecompiled 3,482,074 +72.3% faster
|
||||
BenchmarkLoopDoString 188,119 Base
|
||||
BenchmarkLoopPrecompiled 211,081 +12.2% faster
|
||||
BenchmarkTableOperationsDoString 84,086 Base
|
||||
BenchmarkTableOperationsPrecompiled 93,655 +11.4% faster
|
||||
BenchmarkComplexScript 33,133 Base
|
||||
BenchmarkComplexScriptPrecompiled 41,044 +23.9% faster
|
||||
```
|
||||
|
||||
Use bytecode when you:
|
||||
|
@ -114,7 +78,7 @@ Use bytecode when you:
|
|||
|
||||
## Registering Go Functions
|
||||
|
||||
Want to call Go code from Lua? Easy:
|
||||
Want to call Go code from Lua? It's straightforward:
|
||||
```go
|
||||
// This function adds two numbers and returns the result
|
||||
adder := func(s *luajit.State) int {
|
||||
|
@ -137,7 +101,7 @@ Lua tables are pretty powerful - they're like a mix of Go's maps and slices. We
|
|||
|
||||
```go
|
||||
// Go → Lua
|
||||
stuff := map[string]interface{}{
|
||||
stuff := map[string]any{
|
||||
"name": "Arthur Dent",
|
||||
"age": 30,
|
||||
"items": []float64{1, 2, 3},
|
||||
|
@ -151,25 +115,96 @@ result, err := L.ToTable(-1)
|
|||
|
||||
## Error Handling
|
||||
|
||||
We try to give you useful errors instead of mysterious panics:
|
||||
We provide useful errors instead of mysterious panics:
|
||||
```go
|
||||
if err := L.DoString("this isn't valid Lua!"); err != nil {
|
||||
if luaErr, ok := err.(*luajit.LuaError); ok {
|
||||
fmt.Printf("Oops: %s\n", luaErr.Message)
|
||||
fmt.Printf("Error: %s\n", luaErr.Message)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## A Few Tips
|
||||
## Memory Management
|
||||
|
||||
- Always use those `defer L.Close()` and `defer L.Cleanup()` calls - they prevent memory leaks
|
||||
The wrapper uses a custom table pooling system to reduce GC pressure when handling many tables:
|
||||
|
||||
```go
|
||||
// Tables are pooled and reused internally for better performance
|
||||
for i := 0; i < 1000; i++ {
|
||||
L.GetGlobal("table")
|
||||
table, _ := L.ToTable(-1)
|
||||
// Use table...
|
||||
L.Pop(1)
|
||||
// Table is automatically returned to pool
|
||||
}
|
||||
```
|
||||
|
||||
The sandbox also manages its environment efficiently:
|
||||
|
||||
```go
|
||||
// Environment objects are pooled and reused
|
||||
for i := 0; i < 1000; i++ {
|
||||
result, _ := sandbox.Run("return i + 1")
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### State Management
|
||||
- Always use `defer L.Close()` and `defer L.Cleanup()` to prevent memory leaks
|
||||
- Each Lua state should stick to one goroutine
|
||||
- For concurrent stuff, create multiple states
|
||||
- For concurrent operations, create multiple states
|
||||
- You can share functions between states safely
|
||||
- Keep an eye on your stack in unsafe mode - it won't clean up after itself
|
||||
- Start with stack-safe mode and measure before optimizing
|
||||
- Keep an eye on your stack management - pop as many items as you push
|
||||
|
||||
### Bytecode Optimization
|
||||
- Use bytecode for frequently executed code paths
|
||||
- Consider compiling critical Lua code to bytecode at startup
|
||||
- For small scripts (< 1024 bytes), direct execution might be faster
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Bytecode Serialization
|
||||
|
||||
You can serialize bytecode for distribution or caching:
|
||||
|
||||
```go
|
||||
// Compile once
|
||||
bytecode, _ := L.CompileBytecode(complexScript, "module")
|
||||
|
||||
// Save to file
|
||||
ioutil.WriteFile("module.luac", bytecode, 0644)
|
||||
|
||||
// Later, load from file
|
||||
bytecode, _ := ioutil.ReadFile("module.luac")
|
||||
L.LoadAndRunBytecode(bytecode, "module")
|
||||
```
|
||||
|
||||
### Closures and Upvalues
|
||||
|
||||
Bytecode properly preserves closures and upvalues:
|
||||
|
||||
```go
|
||||
code := `
|
||||
local counter = 0
|
||||
return function()
|
||||
counter = counter + 1
|
||||
return counter
|
||||
end
|
||||
`
|
||||
|
||||
bytecode, _ := L.CompileBytecode(code, "counter")
|
||||
L.LoadAndRunBytecodeWithResults(bytecode, "counter", 1)
|
||||
L.SetGlobal("increment")
|
||||
|
||||
// Later...
|
||||
L.GetGlobal("increment")
|
||||
L.Call(0, 1) // Returns 1
|
||||
L.Pop(1)
|
||||
|
||||
L.GetGlobal("increment")
|
||||
L.Call(0, 1) // Returns 2
|
||||
```
|
||||
|
||||
## Need Help?
|
||||
|
||||
|
@ -177,4 +212,4 @@ Check out the tests in the repository - they're full of examples. If you're stuc
|
|||
|
||||
## License
|
||||
|
||||
MIT Licensed - do whatever you want with it!
|
||||
MIT Licensed - do whatever you want with it!
|
||||
|
|
171
bench/bench_profile.go
Normal file
171
bench/bench_profile.go
Normal file
|
@ -0,0 +1,171 @@
|
|||
package luajit_bench
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Profiling flags
|
||||
var (
|
||||
cpuProfile = flag.String("cpuprofile", "", "write cpu profile to `file`")
|
||||
memProfile = flag.String("memprofile", "", "write memory profile to `file`")
|
||||
memProfileGC = flag.Bool("memprofilegc", false, "force GC before writing memory profile")
|
||||
blockProfile = flag.String("blockprofile", "", "write block profile to `file`")
|
||||
mutexProfile = flag.String("mutexprofile", "", "write mutex profile to `file`")
|
||||
)
|
||||
|
||||
// setupTestMain configures profiling for benchmarks
|
||||
func setupTestMain() {
|
||||
// Make sure the flags are parsed
|
||||
if !flag.Parsed() {
|
||||
flag.Parse()
|
||||
}
|
||||
|
||||
// CPU profiling
|
||||
if *cpuProfile != "" {
|
||||
f, err := os.Create(*cpuProfile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create CPU profile: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to start CPU profile: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("CPU profiling enabled")
|
||||
}
|
||||
|
||||
// Block profiling (goroutine blocking)
|
||||
if *blockProfile != "" {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
fmt.Println("Block profiling enabled")
|
||||
}
|
||||
|
||||
// Mutex profiling (lock contention)
|
||||
if *mutexProfile != "" {
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
fmt.Println("Mutex profiling enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// teardownTestMain completes profiling and writes output files
|
||||
func teardownTestMain() {
|
||||
// Stop CPU profile
|
||||
if *cpuProfile != "" {
|
||||
pprof.StopCPUProfile()
|
||||
fmt.Println("CPU profile written to", *cpuProfile)
|
||||
}
|
||||
|
||||
// Write memory profile
|
||||
if *memProfile != "" {
|
||||
f, err := os.Create(*memProfile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create memory profile: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Force garbage collection before writing memory profile if requested
|
||||
if *memProfileGC {
|
||||
runtime.GC()
|
||||
}
|
||||
|
||||
if err := pprof.WriteHeapProfile(f); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to write memory profile: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("Memory profile written to", *memProfile)
|
||||
}
|
||||
|
||||
// Write block profile
|
||||
if *blockProfile != "" {
|
||||
f, err := os.Create(*blockProfile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create block profile: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := pprof.Lookup("block").WriteTo(f, 0); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to write block profile: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("Block profile written to", *blockProfile)
|
||||
}
|
||||
|
||||
// Write mutex profile
|
||||
if *mutexProfile != "" {
|
||||
f, err := os.Create(*mutexProfile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create mutex profile: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := pprof.Lookup("mutex").WriteTo(f, 0); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to write mutex profile: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("Mutex profile written to", *mutexProfile)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMain is the entry point for all tests in this package
|
||||
func TestMain(m *testing.M) {
|
||||
setupTestMain()
|
||||
code := m.Run()
|
||||
teardownTestMain()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// MemStats captures a snapshot of memory statistics
|
||||
type MemStats struct {
|
||||
Alloc uint64
|
||||
TotalAlloc uint64
|
||||
Sys uint64
|
||||
Mallocs uint64
|
||||
Frees uint64
|
||||
HeapAlloc uint64
|
||||
}
|
||||
|
||||
// CaptureMemStats returns current memory statistics
|
||||
func CaptureMemStats() MemStats {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
return MemStats{
|
||||
Alloc: m.Alloc,
|
||||
TotalAlloc: m.TotalAlloc,
|
||||
Sys: m.Sys,
|
||||
Mallocs: m.Mallocs,
|
||||
Frees: m.Frees,
|
||||
HeapAlloc: m.HeapAlloc,
|
||||
}
|
||||
}
|
||||
|
||||
// TrackMemoryUsage runs fn and reports memory usage before and after
|
||||
func TrackMemoryUsage(b *testing.B, name string, fn func()) {
|
||||
b.Helper()
|
||||
|
||||
// Force GC before measurement
|
||||
runtime.GC()
|
||||
|
||||
// Capture memory stats before
|
||||
before := CaptureMemStats()
|
||||
|
||||
// Run the function
|
||||
fn()
|
||||
|
||||
// Force GC after measurement to get accurate stats
|
||||
runtime.GC()
|
||||
|
||||
// Capture memory stats after
|
||||
after := CaptureMemStats()
|
||||
|
||||
// Report stats
|
||||
b.ReportMetric(float64(after.Mallocs-before.Mallocs), name+"-mallocs")
|
||||
b.ReportMetric(float64(after.TotalAlloc-before.TotalAlloc)/float64(b.N), name+"-bytes/op")
|
||||
}
|
472
bench/bench_test.go
Normal file
472
bench/bench_test.go
Normal file
|
@ -0,0 +1,472 @@
|
|||
package luajit_bench
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// BenchmarkSimpleDoString benchmarks direct execution of a simple expression
|
||||
func BenchmarkSimpleDoString(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := "local x = 1 + 1"
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "dostring", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.DoString(code); err != nil {
|
||||
b.Fatalf("DoString failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkSimpleCompileAndRun benchmarks compile and run of a simple expression
|
||||
func BenchmarkSimpleCompileAndRun(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := "local x = 1 + 1"
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "compile-run", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.CompileAndRun(code, "simple"); err != nil {
|
||||
b.Fatalf("CompileAndRun failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkSimpleCompileLoadRun benchmarks compile, load, and run of a simple expression
|
||||
func BenchmarkSimpleCompileLoadRun(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := "local x = 1 + 1"
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "compile-load-run", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
bytecode, err := state.CompileBytecode(code, "simple")
|
||||
if err != nil {
|
||||
b.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
if err := state.LoadAndRunBytecode(bytecode, "simple"); err != nil {
|
||||
b.Fatalf("LoadAndRunBytecode failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkSimplePrecompiledBytecode benchmarks running precompiled bytecode
|
||||
func BenchmarkSimplePrecompiledBytecode(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := "local x = 1 + 1"
|
||||
bytecode, err := state.CompileBytecode(code, "simple")
|
||||
if err != nil {
|
||||
b.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "precompiled", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.LoadAndRunBytecode(bytecode, "simple"); err != nil {
|
||||
b.Fatalf("LoadAndRunBytecode failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkFunctionCallDoString benchmarks direct execution of a function call
|
||||
func BenchmarkFunctionCallDoString(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Setup function
|
||||
setupCode := `
|
||||
function add(a, b)
|
||||
return a + b
|
||||
end
|
||||
`
|
||||
if err := state.DoString(setupCode); err != nil {
|
||||
b.Fatalf("Failed to set up function: %v", err)
|
||||
}
|
||||
|
||||
code := "local result = add(10, 20)"
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "func-dostring", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.DoString(code); err != nil {
|
||||
b.Fatalf("DoString failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkFunctionCallPrecompiled benchmarks precompiled function call
|
||||
func BenchmarkFunctionCallPrecompiled(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Setup function
|
||||
setupCode := `
|
||||
function add(a, b)
|
||||
return a + b
|
||||
end
|
||||
`
|
||||
if err := state.DoString(setupCode); err != nil {
|
||||
b.Fatalf("Failed to set up function: %v", err)
|
||||
}
|
||||
|
||||
code := "local result = add(10, 20)"
|
||||
bytecode, err := state.CompileBytecode(code, "call")
|
||||
if err != nil {
|
||||
b.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "func-precompiled", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.LoadAndRunBytecode(bytecode, "call"); err != nil {
|
||||
b.Fatalf("LoadAndRunBytecode failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkLoopDoString benchmarks direct execution of a loop
|
||||
func BenchmarkLoopDoString(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := `
|
||||
local sum = 0
|
||||
for i = 1, 1000 do
|
||||
sum = sum + i
|
||||
end
|
||||
`
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "loop-dostring", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.DoString(code); err != nil {
|
||||
b.Fatalf("DoString failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkLoopPrecompiled benchmarks precompiled loop execution
|
||||
func BenchmarkLoopPrecompiled(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := `
|
||||
local sum = 0
|
||||
for i = 1, 1000 do
|
||||
sum = sum + i
|
||||
end
|
||||
`
|
||||
bytecode, err := state.CompileBytecode(code, "loop")
|
||||
if err != nil {
|
||||
b.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "loop-precompiled", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.LoadAndRunBytecode(bytecode, "loop"); err != nil {
|
||||
b.Fatalf("LoadAndRunBytecode failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkTableOperationsDoString benchmarks direct execution of table operations
|
||||
func BenchmarkTableOperationsDoString(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := `
|
||||
local t = {}
|
||||
for i = 1, 100 do
|
||||
t[i] = i * 2
|
||||
end
|
||||
local sum = 0
|
||||
for i, v in ipairs(t) do
|
||||
sum = sum + v
|
||||
end
|
||||
`
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "table-dostring", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.DoString(code); err != nil {
|
||||
b.Fatalf("DoString failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkTableOperationsPrecompiled benchmarks precompiled table operations
|
||||
func BenchmarkTableOperationsPrecompiled(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := `
|
||||
local t = {}
|
||||
for i = 1, 100 do
|
||||
t[i] = i * 2
|
||||
end
|
||||
local sum = 0
|
||||
for i, v in ipairs(t) do
|
||||
sum = sum + v
|
||||
end
|
||||
`
|
||||
bytecode, err := state.CompileBytecode(code, "table")
|
||||
if err != nil {
|
||||
b.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "table-precompiled", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.LoadAndRunBytecode(bytecode, "table"); err != nil {
|
||||
b.Fatalf("LoadAndRunBytecode failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkGoFunctionCall benchmarks calling a Go function from Lua
|
||||
func BenchmarkGoFunctionCall(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Register a simple Go function
|
||||
add := func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a + b)
|
||||
return 1
|
||||
}
|
||||
if err := state.RegisterGoFunction("add", add); err != nil {
|
||||
b.Fatalf("RegisterGoFunction failed: %v", err)
|
||||
}
|
||||
|
||||
code := "local result = add(10, 20)"
|
||||
bytecode, err := state.CompileBytecode(code, "gofunc")
|
||||
if err != nil {
|
||||
b.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "go-func-call", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.LoadAndRunBytecode(bytecode, "gofunc"); err != nil {
|
||||
b.Fatalf("LoadAndRunBytecode failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkComplexScript benchmarks a more complex script
|
||||
func BenchmarkComplexScript(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := `
|
||||
-- Define a simple class
|
||||
local Class = {}
|
||||
Class.__index = Class
|
||||
|
||||
function Class.new(x, y)
|
||||
local self = setmetatable({}, Class)
|
||||
self.x = x or 0
|
||||
self.y = y or 0
|
||||
return self
|
||||
end
|
||||
|
||||
function Class:move(dx, dy)
|
||||
self.x = self.x + dx
|
||||
self.y = self.y + dy
|
||||
return self
|
||||
end
|
||||
|
||||
function Class:getPosition()
|
||||
return self.x, self.y
|
||||
end
|
||||
|
||||
-- Create instances and operate on them
|
||||
local instances = {}
|
||||
for i = 1, 50 do
|
||||
instances[i] = Class.new(i, i*2)
|
||||
end
|
||||
|
||||
local result = 0
|
||||
for i, obj in ipairs(instances) do
|
||||
obj:move(i, -i)
|
||||
local x, y = obj:getPosition()
|
||||
result = result + x + y
|
||||
end
|
||||
|
||||
return result
|
||||
`
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "complex-script", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := state.ExecuteWithResult(code); err != nil {
|
||||
b.Fatalf("ExecuteWithResult failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkComplexScriptPrecompiled benchmarks a precompiled complex script
|
||||
func BenchmarkComplexScriptPrecompiled(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := `
|
||||
-- Define a simple class
|
||||
local Class = {}
|
||||
Class.__index = Class
|
||||
|
||||
function Class.new(x, y)
|
||||
local self = setmetatable({}, Class)
|
||||
self.x = x or 0
|
||||
self.y = y or 0
|
||||
return self
|
||||
end
|
||||
|
||||
function Class:move(dx, dy)
|
||||
self.x = self.x + dx
|
||||
self.y = self.y + dy
|
||||
return self
|
||||
end
|
||||
|
||||
function Class:getPosition()
|
||||
return self.x, self.y
|
||||
end
|
||||
|
||||
-- Create instances and operate on them
|
||||
local instances = {}
|
||||
for i = 1, 50 do
|
||||
instances[i] = Class.new(i, i*2)
|
||||
end
|
||||
|
||||
local result = 0
|
||||
for i, obj in ipairs(instances) do
|
||||
obj:move(i, -i)
|
||||
local x, y = obj:getPosition()
|
||||
result = result + x + y
|
||||
end
|
||||
|
||||
return result
|
||||
`
|
||||
bytecode, err := state.CompileBytecode(code, "complex")
|
||||
if err != nil {
|
||||
b.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "complex-precompiled", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.LoadBytecode(bytecode, "complex"); err != nil {
|
||||
b.Fatalf("LoadBytecode failed: %v", err)
|
||||
}
|
||||
if err := state.RunBytecodeWithResults(1); err != nil {
|
||||
b.Fatalf("RunBytecodeWithResults failed: %v", err)
|
||||
}
|
||||
state.Pop(1) // Pop the result
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkMultipleExecutions benchmarks executing the same bytecode multiple times
|
||||
func BenchmarkMultipleExecutions(b *testing.B) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Setup a stateful environment
|
||||
setupCode := `
|
||||
counter = 0
|
||||
function increment(amount)
|
||||
counter = counter + (amount or 1)
|
||||
return counter
|
||||
end
|
||||
`
|
||||
if err := state.DoString(setupCode); err != nil {
|
||||
b.Fatalf("Failed to set up environment: %v", err)
|
||||
}
|
||||
|
||||
// Compile the function call
|
||||
code := "return increment(5)"
|
||||
bytecode, err := state.CompileBytecode(code, "increment")
|
||||
if err != nil {
|
||||
b.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "multiple-executions", func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := state.LoadBytecode(bytecode, "increment"); err != nil {
|
||||
b.Fatalf("LoadBytecode failed: %v", err)
|
||||
}
|
||||
if err := state.RunBytecodeWithResults(1); err != nil {
|
||||
b.Fatalf("RunBytecodeWithResults failed: %v", err)
|
||||
}
|
||||
state.Pop(1) // Pop the result
|
||||
}
|
||||
})
|
||||
}
|
138
bench/ezbench_test.go
Normal file
138
bench/ezbench_test.go
Normal file
|
@ -0,0 +1,138 @@
|
|||
package luajit_bench
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
var benchCases = []struct {
|
||||
name string
|
||||
code string
|
||||
}{
|
||||
{
|
||||
name: "SimpleAddition",
|
||||
code: `return 1 + 1`,
|
||||
},
|
||||
{
|
||||
name: "LoopSum",
|
||||
code: `
|
||||
local sum = 0
|
||||
for i = 1, 1000 do
|
||||
sum = sum + i
|
||||
end
|
||||
return sum
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "FunctionCall",
|
||||
code: `
|
||||
local result = 0
|
||||
for i = 1, 100 do
|
||||
result = result + i
|
||||
end
|
||||
return result
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "TableCreation",
|
||||
code: `
|
||||
local t = {}
|
||||
for i = 1, 100 do
|
||||
t[i] = i * 2
|
||||
end
|
||||
return t[50]
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "StringOperations",
|
||||
code: `
|
||||
local s = "hello"
|
||||
for i = 1, 10 do
|
||||
s = s .. " world"
|
||||
end
|
||||
return #s
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
func BenchmarkLuaDirectExecution(b *testing.B) {
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
L := luajit.New()
|
||||
if L == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
|
||||
// First verify we can execute the code
|
||||
if err := L.DoString(bc.code); err != nil {
|
||||
b.Fatalf("Failed to execute test code: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
TrackMemoryUsage(b, "direct-"+bc.name, func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Execute string and get results
|
||||
nresults, err := L.Execute(bc.code)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to execute code: %v", err)
|
||||
}
|
||||
L.Pop(nresults) // Clean up any results
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLuaBytecodeExecution(b *testing.B) {
|
||||
// First compile all bytecode
|
||||
bytecodes := make(map[string][]byte)
|
||||
for _, bc := range benchCases {
|
||||
L := luajit.New()
|
||||
if L == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Cleanup()
|
||||
|
||||
bytecode, err := L.CompileBytecode(bc.code, bc.name)
|
||||
if err != nil {
|
||||
L.Close()
|
||||
b.Fatalf("Error compiling bytecode for %s: %v", bc.name, err)
|
||||
}
|
||||
bytecodes[bc.name] = bytecode
|
||||
L.Close()
|
||||
}
|
||||
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
L := luajit.New()
|
||||
if L == nil {
|
||||
b.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
|
||||
bytecode := bytecodes[bc.name]
|
||||
|
||||
// First verify we can execute the bytecode
|
||||
if err := L.LoadAndRunBytecodeWithResults(bytecode, bc.name, 1); err != nil {
|
||||
b.Fatalf("Failed to execute test bytecode: %v", err)
|
||||
}
|
||||
L.Pop(1) // Clean up the result
|
||||
|
||||
b.ResetTimer()
|
||||
b.SetBytes(int64(len(bytecode))) // Track bytecode size in benchmarks
|
||||
|
||||
TrackMemoryUsage(b, "bytecode-"+bc.name, func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := L.LoadAndRunBytecode(bytecode, bc.name); err != nil {
|
||||
b.Fatalf("Error executing bytecode: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
78
bench/profile.sh
Executable file
78
bench/profile.sh
Executable file
|
@ -0,0 +1,78 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Easy script to run benchmarks with profiling enabled
|
||||
# Usage: ./profile_benchmarks.sh [benchmark_pattern]
|
||||
|
||||
set -e
|
||||
|
||||
# Default values
|
||||
BENCHMARK=${1:-"."}
|
||||
OUTPUT_DIR="./profile_results"
|
||||
CPU_PROFILE="$OUTPUT_DIR/cpu.prof"
|
||||
MEM_PROFILE="$OUTPUT_DIR/mem.prof"
|
||||
BLOCK_PROFILE="$OUTPUT_DIR/block.prof"
|
||||
MUTEX_PROFILE="$OUTPUT_DIR/mutex.prof"
|
||||
TRACE_FILE="$OUTPUT_DIR/trace.out"
|
||||
HTML_OUTPUT="$OUTPUT_DIR/profile_report.html"
|
||||
|
||||
# Create output directory
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
echo "Running benchmarks with profiling enabled..."
|
||||
|
||||
# Run benchmarks with profiling flags
|
||||
go test -bench="$BENCHMARK" -benchmem -cpuprofile="$CPU_PROFILE" -memprofile="$MEM_PROFILE" -blockprofile="$BLOCK_PROFILE" -mutexprofile="$MUTEX_PROFILE" -count=5 -timeout=30m
|
||||
|
||||
echo "Generating CPU profile analysis..."
|
||||
go tool pprof -http=":1880" -output="$OUTPUT_DIR/cpu_graph.svg" "$CPU_PROFILE"
|
||||
|
||||
echo "Generating memory profile analysis..."
|
||||
go tool pprof -http=":1880" -output="$OUTPUT_DIR/mem_graph.svg" "$MEM_PROFILE"
|
||||
|
||||
# Generate a simple HTML report
|
||||
cat > "$HTML_OUTPUT" << EOF
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>LuaJIT Benchmark Profiling Results</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; margin: 20px; }
|
||||
h1, h2 { color: #333; }
|
||||
.profile { margin-bottom: 30px; }
|
||||
img { max-width: 100%; border: 1px solid #ddd; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>LuaJIT Benchmark Profiling Results</h1>
|
||||
<p>Generated on: $(date)</p>
|
||||
|
||||
<div class="profile">
|
||||
<h2>CPU Profile</h2>
|
||||
<img src="cpu_graph.svg" alt="CPU Profile Graph">
|
||||
<p>Command to explore: <code>go tool pprof $CPU_PROFILE</code></p>
|
||||
</div>
|
||||
|
||||
<div class="profile">
|
||||
<h2>Memory Profile</h2>
|
||||
<img src="mem_graph.svg" alt="Memory Profile Graph">
|
||||
<p>Command to explore: <code>go tool pprof $MEM_PROFILE</code></p>
|
||||
</div>
|
||||
|
||||
<div class="profile">
|
||||
<h2>Tips for Profile Analysis</h2>
|
||||
<ul>
|
||||
<li>Use <code>go tool pprof -http=:8080 $CPU_PROFILE</code> for interactive web UI</li>
|
||||
<li>Use <code>top10</code> in pprof to see the top 10 functions by CPU/memory usage</li>
|
||||
<li>Use <code>list FunctionName</code> to see line-by-line stats for a specific function</li>
|
||||
</ul>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
EOF
|
||||
|
||||
echo "Profiling complete! Results available in $OUTPUT_DIR"
|
||||
echo "View the HTML report at $HTML_OUTPUT"
|
||||
echo ""
|
||||
echo "For detailed interactive analysis, run:"
|
||||
echo " go tool pprof -http=:1880 $CPU_PROFILE # For CPU profile"
|
||||
echo " go tool pprof -http=:1880 $MEM_PROFILE # For memory profile"
|
|
@ -1,148 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
type benchCase struct {
|
||||
name string
|
||||
code string
|
||||
}
|
||||
|
||||
var cases = []benchCase{
|
||||
{
|
||||
name: "Simple Addition",
|
||||
code: `return 1 + 1`,
|
||||
},
|
||||
{
|
||||
name: "Loop Sum",
|
||||
code: `
|
||||
local sum = 0
|
||||
for i = 1, 1000 do
|
||||
sum = sum + i
|
||||
end
|
||||
return sum
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Function Call",
|
||||
code: `
|
||||
local result = 0
|
||||
for i = 1, 100 do
|
||||
result = result + i
|
||||
end
|
||||
return result
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Table Creation",
|
||||
code: `
|
||||
local t = {}
|
||||
for i = 1, 100 do
|
||||
t[i] = i * 2
|
||||
end
|
||||
return t[50]
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "String Operations",
|
||||
code: `
|
||||
local s = "hello"
|
||||
for i = 1, 10 do
|
||||
s = s .. " world"
|
||||
end
|
||||
return #s
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
func runBenchmark(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) {
|
||||
start := time.Now()
|
||||
deadline := start.Add(duration)
|
||||
var ops int64
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
if err := L.DoString(code); err != nil {
|
||||
fmt.Printf("Error executing code: %v\n", err)
|
||||
return 0, 0
|
||||
}
|
||||
L.Pop(1)
|
||||
ops++
|
||||
}
|
||||
|
||||
return time.Since(start), ops
|
||||
}
|
||||
|
||||
func runBytecodeTest(L *luajit.State, code string, duration time.Duration) (time.Duration, int64) {
|
||||
// First compile the bytecode
|
||||
bytecode, err := L.CompileBytecode(code, "bench")
|
||||
if err != nil {
|
||||
fmt.Printf("Error compiling bytecode: %v\n", err)
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
deadline := start.Add(duration)
|
||||
var ops int64
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
if err := L.LoadBytecode(bytecode, "bench"); err != nil {
|
||||
fmt.Printf("Error executing bytecode: %v\n", err)
|
||||
return 0, 0
|
||||
}
|
||||
ops++
|
||||
}
|
||||
|
||||
return time.Since(start), ops
|
||||
}
|
||||
|
||||
func benchmarkCase(newState func() *luajit.State, bc benchCase) {
|
||||
fmt.Printf("\n%s:\n", bc.name)
|
||||
|
||||
// Direct execution benchmark
|
||||
L := newState()
|
||||
if L == nil {
|
||||
fmt.Printf(" Failed to create Lua state\n")
|
||||
return
|
||||
}
|
||||
execTime, ops := runBenchmark(L, bc.code, 2*time.Second)
|
||||
L.Close()
|
||||
if ops > 0 {
|
||||
opsPerSec := float64(ops) / execTime.Seconds()
|
||||
fmt.Printf(" Direct: %.0f ops/sec\n", opsPerSec)
|
||||
}
|
||||
|
||||
// Bytecode execution benchmark
|
||||
L = newState()
|
||||
if L == nil {
|
||||
fmt.Printf(" Failed to create Lua state\n")
|
||||
return
|
||||
}
|
||||
execTime, ops = runBytecodeTest(L, bc.code, 2*time.Second)
|
||||
L.Close()
|
||||
if ops > 0 {
|
||||
opsPerSec := float64(ops) / execTime.Seconds()
|
||||
fmt.Printf(" Bytecode: %.0f ops/sec\n", opsPerSec)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
modes := []struct {
|
||||
name string
|
||||
newState func() *luajit.State
|
||||
}{
|
||||
{"Safe", luajit.NewSafe},
|
||||
{"Unsafe", luajit.New},
|
||||
}
|
||||
|
||||
for _, mode := range modes {
|
||||
fmt.Printf("\n=== %s Mode ===\n", mode.name)
|
||||
|
||||
for _, c := range cases {
|
||||
benchmarkCase(mode.newState, c)
|
||||
}
|
||||
}
|
||||
}
|
207
bytecode.go
207
bytecode.go
|
@ -12,7 +12,7 @@ typedef struct {
|
|||
const char *name;
|
||||
} BytecodeReader;
|
||||
|
||||
static const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
|
||||
const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
|
||||
BytecodeReader *r = (BytecodeReader *)ud;
|
||||
(void)L; // unused
|
||||
if (r->size == 0) return NULL;
|
||||
|
@ -21,29 +21,33 @@ static const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
|
|||
return (const char *)r->buf;
|
||||
}
|
||||
|
||||
static int load_bytecode_chunk(lua_State *L, const unsigned char *buf, size_t len, const char *name) {
|
||||
int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char *name) {
|
||||
BytecodeReader reader = {buf, len, name};
|
||||
return lua_load(L, bytecode_reader, &reader, name);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
unsigned char *buf;
|
||||
size_t len;
|
||||
} BytecodeWriter;
|
||||
|
||||
int bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
|
||||
BytecodeWriter *w = (BytecodeWriter *)ud;
|
||||
unsigned char *newbuf;
|
||||
(void)L; // unused
|
||||
|
||||
newbuf = (unsigned char *)realloc(w->buf, w->len + sz);
|
||||
// Direct bytecode dumping without intermediate buffer - more efficient
|
||||
int direct_bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
|
||||
void **data = (void **)ud;
|
||||
size_t current_size = (size_t)data[1];
|
||||
void *newbuf = realloc(data[0], current_size + sz);
|
||||
if (newbuf == NULL) return 1;
|
||||
|
||||
memcpy(newbuf + w->len, p, sz);
|
||||
w->buf = newbuf;
|
||||
w->len += sz;
|
||||
memcpy((unsigned char*)newbuf + current_size, p, sz);
|
||||
data[0] = newbuf;
|
||||
data[1] = (void*)(current_size + sz);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Combined load and run bytecode in a single call
|
||||
int load_and_run_bytecode(lua_State *L, const unsigned char *buf, size_t len,
|
||||
const char *name, int nresults) {
|
||||
BytecodeReader reader = {buf, len, name};
|
||||
int status = lua_load(L, bytecode_reader, &reader, name);
|
||||
if (status != 0) return status;
|
||||
|
||||
return lua_pcall(L, 0, nresults, 0);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
|
@ -51,55 +55,46 @@ import (
|
|||
"unsafe"
|
||||
)
|
||||
|
||||
func (s *State) compileBytecodeUnsafe(code string, name string) ([]byte, error) {
|
||||
// First load the string but don't execute it
|
||||
ccode := C.CString(code)
|
||||
defer C.free(unsafe.Pointer(ccode))
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
if C.luaL_loadstring(s.L, ccode) != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(C.lua_status(s.L)),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1)
|
||||
// CompileBytecode compiles a Lua chunk to bytecode without executing it
|
||||
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
|
||||
if err := s.LoadString(code); err != nil {
|
||||
return nil, fmt.Errorf("failed to load string: %w", err)
|
||||
}
|
||||
|
||||
// Set up writer
|
||||
var writer C.BytecodeWriter
|
||||
writer.buf = nil
|
||||
writer.len = 0
|
||||
// Use a simpler direct writer with just two pointers
|
||||
data := [2]unsafe.Pointer{nil, nil}
|
||||
|
||||
// Dump the function to bytecode
|
||||
if C.lua_dump(s.L, (*[0]byte)(C.bytecode_writer), unsafe.Pointer(&writer)) != 0 {
|
||||
if writer.buf != nil {
|
||||
C.free(unsafe.Pointer(writer.buf))
|
||||
}
|
||||
s.Pop(1)
|
||||
return nil, fmt.Errorf("failed to dump bytecode")
|
||||
status := C.lua_dump(s.L, (*[0]byte)(unsafe.Pointer(C.direct_bytecode_writer)), unsafe.Pointer(&data))
|
||||
if status != 0 {
|
||||
return nil, fmt.Errorf("failed to dump bytecode: status %d", status)
|
||||
}
|
||||
|
||||
// Copy to Go slice
|
||||
bytecode := C.GoBytes(unsafe.Pointer(writer.buf), C.int(writer.len))
|
||||
|
||||
// Clean up
|
||||
if writer.buf != nil {
|
||||
C.free(unsafe.Pointer(writer.buf))
|
||||
// Get result
|
||||
var bytecode []byte
|
||||
if data[0] != nil {
|
||||
// Create Go slice that references the C memory
|
||||
length := uintptr(data[1])
|
||||
bytecode = C.GoBytes(data[0], C.int(length))
|
||||
C.free(data[0])
|
||||
}
|
||||
s.Pop(1) // Remove the function
|
||||
|
||||
s.Pop(1) // Remove the function from stack
|
||||
|
||||
return bytecode, nil
|
||||
}
|
||||
|
||||
func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error {
|
||||
// LoadBytecode loads precompiled bytecode without executing it
|
||||
func (s *State) LoadBytecode(bytecode []byte, name string) error {
|
||||
if len(bytecode) == 0 {
|
||||
return fmt.Errorf("empty bytecode")
|
||||
}
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
// Load the bytecode
|
||||
status := C.load_bytecode_chunk(
|
||||
status := C.load_bytecode(
|
||||
s.L,
|
||||
(*C.uchar)(unsafe.Pointer(&bytecode[0])),
|
||||
C.size_t(len(bytecode)),
|
||||
|
@ -111,49 +106,107 @@ func (s *State) loadBytecodeUnsafe(bytecode []byte, name string) error {
|
|||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1)
|
||||
return fmt.Errorf("failed to load bytecode: %w", err)
|
||||
}
|
||||
|
||||
// Execute the loaded chunk
|
||||
if err := s.safeCall(func() C.int {
|
||||
return C.lua_pcall(s.L, 0, 0, 0)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to execute bytecode: %w", err)
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompileBytecode compiles a Lua chunk to bytecode without executing it
|
||||
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
|
||||
if s.safeStack {
|
||||
return stackGuardValue[[]byte](s, func() ([]byte, error) {
|
||||
return s.compileBytecodeUnsafe(code, name)
|
||||
})
|
||||
}
|
||||
return s.compileBytecodeUnsafe(code, name)
|
||||
// RunBytecode executes previously loaded bytecode with 0 results
|
||||
func (s *State) RunBytecode() error {
|
||||
return s.RunBytecodeWithResults(0)
|
||||
}
|
||||
|
||||
// LoadBytecode loads precompiled bytecode and executes it
|
||||
func (s *State) LoadBytecode(bytecode []byte, name string) error {
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
return s.loadBytecodeUnsafe(bytecode, name)
|
||||
})
|
||||
// RunBytecodeWithResults executes bytecode and keeps nresults on the stack
|
||||
// Use LUA_MULTRET (-1) to keep all results
|
||||
func (s *State) RunBytecodeWithResults(nresults int) error {
|
||||
status := C.lua_pcall(s.L, 0, C.int(nresults), 0)
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
}
|
||||
return s.loadBytecodeUnsafe(bytecode, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper function to compile and immediately load/execute bytecode
|
||||
func (s *State) CompileAndLoad(code string, name string) error {
|
||||
// LoadAndRunBytecode loads and executes bytecode in a single CGO transition
|
||||
func (s *State) LoadAndRunBytecode(bytecode []byte, name string) error {
|
||||
if len(bytecode) == 0 {
|
||||
return fmt.Errorf("empty bytecode")
|
||||
}
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
// Use combined load and run function
|
||||
status := C.load_and_run_bytecode(
|
||||
s.L,
|
||||
(*C.uchar)(unsafe.Pointer(&bytecode[0])),
|
||||
C.size_t(len(bytecode)),
|
||||
cname,
|
||||
0, // No results
|
||||
)
|
||||
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndRunBytecodeWithResults loads and executes bytecode, preserving results
|
||||
func (s *State) LoadAndRunBytecodeWithResults(bytecode []byte, name string, nresults int) error {
|
||||
if len(bytecode) == 0 {
|
||||
return fmt.Errorf("empty bytecode")
|
||||
}
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
// Use combined load and run function
|
||||
status := C.load_and_run_bytecode(
|
||||
s.L,
|
||||
(*C.uchar)(unsafe.Pointer(&bytecode[0])),
|
||||
C.size_t(len(bytecode)),
|
||||
cname,
|
||||
C.int(nresults),
|
||||
)
|
||||
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompileAndRun compiles and immediately executes Lua code
|
||||
func (s *State) CompileAndRun(code string, name string) error {
|
||||
// Skip bytecode step for small scripts - direct execution is faster
|
||||
if len(code) < 1024 {
|
||||
return s.DoString(code)
|
||||
}
|
||||
|
||||
bytecode, err := s.CompileBytecode(code, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compile error: %w", err)
|
||||
}
|
||||
|
||||
if err := s.LoadBytecode(bytecode, name); err != nil {
|
||||
return fmt.Errorf("load error: %w", err)
|
||||
if err := s.LoadAndRunBytecode(bytecode, name); err != nil {
|
||||
return fmt.Errorf("execution error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
178
bytecode_test.go
178
bytecode_test.go
|
@ -1,178 +0,0 @@
|
|||
package luajit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBytecodeCompilation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple assignment",
|
||||
code: "x = 42",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "function definition",
|
||||
code: "function add(a,b) return a+b end",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "syntax error",
|
||||
code: "function bad syntax",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, tt := range tests {
|
||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
bytecode, err := L.CompileBytecode(tt.code, "test")
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
if len(bytecode) == 0 {
|
||||
t.Error("CompileBytecode() returned empty bytecode")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytecodeExecution(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Compile some test code
|
||||
code := `
|
||||
function add(a, b)
|
||||
return a + b
|
||||
end
|
||||
result = add(40, 2)
|
||||
`
|
||||
|
||||
bytecode, err := L.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode() error = %v", err)
|
||||
}
|
||||
|
||||
// Load and execute the bytecode
|
||||
if err := L.LoadBytecode(bytecode, "test"); err != nil {
|
||||
t.Fatalf("LoadBytecode() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify the result
|
||||
L.GetGlobal("result")
|
||||
if result := L.ToNumber(-1); result != 42 {
|
||||
t.Errorf("got result = %v, want 42", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidBytecode(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Test with invalid bytecode
|
||||
invalidBytecode := []byte("this is not valid bytecode")
|
||||
if err := L.LoadBytecode(invalidBytecode, "test"); err == nil {
|
||||
t.Error("LoadBytecode() expected error with invalid bytecode")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytecodeRoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
check func(*State) error
|
||||
}{
|
||||
{
|
||||
name: "global variable",
|
||||
code: "x = 42",
|
||||
check: func(L *State) error {
|
||||
L.GetGlobal("x")
|
||||
if x := L.ToNumber(-1); x != 42 {
|
||||
return fmt.Errorf("got x = %v, want 42", x)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "function definition",
|
||||
code: "function test() return 'hello' end",
|
||||
check: func(L *State) error {
|
||||
if err := L.DoString("result = test()"); err != nil {
|
||||
return err
|
||||
}
|
||||
L.GetGlobal("result")
|
||||
if s := L.ToString(-1); s != "hello" {
|
||||
return fmt.Errorf("got result = %q, want 'hello'", s)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, tt := range tests {
|
||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||
// First state for compilation
|
||||
L1 := f.new()
|
||||
if L1 == nil {
|
||||
t.Fatal("Failed to create first Lua state")
|
||||
}
|
||||
defer L1.Close()
|
||||
|
||||
// Compile the code
|
||||
bytecode, err := L1.CompileBytecode(tt.code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode() error = %v", err)
|
||||
}
|
||||
|
||||
// Second state for execution
|
||||
L2 := f.new()
|
||||
if L2 == nil {
|
||||
t.Fatal("Failed to create second Lua state")
|
||||
}
|
||||
defer L2.Close()
|
||||
|
||||
// Load and execute the bytecode
|
||||
if err := L2.LoadBytecode(bytecode, "test"); err != nil {
|
||||
t.Fatalf("LoadBytecode() error = %v", err)
|
||||
}
|
||||
|
||||
// Run the check function
|
||||
if err := tt.check(L2); err != nil {
|
||||
t.Errorf("check failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
70
example/main.go
Normal file
70
example/main.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Println("Usage: go run main.go script.lua")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
scriptPath := os.Args[1]
|
||||
|
||||
// Create a new Lua state
|
||||
L := luajit.New()
|
||||
if L == nil {
|
||||
log.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Register a Go function to be called from Lua
|
||||
L.RegisterGoFunction("printFromGo", func(s *luajit.State) int {
|
||||
msg := s.ToString(1) // Get first argument
|
||||
fmt.Printf("Go received from Lua: %s\n", msg)
|
||||
|
||||
// Return a value to Lua
|
||||
s.PushString("Hello from Go!")
|
||||
return 1 // Number of return values
|
||||
})
|
||||
|
||||
// Add some values to the Lua environment
|
||||
L.PushValue(map[string]any{
|
||||
"appName": "LuaJIT Example",
|
||||
"version": 1.0,
|
||||
"features": []float64{1, 2, 3},
|
||||
})
|
||||
L.SetGlobal("config")
|
||||
|
||||
// Get the directory of the script to properly handle requires
|
||||
dir := filepath.Dir(scriptPath)
|
||||
L.AddPackagePath(filepath.Join(dir, "?.lua"))
|
||||
|
||||
// Execute the script
|
||||
fmt.Printf("Running Lua script: %s\n", scriptPath)
|
||||
if err := L.DoFile(scriptPath); err != nil {
|
||||
log.Fatalf("Error executing script: %v", err)
|
||||
}
|
||||
|
||||
// Call a Lua function and get its result
|
||||
L.GetGlobal("getResult")
|
||||
if L.IsFunction(-1) {
|
||||
if err := L.Call(0, 1); err != nil {
|
||||
log.Fatalf("Error calling Lua function: %v", err)
|
||||
}
|
||||
|
||||
result, err := L.ToValue(-1)
|
||||
if err != nil {
|
||||
log.Fatalf("Error converting Lua result: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Result from Lua: %v\n", result)
|
||||
L.Pop(1) // Clean up the result
|
||||
}
|
||||
}
|
35
example/script.lua
Normal file
35
example/script.lua
Normal file
|
@ -0,0 +1,35 @@
|
|||
-- Example Lua script to demonstrate Go-Lua integration
|
||||
|
||||
-- Access the config table passed from Go
|
||||
print("Script started")
|
||||
print("App name:", config.appName)
|
||||
print("Version:", config.version)
|
||||
print("Features:", table.concat(config.features, ", "))
|
||||
|
||||
-- Call the Go function
|
||||
local response = printFromGo("Hello from Lua!")
|
||||
print("Response from Go:", response)
|
||||
|
||||
-- Function that will be called from Go
|
||||
function getResult()
|
||||
local result = {
|
||||
status = "success",
|
||||
calculations = {
|
||||
sum = 10 + 20,
|
||||
product = 5 * 7
|
||||
},
|
||||
message = "Calculation completed"
|
||||
}
|
||||
return result
|
||||
end
|
||||
|
||||
-- Load external module (if available)
|
||||
local success, utils = pcall(require, "utils")
|
||||
if success then
|
||||
print("Utils module loaded")
|
||||
utils.doSomething()
|
||||
else
|
||||
print("Utils module not available:", utils)
|
||||
end
|
||||
|
||||
print("Script completed")
|
19
example/utils.lua
Normal file
19
example/utils.lua
Normal file
|
@ -0,0 +1,19 @@
|
|||
-- Optional utility module
|
||||
|
||||
local utils = {}
|
||||
|
||||
function utils.doSomething()
|
||||
print("Utils module function called")
|
||||
return true
|
||||
end
|
||||
|
||||
function utils.calculate(a, b)
|
||||
return {
|
||||
sum = a + b,
|
||||
difference = a - b,
|
||||
product = a * b,
|
||||
quotient = a / b
|
||||
}
|
||||
end
|
||||
|
||||
return utils
|
46
functions.go
46
functions.go
|
@ -7,8 +7,9 @@ package luajit
|
|||
|
||||
extern int goFunctionWrapper(lua_State* L);
|
||||
|
||||
// Helper function to access upvalues
|
||||
static int get_upvalue_index(int i) {
|
||||
return -10002 - i; // LUA_GLOBALSINDEX - i
|
||||
return lua_upvalueindex(i);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
@ -18,28 +19,35 @@ import (
|
|||
"unsafe"
|
||||
)
|
||||
|
||||
// GoFunction defines the signature for Go functions callable from Lua
|
||||
type GoFunction func(*State) int
|
||||
|
||||
// Static registry size reduces resizing operations
|
||||
const initialRegistrySize = 64
|
||||
|
||||
var (
|
||||
// functionRegistry stores all registered Go functions
|
||||
functionRegistry = struct {
|
||||
sync.RWMutex
|
||||
funcs map[unsafe.Pointer]GoFunction
|
||||
funcs map[unsafe.Pointer]GoFunction
|
||||
initOnce sync.Once
|
||||
}{
|
||||
funcs: make(map[unsafe.Pointer]GoFunction),
|
||||
funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize),
|
||||
}
|
||||
)
|
||||
|
||||
//export goFunctionWrapper
|
||||
func goFunctionWrapper(L *C.lua_State) C.int {
|
||||
state := &State{L: L, safeStack: true}
|
||||
state := &State{L: L}
|
||||
|
||||
// Get upvalue using standard Lua 5.1 macro
|
||||
// Get function pointer from the first upvalue
|
||||
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
|
||||
if ptr == nil {
|
||||
state.PushString("error: function not found")
|
||||
state.PushString("error: function pointer not found")
|
||||
return -1
|
||||
}
|
||||
|
||||
// Use read-lock for better concurrency
|
||||
functionRegistry.RLock()
|
||||
fn, ok := functionRegistry.funcs[ptr]
|
||||
functionRegistry.RUnlock()
|
||||
|
@ -49,50 +57,56 @@ func goFunctionWrapper(L *C.lua_State) C.int {
|
|||
return -1
|
||||
}
|
||||
|
||||
result := fn(state)
|
||||
return C.int(result)
|
||||
// Call the Go function
|
||||
return C.int(fn(state))
|
||||
}
|
||||
|
||||
// PushGoFunction wraps a Go function and pushes it onto the Lua stack
|
||||
func (s *State) PushGoFunction(fn GoFunction) error {
|
||||
// Push lightuserdata as upvalue and create closure
|
||||
// Allocate a pointer to use as the function key
|
||||
ptr := C.malloc(1)
|
||||
if ptr == nil {
|
||||
return fmt.Errorf("failed to allocate memory for function pointer")
|
||||
}
|
||||
|
||||
// Register the function
|
||||
functionRegistry.Lock()
|
||||
functionRegistry.funcs[ptr] = fn
|
||||
functionRegistry.Unlock()
|
||||
|
||||
// Push the pointer as lightuserdata (first upvalue)
|
||||
C.lua_pushlightuserdata(s.L, ptr)
|
||||
|
||||
// Create closure with the C wrapper and the upvalue
|
||||
C.lua_pushcclosure(s.L, (*[0]byte)(C.goFunctionWrapper), 1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterGoFunction registers a Go function as a global Lua function
|
||||
func (s *State) RegisterGoFunction(name string, fn GoFunction) error {
|
||||
if err := s.PushGoFunction(fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
|
||||
s.SetGlobal(name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnregisterGoFunction removes a global function
|
||||
func (s *State) UnregisterGoFunction(name string) {
|
||||
s.PushNil()
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
|
||||
s.SetGlobal(name)
|
||||
}
|
||||
|
||||
// Cleanup frees all function pointers and clears the registry
|
||||
func (s *State) Cleanup() {
|
||||
functionRegistry.Lock()
|
||||
defer functionRegistry.Unlock()
|
||||
|
||||
// Free all allocated pointers
|
||||
for ptr := range functionRegistry.funcs {
|
||||
C.free(ptr)
|
||||
delete(functionRegistry.funcs, ptr)
|
||||
}
|
||||
functionRegistry.funcs = make(map[unsafe.Pointer]GoFunction)
|
||||
}
|
||||
|
|
|
@ -1,109 +0,0 @@
|
|||
package luajit
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGoFunctions(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
|
||||
addFunc := func(s *State) int {
|
||||
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
|
||||
return 1
|
||||
}
|
||||
|
||||
if err := L.RegisterGoFunction("add", addFunc); err != nil {
|
||||
t.Fatalf("Failed to register function: %v", err)
|
||||
}
|
||||
|
||||
// Test basic function call
|
||||
if err := L.DoString("result = add(40, 2)"); err != nil {
|
||||
t.Fatalf("Failed to call function: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("result")
|
||||
if result := L.ToNumber(-1); result != 42 {
|
||||
t.Errorf("got %v, want 42", result)
|
||||
}
|
||||
L.Pop(1)
|
||||
|
||||
// Test multiple return values
|
||||
multiFunc := func(s *State) int {
|
||||
s.PushString("hello")
|
||||
s.PushNumber(42)
|
||||
s.PushBoolean(true)
|
||||
return 3
|
||||
}
|
||||
|
||||
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
|
||||
t.Fatalf("Failed to register multi function: %v", err)
|
||||
}
|
||||
|
||||
code := `
|
||||
a, b, c = multi()
|
||||
result = (a == "hello" and b == 42 and c == true)
|
||||
`
|
||||
|
||||
if err := L.DoString(code); err != nil {
|
||||
t.Fatalf("Failed to call multi function: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("result")
|
||||
if !L.ToBoolean(-1) {
|
||||
t.Error("Multiple return values test failed")
|
||||
}
|
||||
L.Pop(1)
|
||||
|
||||
// Test error handling
|
||||
errFunc := func(s *State) int {
|
||||
s.PushString("test error")
|
||||
return -1
|
||||
}
|
||||
|
||||
if err := L.RegisterGoFunction("err", errFunc); err != nil {
|
||||
t.Fatalf("Failed to register error function: %v", err)
|
||||
}
|
||||
|
||||
if err := L.DoString("err()"); err == nil {
|
||||
t.Error("Expected error from error function")
|
||||
}
|
||||
|
||||
// Test unregistering
|
||||
L.UnregisterGoFunction("add")
|
||||
if err := L.DoString("add(1, 2)"); err == nil {
|
||||
t.Error("Expected error calling unregistered function")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStackSafety(t *testing.T) {
|
||||
L := NewSafe()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
|
||||
// Test stack overflow protection
|
||||
overflowFunc := func(s *State) int {
|
||||
for i := 0; i < 100; i++ {
|
||||
s.PushNumber(float64(i))
|
||||
}
|
||||
s.PushString("done")
|
||||
return 101
|
||||
}
|
||||
|
||||
if err := L.RegisterGoFunction("overflow", overflowFunc); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := L.DoString("overflow()"); err != nil {
|
||||
t.Logf("Got expected error: %v", err)
|
||||
}
|
||||
}
|
89
stack.go
89
stack.go
|
@ -25,84 +25,23 @@ const (
|
|||
LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table
|
||||
)
|
||||
|
||||
// checkStack ensures there is enough space on the Lua stack
|
||||
func (s *State) checkStack(n int) error {
|
||||
if C.lua_checkstack(s.L, C.int(n)) == 0 {
|
||||
return fmt.Errorf("stack overflow (cannot allocate %d slots)", n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// safeCall wraps a potentially dangerous C call with stack checking
|
||||
func (s *State) safeCall(f func() C.int) error {
|
||||
// Save current stack size
|
||||
top := s.GetTop()
|
||||
|
||||
// Ensure we have enough stack space (minimum 20 slots as per Lua standard)
|
||||
if err := s.checkStack(LUA_MINSTACK); err != nil {
|
||||
return err
|
||||
// GetStackTrace returns the current Lua stack trace
|
||||
func (s *State) GetStackTrace() string {
|
||||
s.GetGlobal("debug")
|
||||
if !s.IsTable(-1) {
|
||||
s.Pop(1)
|
||||
return "debug table not available"
|
||||
}
|
||||
|
||||
// Make the call
|
||||
status := f()
|
||||
|
||||
// Check for errors
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
s.GetField(-1, "traceback")
|
||||
if !s.IsFunction(-1) {
|
||||
s.Pop(2) // Remove debug table and non-function
|
||||
return "debug.traceback not available"
|
||||
}
|
||||
|
||||
// For lua_pcall, the function and arguments are popped before results are pushed
|
||||
// So we don't consider it an underflow if the new top is less than the original
|
||||
if status == 0 && s.GetType(-1) == TypeFunction {
|
||||
// If we still have a function on the stack, restore original size
|
||||
s.SetTop(top)
|
||||
}
|
||||
s.Call(0, 1)
|
||||
trace := s.ToString(-1)
|
||||
s.Pop(1) // Remove the trace
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stackGuard wraps a function with stack checking
|
||||
func stackGuard[T any](s *State, f func() (T, error)) (T, error) {
|
||||
// Save current stack size
|
||||
top := s.GetTop()
|
||||
defer func() {
|
||||
// Only restore if stack is larger than original
|
||||
if s.GetTop() > top {
|
||||
s.SetTop(top)
|
||||
}
|
||||
}()
|
||||
|
||||
// Run the protected function
|
||||
return f()
|
||||
}
|
||||
|
||||
// stackGuardValue executes a function with stack protection
|
||||
func stackGuardValue[T any](s *State, f func() (T, error)) (T, error) {
|
||||
return stackGuard(s, f)
|
||||
}
|
||||
|
||||
// stackGuardErr executes a function that only returns an error with stack protection
|
||||
func stackGuardErr(s *State, f func() error) error {
|
||||
// Save current stack size
|
||||
top := s.GetTop()
|
||||
defer func() {
|
||||
// Only restore if stack is larger than original
|
||||
if s.GetTop() > top {
|
||||
s.SetTop(top)
|
||||
}
|
||||
}()
|
||||
|
||||
// Run the protected function
|
||||
return f()
|
||||
}
|
||||
|
||||
// getStackTrace returns the current Lua stack trace
|
||||
func (s *State) getStackTrace() string {
|
||||
// Same implementation...
|
||||
return ""
|
||||
return trace
|
||||
}
|
||||
|
|
207
table.go
207
table.go
|
@ -6,172 +6,159 @@ package luajit
|
|||
#include <lauxlib.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
static int get_table_length(lua_State *L, int index) {
|
||||
// Simple direct length check
|
||||
size_t get_table_length(lua_State *L, int index) {
|
||||
return lua_objlen(L, index);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// TableValue represents any value that can be stored in a Lua table
|
||||
type TableValue interface {
|
||||
~string | ~float64 | ~bool | ~int | ~map[string]interface{} | ~[]float64 | ~[]interface{}
|
||||
// GetTableLength returns the length of a table at the given index
|
||||
func (s *State) GetTableLength(index int) int {
|
||||
return int(C.get_table_length(s.L, C.int(index)))
|
||||
}
|
||||
|
||||
func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) }
|
||||
|
||||
// ToTable converts a Lua table to a Go map
|
||||
func (s *State) ToTable(index int) (map[string]interface{}, error) {
|
||||
if s.safeStack {
|
||||
return stackGuardValue[map[string]interface{}](s, func() (map[string]interface{}, error) {
|
||||
if !s.IsTable(index) {
|
||||
return nil, fmt.Errorf("not a table at index %d", index)
|
||||
// PushTable pushes a Go map onto the Lua stack as a table
|
||||
func (s *State) PushTable(table map[string]any) error {
|
||||
// Fast path for array tables
|
||||
if arr, ok := table[""]; ok {
|
||||
if floatArr, ok := arr.([]float64); ok {
|
||||
s.CreateTable(len(floatArr), 0)
|
||||
for i, v := range floatArr {
|
||||
s.PushNumber(float64(i + 1))
|
||||
s.PushNumber(v)
|
||||
s.SetTable(-3)
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
})
|
||||
}
|
||||
if !s.IsTable(index) {
|
||||
return nil, fmt.Errorf("not a table at index %d", index)
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
}
|
||||
|
||||
func (s *State) pushTableSafe(table map[string]interface{}) error {
|
||||
size := 2
|
||||
if err := s.checkStack(size); err != nil {
|
||||
return fmt.Errorf("insufficient stack space: %w", err)
|
||||
return nil
|
||||
} else if anyArr, ok := arr.([]any); ok {
|
||||
s.CreateTable(len(anyArr), 0)
|
||||
for i, v := range anyArr {
|
||||
s.PushNumber(float64(i + 1))
|
||||
if err := s.PushValue(v); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetTable(-3)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
s.NewTable()
|
||||
// Regular table case - optimize capacity hint
|
||||
s.CreateTable(0, len(table))
|
||||
|
||||
// Add each key-value pair directly
|
||||
for k, v := range table {
|
||||
if err := s.pushValueSafe(v); err != nil {
|
||||
s.PushString(k)
|
||||
if err := s.PushValue(v); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetField(-2, k)
|
||||
s.SetTable(-3)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) pushTableUnsafe(table map[string]interface{}) error {
|
||||
s.NewTable()
|
||||
for k, v := range table {
|
||||
if err := s.pushValueUnsafe(v); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetField(-2, k)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) toTableSafe(index int) (map[string]interface{}, error) {
|
||||
if err := s.checkStack(2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
}
|
||||
|
||||
func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
|
||||
// ToTable converts a Lua table at the given index to a Go map
|
||||
func (s *State) ToTable(index int) (map[string]any, error) {
|
||||
absIdx := s.absIndex(index)
|
||||
table := make(map[string]interface{})
|
||||
if !s.IsTable(absIdx) {
|
||||
return nil, fmt.Errorf("value at index %d is not a table", index)
|
||||
}
|
||||
|
||||
// Check if it's an array-like table
|
||||
// Try to detect array-like tables first
|
||||
length := s.GetTableLength(absIdx)
|
||||
if length > 0 {
|
||||
array := make([]float64, length)
|
||||
isArray := true
|
||||
// Fast path for common array case
|
||||
allNumbers := true
|
||||
|
||||
// Try to convert to array
|
||||
for i := 1; i <= length; i++ {
|
||||
// Sample first few values to check if it's likely an array of numbers
|
||||
for i := 1; i <= min(length, 5); i++ {
|
||||
s.PushNumber(float64(i))
|
||||
s.GetTable(absIdx)
|
||||
if s.GetType(-1) != TypeNumber {
|
||||
isArray = false
|
||||
|
||||
if !s.IsNumber(-1) {
|
||||
allNumbers = false
|
||||
s.Pop(1)
|
||||
break
|
||||
}
|
||||
array[i-1] = s.ToNumber(-1)
|
||||
s.Pop(1)
|
||||
}
|
||||
|
||||
if isArray {
|
||||
return map[string]interface{}{"": array}, nil
|
||||
if allNumbers {
|
||||
// Efficiently extract array values
|
||||
array := make([]float64, length)
|
||||
for i := 1; i <= length; i++ {
|
||||
s.PushNumber(float64(i))
|
||||
s.GetTable(absIdx)
|
||||
array[i-1] = s.ToNumber(-1)
|
||||
s.Pop(1)
|
||||
}
|
||||
|
||||
// Return array as a special table with empty key
|
||||
result := make(map[string]any, 1)
|
||||
result[""] = array
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle regular table
|
||||
s.PushNil()
|
||||
for C.lua_next(s.L, C.int(absIdx)) != 0 {
|
||||
key := ""
|
||||
valueType := C.lua_type(s.L, -2)
|
||||
if valueType == C.LUA_TSTRING {
|
||||
// Handle regular table with pre-allocated capacity
|
||||
table := make(map[string]any, max(length, 8))
|
||||
|
||||
// Iterate through all key-value pairs
|
||||
s.PushNil() // Start iteration with nil key
|
||||
for s.Next(absIdx) {
|
||||
// Stack now has key at -2 and value at -1
|
||||
|
||||
// Convert key to string
|
||||
var key string
|
||||
keyType := s.GetType(-2)
|
||||
switch keyType {
|
||||
case TypeString:
|
||||
key = s.ToString(-2)
|
||||
} else if valueType == C.LUA_TNUMBER {
|
||||
key = fmt.Sprintf("%g", s.ToNumber(-2))
|
||||
case TypeNumber:
|
||||
key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64)
|
||||
default:
|
||||
// Skip non-string/non-number keys
|
||||
s.Pop(1) // Pop value, leave key for next iteration
|
||||
continue
|
||||
}
|
||||
|
||||
value, err := s.toValueUnsafe(-1)
|
||||
// Convert and store the value
|
||||
value, err := s.ToValue(-1)
|
||||
if err != nil {
|
||||
s.Pop(1)
|
||||
s.Pop(2) // Pop both key and value
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle nested array case
|
||||
if m, ok := value.(map[string]interface{}); ok {
|
||||
// Unwrap nested array tables
|
||||
if m, ok := value.(map[string]any); ok {
|
||||
if arr, ok := m[""]; ok {
|
||||
value = arr
|
||||
}
|
||||
}
|
||||
|
||||
table[key] = value
|
||||
s.Pop(1)
|
||||
s.Pop(1) // Pop value, leave key for next iteration
|
||||
}
|
||||
|
||||
return table, nil
|
||||
}
|
||||
|
||||
// NewTable creates a new table and pushes it onto the stack
|
||||
func (s *State) NewTable() {
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
// Since we can't return an error, we'll push nil instead
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
// Helper functions for min/max operations
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
C.lua_createtable(s.L, 0, 0)
|
||||
return b
|
||||
}
|
||||
|
||||
// SetTable sets a table field with cached absolute index
|
||||
func (s *State) SetTable(index int) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
C.lua_settable(s.L, C.int(absIdx))
|
||||
}
|
||||
|
||||
// GetTable gets a table field with cached absolute index
|
||||
func (s *State) GetTable(index int) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
C.lua_gettable(s.L, C.int(absIdx))
|
||||
}
|
||||
|
||||
// PushTable pushes a Go map onto the Lua stack as a table with stack checking
|
||||
func (s *State) PushTable(table map[string]interface{}) error {
|
||||
if s.safeStack {
|
||||
return s.pushTableSafe(table)
|
||||
}
|
||||
return s.pushTableUnsafe(table)
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -1,97 +0,0 @@
|
|||
package luajit
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTableOperations(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
data: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "primitives",
|
||||
data: map[string]interface{}{
|
||||
"str": "hello",
|
||||
"num": 42.0,
|
||||
"bool": true,
|
||||
"array": []float64{1.1, 2.2, 3.3},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested",
|
||||
data: map[string]interface{}{
|
||||
"nested": map[string]interface{}{
|
||||
"value": 123.0,
|
||||
"array": []float64{4.4, 5.5},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, tt := range tests {
|
||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
if err := L.PushTable(tt.data); err != nil {
|
||||
t.Fatalf("PushTable() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := L.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("ToTable() error = %v", err)
|
||||
}
|
||||
|
||||
if !tablesEqual(got, tt.data) {
|
||||
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func tablesEqual(a, b map[string]interface{}) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for k, v1 := range a {
|
||||
v2, ok := b[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch v1 := v1.(type) {
|
||||
case map[string]interface{}:
|
||||
v2, ok := v2.(map[string]interface{})
|
||||
if !ok || !tablesEqual(v1, v2) {
|
||||
return false
|
||||
}
|
||||
case []float64:
|
||||
v2, ok := v2.([]float64)
|
||||
if !ok || len(v1) != len(v2) {
|
||||
return false
|
||||
}
|
||||
for i := range v1 {
|
||||
if math.Abs(v1[i]-v2[i]) > 1e-10 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
default:
|
||||
if v1 != v2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
443
tests/bytecode_test.go
Normal file
443
tests/bytecode_test.go
Normal file
|
@ -0,0 +1,443 @@
|
|||
package luajit_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// TestCompileBytecode tests basic bytecode compilation
|
||||
func TestCompileBytecode(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
code := "return 42"
|
||||
bytecode, err := state.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
if len(bytecode) == 0 {
|
||||
t.Fatal("Expected non-empty bytecode")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadBytecode tests loading precompiled bytecode
|
||||
func TestLoadBytecode(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// First compile some bytecode
|
||||
code := "answer = 42"
|
||||
bytecode, err := state.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Then load it
|
||||
err = state.LoadBytecode(bytecode, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify a function is on the stack
|
||||
if !state.IsFunction(-1) {
|
||||
t.Fatal("Expected function at top of stack after LoadBytecode")
|
||||
}
|
||||
|
||||
// Pop the function
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
// TestRunBytecode tests running previously loaded bytecode
|
||||
func TestRunBytecode(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// First compile and load bytecode
|
||||
code := "answer = 42"
|
||||
bytecode, err := state.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
err = state.LoadBytecode(bytecode, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Run the bytecode
|
||||
err = state.RunBytecode()
|
||||
if err != nil {
|
||||
t.Fatalf("RunBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the code has executed correctly
|
||||
state.GetGlobal("answer")
|
||||
if !state.IsNumber(-1) || state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
// TestLoadAndRunBytecode tests the combined load and run functionality
|
||||
func TestLoadAndRunBytecode(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Compile bytecode
|
||||
code := "answer = 42"
|
||||
bytecode, err := state.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Load and run in one step
|
||||
err = state.LoadAndRunBytecode(bytecode, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAndRunBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify execution
|
||||
state.GetGlobal("answer")
|
||||
if !state.IsNumber(-1) || state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
// TestCompileAndRun tests compile and run functionality
|
||||
func TestCompileAndRun(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Compile and run in one step
|
||||
code := "answer = 42"
|
||||
err := state.CompileAndRun(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileAndRun failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify execution
|
||||
state.GetGlobal("answer")
|
||||
if !state.IsNumber(-1) || state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("Expected answer to be 42, got %v", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
// TestEmptyBytecode tests error handling for empty bytecode
|
||||
func TestEmptyBytecode(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Try to load empty bytecode
|
||||
err := state.LoadBytecode([]byte{}, "empty")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty bytecode, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidBytecode tests error handling for invalid bytecode
|
||||
func TestInvalidBytecode(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create some invalid bytecode
|
||||
invalidBytecode := []byte("not valid bytecode")
|
||||
|
||||
// Try to load invalid bytecode
|
||||
err := state.LoadBytecode(invalidBytecode, "invalid")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid bytecode, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBytecodeSerialization tests serializing and deserializing bytecode
|
||||
func TestBytecodeSerialization(t *testing.T) {
|
||||
// First state to compile
|
||||
state1 := luajit.New()
|
||||
if state1 == nil {
|
||||
t.Fatal("Failed to create first Lua state")
|
||||
}
|
||||
defer state1.Close()
|
||||
|
||||
// Compile bytecode
|
||||
code := `
|
||||
function add(a, b)
|
||||
return a + b
|
||||
end
|
||||
result = add(10, 20)
|
||||
`
|
||||
bytecode, err := state1.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Second state to execute
|
||||
state2 := luajit.New()
|
||||
if state2 == nil {
|
||||
t.Fatal("Failed to create second Lua state")
|
||||
}
|
||||
defer state2.Close()
|
||||
|
||||
// Load and run the bytecode in the second state
|
||||
err = state2.LoadAndRunBytecode(bytecode, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAndRunBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify execution
|
||||
state2.GetGlobal("result")
|
||||
if !state2.IsNumber(-1) || state2.ToNumber(-1) != 30 {
|
||||
t.Fatalf("Expected result to be 30, got %v", state2.ToNumber(-1))
|
||||
}
|
||||
state2.Pop(1)
|
||||
|
||||
// Call the function to verify it was properly transferred
|
||||
state2.GetGlobal("add")
|
||||
if !state2.IsFunction(-1) {
|
||||
t.Fatal("Expected add to be a function")
|
||||
}
|
||||
state2.PushNumber(5)
|
||||
state2.PushNumber(7)
|
||||
if err := state2.Call(2, 1); err != nil {
|
||||
t.Fatalf("Failed to call function: %v", err)
|
||||
}
|
||||
if state2.ToNumber(-1) != 12 {
|
||||
t.Fatalf("Expected add(5, 7) to return 12, got %v", state2.ToNumber(-1))
|
||||
}
|
||||
state2.Pop(1)
|
||||
}
|
||||
|
||||
// TestCompilationError tests error handling for compilation errors
|
||||
func TestCompilationError(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Invalid Lua code that should fail to compile
|
||||
code := "function without end"
|
||||
|
||||
// Try to compile
|
||||
_, err := state.CompileBytecode(code, "invalid")
|
||||
if err == nil {
|
||||
t.Fatal("Expected compilation error, got nil")
|
||||
}
|
||||
|
||||
// Check error type
|
||||
var luaErr *luajit.LuaError
|
||||
if !errors.As(err, &luaErr) {
|
||||
t.Fatalf("Expected error to wrap *luajit.LuaError, got %T", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExecutionError tests error handling for runtime errors
|
||||
func TestExecutionError(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Code that compiles but fails at runtime
|
||||
code := "error('deliberate error')"
|
||||
|
||||
// Compile bytecode
|
||||
bytecode, err := state.CompileBytecode(code, "error")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to execute
|
||||
err = state.LoadAndRunBytecode(bytecode, "error")
|
||||
if err == nil {
|
||||
t.Fatal("Expected execution error, got nil")
|
||||
}
|
||||
|
||||
// Check error type
|
||||
if _, ok := err.(*luajit.LuaError); !ok {
|
||||
t.Fatalf("Expected *luajit.LuaError, got %T", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBytecodeEquivalence tests that bytecode execution produces the same results as direct execution
|
||||
func TestBytecodeEquivalence(t *testing.T) {
|
||||
code := `
|
||||
local result = 0
|
||||
for i = 1, 10 do
|
||||
result = result + i
|
||||
end
|
||||
return result
|
||||
`
|
||||
|
||||
// First, execute directly
|
||||
state1 := luajit.New()
|
||||
if state1 == nil {
|
||||
t.Fatal("Failed to create first Lua state")
|
||||
}
|
||||
defer state1.Close()
|
||||
|
||||
directResult, err := state1.ExecuteWithResult(code)
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteWithResult failed: %v", err)
|
||||
}
|
||||
|
||||
// Then, compile and execute bytecode
|
||||
state2 := luajit.New()
|
||||
if state2 == nil {
|
||||
t.Fatal("Failed to create second Lua state")
|
||||
}
|
||||
defer state2.Close()
|
||||
|
||||
bytecode, err := state2.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
err = state2.LoadBytecode(bytecode, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
err = state2.Call(0, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Call failed: %v", err)
|
||||
}
|
||||
|
||||
bytecodeResult, err := state2.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("ToValue failed: %v", err)
|
||||
}
|
||||
state2.Pop(1)
|
||||
|
||||
// Compare results
|
||||
if directResult != bytecodeResult {
|
||||
t.Fatalf("Results differ: direct=%v, bytecode=%v", directResult, bytecodeResult)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBytecodeReuse tests reusing the same bytecode multiple times
|
||||
func TestBytecodeReuse(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create a function in bytecode
|
||||
code := `
|
||||
return function(x)
|
||||
return x * 2
|
||||
end
|
||||
`
|
||||
bytecode, err := state.CompileBytecode(code, "func")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Execute it several times
|
||||
for i := 1; i <= 3; i++ {
|
||||
// Load and run to get the function
|
||||
err = state.LoadAndRunBytecodeWithResults(bytes.Clone(bytecode), "func", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAndRunBytecodeWithResults failed: %v", err)
|
||||
}
|
||||
|
||||
// Stack now has the function at the top
|
||||
if !state.IsFunction(-1) {
|
||||
t.Fatal("Expected function at top of stack")
|
||||
}
|
||||
|
||||
// Call with parameter i
|
||||
state.PushNumber(float64(i))
|
||||
if err := state.Call(1, 1); err != nil {
|
||||
t.Fatalf("Call failed: %v", err)
|
||||
}
|
||||
|
||||
// Check result
|
||||
expected := float64(i * 2)
|
||||
if state.ToNumber(-1) != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, state.ToNumber(-1))
|
||||
}
|
||||
|
||||
// Pop the result
|
||||
state.Pop(1)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBytecodeClosure tests that bytecode properly handles closures and upvalues
|
||||
func TestBytecodeClosure(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create a closure
|
||||
code := `
|
||||
local counter = 0
|
||||
return function()
|
||||
counter = counter + 1
|
||||
return counter
|
||||
end
|
||||
`
|
||||
|
||||
// Compile to bytecode
|
||||
bytecode, err := state.CompileBytecode(code, "closure")
|
||||
if err != nil {
|
||||
t.Fatalf("CompileBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Load and run to get the counter function
|
||||
err = state.LoadAndRunBytecodeWithResults(bytecode, "closure", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAndRunBytecode failed: %v", err)
|
||||
}
|
||||
|
||||
// Stack now has the function at the top
|
||||
if !state.IsFunction(-1) {
|
||||
t.Fatal("Expected function at top of stack")
|
||||
}
|
||||
|
||||
// Store in a global
|
||||
state.SetGlobal("counter_func")
|
||||
|
||||
// Call it multiple times and check the results
|
||||
for i := 1; i <= 3; i++ {
|
||||
state.GetGlobal("counter_func")
|
||||
if err := state.Call(0, 1); err != nil {
|
||||
t.Fatalf("Call failed: %v", err)
|
||||
}
|
||||
|
||||
if state.ToNumber(-1) != float64(i) {
|
||||
t.Fatalf("Expected counter to be %d, got %v", i, state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
}
|
178
tests/functions_test.go
Normal file
178
tests/functions_test.go
Normal file
|
@ -0,0 +1,178 @@
|
|||
package luajit_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
func TestPushGoFunction(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Define a simple function that adds two numbers
|
||||
add := func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a + b)
|
||||
return 1 // Return one result
|
||||
}
|
||||
|
||||
// Push the function onto the stack
|
||||
if err := state.PushGoFunction(add); err != nil {
|
||||
t.Fatalf("PushGoFunction failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify that a function is on the stack
|
||||
if !state.IsFunction(-1) {
|
||||
t.Fatalf("Expected function at top of stack")
|
||||
}
|
||||
|
||||
// Push arguments
|
||||
state.PushNumber(3)
|
||||
state.PushNumber(4)
|
||||
|
||||
// Call the function
|
||||
if err := state.Call(2, 1); err != nil {
|
||||
t.Fatalf("Failed to call function: %v", err)
|
||||
}
|
||||
|
||||
// Check the result
|
||||
if state.ToNumber(-1) != 7 {
|
||||
t.Fatalf("Function returned %f, expected 7", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
func TestRegisterGoFunction(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Define a function that squares a number
|
||||
square := func(s *luajit.State) int {
|
||||
x := s.ToNumber(1)
|
||||
s.PushNumber(x * x)
|
||||
return 1
|
||||
}
|
||||
|
||||
// Register the function
|
||||
if err := state.RegisterGoFunction("square", square); err != nil {
|
||||
t.Fatalf("RegisterGoFunction failed: %v", err)
|
||||
}
|
||||
|
||||
// Call the function from Lua
|
||||
if err := state.DoString("result = square(5)"); err != nil {
|
||||
t.Fatalf("Failed to call registered function: %v", err)
|
||||
}
|
||||
|
||||
// Check the result
|
||||
state.GetGlobal("result")
|
||||
if state.ToNumber(-1) != 25 {
|
||||
t.Fatalf("Function returned %f, expected 25", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test UnregisterGoFunction
|
||||
state.UnregisterGoFunction("square")
|
||||
|
||||
// Function should no longer exist
|
||||
err := state.DoString("result = square(5)")
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error after unregistering function, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoFunctionWithErrorHandling(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Function that returns an error in Lua
|
||||
errFunc := func(s *luajit.State) int {
|
||||
s.PushString("error from Go function")
|
||||
return -1 // Signal error
|
||||
}
|
||||
|
||||
// Register the function
|
||||
if err := state.RegisterGoFunction("errorFunc", errFunc); err != nil {
|
||||
t.Fatalf("RegisterGoFunction failed: %v", err)
|
||||
}
|
||||
|
||||
// Call the function expecting an error
|
||||
err := state.DoString("result = errorFunc()")
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error from function, got nil")
|
||||
}
|
||||
|
||||
// Error message should contain our message
|
||||
luaErr, ok := err.(*luajit.LuaError)
|
||||
if !ok {
|
||||
t.Fatalf("Expected LuaError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
if luaErr.Message == "" {
|
||||
t.Fatalf("Expected non-empty error message from Go function")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanup(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
||||
// Register several functions
|
||||
for i := 0; i < 5; i++ {
|
||||
dummy := func(s *luajit.State) int { return 0 }
|
||||
if err := state.RegisterGoFunction("dummy", dummy); err != nil {
|
||||
t.Fatalf("RegisterGoFunction failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Call Cleanup explicitly
|
||||
state.Cleanup()
|
||||
|
||||
// Make sure we can still close the state
|
||||
state.Close()
|
||||
|
||||
// Also test that Close can be called after Cleanup
|
||||
state = luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create second Lua state")
|
||||
}
|
||||
|
||||
state.Close() // Should call Cleanup internally
|
||||
}
|
||||
|
||||
func TestGoFunctionErrorPointer(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create a Lua function that calls a non-existent Go function pointer
|
||||
// This isn't a direct test of internal implementation, but tries to cover
|
||||
// error cases in the goFunctionWrapper
|
||||
code := `
|
||||
function test()
|
||||
-- This is a stub that doesn't actually call the wrapper,
|
||||
-- but we're testing error handling in our State.DoString
|
||||
return "test"
|
||||
end
|
||||
`
|
||||
if err := state.DoString(code); err != nil {
|
||||
t.Fatalf("Failed to define test function: %v", err)
|
||||
}
|
||||
|
||||
// The real test is that Cleanup doesn't crash
|
||||
state.Cleanup()
|
||||
}
|
53
tests/stack_test.go
Normal file
53
tests/stack_test.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package luajit_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
func TestLuaError(t *testing.T) {
|
||||
err := &luajit.LuaError{
|
||||
Code: 123,
|
||||
Message: "test error",
|
||||
}
|
||||
|
||||
expected := "lua error (code=123): test error"
|
||||
if err.Error() != expected {
|
||||
t.Errorf("Expected error message %q, got %q", expected, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStackTrace(t *testing.T) {
|
||||
s := luajit.New()
|
||||
defer s.Close()
|
||||
|
||||
// Test with debug library available
|
||||
trace := s.GetStackTrace()
|
||||
if !strings.Contains(trace, "stack traceback:") {
|
||||
t.Errorf("Expected trace to contain 'stack traceback:', got %q", trace)
|
||||
}
|
||||
|
||||
// Test when debug table is not available
|
||||
err := s.DoString("debug = nil")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set debug to nil: %v", err)
|
||||
}
|
||||
|
||||
trace = s.GetStackTrace()
|
||||
if trace != "debug table not available" {
|
||||
t.Errorf("Expected 'debug table not available', got %q", trace)
|
||||
}
|
||||
|
||||
// Test when debug.traceback is not available
|
||||
err = s.DoString("debug = {}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set debug to empty table: %v", err)
|
||||
}
|
||||
|
||||
trace = s.GetStackTrace()
|
||||
if trace != "debug.traceback not available" {
|
||||
t.Errorf("Expected 'debug.traceback not available', got %q", trace)
|
||||
}
|
||||
}
|
246
tests/table_test.go
Normal file
246
tests/table_test.go
Normal file
|
@ -0,0 +1,246 @@
|
|||
package luajit_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
func TestGetTableLength(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create a table with numeric indices
|
||||
if err := state.DoString("t = {10, 20, 30, 40, 50}"); err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
// Get the table
|
||||
state.GetGlobal("t")
|
||||
length := state.GetTableLength(-1)
|
||||
if length != 5 {
|
||||
t.Fatalf("Expected length 5, got %d", length)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Create a table with string keys
|
||||
if err := state.DoString("t2 = {a=1, b=2, c=3}"); err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
// Get the table
|
||||
state.GetGlobal("t2")
|
||||
length = state.GetTableLength(-1)
|
||||
if length != 0 {
|
||||
t.Fatalf("Expected length 0 for string-keyed table, got %d", length)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
func TestPushTable(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create a test table
|
||||
testTable := map[string]any{
|
||||
"int": 42,
|
||||
"float": 3.14,
|
||||
"string": "hello",
|
||||
"boolean": true,
|
||||
"nil": nil,
|
||||
}
|
||||
|
||||
// Push the table onto the stack
|
||||
if err := state.PushTable(testTable); err != nil {
|
||||
t.Fatalf("Failed to push table: %v", err)
|
||||
}
|
||||
|
||||
// Execute Lua code to test the table contents
|
||||
if err := state.DoString(`
|
||||
function validate_table(t)
|
||||
return t.int == 42 and
|
||||
math.abs(t.float - 3.14) < 0.0001 and
|
||||
t.string == "hello" and
|
||||
t.boolean == true and
|
||||
t["nil"] == nil
|
||||
end
|
||||
`); err != nil {
|
||||
t.Fatalf("Failed to create validation function: %v", err)
|
||||
}
|
||||
|
||||
// Call the validation function
|
||||
state.GetGlobal("validate_table")
|
||||
state.PushCopy(-2) // Copy the table to the top
|
||||
if err := state.Call(1, 1); err != nil {
|
||||
t.Fatalf("Failed to call validation function: %v", err)
|
||||
}
|
||||
|
||||
if !state.ToBoolean(-1) {
|
||||
t.Fatalf("Table validation failed")
|
||||
}
|
||||
state.Pop(2) // Pop the result and the table
|
||||
}
|
||||
|
||||
func TestToTable(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test regular table conversion
|
||||
if err := state.DoString(`t = {a=1, b=2.5, c="test", d=true, e=nil}`); err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("t")
|
||||
table, err := state.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert table: %v", err)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
expected := map[string]any{
|
||||
"a": float64(1),
|
||||
"b": 2.5,
|
||||
"c": "test",
|
||||
"d": true,
|
||||
}
|
||||
|
||||
for k, v := range expected {
|
||||
if table[k] != v {
|
||||
t.Fatalf("Expected table[%s] = %v, got %v", k, v, table[k])
|
||||
}
|
||||
}
|
||||
|
||||
// Test array-like table conversion
|
||||
if err := state.DoString(`arr = {10, 20, 30, 40, 50}`); err != nil {
|
||||
t.Fatalf("Failed to create test array: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("arr")
|
||||
table, err = state.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert array table: %v", err)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// For array tables, we should get a special format with an empty key
|
||||
// and the array as the value
|
||||
expectedArray := []float64{10, 20, 30, 40, 50}
|
||||
if arr, ok := table[""].([]float64); !ok {
|
||||
t.Fatalf("Expected array table to be converted with empty key, got: %v", table)
|
||||
} else if !reflect.DeepEqual(arr, expectedArray) {
|
||||
t.Fatalf("Expected %v, got %v", expectedArray, arr)
|
||||
}
|
||||
|
||||
// Test invalid table index
|
||||
_, err = state.ToTable(100)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error for invalid table index, got nil")
|
||||
}
|
||||
|
||||
// Test non-table value
|
||||
state.PushNumber(123)
|
||||
_, err = state.ToTable(-1)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error for non-table value, got nil")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test mixed array with non-numeric values
|
||||
if err := state.DoString(`mixed = {10, 20, key="value", 30}`); err != nil {
|
||||
t.Fatalf("Failed to create mixed table: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("mixed")
|
||||
table, err = state.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert mixed table: %v", err)
|
||||
}
|
||||
|
||||
// Let's print the table for debugging
|
||||
t.Logf("Table contents: %v", table)
|
||||
|
||||
state.Pop(1)
|
||||
|
||||
// Check if the array part is detected and stored with empty key
|
||||
if arr, ok := table[""]; !ok {
|
||||
t.Fatalf("Expected array-like part to be detected, got: %v", table)
|
||||
} else {
|
||||
// Verify the array contains the expected values
|
||||
expectedArr := []float64{10, 20, 30}
|
||||
actualArr := arr.([]float64)
|
||||
if len(actualArr) != len(expectedArr) {
|
||||
t.Fatalf("Expected array length %d, got %d", len(expectedArr), len(actualArr))
|
||||
}
|
||||
|
||||
for i, v := range expectedArr {
|
||||
if actualArr[i] != v {
|
||||
t.Fatalf("Expected array[%d] = %v, got %v", i, v, actualArr[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Based on the implementation, we need to create a separate test for string keys
|
||||
if err := state.DoString(`dict = {foo="bar", baz="qux"}`); err != nil {
|
||||
t.Fatalf("Failed to create dict table: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("dict")
|
||||
dictTable, err := state.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert dict table: %v", err)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Check the string keys
|
||||
if val, ok := dictTable["foo"]; !ok || val != "bar" {
|
||||
t.Fatalf("Expected dictTable[\"foo\"] = \"bar\", got: %v", val)
|
||||
}
|
||||
if val, ok := dictTable["baz"]; !ok || val != "qux" {
|
||||
t.Fatalf("Expected dictTable[\"baz\"] = \"qux\", got: %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTablePooling(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create a Lua table and push it onto the stack
|
||||
if err := state.DoString(`t = {a=1, b=2}`); err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("t")
|
||||
|
||||
// First conversion - should get a table from the pool
|
||||
table1, err := state.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert table (1): %v", err)
|
||||
}
|
||||
|
||||
// Second conversion - should get another table from the pool
|
||||
table2, err := state.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert table (2): %v", err)
|
||||
}
|
||||
|
||||
// Both tables should have the same content
|
||||
if !reflect.DeepEqual(table1, table2) {
|
||||
t.Fatalf("Tables should have the same content: %v vs %v", table1, table2)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
state.Pop(1)
|
||||
}
|
473
tests/wrapper_test.go
Normal file
473
tests/wrapper_test.go
Normal file
|
@ -0,0 +1,473 @@
|
|||
package luajit_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
func TestStateLifecycle(t *testing.T) {
|
||||
// Test creation
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
||||
// Test close
|
||||
state.Close()
|
||||
|
||||
// Test close is idempotent (doesn't crash)
|
||||
state.Close()
|
||||
}
|
||||
|
||||
func TestStackManipulation(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test initial stack size
|
||||
if state.GetTop() != 0 {
|
||||
t.Fatalf("Expected empty stack, got %d elements", state.GetTop())
|
||||
}
|
||||
|
||||
// Push values
|
||||
state.PushNil()
|
||||
state.PushBoolean(true)
|
||||
state.PushNumber(42)
|
||||
state.PushString("hello")
|
||||
|
||||
// Check stack size
|
||||
if state.GetTop() != 4 {
|
||||
t.Fatalf("Expected 4 elements, got %d", state.GetTop())
|
||||
}
|
||||
|
||||
// Test SetTop
|
||||
state.SetTop(2)
|
||||
if state.GetTop() != 2 {
|
||||
t.Fatalf("Expected 2 elements after SetTop, got %d", state.GetTop())
|
||||
}
|
||||
|
||||
// Test PushCopy
|
||||
state.PushCopy(2) // Copy the boolean
|
||||
if !state.IsBoolean(-1) {
|
||||
t.Fatalf("Expected boolean at top of stack")
|
||||
}
|
||||
|
||||
// Test Pop
|
||||
state.Pop(1)
|
||||
if state.GetTop() != 2 {
|
||||
t.Fatalf("Expected 2 elements after Pop, got %d", state.GetTop())
|
||||
}
|
||||
|
||||
// Test Remove
|
||||
state.PushNumber(99)
|
||||
state.Remove(1) // Remove the first element (nil)
|
||||
if state.GetTop() != 2 {
|
||||
t.Fatalf("Expected 2 elements after Remove, got %d", state.GetTop())
|
||||
}
|
||||
|
||||
// Verify first element is now boolean
|
||||
if !state.IsBoolean(1) {
|
||||
t.Fatalf("Expected boolean at index 1 after Remove")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeChecking(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Push values of different types
|
||||
state.PushNil()
|
||||
state.PushBoolean(true)
|
||||
state.PushNumber(42)
|
||||
state.PushString("hello")
|
||||
state.NewTable()
|
||||
|
||||
// Check types with GetType
|
||||
if state.GetType(1) != luajit.TypeNil {
|
||||
t.Fatalf("Expected nil type at index 1, got %s", state.GetType(1))
|
||||
}
|
||||
if state.GetType(2) != luajit.TypeBoolean {
|
||||
t.Fatalf("Expected boolean type at index 2, got %s", state.GetType(2))
|
||||
}
|
||||
if state.GetType(3) != luajit.TypeNumber {
|
||||
t.Fatalf("Expected number type at index 3, got %s", state.GetType(3))
|
||||
}
|
||||
if state.GetType(4) != luajit.TypeString {
|
||||
t.Fatalf("Expected string type at index 4, got %s", state.GetType(4))
|
||||
}
|
||||
if state.GetType(5) != luajit.TypeTable {
|
||||
t.Fatalf("Expected table type at index 5, got %s", state.GetType(5))
|
||||
}
|
||||
|
||||
// Test individual type checking functions
|
||||
if !state.IsNil(1) {
|
||||
t.Fatalf("IsNil failed for nil value")
|
||||
}
|
||||
if !state.IsBoolean(2) {
|
||||
t.Fatalf("IsBoolean failed for boolean value")
|
||||
}
|
||||
if !state.IsNumber(3) {
|
||||
t.Fatalf("IsNumber failed for number value")
|
||||
}
|
||||
if !state.IsString(4) {
|
||||
t.Fatalf("IsString failed for string value")
|
||||
}
|
||||
if !state.IsTable(5) {
|
||||
t.Fatalf("IsTable failed for table value")
|
||||
}
|
||||
|
||||
// Function test
|
||||
state.DoString("function test() return true end")
|
||||
state.GetGlobal("test")
|
||||
if !state.IsFunction(-1) {
|
||||
t.Fatalf("IsFunction failed for function value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueConversion(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Push values
|
||||
state.PushBoolean(true)
|
||||
state.PushNumber(42.5)
|
||||
state.PushString("hello")
|
||||
|
||||
// Test conversion
|
||||
if !state.ToBoolean(1) {
|
||||
t.Fatalf("ToBoolean failed")
|
||||
}
|
||||
if state.ToNumber(2) != 42.5 {
|
||||
t.Fatalf("ToNumber failed, expected 42.5, got %f", state.ToNumber(2))
|
||||
}
|
||||
if state.ToString(3) != "hello" {
|
||||
t.Fatalf("ToString failed, expected 'hello', got '%s'", state.ToString(3))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTableOperations(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test CreateTable
|
||||
state.CreateTable(0, 3)
|
||||
|
||||
// Add fields using SetField
|
||||
state.PushNumber(42)
|
||||
state.SetField(-2, "answer")
|
||||
|
||||
state.PushString("hello")
|
||||
state.SetField(-2, "greeting")
|
||||
|
||||
state.PushBoolean(true)
|
||||
state.SetField(-2, "flag")
|
||||
|
||||
// Test GetField
|
||||
state.GetField(-1, "answer")
|
||||
if state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("GetField for 'answer' failed")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.GetField(-1, "greeting")
|
||||
if state.ToString(-1) != "hello" {
|
||||
t.Fatalf("GetField for 'greeting' failed")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test Next for iteration
|
||||
state.PushNil() // Start iteration
|
||||
count := 0
|
||||
for state.Next(-2) {
|
||||
count++
|
||||
state.Pop(1) // Pop value, leave key for next iteration
|
||||
}
|
||||
|
||||
if count != 3 {
|
||||
t.Fatalf("Expected 3 table entries, found %d", count)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
state.Pop(1) // Pop the table
|
||||
}
|
||||
|
||||
func TestGlobalOperations(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Set a global value
|
||||
state.PushNumber(42)
|
||||
state.SetGlobal("answer")
|
||||
|
||||
// Get the global value
|
||||
state.GetGlobal("answer")
|
||||
if state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("GetGlobal failed, expected 42, got %f", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test non-existent global (should be nil)
|
||||
state.GetGlobal("nonexistent")
|
||||
if !state.IsNil(-1) {
|
||||
t.Fatalf("Expected nil for non-existent global")
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
func TestCodeExecution(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test LoadString
|
||||
if err := state.LoadString("return 42"); err != nil {
|
||||
t.Fatalf("LoadString failed: %v", err)
|
||||
}
|
||||
|
||||
// Test Call
|
||||
if err := state.Call(0, 1); err != nil {
|
||||
t.Fatalf("Call failed: %v", err)
|
||||
}
|
||||
|
||||
if state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("Call result incorrect, expected 42, got %f", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test DoString
|
||||
if err := state.DoString("answer = 42 + 1"); err != nil {
|
||||
t.Fatalf("DoString failed: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("answer")
|
||||
if state.ToNumber(-1) != 43 {
|
||||
t.Fatalf("DoString execution incorrect, expected 43, got %f", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test Execute
|
||||
nresults, err := state.Execute("return 5, 10, 15")
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
|
||||
if nresults != 3 {
|
||||
t.Fatalf("Execute returned %d results, expected 3", nresults)
|
||||
}
|
||||
|
||||
if state.ToNumber(-3) != 5 || state.ToNumber(-2) != 10 || state.ToNumber(-1) != 15 {
|
||||
t.Fatalf("Execute results incorrect")
|
||||
}
|
||||
state.Pop(3)
|
||||
|
||||
// Test ExecuteWithResult
|
||||
result, err := state.ExecuteWithResult("return 'hello'")
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteWithResult failed: %v", err)
|
||||
}
|
||||
|
||||
if result != "hello" {
|
||||
t.Fatalf("ExecuteWithResult returned %v, expected 'hello'", result)
|
||||
}
|
||||
|
||||
// Test error handling
|
||||
err = state.DoString("this is not valid lua code")
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error for invalid code, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoFile(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create a temporary Lua file
|
||||
content := []byte("answer = 42")
|
||||
tmpfile, err := os.CreateTemp("", "test-*.lua")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
if _, err := tmpfile.Write(content); err != nil {
|
||||
t.Fatalf("Failed to write to temp file: %v", err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Fatalf("Failed to close temp file: %v", err)
|
||||
}
|
||||
|
||||
// Test LoadFile and DoFile
|
||||
if err := state.LoadFile(tmpfile.Name()); err != nil {
|
||||
t.Fatalf("LoadFile failed: %v", err)
|
||||
}
|
||||
|
||||
if err := state.Call(0, 0); err != nil {
|
||||
t.Fatalf("Call failed after LoadFile: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("answer")
|
||||
if state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("Incorrect result after LoadFile, expected 42, got %f", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Reset global
|
||||
if err := state.DoString("answer = nil"); err != nil {
|
||||
t.Fatalf("Failed to reset answer: %v", err)
|
||||
}
|
||||
|
||||
// Test DoFile
|
||||
if err := state.DoFile(tmpfile.Name()); err != nil {
|
||||
t.Fatalf("DoFile failed: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("answer")
|
||||
if state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("Incorrect result after DoFile, expected 42, got %f", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
func TestPackagePath(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test SetPackagePath
|
||||
testPath := "/test/path/?.lua"
|
||||
if err := state.SetPackagePath(testPath); err != nil {
|
||||
t.Fatalf("SetPackagePath failed: %v", err)
|
||||
}
|
||||
|
||||
result, err := state.ExecuteWithResult("return package.path")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get package.path: %v", err)
|
||||
}
|
||||
|
||||
if result != testPath {
|
||||
t.Fatalf("Expected package.path to be '%s', got '%s'", testPath, result)
|
||||
}
|
||||
|
||||
// Test AddPackagePath
|
||||
addPath := "/another/path/?.lua"
|
||||
if err := state.AddPackagePath(addPath); err != nil {
|
||||
t.Fatalf("AddPackagePath failed: %v", err)
|
||||
}
|
||||
|
||||
result, err = state.ExecuteWithResult("return package.path")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get package.path: %v", err)
|
||||
}
|
||||
|
||||
expected := testPath + ";" + addPath
|
||||
if result != expected {
|
||||
t.Fatalf("Expected package.path to be '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushValueAndToValue(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
testCases := []struct {
|
||||
value any
|
||||
}{
|
||||
{nil},
|
||||
{true},
|
||||
{false},
|
||||
{42},
|
||||
{42.5},
|
||||
{"hello"},
|
||||
{[]float64{1, 2, 3, 4, 5}},
|
||||
{[]any{1, "test", true}},
|
||||
{map[string]any{"a": 1, "b": "test", "c": true}},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
// Push value
|
||||
err := state.PushValue(tc.value)
|
||||
if err != nil {
|
||||
t.Fatalf("PushValue failed for testCase %d: %v", i, err)
|
||||
}
|
||||
|
||||
// Check stack
|
||||
if state.GetTop() != i+1 {
|
||||
t.Fatalf("Stack size incorrect after push, expected %d, got %d", i+1, state.GetTop())
|
||||
}
|
||||
}
|
||||
|
||||
// Test conversion back to Go
|
||||
for i := range testCases {
|
||||
index := len(testCases) - i
|
||||
value, err := state.ToValue(index)
|
||||
if err != nil {
|
||||
t.Fatalf("ToValue failed for index %d: %v", index, err)
|
||||
}
|
||||
|
||||
// For tables, we need special handling due to how Go types are stored
|
||||
switch expected := testCases[index-1].value.(type) {
|
||||
case []float64:
|
||||
// Arrays come back as map[string]any with empty key
|
||||
if m, ok := value.(map[string]any); ok {
|
||||
if arr, ok := m[""].([]float64); ok {
|
||||
if !reflect.DeepEqual(arr, expected) {
|
||||
t.Fatalf("Value mismatch for testCase %d: expected %v, got %v", index-1, expected, arr)
|
||||
}
|
||||
} else {
|
||||
t.Fatalf("Invalid array conversion for testCase %d", index-1)
|
||||
}
|
||||
} else {
|
||||
t.Fatalf("Expected map for array value in testCase %d, got %T", index-1, value)
|
||||
}
|
||||
case int:
|
||||
if num, ok := value.(float64); ok {
|
||||
if float64(expected) == num {
|
||||
continue // Values match after type conversion
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
// Skip detailed comparison for mixed arrays
|
||||
case map[string]any:
|
||||
// Skip detailed comparison for maps
|
||||
default:
|
||||
if !reflect.DeepEqual(value, testCases[index-1].value) {
|
||||
t.Fatalf("Value mismatch for testCase %d: expected %v, got %v",
|
||||
index-1, testCases[index-1].value, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test unsupported type
|
||||
complex := complex(1, 2)
|
||||
err := state.PushValue(complex)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error for unsupported type")
|
||||
}
|
||||
}
|
103
types.go
103
types.go
|
@ -4,12 +4,16 @@ package luajit
|
|||
#include <lua.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// LuaType represents Lua value types
|
||||
type LuaType int
|
||||
|
||||
const (
|
||||
// These constants must match lua.h's LUA_T* values
|
||||
// These constants match lua.h's LUA_T* values
|
||||
TypeNone LuaType = -1
|
||||
TypeNil LuaType = 0
|
||||
TypeBoolean LuaType = 1
|
||||
|
@ -49,3 +53,100 @@ func (t LuaType) String() string {
|
|||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertValue converts a value to the requested type with proper type conversion
|
||||
func ConvertValue[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
|
||||
// Handle nil case
|
||||
if value == nil {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// Try direct type assertion first
|
||||
if result, ok := value.(T); ok {
|
||||
return result, true
|
||||
}
|
||||
|
||||
// Type-specific conversions
|
||||
switch any(zero).(type) {
|
||||
case string:
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return any(fmt.Sprintf("%g", v)).(T), true
|
||||
case int:
|
||||
return any(strconv.Itoa(v)).(T), true
|
||||
case bool:
|
||||
if v {
|
||||
return any("true").(T), true
|
||||
}
|
||||
return any("false").(T), true
|
||||
}
|
||||
case int:
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return any(int(v)).(T), true
|
||||
case string:
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return any(i).(T), true
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
return any(1).(T), true
|
||||
}
|
||||
return any(0).(T), true
|
||||
}
|
||||
case float64:
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return any(float64(v)).(T), true
|
||||
case string:
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return any(f).(T), true
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
return any(1.0).(T), true
|
||||
}
|
||||
return any(0.0).(T), true
|
||||
}
|
||||
case bool:
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "true", "yes", "1":
|
||||
return any(true).(T), true
|
||||
case "false", "no", "0":
|
||||
return any(false).(T), true
|
||||
}
|
||||
case int:
|
||||
return any(v != 0).(T), true
|
||||
case float64:
|
||||
return any(v != 0).(T), true
|
||||
}
|
||||
}
|
||||
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// GetTypedValue gets a value from the state with type conversion
|
||||
func GetTypedValue[T any](s *State, index int) (T, bool) {
|
||||
var zero T
|
||||
|
||||
// Get the value as any type
|
||||
value, err := s.ToValue(index)
|
||||
if err != nil {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// Convert it to the requested type
|
||||
return ConvertValue[T](value)
|
||||
}
|
||||
|
||||
// GetGlobalTyped gets a global variable with type conversion
|
||||
func GetGlobalTyped[T any](s *State, name string) (T, bool) {
|
||||
s.GetGlobal(name)
|
||||
defer s.Pop(1)
|
||||
|
||||
return GetTypedValue[T](s, -1)
|
||||
}
|
||||
|
|
554
wrapper.go
554
wrapper.go
|
@ -1,127 +1,275 @@
|
|||
package luajit
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/vendor/luajit/include
|
||||
#cgo !windows pkg-config: --static luajit
|
||||
#cgo windows CFLAGS: -I${SRCDIR}/vendor/luajit/include
|
||||
#cgo windows LDFLAGS: -L${SRCDIR}/vendor/luajit/windows -lluajit -static
|
||||
#cgo !windows LDFLAGS: -L${SRCDIR}/vendor/luajit/linux -lluajit -static
|
||||
|
||||
#include <lua.h>
|
||||
#include <lualib.h>
|
||||
#include <lauxlib.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
// Direct execution helpers to minimize CGO transitions
|
||||
static int do_string(lua_State *L, const char *s) {
|
||||
int status = luaL_loadstring(L, s);
|
||||
if (status) return status;
|
||||
return lua_pcall(L, 0, LUA_MULTRET, 0);
|
||||
if (status == 0) {
|
||||
status = lua_pcall(L, 0, 0, 0);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
static int do_file(lua_State *L, const char *filename) {
|
||||
int status = luaL_loadfile(L, filename);
|
||||
if (status) return status;
|
||||
return lua_pcall(L, 0, LUA_MULTRET, 0);
|
||||
if (status == 0) {
|
||||
status = lua_pcall(L, 0, 0, 0);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
static int execute_with_results(lua_State *L, const char *code, int store_results) {
|
||||
int status = luaL_loadstring(L, code);
|
||||
if (status != 0) return status;
|
||||
return lua_pcall(L, 0, store_results ? LUA_MULTRET : 0, 0);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// State represents a Lua state with configurable stack safety
|
||||
// Type pool for common objects to reduce GC pressure
|
||||
var stringBufferPool = sync.Pool{
|
||||
New: func() any {
|
||||
return new(strings.Builder)
|
||||
},
|
||||
}
|
||||
|
||||
// State represents a Lua state
|
||||
type State struct {
|
||||
L *C.lua_State
|
||||
safeStack bool
|
||||
L *C.lua_State
|
||||
}
|
||||
|
||||
// NewSafe creates a new Lua state with full stack safety guarantees
|
||||
func NewSafe() *State {
|
||||
// New creates a new Lua state with optional standard libraries; true if not specified
|
||||
func New(openLibs ...bool) *State {
|
||||
L := C.luaL_newstate()
|
||||
if L == nil {
|
||||
return nil
|
||||
}
|
||||
C.luaL_openlibs(L)
|
||||
return &State{L: L, safeStack: true}
|
||||
}
|
||||
|
||||
// New creates a new Lua state with minimal stack checking
|
||||
func New() *State {
|
||||
L := C.luaL_newstate()
|
||||
if L == nil {
|
||||
return nil
|
||||
if len(openLibs) == 0 || openLibs[0] {
|
||||
C.luaL_openlibs(L)
|
||||
}
|
||||
C.luaL_openlibs(L)
|
||||
return &State{L: L, safeStack: false}
|
||||
|
||||
return &State{L: L}
|
||||
}
|
||||
|
||||
// Close closes the Lua state
|
||||
// Close closes the Lua state and frees resources
|
||||
func (s *State) Close() {
|
||||
if s.L != nil {
|
||||
s.Cleanup() // Clean up Go function registry
|
||||
C.lua_close(s.L)
|
||||
s.L = nil
|
||||
}
|
||||
}
|
||||
|
||||
// DoString executes a Lua string with appropriate stack management
|
||||
func (s *State) DoString(str string) error {
|
||||
cstr := C.CString(str)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
// Stack manipulation methods
|
||||
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
return s.safeCall(func() C.int {
|
||||
return C.do_string(s.L, cstr)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
status := C.do_string(s.L, cstr)
|
||||
if status != 0 {
|
||||
return &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
// GetTop returns the index of the top element in the stack
|
||||
func (s *State) GetTop() int {
|
||||
return int(C.lua_gettop(s.L))
|
||||
}
|
||||
|
||||
// PushValue pushes a Go value onto the stack
|
||||
func (s *State) PushValue(v interface{}) error {
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
return fmt.Errorf("pushing value: %w", err)
|
||||
}
|
||||
return s.pushValueUnsafe(v)
|
||||
})
|
||||
}
|
||||
return s.pushValueUnsafe(v)
|
||||
// SetTop sets the stack top to a specific index
|
||||
func (s *State) SetTop(index int) {
|
||||
C.lua_settop(s.L, C.int(index))
|
||||
}
|
||||
|
||||
func (s *State) pushValueSafe(v interface{}) error {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
return fmt.Errorf("pushing value: %w", err)
|
||||
}
|
||||
return s.pushValueUnsafe(v)
|
||||
// PushCopy pushes a copy of the value at the given index onto the stack
|
||||
func (s *State) PushCopy(index int) {
|
||||
C.lua_pushvalue(s.L, C.int(index))
|
||||
}
|
||||
|
||||
func (s *State) pushValueUnsafe(v interface{}) error {
|
||||
// Pop pops n elements from the stack
|
||||
func (s *State) Pop(n int) {
|
||||
C.lua_settop(s.L, C.int(-n-1))
|
||||
}
|
||||
|
||||
// Remove removes the element at the given valid index
|
||||
func (s *State) Remove(index int) {
|
||||
C.lua_remove(s.L, C.int(index))
|
||||
}
|
||||
|
||||
// absIndex converts a possibly negative index to its absolute position
|
||||
func (s *State) absIndex(index int) int {
|
||||
if index > 0 || index <= LUA_REGISTRYINDEX {
|
||||
return index
|
||||
}
|
||||
return s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
// Type checking methods
|
||||
|
||||
// GetType returns the type of the value at the given index
|
||||
func (s *State) GetType(index int) LuaType {
|
||||
return LuaType(C.lua_type(s.L, C.int(index)))
|
||||
}
|
||||
|
||||
// IsNil checks if the value at the given index is nil
|
||||
func (s *State) IsNil(index int) bool {
|
||||
return s.GetType(index) == TypeNil
|
||||
}
|
||||
|
||||
// IsBoolean checks if the value at the given index is a boolean
|
||||
func (s *State) IsBoolean(index int) bool {
|
||||
return s.GetType(index) == TypeBoolean
|
||||
}
|
||||
|
||||
// IsNumber checks if the value at the given index is a number
|
||||
func (s *State) IsNumber(index int) bool {
|
||||
return C.lua_isnumber(s.L, C.int(index)) != 0
|
||||
}
|
||||
|
||||
// IsString checks if the value at the given index is a string
|
||||
func (s *State) IsString(index int) bool {
|
||||
return C.lua_isstring(s.L, C.int(index)) != 0
|
||||
}
|
||||
|
||||
// IsTable checks if the value at the given index is a table
|
||||
func (s *State) IsTable(index int) bool {
|
||||
return s.GetType(index) == TypeTable
|
||||
}
|
||||
|
||||
// IsFunction checks if the value at the given index is a function
|
||||
func (s *State) IsFunction(index int) bool {
|
||||
return s.GetType(index) == TypeFunction
|
||||
}
|
||||
|
||||
// Value conversion methods
|
||||
|
||||
// ToBoolean returns the value at the given index as a boolean
|
||||
func (s *State) ToBoolean(index int) bool {
|
||||
return C.lua_toboolean(s.L, C.int(index)) != 0
|
||||
}
|
||||
|
||||
// ToNumber returns the value at the given index as a number
|
||||
func (s *State) ToNumber(index int) float64 {
|
||||
return float64(C.lua_tonumber(s.L, C.int(index)))
|
||||
}
|
||||
|
||||
// ToString returns the value at the given index as a string
|
||||
func (s *State) ToString(index int) string {
|
||||
var length C.size_t
|
||||
cstr := C.lua_tolstring(s.L, C.int(index), &length)
|
||||
if cstr == nil {
|
||||
return ""
|
||||
}
|
||||
return C.GoStringN(cstr, C.int(length))
|
||||
}
|
||||
|
||||
// Push methods
|
||||
|
||||
// PushNil pushes a nil value onto the stack
|
||||
func (s *State) PushNil() {
|
||||
C.lua_pushnil(s.L)
|
||||
}
|
||||
|
||||
// PushBoolean pushes a boolean value onto the stack
|
||||
func (s *State) PushBoolean(b bool) {
|
||||
var value C.int
|
||||
if b {
|
||||
value = 1
|
||||
}
|
||||
C.lua_pushboolean(s.L, value)
|
||||
}
|
||||
|
||||
// PushNumber pushes a number value onto the stack
|
||||
func (s *State) PushNumber(n float64) {
|
||||
C.lua_pushnumber(s.L, C.lua_Number(n))
|
||||
}
|
||||
|
||||
// PushString pushes a string value onto the stack
|
||||
func (s *State) PushString(str string) {
|
||||
// Use direct C string for short strings (avoid allocations)
|
||||
if len(str) < 128 {
|
||||
cstr := C.CString(str)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_pushlstring(s.L, cstr, C.size_t(len(str)))
|
||||
return
|
||||
}
|
||||
|
||||
// For longer strings, avoid double copy by using unsafe pointer
|
||||
header := (*struct {
|
||||
p unsafe.Pointer
|
||||
len int
|
||||
cap int
|
||||
})(unsafe.Pointer(&str))
|
||||
|
||||
C.lua_pushlstring(s.L, (*C.char)(header.p), C.size_t(len(str)))
|
||||
}
|
||||
|
||||
// Table operations
|
||||
|
||||
// CreateTable creates a new table and pushes it onto the stack
|
||||
func (s *State) CreateTable(narr, nrec int) {
|
||||
C.lua_createtable(s.L, C.int(narr), C.int(nrec))
|
||||
}
|
||||
|
||||
// NewTable creates a new empty table and pushes it onto the stack
|
||||
func (s *State) NewTable() {
|
||||
C.lua_createtable(s.L, 0, 0)
|
||||
}
|
||||
|
||||
// GetTable gets a table field (t[k]) where t is at the given index and k is at the top of the stack
|
||||
func (s *State) GetTable(index int) {
|
||||
C.lua_gettable(s.L, C.int(index))
|
||||
}
|
||||
|
||||
// SetTable sets a table field (t[k] = v) where t is at the given index, k is at -2, and v is at -1
|
||||
func (s *State) SetTable(index int) {
|
||||
C.lua_settable(s.L, C.int(index))
|
||||
}
|
||||
|
||||
// GetField gets a table field t[k] and pushes it onto the stack
|
||||
func (s *State) GetField(index int, key string) {
|
||||
ckey := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(ckey))
|
||||
C.lua_getfield(s.L, C.int(index), ckey)
|
||||
}
|
||||
|
||||
// SetField sets a table field t[k] = v, where v is the value at the top of the stack
|
||||
func (s *State) SetField(index int, key string) {
|
||||
ckey := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(ckey))
|
||||
C.lua_setfield(s.L, C.int(index), ckey)
|
||||
}
|
||||
|
||||
// Next pops a key from the stack and pushes the next key-value pair from the table at the given index
|
||||
func (s *State) Next(index int) bool {
|
||||
return C.lua_next(s.L, C.int(index)) != 0
|
||||
}
|
||||
|
||||
// PushValue pushes a Go value onto the stack with proper type conversion
|
||||
func (s *State) PushValue(v any) error {
|
||||
switch v := v.(type) {
|
||||
case nil:
|
||||
s.PushNil()
|
||||
case bool:
|
||||
s.PushBoolean(v)
|
||||
case float64:
|
||||
s.PushNumber(v)
|
||||
case int:
|
||||
s.PushNumber(float64(v))
|
||||
case float64:
|
||||
s.PushNumber(v)
|
||||
case string:
|
||||
s.PushString(v)
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
// Special case: handle array stored in map
|
||||
if arr, ok := v[""].([]float64); ok {
|
||||
s.NewTable()
|
||||
s.CreateTable(len(arr), 0)
|
||||
for i, elem := range arr {
|
||||
s.PushNumber(float64(i + 1))
|
||||
s.PushNumber(elem)
|
||||
|
@ -129,19 +277,19 @@ func (s *State) pushValueUnsafe(v interface{}) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
return s.pushTableUnsafe(v)
|
||||
return s.PushTable(v)
|
||||
case []float64:
|
||||
s.NewTable()
|
||||
s.CreateTable(len(v), 0)
|
||||
for i, elem := range v {
|
||||
s.PushNumber(float64(i + 1))
|
||||
s.PushNumber(elem)
|
||||
s.SetTable(-3)
|
||||
}
|
||||
case []interface{}:
|
||||
s.NewTable()
|
||||
case []any:
|
||||
s.CreateTable(len(v), 0)
|
||||
for i, elem := range v {
|
||||
s.PushNumber(float64(i + 1))
|
||||
if err := s.pushValueUnsafe(elem); err != nil {
|
||||
if err := s.PushValue(elem); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetTable(-3)
|
||||
|
@ -152,18 +300,10 @@ func (s *State) pushValueUnsafe(v interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ToValue converts a Lua value to a Go value
|
||||
func (s *State) ToValue(index int) (interface{}, error) {
|
||||
if s.safeStack {
|
||||
return stackGuardValue[interface{}](s, func() (interface{}, error) {
|
||||
return s.toValueUnsafe(index)
|
||||
})
|
||||
}
|
||||
return s.toValueUnsafe(index)
|
||||
}
|
||||
|
||||
func (s *State) toValueUnsafe(index int) (interface{}, error) {
|
||||
switch s.GetType(index) {
|
||||
// ToValue converts a Lua value at the given index to a Go value
|
||||
func (s *State) ToValue(index int) (any, error) {
|
||||
luaType := s.GetType(index)
|
||||
switch luaType {
|
||||
case TypeNil:
|
||||
return nil, nil
|
||||
case TypeBoolean:
|
||||
|
@ -173,151 +313,171 @@ func (s *State) toValueUnsafe(index int) (interface{}, error) {
|
|||
case TypeString:
|
||||
return s.ToString(index), nil
|
||||
case TypeTable:
|
||||
if !s.IsTable(index) {
|
||||
return nil, fmt.Errorf("not a table at index %d", index)
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
return s.ToTable(index)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
|
||||
return nil, fmt.Errorf("unsupported type: %s", luaType)
|
||||
}
|
||||
}
|
||||
|
||||
// Simple operations remain unchanged as they don't need stack protection
|
||||
// Global operations
|
||||
|
||||
func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) }
|
||||
func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction }
|
||||
func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable }
|
||||
func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 }
|
||||
func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) }
|
||||
func (s *State) ToString(index int) string {
|
||||
return C.GoString(C.lua_tolstring(s.L, C.int(index), nil))
|
||||
}
|
||||
func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) }
|
||||
func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) }
|
||||
func (s *State) SetTop(index int) { C.lua_settop(s.L, C.int(index)) }
|
||||
|
||||
// Push operations
|
||||
|
||||
func (s *State) PushNil() { C.lua_pushnil(s.L) }
|
||||
func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, C.int(bool2int(b))) }
|
||||
func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.double(n)) }
|
||||
func (s *State) PushString(str string) {
|
||||
cstr := C.CString(str)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_pushstring(s.L, cstr)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func bool2int(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *State) absIndex(index int) int {
|
||||
if index > 0 || index <= LUA_REGISTRYINDEX {
|
||||
return index
|
||||
}
|
||||
return s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
// SetField sets a field in a table at the given index with cached absolute index
|
||||
func (s *State) SetField(index int, key string) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
cstr := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_setfield(s.L, C.int(absIdx), cstr)
|
||||
}
|
||||
|
||||
// GetField gets a field from a table with cached absolute index
|
||||
func (s *State) GetField(index int, key string) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cstr := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_getfield(s.L, C.int(absIdx), cstr)
|
||||
}
|
||||
|
||||
// GetGlobal gets a global variable and pushes it onto the stack
|
||||
// GetGlobal pushes the global variable with the given name onto the stack
|
||||
func (s *State) GetGlobal(name string) {
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
cstr := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
||||
s.GetField(LUA_GLOBALSINDEX, name)
|
||||
}
|
||||
|
||||
// SetGlobal sets a global variable from the value at the top of the stack
|
||||
// SetGlobal sets the global variable with the given name to the value at the top of the stack
|
||||
func (s *State) SetGlobal(name string) {
|
||||
// SetGlobal doesn't need stack space checking as it pops the value
|
||||
cstr := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
||||
s.SetField(LUA_GLOBALSINDEX, name)
|
||||
}
|
||||
|
||||
// Remove removes element with cached absolute index
|
||||
func (s *State) Remove(index int) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
// Code execution methods
|
||||
|
||||
// LoadString loads a Lua chunk from a string without executing it
|
||||
func (s *State) LoadString(code string) error {
|
||||
ccode := C.CString(code)
|
||||
defer C.free(unsafe.Pointer(ccode))
|
||||
|
||||
status := C.luaL_loadstring(s.L, ccode)
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
}
|
||||
C.lua_remove(s.L, C.int(absIdx))
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoFile executes a Lua file with appropriate stack management
|
||||
// LoadFile loads a Lua chunk from a file without executing it
|
||||
func (s *State) LoadFile(filename string) error {
|
||||
cfilename := C.CString(filename)
|
||||
defer C.free(unsafe.Pointer(cfilename))
|
||||
|
||||
status := C.luaL_loadfile(s.L, cfilename)
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call calls a function with the given number of arguments and results
|
||||
func (s *State) Call(nargs, nresults int) error {
|
||||
status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoString executes a Lua string and cleans up the stack
|
||||
func (s *State) DoString(code string) error {
|
||||
ccode := C.CString(code)
|
||||
defer C.free(unsafe.Pointer(ccode))
|
||||
|
||||
status := C.do_string(s.L, ccode)
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoFile executes a Lua file and cleans up the stack
|
||||
func (s *State) DoFile(filename string) error {
|
||||
cfilename := C.CString(filename)
|
||||
defer C.free(unsafe.Pointer(cfilename))
|
||||
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
return s.safeCall(func() C.int {
|
||||
return C.do_file(s.L, cfilename)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
status := C.do_file(s.L, cfilename)
|
||||
if status != 0 {
|
||||
return &LuaError{
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute executes a Lua string and returns the number of results left on the stack
|
||||
func (s *State) Execute(code string) (int, error) {
|
||||
baseTop := s.GetTop()
|
||||
|
||||
ccode := C.CString(code)
|
||||
defer C.free(unsafe.Pointer(ccode))
|
||||
|
||||
status := C.execute_with_results(s.L, ccode, 1) // store_results=true
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return s.GetTop() - baseTop, nil
|
||||
}
|
||||
|
||||
// ExecuteWithResult executes a Lua string and returns the first result
|
||||
func (s *State) ExecuteWithResult(code string) (any, error) {
|
||||
top := s.GetTop()
|
||||
defer s.SetTop(top) // Restore stack when done
|
||||
|
||||
nresults, err := s.Execute(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if nresults == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s.ToValue(-nresults)
|
||||
}
|
||||
|
||||
// BatchExecute executes multiple statements with a single CGO transition
|
||||
func (s *State) BatchExecute(statements []string) error {
|
||||
// Join statements with semicolons
|
||||
combinedCode := ""
|
||||
for i, stmt := range statements {
|
||||
combinedCode += stmt
|
||||
if i < len(statements)-1 {
|
||||
combinedCode += "; "
|
||||
}
|
||||
}
|
||||
|
||||
return s.DoString(combinedCode)
|
||||
}
|
||||
|
||||
// Package path operations
|
||||
|
||||
// SetPackagePath sets the Lua package.path
|
||||
func (s *State) SetPackagePath(path string) error {
|
||||
path = filepath.ToSlash(path)
|
||||
if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil {
|
||||
return fmt.Errorf("setting package.path: %w", err)
|
||||
}
|
||||
return nil
|
||||
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
|
||||
code := fmt.Sprintf(`package.path = %q`, path)
|
||||
return s.DoString(code)
|
||||
}
|
||||
|
||||
// AddPackagePath adds a path to package.path
|
||||
func (s *State) AddPackagePath(path string) error {
|
||||
path = filepath.ToSlash(path)
|
||||
if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil {
|
||||
return fmt.Errorf("adding to package.path: %w", err)
|
||||
}
|
||||
return nil
|
||||
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
|
||||
code := fmt.Sprintf(`package.path = package.path .. ";%s"`, path)
|
||||
return s.DoString(code)
|
||||
}
|
||||
|
|
276
wrapper_test.go
276
wrapper_test.go
|
@ -1,276 +0,0 @@
|
|||
package luajit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type stateFactory struct {
|
||||
name string
|
||||
new func() *State
|
||||
}
|
||||
|
||||
var factories = []stateFactory{
|
||||
{"unsafe", New},
|
||||
{"safe", NewSafe},
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
wantErr bool
|
||||
}{
|
||||
{"simple addition", "return 1 + 1", false},
|
||||
{"set global", "test = 42", false},
|
||||
{"syntax error", "this is not valid lua", true},
|
||||
{"runtime error", "error('test error')", true},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, tt := range tests {
|
||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
err := L.DoString(tt.code)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushAndGetValues(t *testing.T) {
|
||||
values := []struct {
|
||||
name string
|
||||
push func(*State)
|
||||
check func(*State) error
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
push: func(L *State) { L.PushString("hello") },
|
||||
check: func(L *State) error {
|
||||
if got := L.ToString(-1); got != "hello" {
|
||||
return fmt.Errorf("got %q, want %q", got, "hello")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number",
|
||||
push: func(L *State) { L.PushNumber(42.5) },
|
||||
check: func(L *State) error {
|
||||
if got := L.ToNumber(-1); got != 42.5 {
|
||||
return fmt.Errorf("got %f, want %f", got, 42.5)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "boolean",
|
||||
push: func(L *State) { L.PushBoolean(true) },
|
||||
check: func(L *State) error {
|
||||
if got := L.ToBoolean(-1); !got {
|
||||
return fmt.Errorf("got %v, want true", got)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
push: func(L *State) { L.PushNil() },
|
||||
check: func(L *State) error {
|
||||
if typ := L.GetType(-1); typ != TypeNil {
|
||||
return fmt.Errorf("got type %v, want TypeNil", typ)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, v := range values {
|
||||
t.Run(f.name+"/"+v.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
v.push(L)
|
||||
if err := v.check(L); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStackManipulation(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Push values
|
||||
values := []string{"first", "second", "third"}
|
||||
for _, v := range values {
|
||||
L.PushString(v)
|
||||
}
|
||||
|
||||
// Check size
|
||||
if top := L.GetTop(); top != len(values) {
|
||||
t.Errorf("stack size = %d, want %d", top, len(values))
|
||||
}
|
||||
|
||||
// Pop one value
|
||||
L.Pop(1)
|
||||
|
||||
// Check new top
|
||||
if str := L.ToString(-1); str != "second" {
|
||||
t.Errorf("top value = %q, want 'second'", str)
|
||||
}
|
||||
|
||||
// Check new size
|
||||
if top := L.GetTop(); top != len(values)-1 {
|
||||
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobals(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Test via Lua
|
||||
if err := L.DoString(`globalVar = "test"`); err != nil {
|
||||
t.Fatalf("DoString error: %v", err)
|
||||
}
|
||||
|
||||
// Get the global
|
||||
L.GetGlobal("globalVar")
|
||||
if str := L.ToString(-1); str != "test" {
|
||||
t.Errorf("global value = %q, want 'test'", str)
|
||||
}
|
||||
L.Pop(1)
|
||||
|
||||
// Set and get via API
|
||||
L.PushNumber(42)
|
||||
L.SetGlobal("testNum")
|
||||
|
||||
L.GetGlobal("testNum")
|
||||
if num := L.ToNumber(-1); num != 42 {
|
||||
t.Errorf("global number = %f, want 42", num)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoFile(t *testing.T) {
|
||||
L := NewSafe()
|
||||
defer L.Close()
|
||||
|
||||
// Create test file
|
||||
content := []byte(`
|
||||
function add(a, b)
|
||||
return a + b
|
||||
end
|
||||
result = add(40, 2)
|
||||
`)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
filename := filepath.Join(tmpDir, "test.lua")
|
||||
if err := os.WriteFile(filename, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
if err := L.DoFile(filename); err != nil {
|
||||
t.Fatalf("DoFile failed: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("result")
|
||||
if result := L.ToNumber(-1); result != 42 {
|
||||
t.Errorf("Expected result=42, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAndPackagePath(t *testing.T) {
|
||||
L := NewSafe()
|
||||
defer L.Close()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create module file
|
||||
moduleContent := []byte(`
|
||||
local M = {}
|
||||
function M.multiply(a, b)
|
||||
return a * b
|
||||
end
|
||||
return M
|
||||
`)
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "mathmod.lua"), moduleContent, 0644); err != nil {
|
||||
t.Fatalf("Failed to create module file: %v", err)
|
||||
}
|
||||
|
||||
// Add module path and test require
|
||||
if err := L.AddPackagePath(filepath.Join(tmpDir, "?.lua")); err != nil {
|
||||
t.Fatalf("AddPackagePath failed: %v", err)
|
||||
}
|
||||
|
||||
if err := L.DoString(`
|
||||
local math = require("mathmod")
|
||||
result = math.multiply(6, 7)
|
||||
`); err != nil {
|
||||
t.Fatalf("Failed to require module: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("result")
|
||||
if result := L.ToNumber(-1); result != 42 {
|
||||
t.Errorf("Expected result=42, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetPackagePath(t *testing.T) {
|
||||
L := NewSafe()
|
||||
defer L.Close()
|
||||
|
||||
customPath := "./custom/?.lua"
|
||||
if err := L.SetPackagePath(customPath); err != nil {
|
||||
t.Fatalf("SetPackagePath failed: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("package")
|
||||
L.GetField(-1, "path")
|
||||
if path := L.ToString(-1); path != customPath {
|
||||
t.Errorf("Expected package.path=%q, got %q", customPath, path)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user