Compare commits

..

7 Commits

Author SHA1 Message Date
14009697f0 add byte slice to string conversion to PushValue 2025-06-06 17:38:38 -05:00
f47d36eb8b update docs 2025-06-03 17:01:56 -05:00
516d66c2f2 maps always return any 2025-06-03 16:56:58 -05:00
4e59fc6e5c add map array support to pushvalue 2025-06-02 21:04:06 -05:00
c22638b51f update readme 2025-05-31 17:54:08 -05:00
202664f635 update documentation 2025-05-31 17:49:25 -05:00
f4bfff470f massive rewrite
fix go func mallocs
add helper utils
2025-05-31 17:42:58 -05:00
14 changed files with 1652 additions and 976 deletions

82
API.md Normal file
View File

@ -0,0 +1,82 @@
# API Quick Reference
## Core State
- New(openLibs ...bool) *State
- Close()
- Cleanup()
## Stack
- GetTop() int
- SetTop(index int)
- Pop(n int)
- PushCopy(index int)
- Remove(index int)
## Type Checks
- GetType(index int) LuaType
- IsNil/IsBoolean/IsNumber/IsString/IsTable/IsFunction(index int) bool
## Values
- ToString/ToNumber/ToBoolean(index int) T
- ToValue(index int) (any, error)
- ToTable(index int) (any, error)
- PushNil/PushBoolean/PushNumber/PushString/PushValue()
## Tables
- NewTable()
- CreateTable(narr, nrec int)
- GetTable/SetTable(index int)
- GetField/SetField(index int, key string)
- GetFieldString/Number/Bool/Table(index int, key string, default T) T
- GetTableLength(index int) int
- Next(index int) bool
- ForEachTableKV/ForEachArray(index int, fn func)
- NewTableBuilder() *TableBuilder
## Functions
- RegisterGoFunction(name string, fn GoFunction) error
- UnregisterGoFunction(name string)
- PushGoFunction(fn GoFunction) error
- Call(nargs, nresults int) error
- CallGlobal(name string, args ...any) ([]any, error)
## Globals
- GetGlobal/SetGlobal(name string)
## Execution
- LoadString/LoadFile(source string) error
- DoString/DoFile(source string) error
- Execute(code string) (int, error)
- ExecuteWithResult(code string) (any, error)
- BatchExecute(statements []string) error
## Bytecode
- CompileBytecode(code, name string) ([]byte, error)
- LoadBytecode(bytecode []byte, name string) error
- RunBytecode() error
- RunBytecodeWithResults(nresults int) error
- LoadAndRunBytecode(bytecode []byte, name string) error
- LoadAndRunBytecodeWithResults(bytecode []byte, name string, nresults int) error
- CompileAndRun(code, name string) error
## Validation
- CheckArgs(specs ...ArgSpec) error
- CheckMinArgs/CheckExactArgs(n int) error
- SafeToString/Number/Table(index int) (T, error)
## Error Handling
- PushError(format string, args ...any) int
- GetStackTrace() string
- GetErrorInfo(context string) *LuaError
- CreateLuaError(code int, context string) *LuaError
## Package
- SetPackagePath/AddPackagePath(path string) error
## Metatable
- SetMetatable(index int)
- GetMetatable(index int) bool
## Constants
- TypeNil/Boolean/Number/String/Table/Function/UserData/Thread
- LUA_MINSTACK/MAXSTACK/REGISTRYINDEX/GLOBALSINDEX

159
DOCS.md
View File

@ -56,17 +56,6 @@ L.Remove(-1) // Remove top element
L.Remove(1) // Remove first element L.Remove(1) // Remove first element
``` ```
### absIndex(index int) int
Internal function that converts a possibly negative index to its absolute position.
### checkStack(n int) error
Internal function that ensures there's enough space for n new elements.
```go
if err := L.checkStack(2); err != nil {
return err
}
```
## Type Checks ## Type Checks
### GetType(index int) LuaType ### GetType(index int) LuaType
@ -111,7 +100,7 @@ bool := L.ToBoolean(-1)
``` ```
### ToValue(index int) (any, error) ### ToValue(index int) (any, error)
Converts any Lua value to its Go equivalent. Converts any Lua value to its Go equivalent with automatic type detection.
```go ```go
val, err := L.ToValue(-1) val, err := L.ToValue(-1)
if err != nil { if err != nil {
@ -119,8 +108,8 @@ if err != nil {
} }
``` ```
### ToTable(index int) (map[string]any, error) ### ToTable(index int) (any, error)
Converts a Lua table to a Go map. Converts a Lua table to optimal Go type; arrays or `map[string]any`.
```go ```go
table, err := L.ToTable(-1) table, err := L.ToTable(-1)
if err != nil { if err != nil {
@ -149,19 +138,13 @@ L.PushNil()
``` ```
### PushValue(v any) error ### PushValue(v any) error
Pushes any Go value onto the stack. Pushes any Go value onto the stack with comprehensive type support.
```go ```go
err := L.PushValue(myValue) // Supports: primitives, slices, maps with various type combinations
``` err := L.PushValue(map[string]any{
### PushTable(table map[string]any) error
Pushes a Go map as a Lua table.
```go
data := map[string]any{
"key": "value", "key": "value",
"numbers": []float64{1, 2, 3}, "numbers": []float64{1, 2, 3},
} })
err := L.PushTable(data)
``` ```
## Table Operations ## Table Operations
@ -218,6 +201,25 @@ for L.Next(-2) {
} }
``` ```
### GetFieldString/Number/Bool/Table(index int, key string, default T) T
Get typed fields from tables with default values.
```go
name := L.GetFieldString(-1, "name", "unknown")
age := L.GetFieldNumber(-1, "age", 0)
active := L.GetFieldBool(-1, "active", false)
config, ok := L.GetFieldTable(-1, "config")
```
### ForEachTableKV(index int, fn func(key, value string) bool)
### ForEachArray(index int, fn func(i int, state *State) bool)
Convenient iteration helpers.
```go
L.ForEachTableKV(-1, func(key, value string) bool {
fmt.Printf("%s: %s\n", key, value)
return true // continue iteration
})
```
## Function Registration and Calling ## Function Registration and Calling
### GoFunction ### GoFunction
@ -258,6 +260,12 @@ L.PushNumber(2)
err := L.Call(2, 1) // Call with 2 args, expect 1 result err := L.Call(2, 1) // Call with 2 args, expect 1 result
``` ```
### CallGlobal(name string, args ...any) ([]any, error)
Calls a global function with arguments and returns all results.
```go
results, err := L.CallGlobal("myfunction", 1, 2, "hello")
```
## Global Operations ## Global Operations
### GetGlobal(name string) ### GetGlobal(name string)
@ -317,6 +325,16 @@ result, err := L.ExecuteWithResult("return 'hello'")
// result would be "hello" // result would be "hello"
``` ```
### BatchExecute(statements []string) error
Executes multiple statements as a single batch.
```go
err := L.BatchExecute([]string{
"x = 10",
"y = 20",
"result = x + y",
})
```
## Bytecode Operations ## Bytecode Operations
### CompileBytecode(code string, name string) ([]byte, error) ### CompileBytecode(code string, name string) ([]byte, error)
@ -375,14 +393,35 @@ Adds a path to package.path.
err := L.AddPackagePath("./modules/?.lua") err := L.AddPackagePath("./modules/?.lua")
``` ```
## Metatable Operations
### SetMetatable(index int)
Sets the metatable for the value at the given index.
```go
L.SetMetatable(-1)
```
### GetMetatable(index int) bool
Gets the metatable for the value at the given index.
```go
if L.GetMetatable(-1) {
// Metatable is now on stack
L.Pop(1)
}
```
## Error Handling ## Error Handling
### LuaError ### LuaError
Error type containing both an error code and message. Enhanced error type with detailed context information.
```go ```go
type LuaError struct { type LuaError struct {
Code int Code int
Message string Message string
File string
Line int
StackTrace string
Context string
} }
``` ```
@ -393,12 +432,62 @@ trace := L.GetStackTrace()
fmt.Println(trace) fmt.Println(trace)
``` ```
### safeCall(f func() C.int) error ### GetErrorInfo(context string) *LuaError
Internal function that wraps a potentially dangerous C call with stack checking. Extracts detailed error information from the Lua stack.
```go ```go
err := s.safeCall(func() C.int { err := L.GetErrorInfo("MyFunction")
return C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0) ```
})
### CreateLuaError(code int, context string) *LuaError
Creates a LuaError with full context information.
```go
err := L.CreateLuaError(status, "DoString")
```
### PushError(format string, args ...any) int
Pushes an error string and returns -1.
```go
return s.PushError("invalid argument: %v", arg)
```
## Validation
### CheckArgs(specs ...ArgSpec) error
Validates function arguments against specifications.
```go
err := s.CheckArgs(
ArgSpec{Name: "name", Type: "string", Required: true, Check: CheckString},
ArgSpec{Name: "age", Type: "number", Required: false, Check: CheckNumber},
)
```
### CheckMinArgs/CheckExactArgs(n int) error
Argument count validation.
```go
if err := s.CheckMinArgs(2); err != nil {
return s.PushError(err.Error())
}
```
### SafeToString/Number/Table(index int) (T, error)
Safe value conversion with error handling.
```go
str, err := s.SafeToString(1)
if err != nil {
return s.PushError(err.Error())
}
```
## Table Building
### NewTableBuilder() *TableBuilder
Creates a new table builder for fluent table construction.
```go
L.NewTableBuilder().
SetString("name", "John").
SetNumber("age", 30).
SetBool("active", true).
Build()
``` ```
## Thread Safety Notes ## Thread Safety Notes
@ -423,9 +512,3 @@ L.PushString("hello")
// ... use the string // ... use the string
L.Pop(1) // Clean up when done L.Pop(1) // Clean up when done
``` ```
Sandbox management:
```go
sandbox := luajit.NewSandbox()
defer sandbox.Close()
```

177
README.md
View File

@ -18,8 +18,6 @@ First, grab the package:
go get git.sharkk.net/Sky/LuaJIT-to-Go go get git.sharkk.net/Sky/LuaJIT-to-Go
``` ```
You'll need LuaJIT's development files, but don't worry - we include libraries for Windows and Linux in the vendor directory.
Here's the simplest thing you can do: Here's the simplest thing you can do:
```go ```go
L := luajit.New() // pass false to not load standard libs L := luajit.New() // pass false to not load standard libs
@ -36,25 +34,21 @@ Need even more performance? You can compile your Lua code to bytecode and reuse
```go ```go
// Compile once // Compile once
bytecode, err := L.CompileBytecode(` bytecode, err := L.CompileBytecode(`
local function calculate(x) local function calculate(x)
return x * x + x + 1 return x * x + x + 1
end end
return calculate(10) return calculate(10)
`, "calc") `, "calc")
// Execute many times // Execute many times
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
err := L.LoadAndRunBytecode(bytecode, "calc") err := L.LoadAndRunBytecode(bytecode, "calc")
} }
// Or do both at once // Or do both at once
err := L.CompileAndRun(`return "hello"`, "greeting") err := L.CompileAndRun(`return "hello"`, "greeting")
``` ```
### When to Use Bytecode
Bytecode execution is consistently faster than direct execution:
``` ```
Benchmark Ops/sec Comparison Benchmark Ops/sec Comparison
---------------------------------------------------------------------------- ----------------------------------------------------------------------------
@ -70,21 +64,15 @@ BenchmarkComplexScript 33,133 Base
BenchmarkComplexScriptPrecompiled 41,044 +23.9% faster BenchmarkComplexScriptPrecompiled 41,044 +23.9% faster
``` ```
Use bytecode when you:
- Have code that runs frequently
- Need maximum performance
- Want to precompile your Lua code
- Are distributing Lua code to many instances
## Registering Go Functions ## Registering Go Functions
Want to call Go code from Lua? It's straightforward: Want to call Go code from Lua? It's straightforward:
```go ```go
// This function adds two numbers and returns the result // This function adds two numbers and returns the result
adder := func(s *luajit.State) int { adder := func(s *luajit.State) int {
sum := s.ToNumber(1) + s.ToNumber(2) sum := s.ToNumber(1) + s.ToNumber(2)
s.PushNumber(sum) s.PushNumber(sum)
return 1 // we're returning one value return 1 // we're returning one value
} }
L.RegisterGoFunction("add", adder) L.RegisterGoFunction("add", adder)
@ -95,22 +83,62 @@ Now in Lua:
result = add(40, 2) -- result = 42 result = add(40, 2) -- result = 42
``` ```
### Function Validation
Validate arguments easily:
```go
calculator := func(s *luajit.State) int {
if err := s.CheckArgs(
luajit.ArgSpec{Name: "x", Type: "number", Required: true, Check: luajit.CheckNumber},
luajit.ArgSpec{Name: "y", Type: "number", Required: true, Check: luajit.CheckNumber},
); err != nil {
return s.PushError(err.Error())
}
result := s.ToNumber(1) + s.ToNumber(2)
s.PushNumber(result)
return 1
}
```
## Working with Tables ## Working with Tables
Lua tables are pretty powerful - they're like a mix of Go's maps and slices. We make it easy to work with them: Lua tables are powerful - they're like a mix of Go's maps and slices. We make it easy to work with them:
```go ```go
// Go → Lua // Go → Lua
stuff := map[string]any{ stuff := map[string]any{
"name": "Arthur Dent", "name": "Arthur Dent",
"age": 30, "age": 30,
"items": []float64{1, 2, 3}, "items": []float64{1, 2, 3},
} }
L.PushTable(stuff) L.PushValue(stuff) // Handles all Go types automatically
// Lua → Go // Lua → Go with automatic type detection
L.GetGlobal("some_table") L.GetGlobal("some_table")
result, err := L.ToTable(-1) result, err := L.ToTable(-1) // Returns optimal Go type (typed array, or map[string]any)
```
### Table Builder
Build tables fluently:
```go
L.NewTableBuilder().
SetString("name", "John").
SetNumber("age", 30).
SetBool("active", true).
SetArray("scores", []any{95, 87, 92}).
Build()
```
### Table Field Access
Get fields with defaults:
```go
L.GetGlobal("config")
host := L.GetFieldString(-1, "host", "localhost")
port := L.GetFieldNumber(-1, "port", 8080)
debug := L.GetFieldBool(-1, "debug", false)
``` ```
## Error Handling ## Error Handling
@ -118,34 +146,31 @@ result, err := L.ToTable(-1)
We provide useful errors instead of mysterious panics: We provide useful errors instead of mysterious panics:
```go ```go
if err := L.DoString("this isn't valid Lua!"); err != nil { if err := L.DoString("this isn't valid Lua!"); err != nil {
if luaErr, ok := err.(*luajit.LuaError); ok { if luaErr, ok := err.(*luajit.LuaError); ok {
fmt.Printf("Error: %s\n", luaErr.Message) fmt.Printf("Error in %s:%d - %s\n", luaErr.File, luaErr.Line, luaErr.Message)
} fmt.Printf("Stack trace:\n%s\n", luaErr.StackTrace)
}
} }
``` ```
## Memory Management ## Memory Management
The wrapper uses a custom table pooling system to reduce GC pressure when handling many tables: The wrapper uses bytecode buffer pooling to reduce allocations:
```go ```go
// Tables are pooled and reused internally for better performance // Bytecode buffers are pooled and reused internally
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
L.GetGlobal("table") bytecode, _ := L.CompileBytecode(code, "test")
table, _ := L.ToTable(-1) // Buffer automatically returned to pool
// Use table...
L.Pop(1)
// Table is automatically returned to pool
} }
``` ```
The sandbox also manages its environment efficiently: Function pointers are managed safely:
```go ```go
// Environment objects are pooled and reused // Functions are registered in a thread-safe registry
for i := 0; i < 1000; i++ { L.RegisterGoFunction("myFunc", myGoFunc)
result, _ := sandbox.Run("return i + 1") defer L.Cleanup() // Cleans up all registered functions
}
``` ```
## Best Practices ## Best Practices
@ -160,7 +185,12 @@ for i := 0; i < 1000; i++ {
### Bytecode Optimization ### Bytecode Optimization
- Use bytecode for frequently executed code paths - Use bytecode for frequently executed code paths
- Consider compiling critical Lua code to bytecode at startup - Consider compiling critical Lua code to bytecode at startup
- For small scripts (< 1024 bytes), direct execution might be faster - For small scripts (< 1024 bytes), direct execution might be faster due to compilation overhead
### Type Conversion
- Use `ToTable()` for automagic type detection and optimized Go arrays/maps
- Use `PushValue()` for automagic Go-to-Lua conversion
- Leverage typed field accessors for config-style tables
## Advanced Features ## Advanced Features
@ -186,11 +216,11 @@ Bytecode properly preserves closures and upvalues:
```go ```go
code := ` code := `
local counter = 0 local counter = 0
return function() return function()
counter = counter + 1 counter = counter + 1
return counter return counter
end end
` `
bytecode, _ := L.CompileBytecode(code, "counter") bytecode, _ := L.CompileBytecode(code, "counter")
@ -198,14 +228,53 @@ L.LoadAndRunBytecodeWithResults(bytecode, "counter", 1)
L.SetGlobal("increment") L.SetGlobal("increment")
// Later... // Later...
L.GetGlobal("increment") results, _ := L.CallGlobal("increment") // Returns []any{1}
L.Call(0, 1) // Returns 1 results, _ = L.CallGlobal("increment") // Returns []any{2}
L.Pop(1)
L.GetGlobal("increment")
L.Call(0, 1) // Returns 2
``` ```
### Batch Execution
Execute multiple statements efficiently:
```go
statements := []string{
"x = 10",
"y = 20",
"result = x + y",
}
err := L.BatchExecute(statements)
```
### Package Path Management
Manage Lua module paths:
```go
L.SetPackagePath("./?.lua;./modules/?.lua")
L.AddPackagePath("./vendor/?.lua")
```
### Type Conversion System
The wrapper includes a comprehensive type conversion system:
```go
// Get typed values with automatic conversion
value, ok := luajit.GetTypedValue[int](L, -1)
global, ok := luajit.GetGlobalTyped[[]string](L, "myArray")
// Convert between compatible types
result, ok := luajit.ConvertValue[map[string]int](someMap)
```
## Performance Tips
- Use bytecode for repeated execution
- Prefer `CallGlobal()` for simple function calls
- Use typed field accessors for configuration parsing
- Leverage automatic type detection in `ToTable()`
- Pool your Lua states for high-throughput scenarios
## Need Help? ## Need Help?
Check out the tests in the repository - they're full of examples. If you're stuck, open an issue! We're here to help. Check out the tests in the repository - they're full of examples. If you're stuck, open an issue! We're here to help.

View File

@ -321,37 +321,37 @@ func BenchmarkComplexScript(b *testing.B) {
-- Define a simple class -- Define a simple class
local Class = {} local Class = {}
Class.__index = Class Class.__index = Class
function Class.new(x, y) function Class.new(x, y)
local self = setmetatable({}, Class) local self = setmetatable({}, Class)
self.x = x or 0 self.x = x or 0
self.y = y or 0 self.y = y or 0
return self return self
end end
function Class:move(dx, dy) function Class:move(dx, dy)
self.x = self.x + dx self.x = self.x + dx
self.y = self.y + dy self.y = self.y + dy
return self return self
end end
function Class:getPosition() function Class:getPosition()
return self.x, self.y return self.x, self.y
end end
-- Create instances and operate on them -- Create instances and operate on them
local instances = {} local instances = {}
for i = 1, 50 do for i = 1, 50 do
instances[i] = Class.new(i, i*2) instances[i] = Class.new(i, i*2)
end end
local result = 0 local result = 0
for i, obj in ipairs(instances) do for i, obj in ipairs(instances) do
obj:move(i, -i) obj:move(i, -i)
local x, y = obj:getPosition() local x, y = obj:getPosition()
result = result + x + y result = result + x + y
end end
return result return result
` `
b.ResetTimer() b.ResetTimer()
@ -377,37 +377,37 @@ func BenchmarkComplexScriptPrecompiled(b *testing.B) {
-- Define a simple class -- Define a simple class
local Class = {} local Class = {}
Class.__index = Class Class.__index = Class
function Class.new(x, y) function Class.new(x, y)
local self = setmetatable({}, Class) local self = setmetatable({}, Class)
self.x = x or 0 self.x = x or 0
self.y = y or 0 self.y = y or 0
return self return self
end end
function Class:move(dx, dy) function Class:move(dx, dy)
self.x = self.x + dx self.x = self.x + dx
self.y = self.y + dy self.y = self.y + dy
return self return self
end end
function Class:getPosition() function Class:getPosition()
return self.x, self.y return self.x, self.y
end end
-- Create instances and operate on them -- Create instances and operate on them
local instances = {} local instances = {}
for i = 1, 50 do for i = 1, 50 do
instances[i] = Class.new(i, i*2) instances[i] = Class.new(i, i*2)
end end
local result = 0 local result = 0
for i, obj in ipairs(instances) do for i, obj in ipairs(instances) do
obj:move(i, -i) obj:move(i, -i)
local x, y = obj:getPosition() local x, y = obj:getPosition()
result = result + x + y result = result + x + y
end end
return result return result
` `
bytecode, err := state.CompileBytecode(code, "complex") bytecode, err := state.CompileBytecode(code, "complex")

72
builder.go Normal file
View File

@ -0,0 +1,72 @@
package luajit
// TableBuilder provides a fluent interface for building Lua tables
type TableBuilder struct {
state *State
index int
}
// NewTableBuilder creates a new table and returns a builder
func (s *State) NewTableBuilder() *TableBuilder {
s.NewTable()
return &TableBuilder{
state: s,
index: s.GetTop(),
}
}
// SetString sets a string field
func (tb *TableBuilder) SetString(key, value string) *TableBuilder {
tb.state.PushString(value)
tb.state.SetField(tb.index, key)
return tb
}
// SetNumber sets a number field
func (tb *TableBuilder) SetNumber(key string, value float64) *TableBuilder {
tb.state.PushNumber(value)
tb.state.SetField(tb.index, key)
return tb
}
// SetBool sets a boolean field
func (tb *TableBuilder) SetBool(key string, value bool) *TableBuilder {
tb.state.PushBoolean(value)
tb.state.SetField(tb.index, key)
return tb
}
// SetNil sets a nil field
func (tb *TableBuilder) SetNil(key string) *TableBuilder {
tb.state.PushNil()
tb.state.SetField(tb.index, key)
return tb
}
// SetTable sets a table field
func (tb *TableBuilder) SetTable(key string, value any) *TableBuilder {
if err := tb.state.PushValue(value); err == nil {
tb.state.SetField(tb.index, key)
}
return tb
}
// SetArray sets an array field
func (tb *TableBuilder) SetArray(key string, values []any) *TableBuilder {
tb.state.CreateTable(len(values), 0)
for i, v := range values {
tb.state.PushNumber(float64(i + 1))
if err := tb.state.PushValue(v); err == nil {
tb.state.SetTable(-3)
} else {
tb.state.Pop(1)
}
}
tb.state.SetField(tb.index, key)
return tb
}
// Build finalizes the table (no-op, table is already on stack)
func (tb *TableBuilder) Build() {
// Table is already on the stack at tb.index
}

View File

@ -12,6 +12,12 @@ typedef struct {
const char *name; const char *name;
} BytecodeReader; } BytecodeReader;
typedef struct {
unsigned char *buf;
size_t size;
size_t capacity;
} BytecodeBuffer;
const char *bytecode_reader(lua_State *L, void *ud, size_t *size) { const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
BytecodeReader *r = (BytecodeReader *)ud; BytecodeReader *r = (BytecodeReader *)ud;
(void)L; // unused (void)L; // unused
@ -26,16 +32,24 @@ int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char
return lua_load(L, bytecode_reader, &reader, name); return lua_load(L, bytecode_reader, &reader, name);
} }
// Direct bytecode dumping without intermediate buffer - more efficient // Optimized bytecode writer with pre-allocated buffer
int direct_bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) { int buffered_bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
void **data = (void **)ud; BytecodeBuffer *buf = (BytecodeBuffer *)ud;
size_t current_size = (size_t)data[1];
void *newbuf = realloc(data[0], current_size + sz);
if (newbuf == NULL) return 1;
memcpy((unsigned char*)newbuf + current_size, p, sz); // Grow buffer if needed (double size to avoid frequent reallocs)
data[0] = newbuf; if (buf->size + sz > buf->capacity) {
data[1] = (void*)(current_size + sz); size_t new_capacity = buf->capacity;
while (new_capacity < buf->size + sz) {
new_capacity *= 2;
}
unsigned char *newbuf = realloc(buf->buf, new_capacity);
if (newbuf == NULL) return 1;
buf->buf = newbuf;
buf->capacity = new_capacity;
}
memcpy(buf->buf + buf->size, p, sz);
buf->size += sz;
return 0; return 0;
} }
@ -52,36 +66,56 @@ int load_and_run_bytecode(lua_State *L, const unsigned char *buf, size_t len,
import "C" import "C"
import ( import (
"fmt" "fmt"
"sync"
"unsafe" "unsafe"
) )
// bytecodeBuffer wraps []byte to avoid boxing allocations in sync.Pool
type bytecodeBuffer struct {
data []byte
}
// Buffer pool for bytecode generation
var bytecodeBufferPool = sync.Pool{
New: func() any {
return &bytecodeBuffer{data: make([]byte, 0, 1024)}
},
}
// CompileBytecode compiles a Lua chunk to bytecode without executing it // CompileBytecode compiles a Lua chunk to bytecode without executing it
func (s *State) CompileBytecode(code string, name string) ([]byte, error) { func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
if err := s.LoadString(code); err != nil { if err := s.LoadString(code); err != nil {
return nil, fmt.Errorf("failed to load string: %w", err) return nil, fmt.Errorf("failed to load string: %w", err)
} }
// Use a simpler direct writer with just two pointers // Always use C memory for dump operation to avoid cgo pointer issues
data := [2]unsafe.Pointer{nil, nil} cbuf := C.BytecodeBuffer{
buf: (*C.uchar)(C.malloc(1024)),
size: 0,
capacity: 1024,
}
if cbuf.buf == nil {
return nil, fmt.Errorf("failed to allocate initial buffer")
}
// Dump the function to bytecode // Dump the function to bytecode
status := C.lua_dump(s.L, (*[0]byte)(unsafe.Pointer(C.direct_bytecode_writer)), unsafe.Pointer(&data)) status := C.lua_dump(s.L, (*[0]byte)(unsafe.Pointer(C.buffered_bytecode_writer)), unsafe.Pointer(&cbuf))
if status != 0 {
return nil, fmt.Errorf("failed to dump bytecode: status %d", status)
}
// Get result
var bytecode []byte
if data[0] != nil {
// Create Go slice that references the C memory
length := uintptr(data[1])
bytecode = C.GoBytes(data[0], C.int(length))
C.free(data[0])
}
s.Pop(1) // Remove the function from stack s.Pop(1) // Remove the function from stack
return bytecode, nil if status != 0 {
C.free(unsafe.Pointer(cbuf.buf))
return nil, fmt.Errorf("failed to dump bytecode: status %d", status)
}
// Copy to Go memory and free C buffer
var result []byte
if cbuf.size > 0 {
result = C.GoBytes(unsafe.Pointer(cbuf.buf), C.int(cbuf.size))
}
C.free(unsafe.Pointer(cbuf.buf))
return result, nil
} }
// LoadBytecode loads precompiled bytecode without executing it // LoadBytecode loads precompiled bytecode without executing it
@ -116,7 +150,6 @@ func (s *State) RunBytecode() error {
} }
// RunBytecodeWithResults executes bytecode and keeps nresults on the stack // RunBytecodeWithResults executes bytecode and keeps nresults on the stack
// Use LUA_MULTRET (-1) to keep all results
func (s *State) RunBytecodeWithResults(nresults int) error { func (s *State) RunBytecodeWithResults(nresults int) error {
status := C.lua_pcall(s.L, 0, C.int(nresults), 0) status := C.lua_pcall(s.L, 0, C.int(nresults), 0)
if status != 0 { if status != 0 {
@ -136,13 +169,12 @@ func (s *State) LoadAndRunBytecode(bytecode []byte, name string) error {
cname := C.CString(name) cname := C.CString(name)
defer C.free(unsafe.Pointer(cname)) defer C.free(unsafe.Pointer(cname))
// Use combined load and run function
status := C.load_and_run_bytecode( status := C.load_and_run_bytecode(
s.L, s.L,
(*C.uchar)(unsafe.Pointer(&bytecode[0])), (*C.uchar)(unsafe.Pointer(&bytecode[0])),
C.size_t(len(bytecode)), C.size_t(len(bytecode)),
cname, cname,
0, // No results 0,
) )
if status != 0 { if status != 0 {
@ -163,7 +195,6 @@ func (s *State) LoadAndRunBytecodeWithResults(bytecode []byte, name string, nres
cname := C.CString(name) cname := C.CString(name)
defer C.free(unsafe.Pointer(cname)) defer C.free(unsafe.Pointer(cname))
// Use combined load and run function
status := C.load_and_run_bytecode( status := C.load_and_run_bytecode(
s.L, s.L,
(*C.uchar)(unsafe.Pointer(&bytecode[0])), (*C.uchar)(unsafe.Pointer(&bytecode[0])),

View File

@ -9,7 +9,7 @@ extern int goFunctionWrapper(lua_State* L);
// Helper function to access upvalues // Helper function to access upvalues
static int get_upvalue_index(int i) { static int get_upvalue_index(int i) {
return lua_upvalueindex(i); return lua_upvalueindex(i);
} }
*/ */
import "C" import "C"
@ -34,11 +34,20 @@ var (
}{ }{
funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize), funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize),
} }
// statePool reuses State structs to avoid allocations
statePool = sync.Pool{
New: func() any {
return &State{}
},
}
) )
//export goFunctionWrapper //export goFunctionWrapper
func goFunctionWrapper(L *C.lua_State) C.int { func goFunctionWrapper(L *C.lua_State) C.int {
state := &State{L: L} state := statePool.Get().(*State)
state.L = L
defer statePool.Put(state)
ptr := C.lua_touserdata(L, C.get_upvalue_index(1)) ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
if ptr == nil { if ptr == nil {
@ -51,8 +60,6 @@ func goFunctionWrapper(L *C.lua_State) C.int {
functionRegistry.RUnlock() functionRegistry.RUnlock()
if !ok { if !ok {
// Debug logging
fmt.Printf("Function not found for pointer %p\n", ptr)
state.PushString("error: function not found in registry") state.PushString("error: function not found in registry")
return -1 return -1
} }

View File

@ -46,14 +46,6 @@ func (e *LuaError) Error() string {
return result return result
} }
// Stack management constants from lua.h
const (
LUA_MINSTACK = 20 // Minimum Lua stack size
LUA_MAXSTACK = 1000000 // Maximum Lua stack size
LUA_REGISTRYINDEX = -10000 // Pseudo-index for the Lua registry
LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table
)
// GetStackTrace returns the current Lua stack trace // GetStackTrace returns the current Lua stack trace
func (s *State) GetStackTrace() string { func (s *State) GetStackTrace() string {
s.GetGlobal("debug") s.GetGlobal("debug")
@ -64,13 +56,13 @@ func (s *State) GetStackTrace() string {
s.GetField(-1, "traceback") s.GetField(-1, "traceback")
if !s.IsFunction(-1) { if !s.IsFunction(-1) {
s.Pop(2) // Remove debug table and non-function s.Pop(2)
return "debug.traceback not available" return "debug.traceback not available"
} }
s.Call(0, 1) s.Call(0, 1)
trace := s.ToString(-1) trace := s.ToString(-1)
s.Pop(1) // Remove the trace s.Pop(1)
return trace return trace
} }
@ -97,13 +89,11 @@ func (s *State) GetErrorInfo(context string) *LuaError {
if secondColonPos := strings.Index(afterColon, ":"); secondColonPos > 0 { if secondColonPos := strings.Index(afterColon, ":"); secondColonPos > 0 {
file = beforeColon file = beforeColon
if n, err := fmt.Sscanf(afterColon[:secondColonPos], "%d", &line); n == 1 && err == nil { if n, err := fmt.Sscanf(afterColon[:secondColonPos], "%d", &line); n == 1 && err == nil {
// Strip the file:line part from message for cleaner display
message = strings.TrimSpace(afterColon[secondColonPos+1:]) message = strings.TrimSpace(afterColon[secondColonPos+1:])
} }
} }
} }
// Get stack trace
stackTrace := s.GetStackTrace() stackTrace := s.GetStackTrace()
return &LuaError{ return &LuaError{
@ -121,3 +111,9 @@ func (s *State) CreateLuaError(code int, context string) *LuaError {
err.Code = code err.Code = code
return err return err
} }
// PushError pushes an error string and returns -1
func (s *State) PushError(format string, args ...any) int {
s.PushString(fmt.Sprintf(format, args...))
return -1
}

164
table.go
View File

@ -1,164 +0,0 @@
package luajit
/*
#include <lua.h>
#include <lualib.h>
#include <lauxlib.h>
#include <stdlib.h>
// Simple direct length check
size_t get_table_length(lua_State *L, int index) {
return lua_objlen(L, index);
}
*/
import "C"
import (
"fmt"
"strconv"
)
// GetTableLength returns the length of a table at the given index
func (s *State) GetTableLength(index int) int {
return int(C.get_table_length(s.L, C.int(index)))
}
// PushTable pushes a Go map onto the Lua stack as a table
func (s *State) PushTable(table map[string]any) error {
// Fast path for array tables
if arr, ok := table[""]; ok {
if floatArr, ok := arr.([]float64); ok {
s.CreateTable(len(floatArr), 0)
for i, v := range floatArr {
s.PushNumber(float64(i + 1))
s.PushNumber(v)
s.SetTable(-3)
}
return nil
} else if anyArr, ok := arr.([]any); ok {
s.CreateTable(len(anyArr), 0)
for i, v := range anyArr {
s.PushNumber(float64(i + 1))
if err := s.PushValue(v); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
}
// Regular table case - optimize capacity hint
s.CreateTable(0, len(table))
// Add each key-value pair directly
for k, v := range table {
s.PushString(k)
if err := s.PushValue(v); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
// ToTable converts a Lua table at the given index to a Go map
func (s *State) ToTable(index int) (map[string]any, error) {
absIdx := s.absIndex(index)
if !s.IsTable(absIdx) {
return nil, fmt.Errorf("value at index %d is not a table", index)
}
// Try to detect array-like tables first
length := s.GetTableLength(absIdx)
if length > 0 {
// Fast path for common array case
allNumbers := true
// Sample first few values to check if it's likely an array of numbers
for i := 1; i <= min(length, 5); i++ {
s.PushNumber(float64(i))
s.GetTable(absIdx)
if !s.IsNumber(-1) {
allNumbers = false
s.Pop(1)
break
}
s.Pop(1)
}
if allNumbers {
// Efficiently extract array values
array := make([]float64, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(absIdx)
array[i-1] = s.ToNumber(-1)
s.Pop(1)
}
// Return array as a special table with empty key
result := make(map[string]any, 1)
result[""] = array
return result, nil
}
}
// Handle regular table with pre-allocated capacity
table := make(map[string]any, max(length, 8))
// Iterate through all key-value pairs
s.PushNil() // Start iteration with nil key
for s.Next(absIdx) {
// Stack now has key at -2 and value at -1
// Convert key to string
var key string
keyType := s.GetType(-2)
switch keyType {
case TypeString:
key = s.ToString(-2)
case TypeNumber:
key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64)
default:
// Skip non-string/non-number keys
s.Pop(1) // Pop value, leave key for next iteration
continue
}
// Convert and store the value
value, err := s.ToValue(-1)
if err != nil {
s.Pop(2) // Pop both key and value
return nil, err
}
// Unwrap nested array tables
if m, ok := value.(map[string]any); ok {
if arr, ok := m[""]; ok {
value = arr
}
}
table[key] = value
s.Pop(1) // Pop value, leave key for next iteration
}
return table, nil
}
// Helper functions for min/max operations
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
}

View File

@ -19,7 +19,6 @@ func TestGetTableLength(t *testing.T) {
t.Fatalf("Failed to create test table: %v", err) t.Fatalf("Failed to create test table: %v", err)
} }
// Get the table
state.GetGlobal("t") state.GetGlobal("t")
length := state.GetTableLength(-1) length := state.GetTableLength(-1)
if length != 5 { if length != 5 {
@ -32,7 +31,6 @@ func TestGetTableLength(t *testing.T) {
t.Fatalf("Failed to create test table: %v", err) t.Fatalf("Failed to create test table: %v", err)
} }
// Get the table
state.GetGlobal("t2") state.GetGlobal("t2")
length = state.GetTableLength(-1) length = state.GetTableLength(-1)
if length != 0 { if length != 0 {
@ -41,206 +39,234 @@ func TestGetTableLength(t *testing.T) {
state.Pop(1) state.Pop(1)
} }
func TestPushTable(t *testing.T) { func TestPushTypedArrays(t *testing.T) {
state := luajit.New() state := luajit.New()
if state == nil { if state == nil {
t.Fatal("Failed to create Lua state") t.Fatal("Failed to create Lua state")
} }
defer state.Close() defer state.Close()
// Create a test table // Test []int
testTable := map[string]any{ intArr := []int{1, 2, 3, 4, 5}
"int": 42, if err := state.PushValue(intArr); err != nil {
"float": 3.14, t.Fatalf("Failed to push int array: %v", err)
"string": "hello",
"boolean": true,
"nil": nil,
} }
state.SetGlobal("int_arr")
// Push the table onto the stack // Test []string
if err := state.PushTable(testTable); err != nil { stringArr := []string{"hello", "world", "test"}
t.Fatalf("Failed to push table: %v", err) if err := state.PushValue(stringArr); err != nil {
t.Fatalf("Failed to push string array: %v", err)
} }
state.SetGlobal("string_arr")
// Execute Lua code to test the table contents // Test []bool
boolArr := []bool{true, false, true}
if err := state.PushValue(boolArr); err != nil {
t.Fatalf("Failed to push bool array: %v", err)
}
state.SetGlobal("bool_arr")
// Test []float64
floatArr := []float64{1.1, 2.2, 3.3}
if err := state.PushValue(floatArr); err != nil {
t.Fatalf("Failed to push float array: %v", err)
}
state.SetGlobal("float_arr")
// Verify arrays in Lua
if err := state.DoString(` if err := state.DoString(`
function validate_table(t) assert(int_arr[1] == 1 and int_arr[5] == 5)
return t.int == 42 and assert(string_arr[1] == "hello" and string_arr[3] == "test")
math.abs(t.float - 3.14) < 0.0001 and assert(bool_arr[1] == true and bool_arr[2] == false)
t.string == "hello" and assert(math.abs(float_arr[1] - 1.1) < 0.0001)
t.boolean == true and
t["nil"] == nil
end
`); err != nil { `); err != nil {
t.Fatalf("Failed to create validation function: %v", err) t.Fatalf("Array verification failed: %v", err)
} }
// Call the validation function
state.GetGlobal("validate_table")
state.PushCopy(-2) // Copy the table to the top
if err := state.Call(1, 1); err != nil {
t.Fatalf("Failed to call validation function: %v", err)
}
if !state.ToBoolean(-1) {
t.Fatalf("Table validation failed")
}
state.Pop(2) // Pop the result and the table
} }
func TestToTable(t *testing.T) { func TestPushTypedMaps(t *testing.T) {
state := luajit.New() state := luajit.New()
if state == nil { if state == nil {
t.Fatal("Failed to create Lua state") t.Fatal("Failed to create Lua state")
} }
defer state.Close() defer state.Close()
// Test regular table conversion // Test map[string]string
if err := state.DoString(`t = {a=1, b=2.5, c="test", d=true, e=nil}`); err != nil { stringMap := map[string]string{"name": "John", "city": "NYC"}
t.Fatalf("Failed to create test table: %v", err) if err := state.PushValue(stringMap); err != nil {
t.Fatalf("Failed to push string map: %v", err)
} }
state.SetGlobal("string_map")
state.GetGlobal("t") // Test map[string]int
table, err := state.ToTable(-1) intMap := map[string]int{"age": 25, "score": 100}
if err != nil { if err := state.PushValue(intMap); err != nil {
t.Fatalf("Failed to convert table: %v", err) t.Fatalf("Failed to push int map: %v", err)
} }
state.Pop(1) state.SetGlobal("int_map")
expected := map[string]any{ // Test map[int]any
"a": float64(1), intKeyMap := map[int]any{1: "first", 2: 42, 3: true}
"b": 2.5, if err := state.PushValue(intKeyMap); err != nil {
"c": "test", t.Fatalf("Failed to push int key map: %v", err)
"d": true,
} }
state.SetGlobal("int_key_map")
for k, v := range expected { // Verify maps in Lua
if table[k] != v { if err := state.DoString(`
t.Fatalf("Expected table[%s] = %v, got %v", k, v, table[k]) assert(string_map.name == "John" and string_map.city == "NYC")
} assert(int_map.age == 25 and int_map.score == 100)
} assert(int_key_map[1] == "first" and int_key_map[2] == 42 and int_key_map[3] == true)
`); err != nil {
// Test array-like table conversion t.Fatalf("Map verification failed: %v", err)
if err := state.DoString(`arr = {10, 20, 30, 40, 50}`); err != nil {
t.Fatalf("Failed to create test array: %v", err)
}
state.GetGlobal("arr")
table, err = state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert array table: %v", err)
}
state.Pop(1)
// For array tables, we should get a special format with an empty key
// and the array as the value
expectedArray := []float64{10, 20, 30, 40, 50}
if arr, ok := table[""].([]float64); !ok {
t.Fatalf("Expected array table to be converted with empty key, got: %v", table)
} else if !reflect.DeepEqual(arr, expectedArray) {
t.Fatalf("Expected %v, got %v", expectedArray, arr)
}
// Test invalid table index
_, err = state.ToTable(100)
if err == nil {
t.Fatalf("Expected error for invalid table index, got nil")
}
// Test non-table value
state.PushNumber(123)
_, err = state.ToTable(-1)
if err == nil {
t.Fatalf("Expected error for non-table value, got nil")
}
state.Pop(1)
// Test mixed array with non-numeric values
if err := state.DoString(`mixed = {10, 20, key="value", 30}`); err != nil {
t.Fatalf("Failed to create mixed table: %v", err)
}
state.GetGlobal("mixed")
table, err = state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert mixed table: %v", err)
}
// Let's print the table for debugging
t.Logf("Table contents: %v", table)
state.Pop(1)
// Check if the array part is detected and stored with empty key
if arr, ok := table[""]; !ok {
t.Fatalf("Expected array-like part to be detected, got: %v", table)
} else {
// Verify the array contains the expected values
expectedArr := []float64{10, 20, 30}
actualArr := arr.([]float64)
if len(actualArr) != len(expectedArr) {
t.Fatalf("Expected array length %d, got %d", len(expectedArr), len(actualArr))
}
for i, v := range expectedArr {
if actualArr[i] != v {
t.Fatalf("Expected array[%d] = %v, got %v", i, v, actualArr[i])
}
}
}
// Based on the implementation, we need to create a separate test for string keys
if err := state.DoString(`dict = {foo="bar", baz="qux"}`); err != nil {
t.Fatalf("Failed to create dict table: %v", err)
}
state.GetGlobal("dict")
dictTable, err := state.ToTable(-1)
if err != nil {
t.Fatalf("Failed to convert dict table: %v", err)
}
state.Pop(1)
// Check the string keys
if val, ok := dictTable["foo"]; !ok || val != "bar" {
t.Fatalf("Expected dictTable[\"foo\"] = \"bar\", got: %v", val)
}
if val, ok := dictTable["baz"]; !ok || val != "qux" {
t.Fatalf("Expected dictTable[\"baz\"] = \"qux\", got: %v", val)
} }
} }
func TestTablePooling(t *testing.T) { func TestToTableTypedArrays(t *testing.T) {
state := luajit.New() state := luajit.New()
if state == nil { if state == nil {
t.Fatal("Failed to create Lua state") t.Fatal("Failed to create Lua state")
} }
defer state.Close() defer state.Close()
// Create a Lua table and push it onto the stack // Test integer array detection
if err := state.DoString(`t = {a=1, b=2}`); err != nil { if err := state.DoString("int_arr = {10, 20, 30}"); err != nil {
t.Fatalf("Failed to create test table: %v", err) t.Fatalf("Failed to create int array: %v", err)
} }
state.GetGlobal("int_arr")
state.GetGlobal("t") result, err := state.ToValue(-1)
// First conversion - should get a table from the pool
table1, err := state.ToTable(-1)
if err != nil { if err != nil {
t.Fatalf("Failed to convert table (1): %v", err) t.Fatalf("Failed to convert int array: %v", err)
} }
intArr, ok := result.([]int)
if !ok {
t.Fatalf("Expected []int, got %T", result)
}
expected := []int{10, 20, 30}
if !reflect.DeepEqual(intArr, expected) {
t.Fatalf("Expected %v, got %v", expected, intArr)
}
state.Pop(1)
// Second conversion - should get another table from the pool // Test float array detection
table2, err := state.ToTable(-1) if err := state.DoString("float_arr = {1.5, 2.7, 3.9}"); err != nil {
t.Fatalf("Failed to create float array: %v", err)
}
state.GetGlobal("float_arr")
result, err = state.ToValue(-1)
if err != nil { if err != nil {
t.Fatalf("Failed to convert table (2): %v", err) t.Fatalf("Failed to convert float array: %v", err)
} }
floatArr, ok := result.([]float64)
// Both tables should have the same content if !ok {
if !reflect.DeepEqual(table1, table2) { t.Fatalf("Expected []float64, got %T", result)
t.Fatalf("Tables should have the same content: %v vs %v", table1, table2)
} }
expectedFloat := []float64{1.5, 2.7, 3.9}
if !reflect.DeepEqual(floatArr, expectedFloat) {
t.Fatalf("Expected %v, got %v", expectedFloat, floatArr)
}
state.Pop(1)
// Clean up // Test string array detection
if err := state.DoString(`string_arr = {"hello", "world"}`); err != nil {
t.Fatalf("Failed to create string array: %v", err)
}
state.GetGlobal("string_arr")
result, err = state.ToValue(-1)
if err != nil {
t.Fatalf("Failed to convert string array: %v", err)
}
stringArr, ok := result.([]string)
if !ok {
t.Fatalf("Expected []string, got %T", result)
}
expectedString := []string{"hello", "world"}
if !reflect.DeepEqual(stringArr, expectedString) {
t.Fatalf("Expected %v, got %v", expectedString, stringArr)
}
state.Pop(1)
// Test bool array detection
if err := state.DoString("bool_arr = {true, false, true}"); err != nil {
t.Fatalf("Failed to create bool array: %v", err)
}
state.GetGlobal("bool_arr")
result, err = state.ToValue(-1)
if err != nil {
t.Fatalf("Failed to convert bool array: %v", err)
}
boolArr, ok := result.([]bool)
if !ok {
t.Fatalf("Expected []bool, got %T", result)
}
expectedBool := []bool{true, false, true}
if !reflect.DeepEqual(boolArr, expectedBool) {
t.Fatalf("Expected %v, got %v", expectedBool, boolArr)
}
state.Pop(1)
}
func TestToTableTypedMaps(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test string map detection
if err := state.DoString(`string_map = {name="John", city="NYC"}`); err != nil {
t.Fatalf("Failed to create string map: %v", err)
}
state.GetGlobal("string_map")
result, err := state.ToValue(-1)
if err != nil {
t.Fatalf("Failed to convert string map: %v", err)
}
stringMap, ok := result.(map[string]string)
if !ok {
t.Fatalf("Expected map[string]string, got %T", result)
}
expectedStringMap := map[string]string{"name": "John", "city": "NYC"}
if !reflect.DeepEqual(stringMap, expectedStringMap) {
t.Fatalf("Expected %v, got %v", expectedStringMap, stringMap)
}
state.Pop(1)
// Test int map detection
if err := state.DoString("int_map = {age=25, score=100}"); err != nil {
t.Fatalf("Failed to create int map: %v", err)
}
state.GetGlobal("int_map")
result, err = state.ToValue(-1)
if err != nil {
t.Fatalf("Failed to convert int map: %v", err)
}
intMap, ok := result.(map[string]int)
if !ok {
t.Fatalf("Expected map[string]int, got %T", result)
}
expectedIntMap := map[string]int{"age": 25, "score": 100}
if !reflect.DeepEqual(intMap, expectedIntMap) {
t.Fatalf("Expected %v, got %v", expectedIntMap, intMap)
}
state.Pop(1)
// Test mixed map (should fallback to map[string]any)
if err := state.DoString(`mixed_map = {name="John", age=25, active=true}`); err != nil {
t.Fatalf("Failed to create mixed map: %v", err)
}
state.GetGlobal("mixed_map")
result, err = state.ToValue(-1)
if err != nil {
t.Fatalf("Failed to convert mixed map: %v", err)
}
mixedMap, ok := result.(map[string]any)
if !ok {
t.Fatalf("Expected map[string]any, got %T", result)
}
if mixedMap["name"] != "John" || mixedMap["age"] != 25 || mixedMap["active"] != true {
t.Fatalf("Mixed map conversion failed: %v", mixedMap)
}
state.Pop(1) state.Pop(1)
} }

View File

@ -9,70 +9,50 @@ import (
) )
func TestStateLifecycle(t *testing.T) { func TestStateLifecycle(t *testing.T) {
// Test creation
state := luajit.New() state := luajit.New()
if state == nil { if state == nil {
t.Fatal("Failed to create Lua state") t.Fatal("Failed to create Lua state")
} }
// Test close
state.Close()
// Test close is idempotent (doesn't crash)
state.Close() state.Close()
state.Close() // Test idempotent close
} }
func TestStackManipulation(t *testing.T) { func TestStackOperations(t *testing.T) {
state := luajit.New() state := luajit.New()
if state == nil { if state == nil {
t.Fatal("Failed to create Lua state") t.Fatal("Failed to create Lua state")
} }
defer state.Close() defer state.Close()
// Test initial stack size // Test stack manipulation
if state.GetTop() != 0 { if state.GetTop() != 0 {
t.Fatalf("Expected empty stack, got %d elements", state.GetTop()) t.Fatalf("Expected empty stack, got %d", state.GetTop())
} }
// Push values
state.PushNil() state.PushNil()
state.PushBoolean(true) state.PushBoolean(true)
state.PushNumber(42) state.PushNumber(42)
state.PushString("hello") state.PushString("hello")
// Check stack size
if state.GetTop() != 4 { if state.GetTop() != 4 {
t.Fatalf("Expected 4 elements, got %d", state.GetTop()) t.Fatalf("Expected 4 elements, got %d", state.GetTop())
} }
// Test SetTop
state.SetTop(2) state.SetTop(2)
if state.GetTop() != 2 { if state.GetTop() != 2 {
t.Fatalf("Expected 2 elements after SetTop, got %d", state.GetTop()) t.Fatalf("Expected 2 elements after SetTop, got %d", state.GetTop())
} }
// Test PushCopy state.PushCopy(2)
state.PushCopy(2) // Copy the boolean
if !state.IsBoolean(-1) { if !state.IsBoolean(-1) {
t.Fatalf("Expected boolean at top of stack") t.Fatal("Expected boolean at top")
} }
// Test Pop
state.Pop(1) state.Pop(1)
if state.GetTop() != 2 {
t.Fatalf("Expected 2 elements after Pop, got %d", state.GetTop())
}
// Test Remove
state.PushNumber(99) state.PushNumber(99)
state.Remove(1) // Remove the first element (nil) state.Remove(1)
if state.GetTop() != 2 {
t.Fatalf("Expected 2 elements after Remove, got %d", state.GetTop())
}
// Verify first element is now boolean
if !state.IsBoolean(1) { if !state.IsBoolean(1) {
t.Fatalf("Expected boolean at index 1 after Remove") t.Fatal("Expected boolean at index 1 after Remove")
} }
} }
@ -83,52 +63,33 @@ func TestTypeChecking(t *testing.T) {
} }
defer state.Close() defer state.Close()
// Push values of different types values := []struct {
state.PushNil() push func()
state.PushBoolean(true) luaType luajit.LuaType
state.PushNumber(42) checkFn func(int) bool
state.PushString("hello") }{
state.NewTable() {state.PushNil, luajit.TypeNil, state.IsNil},
{func() { state.PushBoolean(true) }, luajit.TypeBoolean, state.IsBoolean},
// Check types with GetType {func() { state.PushNumber(42) }, luajit.TypeNumber, state.IsNumber},
if state.GetType(1) != luajit.TypeNil { {func() { state.PushString("test") }, luajit.TypeString, state.IsString},
t.Fatalf("Expected nil type at index 1, got %s", state.GetType(1)) {state.NewTable, luajit.TypeTable, state.IsTable},
}
if state.GetType(2) != luajit.TypeBoolean {
t.Fatalf("Expected boolean type at index 2, got %s", state.GetType(2))
}
if state.GetType(3) != luajit.TypeNumber {
t.Fatalf("Expected number type at index 3, got %s", state.GetType(3))
}
if state.GetType(4) != luajit.TypeString {
t.Fatalf("Expected string type at index 4, got %s", state.GetType(4))
}
if state.GetType(5) != luajit.TypeTable {
t.Fatalf("Expected table type at index 5, got %s", state.GetType(5))
} }
// Test individual type checking functions for i, v := range values {
if !state.IsNil(1) { v.push()
t.Fatalf("IsNil failed for nil value") idx := i + 1
} if state.GetType(idx) != v.luaType {
if !state.IsBoolean(2) { t.Fatalf("Type mismatch at %d: expected %s, got %s", idx, v.luaType, state.GetType(idx))
t.Fatalf("IsBoolean failed for boolean value") }
} if !v.checkFn(idx) {
if !state.IsNumber(3) { t.Fatalf("Type check failed at %d", idx)
t.Fatalf("IsNumber failed for number value") }
}
if !state.IsString(4) {
t.Fatalf("IsString failed for string value")
}
if !state.IsTable(5) {
t.Fatalf("IsTable failed for table value")
} }
// Function test
state.DoString("function test() return true end") state.DoString("function test() return true end")
state.GetGlobal("test") state.GetGlobal("test")
if !state.IsFunction(-1) { if !state.IsFunction(-1) {
t.Fatalf("IsFunction failed for function value") t.Fatal("IsFunction failed")
} }
} }
@ -139,20 +100,18 @@ func TestValueConversion(t *testing.T) {
} }
defer state.Close() defer state.Close()
// Push values
state.PushBoolean(true) state.PushBoolean(true)
state.PushNumber(42.5) state.PushNumber(42.5)
state.PushString("hello") state.PushString("hello")
// Test conversion
if !state.ToBoolean(1) { if !state.ToBoolean(1) {
t.Fatalf("ToBoolean failed") t.Fatal("ToBoolean failed")
} }
if state.ToNumber(2) != 42.5 { if state.ToNumber(2) != 42.5 {
t.Fatalf("ToNumber failed, expected 42.5, got %f", state.ToNumber(2)) t.Fatalf("ToNumber failed: expected 42.5, got %f", state.ToNumber(2))
} }
if state.ToString(3) != "hello" { if state.ToString(3) != "hello" {
t.Fatalf("ToString failed, expected 'hello', got '%s'", state.ToString(3)) t.Fatalf("ToString failed: expected 'hello', got '%s'", state.ToString(3))
} }
} }
@ -163,46 +122,34 @@ func TestTableOperations(t *testing.T) {
} }
defer state.Close() defer state.Close()
// Test CreateTable
state.CreateTable(0, 3) state.CreateTable(0, 3)
// Add fields using SetField // Set fields
state.PushNumber(42) state.PushNumber(42)
state.SetField(-2, "answer") state.SetField(-2, "answer")
state.PushString("hello") state.PushString("hello")
state.SetField(-2, "greeting") state.SetField(-2, "greeting")
state.PushBoolean(true) state.PushBoolean(true)
state.SetField(-2, "flag") state.SetField(-2, "flag")
// Test GetField // Get fields
state.GetField(-1, "answer") state.GetField(-1, "answer")
if state.ToNumber(-1) != 42 { if state.ToNumber(-1) != 42 {
t.Fatalf("GetField for 'answer' failed") t.Fatal("GetField failed for 'answer'")
} }
state.Pop(1) state.Pop(1)
state.GetField(-1, "greeting") // Test iteration
if state.ToString(-1) != "hello" { state.PushNil()
t.Fatalf("GetField for 'greeting' failed")
}
state.Pop(1)
// Test Next for iteration
state.PushNil() // Start iteration
count := 0 count := 0
for state.Next(-2) { for state.Next(-2) {
count++ count++
state.Pop(1) // Pop value, leave key for next iteration state.Pop(1)
} }
if count != 3 { if count != 3 {
t.Fatalf("Expected 3 table entries, found %d", count) t.Fatalf("Expected 3 entries, found %d", count)
} }
state.Pop(1)
// Clean up
state.Pop(1) // Pop the table
} }
func TestGlobalOperations(t *testing.T) { func TestGlobalOperations(t *testing.T) {
@ -212,21 +159,18 @@ func TestGlobalOperations(t *testing.T) {
} }
defer state.Close() defer state.Close()
// Set a global value
state.PushNumber(42) state.PushNumber(42)
state.SetGlobal("answer") state.SetGlobal("answer")
// Get the global value
state.GetGlobal("answer") state.GetGlobal("answer")
if state.ToNumber(-1) != 42 { if state.ToNumber(-1) != 42 {
t.Fatalf("GetGlobal failed, expected 42, got %f", state.ToNumber(-1)) t.Fatalf("GetGlobal failed: expected 42, got %f", state.ToNumber(-1))
} }
state.Pop(1) state.Pop(1)
// Test non-existent global (should be nil)
state.GetGlobal("nonexistent") state.GetGlobal("nonexistent")
if !state.IsNil(-1) { if !state.IsNil(-1) {
t.Fatalf("Expected nil for non-existent global") t.Fatal("Expected nil for non-existent global")
} }
state.Pop(1) state.Pop(1)
} }
@ -238,18 +182,15 @@ func TestCodeExecution(t *testing.T) {
} }
defer state.Close() defer state.Close()
// Test LoadString // Test LoadString and Call
if err := state.LoadString("return 42"); err != nil { if err := state.LoadString("return 42"); err != nil {
t.Fatalf("LoadString failed: %v", err) t.Fatalf("LoadString failed: %v", err)
} }
// Test Call
if err := state.Call(0, 1); err != nil { if err := state.Call(0, 1); err != nil {
t.Fatalf("Call failed: %v", err) t.Fatalf("Call failed: %v", err)
} }
if state.ToNumber(-1) != 42 { if state.ToNumber(-1) != 42 {
t.Fatalf("Call result incorrect, expected 42, got %f", state.ToNumber(-1)) t.Fatalf("Call result incorrect: expected 42, got %f", state.ToNumber(-1))
} }
state.Pop(1) state.Pop(1)
@ -257,10 +198,9 @@ func TestCodeExecution(t *testing.T) {
if err := state.DoString("answer = 42 + 1"); err != nil { if err := state.DoString("answer = 42 + 1"); err != nil {
t.Fatalf("DoString failed: %v", err) t.Fatalf("DoString failed: %v", err)
} }
state.GetGlobal("answer") state.GetGlobal("answer")
if state.ToNumber(-1) != 43 { if state.ToNumber(-1) != 43 {
t.Fatalf("DoString execution incorrect, expected 43, got %f", state.ToNumber(-1)) t.Fatalf("DoString result incorrect: expected 43, got %f", state.ToNumber(-1))
} }
state.Pop(1) state.Pop(1)
@ -269,13 +209,11 @@ func TestCodeExecution(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Execute failed: %v", err) t.Fatalf("Execute failed: %v", err)
} }
if nresults != 3 { if nresults != 3 {
t.Fatalf("Execute returned %d results, expected 3", nresults) t.Fatalf("Execute returned %d results, expected 3", nresults)
} }
if state.ToNumber(-3) != 5 || state.ToNumber(-2) != 10 || state.ToNumber(-1) != 15 { if state.ToNumber(-3) != 5 || state.ToNumber(-2) != 10 || state.ToNumber(-1) != 15 {
t.Fatalf("Execute results incorrect") t.Fatal("Execute results incorrect")
} }
state.Pop(3) state.Pop(3)
@ -284,26 +222,24 @@ func TestCodeExecution(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("ExecuteWithResult failed: %v", err) t.Fatalf("ExecuteWithResult failed: %v", err)
} }
if result != "hello" { if result != "hello" {
t.Fatalf("ExecuteWithResult returned %v, expected 'hello'", result) t.Fatalf("ExecuteWithResult returned %v, expected 'hello'", result)
} }
// Test error handling // Test error handling
err = state.DoString("this is not valid lua code") if err := state.DoString("invalid lua code"); err == nil {
if err == nil { t.Fatal("Expected error for invalid code")
t.Fatalf("Expected error for invalid code, got nil")
} }
} }
func TestDoFile(t *testing.T) { func TestFileOperations(t *testing.T) {
state := luajit.New() state := luajit.New()
if state == nil { if state == nil {
t.Fatal("Failed to create Lua state") t.Fatal("Failed to create Lua state")
} }
defer state.Close() defer state.Close()
// Create a temporary Lua file // Create temp file
content := []byte("answer = 42") content := []byte("answer = 42")
tmpfile, err := os.CreateTemp("", "test-*.lua") tmpfile, err := os.CreateTemp("", "test-*.lua")
if err != nil { if err != nil {
@ -312,40 +248,17 @@ func TestDoFile(t *testing.T) {
defer os.Remove(tmpfile.Name()) defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write(content); err != nil { if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err) t.Fatalf("Failed to write temp file: %v", err)
}
if err := tmpfile.Close(); err != nil {
t.Fatalf("Failed to close temp file: %v", err)
}
// Test LoadFile and DoFile
if err := state.LoadFile(tmpfile.Name()); err != nil {
t.Fatalf("LoadFile failed: %v", err)
}
if err := state.Call(0, 0); err != nil {
t.Fatalf("Call failed after LoadFile: %v", err)
}
state.GetGlobal("answer")
if state.ToNumber(-1) != 42 {
t.Fatalf("Incorrect result after LoadFile, expected 42, got %f", state.ToNumber(-1))
}
state.Pop(1)
// Reset global
if err := state.DoString("answer = nil"); err != nil {
t.Fatalf("Failed to reset answer: %v", err)
} }
tmpfile.Close()
// Test DoFile // Test DoFile
if err := state.DoFile(tmpfile.Name()); err != nil { if err := state.DoFile(tmpfile.Name()); err != nil {
t.Fatalf("DoFile failed: %v", err) t.Fatalf("DoFile failed: %v", err)
} }
state.GetGlobal("answer") state.GetGlobal("answer")
if state.ToNumber(-1) != 42 { if state.ToNumber(-1) != 42 {
t.Fatalf("Incorrect result after DoFile, expected 42, got %f", state.ToNumber(-1)) t.Fatalf("DoFile result incorrect: expected 42, got %f", state.ToNumber(-1))
} }
state.Pop(1) state.Pop(1)
} }
@ -357,7 +270,6 @@ func TestPackagePath(t *testing.T) {
} }
defer state.Close() defer state.Close()
// Test SetPackagePath
testPath := "/test/path/?.lua" testPath := "/test/path/?.lua"
if err := state.SetPackagePath(testPath); err != nil { if err := state.SetPackagePath(testPath); err != nil {
t.Fatalf("SetPackagePath failed: %v", err) t.Fatalf("SetPackagePath failed: %v", err)
@ -367,12 +279,10 @@ func TestPackagePath(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to get package.path: %v", err) t.Fatalf("Failed to get package.path: %v", err)
} }
if result != testPath { if result != testPath {
t.Fatalf("Expected package.path to be '%s', got '%s'", testPath, result) t.Fatalf("SetPackagePath failed: expected '%s', got '%s'", testPath, result)
} }
// Test AddPackagePath
addPath := "/another/path/?.lua" addPath := "/another/path/?.lua"
if err := state.AddPackagePath(addPath); err != nil { if err := state.AddPackagePath(addPath); err != nil {
t.Fatalf("AddPackagePath failed: %v", err) t.Fatalf("AddPackagePath failed: %v", err)
@ -382,92 +292,134 @@ func TestPackagePath(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to get package.path: %v", err) t.Fatalf("Failed to get package.path: %v", err)
} }
expected := testPath + ";" + addPath expected := testPath + ";" + addPath
if result != expected { if result != expected {
t.Fatalf("Expected package.path to be '%s', got '%s'", expected, result) t.Fatalf("AddPackagePath failed: expected '%s', got '%s'", expected, result)
} }
} }
func TestPushValueAndToValue(t *testing.T) { func TestEnhancedTypes(t *testing.T) {
state := luajit.New() state := luajit.New()
if state == nil { if state == nil {
t.Fatal("Failed to create Lua state") t.Fatal("Failed to create Lua state")
} }
defer state.Close() defer state.Close()
// Test typed arrays
testCases := []struct { testCases := []struct {
value any input any
expected any
}{ }{
{nil}, // Primitive types
{true}, {nil, nil},
{false}, {true, true},
{42}, {42, 42}, // Should preserve as int
{42.5}, {42.5, 42.5}, // Should be float64
{"hello"}, {"hello", "hello"},
{[]float64{1, 2, 3, 4, 5}},
{[]any{1, "test", true}}, // Typed arrays
{map[string]any{"a": 1, "b": "test", "c": true}}, {[]int{1, 2, 3}, []int{1, 2, 3}},
{[]string{"a", "b"}, []string{"a", "b"}},
{[]bool{true, false}, []bool{true, false}},
{[]float64{1.1, 2.2}, []float64{1.1, 2.2}},
// Typed maps
{map[string]string{"name": "John"}, map[string]string{"name": "John"}},
{map[string]int{"age": 25}, map[string]int{"age": 25}},
{map[int]any{10: "first", 20: 42}, map[string]any{"10": "first", "20": 42}},
} }
for i, tc := range testCases { for i, tc := range testCases {
// Push value // Push and retrieve value
err := state.PushValue(tc.value) if err := state.PushValue(tc.input); err != nil {
if err != nil { t.Fatalf("Case %d: PushValue failed: %v", i, err)
t.Fatalf("PushValue failed for testCase %d: %v", i, err)
} }
// Check stack result, err := state.ToValue(-1)
if state.GetTop() != i+1 { if err != nil {
t.Fatalf("Stack size incorrect after push, expected %d, got %d", i+1, state.GetTop()) t.Fatalf("Case %d: ToValue failed: %v", i, err)
} }
if !reflect.DeepEqual(result, tc.expected) {
t.Fatalf("Case %d: expected %v (%T), got %v (%T)",
i, tc.expected, tc.expected, result, result)
}
state.Pop(1)
} }
// Test conversion back to Go // Test mixed array (should become []any)
for i := range testCases { state.DoString("mixed = {1, 'hello', true}")
index := len(testCases) - i state.GetGlobal("mixed")
value, err := state.ToValue(index) result, err := state.ToValue(-1)
if err != nil { if err != nil {
t.Fatalf("ToValue failed for index %d: %v", index, err) t.Fatalf("Mixed array conversion failed: %v", err)
}
// For tables, we need special handling due to how Go types are stored
switch expected := testCases[index-1].value.(type) {
case []float64:
// Arrays come back as map[string]any with empty key
if m, ok := value.(map[string]any); ok {
if arr, ok := m[""].([]float64); ok {
if !reflect.DeepEqual(arr, expected) {
t.Fatalf("Value mismatch for testCase %d: expected %v, got %v", index-1, expected, arr)
}
} else {
t.Fatalf("Invalid array conversion for testCase %d", index-1)
}
} else {
t.Fatalf("Expected map for array value in testCase %d, got %T", index-1, value)
}
case int:
if num, ok := value.(float64); ok {
if float64(expected) == num {
continue // Values match after type conversion
}
}
case []any:
// Skip detailed comparison for mixed arrays
case map[string]any:
// Skip detailed comparison for maps
default:
if !reflect.DeepEqual(value, testCases[index-1].value) {
t.Fatalf("Value mismatch for testCase %d: expected %v, got %v",
index-1, testCases[index-1].value, value)
}
}
} }
if _, ok := result.([]any); !ok {
t.Fatalf("Expected []any for mixed array, got %T", result)
}
state.Pop(1)
// Test mixed map (should become map[string]any)
state.DoString("mixedMap = {name='John', age=25, active=true}")
state.GetGlobal("mixedMap")
result, err = state.ToValue(-1)
if err != nil {
t.Fatalf("Mixed map conversion failed: %v", err)
}
if _, ok := result.(map[string]any); !ok {
t.Fatalf("Expected map[string]any for mixed map, got %T", result)
}
state.Pop(1)
}
func TestIntegerPreservation(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test that integers are preserved
state.DoString("num = 42")
state.GetGlobal("num")
result, err := state.ToValue(-1)
if err != nil {
t.Fatalf("Integer conversion failed: %v", err)
}
if val, ok := result.(int); !ok || val != 42 {
t.Fatalf("Expected int 42, got %T %v", result, result)
}
state.Pop(1)
// Test that floats remain floats
state.DoString("fnum = 42.5")
state.GetGlobal("fnum")
result, err = state.ToValue(-1)
if err != nil {
t.Fatalf("Float conversion failed: %v", err)
}
if val, ok := result.(float64); !ok || val != 42.5 {
t.Fatalf("Expected float64 42.5, got %T %v", result, result)
}
state.Pop(1)
}
func TestErrorHandling(t *testing.T) {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
// Test unsupported type // Test unsupported type
complex := complex(1, 2) type customStruct struct{ Field int }
err := state.PushValue(complex) if err := state.PushValue(customStruct{Field: 42}); err == nil {
t.Fatal("Expected error for unsupported type")
}
// Test invalid stack index
_, err := state.ToValue(100)
if err == nil { if err == nil {
t.Fatalf("Expected error for unsupported type") t.Fatal("Expected error for invalid index")
} }
} }

338
types.go
View File

@ -13,7 +13,6 @@ import (
type LuaType int type LuaType int
const ( const (
// These constants match lua.h's LUA_T* values
TypeNone LuaType = -1 TypeNone LuaType = -1
TypeNil LuaType = 0 TypeNil LuaType = 0
TypeBoolean LuaType = 1 TypeBoolean LuaType = 1
@ -26,7 +25,6 @@ const (
TypeThread LuaType = 8 TypeThread LuaType = 8
) )
// String returns the string representation of the Lua type
func (t LuaType) String() string { func (t LuaType) String() string {
switch t { switch t {
case TypeNone: case TypeNone:
@ -54,92 +52,309 @@ func (t LuaType) String() string {
} }
} }
// ConvertValue converts a value to the requested type with proper type conversion // ConvertValue converts a value to the requested type with comprehensive type conversion
func ConvertValue[T any](value any) (T, bool) { func ConvertValue[T any](value any) (T, bool) {
var zero T var zero T
// Handle nil case
if value == nil { if value == nil {
return zero, false return zero, false
} }
// Try direct type assertion first
if result, ok := value.(T); ok { if result, ok := value.(T); ok {
return result, true return result, true
} }
// Type-specific conversions
switch any(zero).(type) { switch any(zero).(type) {
case string: case string:
switch v := value.(type) { return convertToString[T](value)
case float64:
return any(fmt.Sprintf("%g", v)).(T), true
case int:
return any(strconv.Itoa(v)).(T), true
case bool:
if v {
return any("true").(T), true
}
return any("false").(T), true
}
case int: case int:
switch v := value.(type) { return convertToInt[T](value)
case float64:
return any(int(v)).(T), true
case string:
if i, err := strconv.Atoi(v); err == nil {
return any(i).(T), true
}
case bool:
if v {
return any(1).(T), true
}
return any(0).(T), true
}
case float64: case float64:
switch v := value.(type) { return convertToFloat[T](value)
case int:
return any(float64(v)).(T), true
case string:
if f, err := strconv.ParseFloat(v, 64); err == nil {
return any(f).(T), true
}
case bool:
if v {
return any(1.0).(T), true
}
return any(0.0).(T), true
}
case bool: case bool:
switch v := value.(type) { return convertToBool[T](value)
case string: case []int:
switch v { return convertToIntSlice[T](value)
case "true", "yes", "1": case []string:
return any(true).(T), true return convertToStringSlice[T](value)
case "false", "no", "0": case []bool:
return any(false).(T), true return convertToBoolSlice[T](value)
} case []float64:
case int: return convertToFloatSlice[T](value)
return any(v != 0).(T), true case []any:
case float64: return convertToAnySlice[T](value)
return any(v != 0).(T), true case map[string]string:
} return convertToStringMap[T](value)
case map[string]int:
return convertToIntMap[T](value)
case map[int]any:
return convertToIntKeyMap[T](value)
case map[string]any:
return convertToAnyMap[T](value)
} }
return zero, false return zero, false
} }
func convertToString[T any](value any) (T, bool) {
var zero T
switch v := value.(type) {
case float64:
if v == float64(int(v)) {
return any(strconv.Itoa(int(v))).(T), true
}
return any(fmt.Sprintf("%g", v)).(T), true
case int:
return any(strconv.Itoa(v)).(T), true
case bool:
return any(strconv.FormatBool(v)).(T), true
}
return zero, false
}
func convertToInt[T any](value any) (T, bool) {
var zero T
switch v := value.(type) {
case float64:
return any(int(v)).(T), true
case string:
if i, err := strconv.Atoi(v); err == nil {
return any(i).(T), true
}
case bool:
if v {
return any(1).(T), true
}
return any(0).(T), true
}
return zero, false
}
func convertToFloat[T any](value any) (T, bool) {
var zero T
switch v := value.(type) {
case int:
return any(float64(v)).(T), true
case string:
if f, err := strconv.ParseFloat(v, 64); err == nil {
return any(f).(T), true
}
case bool:
if v {
return any(1.0).(T), true
}
return any(0.0).(T), true
}
return zero, false
}
func convertToBool[T any](value any) (T, bool) {
var zero T
switch v := value.(type) {
case string:
switch v {
case "true", "yes", "1":
return any(true).(T), true
case "false", "no", "0":
return any(false).(T), true
}
case int:
return any(v != 0).(T), true
case float64:
return any(v != 0).(T), true
}
return zero, false
}
func convertToIntSlice[T any](value any) (T, bool) {
var zero T
switch v := value.(type) {
case []float64:
result := make([]int, len(v))
for i, f := range v {
result[i] = int(f)
}
return any(result).(T), true
case []any:
result := make([]int, 0, len(v))
for _, item := range v {
if i, ok := ConvertValue[int](item); ok {
result = append(result, i)
} else {
return zero, false
}
}
return any(result).(T), true
}
return zero, false
}
func convertToStringSlice[T any](value any) (T, bool) {
var zero T
if v, ok := value.([]any); ok {
result := make([]string, 0, len(v))
for _, item := range v {
if s, ok := ConvertValue[string](item); ok {
result = append(result, s)
} else {
return zero, false
}
}
return any(result).(T), true
}
return zero, false
}
func convertToBoolSlice[T any](value any) (T, bool) {
var zero T
if v, ok := value.([]any); ok {
result := make([]bool, 0, len(v))
for _, item := range v {
if b, ok := ConvertValue[bool](item); ok {
result = append(result, b)
} else {
return zero, false
}
}
return any(result).(T), true
}
return zero, false
}
func convertToFloatSlice[T any](value any) (T, bool) {
var zero T
switch v := value.(type) {
case []int:
result := make([]float64, len(v))
for i, n := range v {
result[i] = float64(n)
}
return any(result).(T), true
case []any:
result := make([]float64, 0, len(v))
for _, item := range v {
if f, ok := ConvertValue[float64](item); ok {
result = append(result, f)
} else {
return zero, false
}
}
return any(result).(T), true
}
return zero, false
}
func convertToAnySlice[T any](value any) (T, bool) {
var zero T
switch v := value.(type) {
case []int:
result := make([]any, len(v))
for i, n := range v {
result[i] = n
}
return any(result).(T), true
case []string:
result := make([]any, len(v))
for i, s := range v {
result[i] = s
}
return any(result).(T), true
case []bool:
result := make([]any, len(v))
for i, b := range v {
result[i] = b
}
return any(result).(T), true
case []float64:
result := make([]any, len(v))
for i, f := range v {
result[i] = f
}
return any(result).(T), true
}
return zero, false
}
func convertToStringMap[T any](value any) (T, bool) {
var zero T
if v, ok := value.(map[string]any); ok {
result := make(map[string]string, len(v))
for k, val := range v {
if s, ok := ConvertValue[string](val); ok {
result[k] = s
} else {
return zero, false
}
}
return any(result).(T), true
}
return zero, false
}
func convertToIntMap[T any](value any) (T, bool) {
var zero T
if v, ok := value.(map[string]any); ok {
result := make(map[string]int, len(v))
for k, val := range v {
if i, ok := ConvertValue[int](val); ok {
result[k] = i
} else {
return zero, false
}
}
return any(result).(T), true
}
return zero, false
}
func convertToIntKeyMap[T any](value any) (T, bool) {
var zero T
if v, ok := value.(map[string]any); ok {
result := make(map[int]any, len(v))
for k, val := range v {
if i, err := strconv.Atoi(k); err == nil {
result[i] = val
} else {
return zero, false
}
}
return any(result).(T), true
}
return zero, false
}
func convertToAnyMap[T any](value any) (T, bool) {
var zero T
switch v := value.(type) {
case map[string]string:
result := make(map[string]any, len(v))
for k, s := range v {
result[k] = s
}
return any(result).(T), true
case map[string]int:
result := make(map[string]any, len(v))
for k, i := range v {
result[k] = i
}
return any(result).(T), true
case map[int]any:
result := make(map[string]any, len(v))
for k, val := range v {
result[strconv.Itoa(k)] = val
}
return any(result).(T), true
}
return zero, false
}
// GetTypedValue gets a value from the state with type conversion // GetTypedValue gets a value from the state with type conversion
func GetTypedValue[T any](s *State, index int) (T, bool) { func GetTypedValue[T any](s *State, index int) (T, bool) {
var zero T
// Get the value as any type
value, err := s.ToValue(index) value, err := s.ToValue(index)
if err != nil { if err != nil {
var zero T
return zero, false return zero, false
} }
// Convert it to the requested type
return ConvertValue[T](value) return ConvertValue[T](value)
} }
@ -147,6 +362,5 @@ func GetTypedValue[T any](s *State, index int) (T, bool) {
func GetGlobalTyped[T any](s *State, name string) (T, bool) { func GetGlobalTyped[T any](s *State, name string) (T, bool) {
s.GetGlobal(name) s.GetGlobal(name)
defer s.Pop(1) defer s.Pop(1)
return GetTypedValue[T](s, -1) return GetTypedValue[T](s, -1)
} }

59
validation.go Normal file
View File

@ -0,0 +1,59 @@
package luajit
import "fmt"
// ArgSpec defines an argument specification for validation
type ArgSpec struct {
Name string
Type string
Required bool
Check func(*State, int) bool
}
// Common argument checkers
var (
CheckString = func(s *State, i int) bool { return s.IsString(i) }
CheckNumber = func(s *State, i int) bool { return s.IsNumber(i) }
CheckBool = func(s *State, i int) bool { return s.IsBoolean(i) }
CheckTable = func(s *State, i int) bool { return s.IsTable(i) }
CheckFunc = func(s *State, i int) bool { return s.IsFunction(i) }
CheckAny = func(s *State, i int) bool { return true }
)
// CheckArgs validates function arguments against specifications
func (s *State) CheckArgs(specs ...ArgSpec) error {
for i, spec := range specs {
argIdx := i + 1
if argIdx > s.GetTop() {
if spec.Required {
return fmt.Errorf("missing argument %d: %s", argIdx, spec.Name)
}
break
}
if s.IsNil(argIdx) && !spec.Required {
continue
}
if !spec.Check(s, argIdx) {
return fmt.Errorf("argument %d (%s) must be %s", argIdx, spec.Name, spec.Type)
}
}
return nil
}
// CheckMinArgs checks for minimum number of arguments
func (s *State) CheckMinArgs(min int) error {
if s.GetTop() < min {
return fmt.Errorf("expected at least %d arguments, got %d", min, s.GetTop())
}
return nil
}
// CheckExactArgs checks for exact number of arguments
func (s *State) CheckExactArgs(count int) error {
if s.GetTop() != count {
return fmt.Errorf("expected exactly %d arguments, got %d", count, s.GetTop())
}
return nil
}

View File

@ -11,7 +11,6 @@ package luajit
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
// Direct execution helpers to minimize CGO transitions
static int do_string(lua_State *L, const char *s) { static int do_string(lua_State *L, const char *s) {
int status = luaL_loadstring(L, s); int status = luaL_loadstring(L, s);
if (status == 0) { if (status == 0) {
@ -34,31 +33,65 @@ static int execute_with_results(lua_State *L, const char *code, int store_result
return lua_pcall(L, 0, store_results ? LUA_MULTRET : 0, 0); return lua_pcall(L, 0, store_results ? LUA_MULTRET : 0, 0);
} }
static int has_metatable(lua_State *L, int index) { static size_t get_table_length(lua_State *L, int index) {
return lua_getmetatable(L, index); return lua_objlen(L, index);
}
static int is_integer(lua_State *L, int index) {
if (!lua_isnumber(L, index)) return 0;
lua_Number n = lua_tonumber(L, index);
return n == (lua_Number)(lua_Integer)n;
}
static int sample_array_type(lua_State *L, int index, int count) {
int all_numbers = 1;
int all_integers = 1;
int all_strings = 1;
int all_bools = 1;
for (int i = 1; i <= count && i <= 5; i++) {
lua_pushnumber(L, i);
lua_gettable(L, index);
int type = lua_type(L, -1);
if (type != LUA_TNUMBER) all_numbers = all_integers = 0;
if (type != LUA_TSTRING) all_strings = 0;
if (type != LUA_TBOOLEAN) all_bools = 0;
if (all_numbers && !is_integer(L, -1)) all_integers = 0;
lua_pop(L, 1);
if (!all_numbers && !all_strings && !all_bools) break;
}
if (all_integers) return 1;
if (all_numbers) return 2;
if (all_strings) return 3;
if (all_bools) return 4;
return 0;
} }
*/ */
import "C" import "C"
import ( import (
"fmt" "fmt"
"strconv"
"strings" "strings"
"sync"
"unsafe" "unsafe"
) )
// Type pool for common objects to reduce GC pressure // Stack management constants
var stringBufferPool = sync.Pool{ const (
New: func() any { LUA_MINSTACK = 20
return new(strings.Builder) LUA_MAXSTACK = 1000000
}, LUA_REGISTRYINDEX = -10000
} LUA_GLOBALSINDEX = -10002
)
// State represents a Lua state
type State struct { type State struct {
L *C.lua_State L *C.lua_State
} }
// New creates a new Lua state with optional standard libraries; true if not specified
func New(openLibs ...bool) *State { func New(openLibs ...bool) *State {
L := C.luaL_newstate() L := C.luaL_newstate()
if L == nil { if L == nil {
@ -72,7 +105,6 @@ func New(openLibs ...bool) *State {
return &State{L: L} return &State{L: L}
} }
// Close closes the Lua state and frees resources
func (s *State) Close() { func (s *State) Close() {
if s.L != nil { if s.L != nil {
C.lua_close(s.L) C.lua_close(s.L)
@ -80,34 +112,13 @@ func (s *State) Close() {
} }
} }
// Stack manipulation methods // Stack operations
func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) }
func (s *State) SetTop(index int) { C.lua_settop(s.L, C.int(index)) }
func (s *State) PushCopy(index int) { C.lua_pushvalue(s.L, C.int(index)) }
func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) }
func (s *State) Remove(index int) { C.lua_remove(s.L, C.int(index)) }
// GetTop returns the index of the top element in the stack
func (s *State) GetTop() int {
return int(C.lua_gettop(s.L))
}
// SetTop sets the stack top to a specific index
func (s *State) SetTop(index int) {
C.lua_settop(s.L, C.int(index))
}
// PushCopy pushes a copy of the value at the given index onto the stack
func (s *State) PushCopy(index int) {
C.lua_pushvalue(s.L, C.int(index))
}
// Pop pops n elements from the stack
func (s *State) Pop(n int) {
C.lua_settop(s.L, C.int(-n-1))
}
// Remove removes the element at the given valid index
func (s *State) Remove(index int) {
C.lua_remove(s.L, C.int(index))
}
// absIndex converts a possibly negative index to its absolute position
func (s *State) absIndex(index int) int { func (s *State) absIndex(index int) int {
if index > 0 || index <= LUA_REGISTRYINDEX { if index > 0 || index <= LUA_REGISTRYINDEX {
return index return index
@ -115,56 +126,19 @@ func (s *State) absIndex(index int) int {
return s.GetTop() + index + 1 return s.GetTop() + index + 1
} }
// Type checking methods // Type checking
func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) }
func (s *State) IsNil(index int) bool { return s.GetType(index) == TypeNil }
func (s *State) IsBoolean(index int) bool { return s.GetType(index) == TypeBoolean }
func (s *State) IsNumber(index int) bool { return C.lua_isnumber(s.L, C.int(index)) != 0 }
func (s *State) IsString(index int) bool { return C.lua_isstring(s.L, C.int(index)) != 0 }
func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable }
func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction }
// GetType returns the type of the value at the given index // Value conversion
func (s *State) GetType(index int) LuaType { func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 }
return LuaType(C.lua_type(s.L, C.int(index))) func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) }
}
// IsNil checks if the value at the given index is nil
func (s *State) IsNil(index int) bool {
return s.GetType(index) == TypeNil
}
// IsBoolean checks if the value at the given index is a boolean
func (s *State) IsBoolean(index int) bool {
return s.GetType(index) == TypeBoolean
}
// IsNumber checks if the value at the given index is a number
func (s *State) IsNumber(index int) bool {
return C.lua_isnumber(s.L, C.int(index)) != 0
}
// IsString checks if the value at the given index is a string
func (s *State) IsString(index int) bool {
return C.lua_isstring(s.L, C.int(index)) != 0
}
// IsTable checks if the value at the given index is a table
func (s *State) IsTable(index int) bool {
return s.GetType(index) == TypeTable
}
// IsFunction checks if the value at the given index is a function
func (s *State) IsFunction(index int) bool {
return s.GetType(index) == TypeFunction
}
// Value conversion methods
// ToBoolean returns the value at the given index as a boolean
func (s *State) ToBoolean(index int) bool {
return C.lua_toboolean(s.L, C.int(index)) != 0
}
// ToNumber returns the value at the given index as a number
func (s *State) ToNumber(index int) float64 {
return float64(C.lua_tonumber(s.L, C.int(index)))
}
// ToString returns the value at the given index as a string
func (s *State) ToString(index int) string { func (s *State) ToString(index int) string {
var length C.size_t var length C.size_t
cstr := C.lua_tolstring(s.L, C.int(index), &length) cstr := C.lua_tolstring(s.L, C.int(index), &length)
@ -175,170 +149,322 @@ func (s *State) ToString(index int) string {
} }
// Push methods // Push methods
func (s *State) PushNil() { C.lua_pushnil(s.L) }
func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, boolToInt(b)) }
func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.lua_Number(n)) }
// PushNil pushes a nil value onto the stack
func (s *State) PushNil() {
C.lua_pushnil(s.L)
}
// PushBoolean pushes a boolean value onto the stack
func (s *State) PushBoolean(b bool) {
var value C.int
if b {
value = 1
}
C.lua_pushboolean(s.L, value)
}
// PushNumber pushes a number value onto the stack
func (s *State) PushNumber(n float64) {
C.lua_pushnumber(s.L, C.lua_Number(n))
}
// PushString pushes a string value onto the stack
func (s *State) PushString(str string) { func (s *State) PushString(str string) {
// Use direct C string for short strings (avoid allocations)
if len(str) < 128 { if len(str) < 128 {
cstr := C.CString(str) cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr)) defer C.free(unsafe.Pointer(cstr))
C.lua_pushlstring(s.L, cstr, C.size_t(len(str))) C.lua_pushlstring(s.L, cstr, C.size_t(len(str)))
return } else {
header := (*struct {
p unsafe.Pointer
len int
cap int
})(unsafe.Pointer(&str))
C.lua_pushlstring(s.L, (*C.char)(header.p), C.size_t(len(str)))
} }
// For longer strings, avoid double copy by using unsafe pointer
header := (*struct {
p unsafe.Pointer
len int
cap int
})(unsafe.Pointer(&str))
C.lua_pushlstring(s.L, (*C.char)(header.p), C.size_t(len(str)))
} }
// Table operations // Table operations
func (s *State) CreateTable(narr, nrec int) { C.lua_createtable(s.L, C.int(narr), C.int(nrec)) }
func (s *State) NewTable() { C.lua_createtable(s.L, 0, 0) }
func (s *State) GetTable(index int) { C.lua_gettable(s.L, C.int(index)) }
func (s *State) SetTable(index int) { C.lua_settable(s.L, C.int(index)) }
func (s *State) Next(index int) bool { return C.lua_next(s.L, C.int(index)) != 0 }
// CreateTable creates a new table and pushes it onto the stack
func (s *State) CreateTable(narr, nrec int) {
C.lua_createtable(s.L, C.int(narr), C.int(nrec))
}
// NewTable creates a new empty table and pushes it onto the stack
func (s *State) NewTable() {
C.lua_createtable(s.L, 0, 0)
}
// GetTable gets a table field (t[k]) where t is at the given index and k is at the top of the stack
func (s *State) GetTable(index int) {
C.lua_gettable(s.L, C.int(index))
}
// SetTable sets a table field (t[k] = v) where t is at the given index, k is at -2, and v is at -1
func (s *State) SetTable(index int) {
C.lua_settable(s.L, C.int(index))
}
// GetField gets a table field t[k] and pushes it onto the stack
func (s *State) GetField(index int, key string) { func (s *State) GetField(index int, key string) {
ckey := C.CString(key) ckey := C.CString(key)
defer C.free(unsafe.Pointer(ckey)) defer C.free(unsafe.Pointer(ckey))
C.lua_getfield(s.L, C.int(index), ckey) C.lua_getfield(s.L, C.int(index), ckey)
} }
// SetField sets a table field t[k] = v, where v is the value at the top of the stack
func (s *State) SetField(index int, key string) { func (s *State) SetField(index int, key string) {
ckey := C.CString(key) ckey := C.CString(key)
defer C.free(unsafe.Pointer(ckey)) defer C.free(unsafe.Pointer(ckey))
C.lua_setfield(s.L, C.int(index), ckey) C.lua_setfield(s.L, C.int(index), ckey)
} }
// Next pops a key from the stack and pushes the next key-value pair from the table at the given index func (s *State) GetTableLength(index int) int {
func (s *State) Next(index int) bool { return int(C.get_table_length(s.L, C.int(index)))
return C.lua_next(s.L, C.int(index)) != 0
} }
// PushValue pushes a Go value onto the stack with proper type conversion // Enhanced PushValue with comprehensive type support
func (s *State) PushValue(v any) error { func (s *State) PushValue(v any) error {
switch v := v.(type) { switch val := v.(type) {
case nil: case nil:
s.PushNil() s.PushNil()
case bool: case bool:
s.PushBoolean(v) s.PushBoolean(val)
case int: case int:
s.PushNumber(float64(v)) s.PushNumber(float64(val))
case int64: case int64:
s.PushNumber(float64(v)) s.PushNumber(float64(val))
case float64: case float64:
s.PushNumber(v) s.PushNumber(val)
case string: case string:
s.PushString(v) s.PushString(val)
case map[string]any: case []byte:
// Special case: handle array stored in map s.PushString(string(val))
if arr, ok := v[""].([]float64); ok { case []int:
s.CreateTable(len(arr), 0) return s.pushIntSlice(val)
for i, elem := range arr { case []string:
s.PushNumber(float64(i + 1)) return s.pushStringSlice(val)
s.PushNumber(elem) case []bool:
s.SetTable(-3) return s.pushBoolSlice(val)
}
return nil
}
return s.PushTable(v)
case []float64: case []float64:
s.CreateTable(len(v), 0) return s.pushFloatSlice(val)
for i, elem := range v {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
s.SetTable(-3)
}
case []any: case []any:
s.CreateTable(len(v), 0) return s.pushAnySlice(val)
for i, elem := range v { case []map[string]any:
s.PushNumber(float64(i + 1)) return s.pushMapSlice(val)
if err := s.PushValue(elem); err != nil { case map[string]string:
return err return s.pushStringMap(val)
} case map[string]int:
s.SetTable(-3) return s.pushIntMap(val)
} case map[int]any:
return s.pushIntKeyMap(val)
case map[string]any:
return s.pushAnyMap(val)
default: default:
return fmt.Errorf("unsupported type: %T", v) return fmt.Errorf("unsupported type: %T", v)
} }
return nil return nil
} }
// ToValue converts a Lua value at the given index to a Go value func (s *State) pushIntSlice(arr []int) error {
s.CreateTable(len(arr), 0)
for i, v := range arr {
s.PushNumber(float64(i + 1))
s.PushNumber(float64(v))
s.SetTable(-3)
}
return nil
}
func (s *State) pushStringSlice(arr []string) error {
s.CreateTable(len(arr), 0)
for i, v := range arr {
s.PushNumber(float64(i + 1))
s.PushString(v)
s.SetTable(-3)
}
return nil
}
func (s *State) pushBoolSlice(arr []bool) error {
s.CreateTable(len(arr), 0)
for i, v := range arr {
s.PushNumber(float64(i + 1))
s.PushBoolean(v)
s.SetTable(-3)
}
return nil
}
func (s *State) pushFloatSlice(arr []float64) error {
s.CreateTable(len(arr), 0)
for i, v := range arr {
s.PushNumber(float64(i + 1))
s.PushNumber(v)
s.SetTable(-3)
}
return nil
}
func (s *State) pushAnySlice(arr []any) error {
s.CreateTable(len(arr), 0)
for i, v := range arr {
s.PushNumber(float64(i + 1))
if err := s.PushValue(v); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
func (s *State) pushStringMap(m map[string]string) error {
s.CreateTable(0, len(m))
for k, v := range m {
s.PushString(k)
s.PushString(v)
s.SetTable(-3)
}
return nil
}
func (s *State) pushIntMap(m map[string]int) error {
s.CreateTable(0, len(m))
for k, v := range m {
s.PushString(k)
s.PushNumber(float64(v))
s.SetTable(-3)
}
return nil
}
func (s *State) pushIntKeyMap(m map[int]any) error {
s.CreateTable(0, len(m))
for k, v := range m {
s.PushNumber(float64(k))
if err := s.PushValue(v); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
func (s *State) pushAnyMap(m map[string]any) error {
s.CreateTable(0, len(m))
for k, v := range m {
s.PushString(k)
if err := s.PushValue(v); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
// Enhanced ToValue with automatic type detection
func (s *State) ToValue(index int) (any, error) { func (s *State) ToValue(index int) (any, error) {
luaType := s.GetType(index) switch s.GetType(index) {
switch luaType {
case TypeNil: case TypeNil:
return nil, nil return nil, nil
case TypeBoolean: case TypeBoolean:
return s.ToBoolean(index), nil return s.ToBoolean(index), nil
case TypeNumber: case TypeNumber:
return s.ToNumber(index), nil num := s.ToNumber(index)
if num == float64(int(num)) && num >= -2147483648 && num <= 2147483647 {
return int(num), nil
}
return num, nil
case TypeString: case TypeString:
return s.ToString(index), nil return s.ToString(index), nil
case TypeTable: case TypeTable:
return s.ToTable(index) return s.ToTable(index)
default: default:
return nil, fmt.Errorf("unsupported type: %s", luaType) return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
} }
} }
// ToTable converts a Lua table to optimal Go type
func (s *State) ToTable(index int) (any, error) {
absIdx := s.absIndex(index)
if !s.IsTable(absIdx) {
return nil, fmt.Errorf("value at index %d is not a table", index)
}
length := s.GetTableLength(absIdx)
if length > 0 {
arrayType := int(C.sample_array_type(s.L, C.int(absIdx), C.int(length)))
switch arrayType {
case 1: // int array
return s.extractIntArray(absIdx, length), nil
case 2: // float array
return s.extractFloatArray(absIdx, length), nil
case 3: // string array
return s.extractStringArray(absIdx, length), nil
case 4: // bool array
return s.extractBoolArray(absIdx, length), nil
default: // mixed array
return s.extractAnyArray(absIdx, length), nil
}
}
return s.extractAnyMap(absIdx)
}
func (s *State) extractIntArray(index, length int) []int {
result := make([]int, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(index)
result[i-1] = int(s.ToNumber(-1))
s.Pop(1)
}
return result
}
func (s *State) extractFloatArray(index, length int) []float64 {
result := make([]float64, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(index)
result[i-1] = s.ToNumber(-1)
s.Pop(1)
}
return result
}
func (s *State) extractStringArray(index, length int) []string {
result := make([]string, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(index)
result[i-1] = s.ToString(-1)
s.Pop(1)
}
return result
}
func (s *State) extractBoolArray(index, length int) []bool {
result := make([]bool, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(index)
result[i-1] = s.ToBoolean(-1)
s.Pop(1)
}
return result
}
func (s *State) extractAnyArray(index, length int) []any {
result := make([]any, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(index)
if val, err := s.ToValue(-1); err == nil {
result[i-1] = val
}
s.Pop(1)
}
return result
}
func (s *State) extractAnyMap(index int) (map[string]any, error) {
result := make(map[string]any)
s.PushNil()
for s.Next(index) {
var key string
switch s.GetType(-2) {
case TypeString:
key = s.ToString(-2)
case TypeNumber:
key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64)
default:
s.Pop(1)
continue
}
if value, err := s.ToValue(-1); err == nil {
result[key] = value
}
s.Pop(1)
}
return result, nil
}
// Global operations // Global operations
func (s *State) GetGlobal(name string) { s.GetField(LUA_GLOBALSINDEX, name) }
func (s *State) SetGlobal(name string) { s.SetField(LUA_GLOBALSINDEX, name) }
// GetGlobal pushes the global variable with the given name onto the stack // Code execution
func (s *State) GetGlobal(name string) {
s.GetField(LUA_GLOBALSINDEX, name)
}
// SetGlobal sets the global variable with the given name to the value at the top of the stack
func (s *State) SetGlobal(name string) {
s.SetField(LUA_GLOBALSINDEX, name)
}
// Code execution methods
// LoadString loads a Lua chunk from a string without executing it
func (s *State) LoadString(code string) error { func (s *State) LoadString(code string) error {
ccode := C.CString(code) ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode)) defer C.free(unsafe.Pointer(ccode))
@ -346,13 +472,12 @@ func (s *State) LoadString(code string) error {
status := C.luaL_loadstring(s.L, ccode) status := C.luaL_loadstring(s.L, ccode)
if status != 0 { if status != 0 {
err := s.CreateLuaError(int(status), "LoadString") err := s.CreateLuaError(int(status), "LoadString")
s.Pop(1) // Remove error message s.Pop(1)
return err return err
} }
return nil return nil
} }
// LoadFile loads a Lua chunk from a file without executing it
func (s *State) LoadFile(filename string) error { func (s *State) LoadFile(filename string) error {
cfilename := C.CString(filename) cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename)) defer C.free(unsafe.Pointer(cfilename))
@ -360,24 +485,22 @@ func (s *State) LoadFile(filename string) error {
status := C.luaL_loadfile(s.L, cfilename) status := C.luaL_loadfile(s.L, cfilename)
if status != 0 { if status != 0 {
err := s.CreateLuaError(int(status), fmt.Sprintf("LoadFile(%s)", filename)) err := s.CreateLuaError(int(status), fmt.Sprintf("LoadFile(%s)", filename))
s.Pop(1) // Remove error message s.Pop(1)
return err return err
} }
return nil return nil
} }
// Call calls a function with the given number of arguments and results
func (s *State) Call(nargs, nresults int) error { func (s *State) Call(nargs, nresults int) error {
status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0) status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
if status != 0 { if status != 0 {
err := s.CreateLuaError(int(status), fmt.Sprintf("Call(%d,%d)", nargs, nresults)) err := s.CreateLuaError(int(status), fmt.Sprintf("Call(%d,%d)", nargs, nresults))
s.Pop(1) // Remove error message s.Pop(1)
return err return err
} }
return nil return nil
} }
// DoString executes a Lua string and cleans up the stack
func (s *State) DoString(code string) error { func (s *State) DoString(code string) error {
ccode := C.CString(code) ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode)) defer C.free(unsafe.Pointer(ccode))
@ -385,13 +508,12 @@ func (s *State) DoString(code string) error {
status := C.do_string(s.L, ccode) status := C.do_string(s.L, ccode)
if status != 0 { if status != 0 {
err := s.CreateLuaError(int(status), "DoString") err := s.CreateLuaError(int(status), "DoString")
s.Pop(1) // Remove error message s.Pop(1)
return err return err
} }
return nil return nil
} }
// DoFile executes a Lua file and cleans up the stack
func (s *State) DoFile(filename string) error { func (s *State) DoFile(filename string) error {
cfilename := C.CString(filename) cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename)) defer C.free(unsafe.Pointer(cfilename))
@ -399,39 +521,35 @@ func (s *State) DoFile(filename string) error {
status := C.do_file(s.L, cfilename) status := C.do_file(s.L, cfilename)
if status != 0 { if status != 0 {
err := s.CreateLuaError(int(status), fmt.Sprintf("DoFile(%s)", filename)) err := s.CreateLuaError(int(status), fmt.Sprintf("DoFile(%s)", filename))
s.Pop(1) // Remove error message s.Pop(1)
return err return err
} }
return nil return nil
} }
// Execute executes a Lua string and returns the number of results left on the stack
func (s *State) Execute(code string) (int, error) { func (s *State) Execute(code string) (int, error) {
baseTop := s.GetTop() baseTop := s.GetTop()
ccode := C.CString(code) ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode)) defer C.free(unsafe.Pointer(ccode))
status := C.execute_with_results(s.L, ccode, 1) // store_results=true status := C.execute_with_results(s.L, ccode, 1)
if status != 0 { if status != 0 {
err := s.CreateLuaError(int(status), "Execute") err := s.CreateLuaError(int(status), "Execute")
s.Pop(1) // Remove error message s.Pop(1)
return 0, err return 0, err
} }
return s.GetTop() - baseTop, nil return s.GetTop() - baseTop, nil
} }
// ExecuteWithResult executes a Lua string and returns the first result
func (s *State) ExecuteWithResult(code string) (any, error) { func (s *State) ExecuteWithResult(code string) (any, error) {
top := s.GetTop() top := s.GetTop()
defer s.SetTop(top) // Restore stack when done defer s.SetTop(top)
nresults, err := s.Execute(code) nresults, err := s.Execute(code)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if nresults == 0 { if nresults == 0 {
return nil, nil return nil, nil
} }
@ -439,42 +557,173 @@ func (s *State) ExecuteWithResult(code string) (any, error) {
return s.ToValue(-nresults) return s.ToValue(-nresults)
} }
// BatchExecute executes multiple statements with a single CGO transition
func (s *State) BatchExecute(statements []string) error { func (s *State) BatchExecute(statements []string) error {
// Join statements with semicolons return s.DoString(strings.Join(statements, "; "))
combinedCode := ""
for i, stmt := range statements {
combinedCode += stmt
if i < len(statements)-1 {
combinedCode += "; "
}
}
return s.DoString(combinedCode)
} }
// Package path operations // Package path operations
// SetPackagePath sets the Lua package.path
func (s *State) SetPackagePath(path string) error { func (s *State) SetPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths path = strings.ReplaceAll(path, "\\", "/")
code := fmt.Sprintf(`package.path = %q`, path) return s.DoString(fmt.Sprintf(`package.path = %q`, path))
return s.DoString(code)
} }
// AddPackagePath adds a path to package.path
func (s *State) AddPackagePath(path string) error { func (s *State) AddPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths path = strings.ReplaceAll(path, "\\", "/")
code := fmt.Sprintf(`package.path = package.path .. ";%s"`, path) return s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path))
return s.DoString(code)
} }
// SetMetatable sets the metatable for the value at the given index // Metatable operations
func (s *State) SetMetatable(index int) { func (s *State) SetMetatable(index int) { C.lua_setmetatable(s.L, C.int(index)) }
C.lua_setmetatable(s.L, C.int(index)) func (s *State) GetMetatable(index int) bool { return C.lua_getmetatable(s.L, C.int(index)) != 0 }
// Helper functions
func boolToInt(b bool) C.int {
if b {
return 1
}
return 0
} }
// GetMetatable gets the metatable of the value at the given index // GetFieldString gets a string field from a table with default
func (s *State) GetMetatable(index int) bool { func (s *State) GetFieldString(index int, key string, defaultVal string) string {
return C.lua_getmetatable(s.L, C.int(index)) != 0 s.GetField(index, key)
defer s.Pop(1)
if s.IsString(-1) {
return s.ToString(-1)
}
return defaultVal
}
// GetFieldNumber gets a number field from a table with default
func (s *State) GetFieldNumber(index int, key string, defaultVal float64) float64 {
s.GetField(index, key)
defer s.Pop(1)
if s.IsNumber(-1) {
return s.ToNumber(-1)
}
return defaultVal
}
// GetFieldBool gets a boolean field from a table with default
func (s *State) GetFieldBool(index int, key string, defaultVal bool) bool {
s.GetField(index, key)
defer s.Pop(1)
if s.IsBoolean(-1) {
return s.ToBoolean(-1)
}
return defaultVal
}
// GetFieldTable gets a table field from a table
func (s *State) GetFieldTable(index int, key string) (any, bool) {
s.GetField(index, key)
defer s.Pop(1)
if s.IsTable(-1) {
val, err := s.ToTable(-1)
return val, err == nil
}
return nil, false
}
// ForEachTableKV iterates over string key-value pairs in a table
func (s *State) ForEachTableKV(index int, fn func(key, value string) bool) {
absIdx := s.absIndex(index)
s.PushNil()
for s.Next(absIdx) {
if s.IsString(-2) && s.IsString(-1) {
if !fn(s.ToString(-2), s.ToString(-1)) {
s.Pop(2)
return
}
}
s.Pop(1)
}
}
// ForEachArray iterates over array elements
func (s *State) ForEachArray(index int, fn func(i int, state *State) bool) {
absIdx := s.absIndex(index)
length := s.GetTableLength(absIdx)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(absIdx)
if !fn(i, s) {
s.Pop(1)
return
}
s.Pop(1)
}
}
// SafeToString safely converts value to string with error
func (s *State) SafeToString(index int) (string, error) {
if !s.IsString(index) && !s.IsNumber(index) {
return "", fmt.Errorf("value at index %d is not a string", index)
}
return s.ToString(index), nil
}
// SafeToNumber safely converts value to number with error
func (s *State) SafeToNumber(index int) (float64, error) {
if !s.IsNumber(index) {
return 0, fmt.Errorf("value at index %d is not a number", index)
}
return s.ToNumber(index), nil
}
// SafeToTable safely converts value to table with error
func (s *State) SafeToTable(index int) (any, error) {
if !s.IsTable(index) {
return nil, fmt.Errorf("value at index %d is not a table", index)
}
return s.ToTable(index)
}
// CallGlobal calls a global function with arguments
func (s *State) CallGlobal(name string, args ...any) ([]any, error) {
s.GetGlobal(name)
if !s.IsFunction(-1) {
s.Pop(1)
return nil, fmt.Errorf("global '%s' is not a function", name)
}
for i, arg := range args {
if err := s.PushValue(arg); err != nil {
s.Pop(i + 1)
return nil, fmt.Errorf("failed to push argument %d: %w", i+1, err)
}
}
baseTop := s.GetTop() - len(args) - 1
if err := s.Call(len(args), C.LUA_MULTRET); err != nil {
return nil, err
}
newTop := s.GetTop()
nresults := newTop - baseTop
results := make([]any, nresults)
for i := 0; i < nresults; i++ {
val, err := s.ToValue(baseTop + i + 1)
if err != nil {
results[i] = nil
} else {
results[i] = val
}
}
s.SetTop(baseTop)
return results, nil
}
func (s *State) pushMapSlice(arr []map[string]any) error {
s.CreateTable(len(arr), 0)
for i, m := range arr {
s.PushNumber(float64(i + 1))
if err := s.PushValue(m); err != nil {
return err
}
s.SetTable(-3)
}
return nil
} }