Compare commits

..

No commits in common. "875abee366fc7766c010a0ee9eec918216f57543" and "f106dfd9eaa1c392c02d69d139635eb2e1c7182f" have entirely different histories.

6 changed files with 542 additions and 686 deletions

View File

@ -1,133 +0,0 @@
package luajit_bench
import (
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code
func BenchmarkSandboxLuaExecution(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Simple Lua code that does some computation
code := `
local sum = 0
for i = 1, 100 do
sum = sum + i
end
return sum
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode
func BenchmarkSandboxBytecodeExecution(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Same code as above, but precompiled
code := `
local sum = 0
for i = 1, 100 do
sum = sum + i
end
return sum
`
// Compile the bytecode once
bytecode, err := sandbox.Compile(code)
if err != nil {
b.Fatalf("Failed to compile bytecode: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.RunBytecode(bytecode)
if err != nil {
b.Fatalf("Failed to run bytecode: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxComplexComputation measures performance with more complex computation
func BenchmarkSandboxComplexComputation(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// More complex Lua code that calculates Fibonacci numbers
code := `
function fibonacci(n)
if n <= 1 then
return n
end
return fibonacci(n-1) + fibonacci(n-2)
end
return fibonacci(15) -- Not too high to avoid excessive runtime
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(610) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function
func BenchmarkSandboxFunctionCall(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Register a simple addition function
add := func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1
}
err := sandbox.RegisterFunction("add", add)
if err != nil {
b.Fatalf("Failed to register function: %v", err)
}
// Lua code that calls the Go function in a loop
code := `
local sum = 0
for i = 1, 100 do
sum = add(sum, i)
end
return sum
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}

View File

@ -12,7 +12,7 @@ typedef struct {
const char *name;
} BytecodeReader;
const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
static const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
BytecodeReader *r = (BytecodeReader *)ud;
(void)L; // unused
if (r->size == 0) return NULL;
@ -21,23 +21,45 @@ const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
return (const char *)r->buf;
}
int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char *name) {
static int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char *name) {
BytecodeReader reader = {buf, len, name};
return lua_load(L, bytecode_reader, &reader, name);
}
// Direct bytecode dumping without intermediate buffer - more efficient
int direct_bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
void **data = (void **)ud;
size_t current_size = (size_t)data[1];
void *newbuf = realloc(data[0], current_size + sz);
if (newbuf == NULL) return 1;
typedef struct {
unsigned char *buf;
size_t len;
size_t capacity;
} BytecodeWriter;
memcpy((unsigned char*)newbuf + current_size, p, sz);
data[0] = newbuf;
data[1] = (void*)(current_size + sz);
static int bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
BytecodeWriter *w = (BytecodeWriter *)ud;
unsigned char *newbuf;
(void)L; // unused
// Check if we need to reallocate
if (w->len + sz > w->capacity) {
size_t new_capacity = w->capacity * 2;
if (new_capacity < w->len + sz) {
new_capacity = w->len + sz;
}
newbuf = (unsigned char *)realloc(w->buf, new_capacity);
if (newbuf == NULL) return 1;
w->buf = newbuf;
w->capacity = new_capacity;
}
memcpy(w->buf + w->len, p, sz);
w->len += sz;
return 0;
}
// Wrapper function that calls lua_dump with bytecode_writer
static int dump_lua_function(lua_State *L, BytecodeWriter *w) {
return lua_dump(L, bytecode_writer, w);
}
*/
import "C"
import (
@ -51,23 +73,31 @@ func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
return nil, fmt.Errorf("failed to load string: %w", err)
}
// Use a simpler direct writer with just two pointers
data := [2]unsafe.Pointer{nil, nil}
// Set up writer with initial capacity
var writer C.BytecodeWriter
writer.buf = nil
writer.len = 0
writer.capacity = 0
// Initial allocation with a reasonable size
const initialSize = 4096
writer.buf = (*C.uchar)(C.malloc(initialSize))
if writer.buf == nil {
s.Pop(1) // Remove the loaded function
return nil, fmt.Errorf("failed to allocate memory for bytecode")
}
writer.capacity = initialSize
// Dump the function to bytecode
err := s.safeCall(func() C.int {
return C.lua_dump(s.L, (*[0]byte)(unsafe.Pointer(C.direct_bytecode_writer)), unsafe.Pointer(&data))
return C.dump_lua_function(s.L, (*C.BytecodeWriter)(unsafe.Pointer(&writer)))
})
// Get result
var bytecode []byte
if data[0] != nil {
// Create Go slice that references the C memory
length := uintptr(data[1])
bytecode = C.GoBytes(data[0], C.int(length))
C.free(data[0])
}
// Copy bytecode to Go slice regardless of the result
bytecode := C.GoBytes(unsafe.Pointer(writer.buf), C.int(writer.len))
// Clean up
C.free(unsafe.Pointer(writer.buf))
s.Pop(1) // Remove the function from stack
if err != nil {
@ -134,11 +164,6 @@ func (s *State) LoadAndRunBytecodeWithResults(bytecode []byte, name string, nres
// CompileAndRun compiles and immediately executes Lua code
func (s *State) CompileAndRun(code string, name string) error {
// Skip bytecode step for small scripts - direct execution is faster
if len(code) < 1024 {
return s.DoString(code)
}
bytecode, err := s.CompileBytecode(code, name)
if err != nil {
return fmt.Errorf("compile error: %w", err)

View File

@ -22,17 +22,13 @@ import (
// GoFunction defines the signature for Go functions callable from Lua
type GoFunction func(*State) int
// Static registry size reduces resizing operations
const initialRegistrySize = 64
var (
// functionRegistry stores all registered Go functions
functionRegistry = struct {
sync.RWMutex
funcs map[unsafe.Pointer]GoFunction
initOnce sync.Once
funcs map[unsafe.Pointer]GoFunction
}{
funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize),
funcs: make(map[unsafe.Pointer]GoFunction),
}
)
@ -47,7 +43,6 @@ func goFunctionWrapper(L *C.lua_State) C.int {
return -1
}
// Use read-lock for better concurrency
functionRegistry.RLock()
fn, ok := functionRegistry.funcs[ptr]
functionRegistry.RUnlock()

View File

@ -14,7 +14,6 @@ type Sandbox struct {
mutex sync.Mutex
initialized bool
modules map[string]any
functions map[string]GoFunction
}
// NewSandbox creates a new sandbox with standard libraries loaded
@ -23,7 +22,6 @@ func NewSandbox() *Sandbox {
state: New(),
initialized: false,
modules: make(map[string]any),
functions: make(map[string]GoFunction),
}
}
@ -38,21 +36,341 @@ func (s *Sandbox) Close() {
}
}
// RegisterFunction registers a Go function in the sandbox
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Register function globally
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
// Add to base environment
return s.state.DoString(`
-- Add the function to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`)
}
// SetGlobal sets a global variable in the sandbox base environment
func (s *Sandbox) SetGlobal(name string, value any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Push the value onto the stack
if err := s.state.PushValue(value); err != nil {
return err
}
// Set the global with the pushed value
s.state.SetGlobal(name)
// Add to base environment
return s.state.DoString(`
-- Add the global to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`)
}
// GetGlobal retrieves a global variable from the sandbox base environment
func (s *Sandbox) GetGlobal(name string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Get the global from the base environment
code := `return __env_system.base_env["` + name + `"]`
return s.state.ExecuteWithResult(code)
}
// Run executes Lua code in the sandbox and returns the result
func (s *Sandbox) Run(code string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Add wrapper for multiple return values
wrappedCode := `
local function _execfunc()
` + code + `
end
local function _wrapresults(...)
local results = {n = select('#', ...)}
for i = 1, results.n do
results[i] = select(i, ...)
end
return results
end
return _wrapresults(_execfunc())
`
// Compile the code
if err := s.state.LoadString(wrappedCode); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Setup call with correct argument order
s.state.PushCopy(-2) // Copy the function
// Remove the original function
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
// Handle multiple return values
if results, ok := result.([]any); ok && len(results) == 1 {
return results[0], err
}
return result, err
}
// RunFile executes a Lua file in the sandbox
func (s *Sandbox) RunFile(filename string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.DoFile(filename)
}
// Compile compiles Lua code to bytecode without executing it
func (s *Sandbox) Compile(code string) ([]byte, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
return s.state.CompileBytecode(code, "sandbox")
}
// RunBytecode executes precompiled Lua bytecode in the sandbox
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Load the bytecode
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
return nil, err
}
// Add wrapper for multiple return values
if err := s.state.DoString(`
__wrap_bytecode = function(f)
local function _wrapresults(...)
local results = {n = select('#', ...)}
for i = 1, results.n do
results[i] = select(i, ...)
end
return results
end
return function()
return _wrapresults(f())
end
end
`); err != nil {
return nil, err
}
// Get wrapper function
s.state.GetGlobal("__wrap_bytecode")
// Push bytecode function
s.state.PushCopy(-2)
// Call wrapper to create wrapped function
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Remove original bytecode function
s.state.Remove(-2)
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push wrapped function
s.state.PushCopy(-2)
// Remove the wrapped function
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
// Handle multiple return values
if results, ok := result.([]any); ok && len(results) == 1 {
return results[0], err
}
return result, err
}
// getResults collects results from the stack (must be called with mutex locked)
func (s *Sandbox) getResults() (any, error) {
numResults := s.state.GetTop()
if numResults == 0 {
return nil, nil
} else if numResults == 1 {
// Return single result directly
value, err := s.state.ToValue(-1)
s.state.Pop(1)
return value, err
}
// Return multiple results as slice
results := make([]any, numResults)
for i := 0; i < numResults; i++ {
value, err := s.state.ToValue(i - numResults)
if err != nil {
s.state.Pop(numResults)
return nil, err
}
results[i] = value
}
s.state.Pop(numResults)
return results, nil
}
// LoadModule loads a Lua module in the sandbox
func (s *Sandbox) LoadModule(name string) error {
code := fmt.Sprintf("require('%s')", name)
_, err := s.Run(code)
return err
}
// SetPackagePath sets the sandbox package.path
func (s *Sandbox) SetPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.SetPackagePath(path)
}
// AddPackagePath adds a path to the sandbox package.path
func (s *Sandbox) AddPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.AddPackagePath(path)
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
s.modules[name] = module
return nil
}
// Initialize sets up the environment system
func (s *Sandbox) Initialize() error {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.initializeUnlocked()
}
// initializeUnlocked sets up the environment system without locking
// It should only be called when the mutex is already locked
func (s *Sandbox) initializeUnlocked() error {
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
if s.initialized {
return nil
return nil // Already initialized
}
// Register modules
@ -75,74 +393,85 @@ func (s *Sandbox) initializeUnlocked() error {
}
s.state.Pop(1)
// Create simplified environment system
// Create the environment system
err := s.state.DoString(`
-- Global shared environment
__env_system = {
base_env = {}, -- Template environment
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
-- Global shared environment (created once)
__env_system = __env_system or {
base_env = nil, -- Template environment
initialized = false, -- Initialization flag
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
}
-- Create base environment with standard libraries
local base = __env_system.base_env
-- Initialize base environment once
if not __env_system.initialized then
-- Create base environment with all standard libraries
local base = {}
-- Safe standard libraries
base.string = string
base.table = table
base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
-- Safe standard libraries
base.string = string
base.table = table
base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
-- Basic functions
base.print = print
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
base.collectgarbage = collectgarbage
base.unpack = unpack or table.unpack
-- Basic functions
base.print = print
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
base.collectgarbage = collectgarbage
base.unpack = unpack or table.unpack
-- Package system
base.package = {
loaded = {},
path = package.path,
preload = {}
}
-- Package system
base.package = {
loaded = {},
path = package.path,
preload = {}
}
base.require = function(modname)
if base.package.loaded[modname] then
return base.package.loaded[modname]
base.require = function(modname)
if base.package.loaded[modname] then
return base.package.loaded[modname]
end
local loader = base.package.preload[modname]
if type(loader) == "function" then
local result = loader(modname)
base.package.loaded[modname] = result or true
return result
end
error("module '" .. modname .. "' not found", 2)
end
local loader = base.package.preload[modname]
if type(loader) == "function" then
local result = loader(modname)
base.package.loaded[modname] = result or true
return result
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
error("module '" .. modname .. "' not found", 2)
-- Store base environment
__env_system.base_env = base
__env_system.initialized = true
end
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
-- Global variable for tracking current environment
__last_env = nil
-- Get an environment for execution
function __get_sandbox_env()
@ -155,15 +484,19 @@ func (s *Sandbox) initializeUnlocked() error {
else
-- Create new environment with metatable inheritance
env = setmetatable({}, {
__index = __env_system.base_env
__index = _G -- Use global environment as fallback
})
end
-- Store reference to current environment
__last_env = env
return env
end
-- Return environment to pool for reuse
function __recycle_env(env)
-- Only recycle if pool isn't full
if __env_system.pool_size < __env_system.max_pool_size then
-- Clear all fields except metatable
for k in pairs(env) do
@ -187,13 +520,6 @@ func (s *Sandbox) initializeUnlocked() error {
-- Execute with protected call
local success, result = pcall(f)
-- Update base environment with new globals
for k, v in pairs(env) do
if k ~= "_G" and type(k) == "string" then
__env_system.base_env[k] = v
end
end
-- Recycle environment
__recycle_env(env)
@ -201,7 +527,16 @@ func (s *Sandbox) initializeUnlocked() error {
if not success then
error(result, 0)
end
-- Handle multiple return values
if type(result) == "table" and result.n ~= nil then
local returnValues = {}
for i=1, result.n do
returnValues[i] = result[i]
end
return returnValues
end
return result
end
`)
@ -214,290 +549,8 @@ func (s *Sandbox) initializeUnlocked() error {
return nil
}
// RegisterFunction registers a Go function in the sandbox
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Register function globally
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
// Store function for re-registration
s.functions[name] = fn
// Add to base environment
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
}
// SetGlobal sets a global variable in the sandbox base environment
func (s *Sandbox) SetGlobal(name string, value any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Push the value onto the stack
if err := s.state.PushValue(value); err != nil {
return err
}
// Set the global with the pushed value
s.state.SetGlobal(name)
// Add to base environment
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
}
// GetGlobal retrieves a global variable from the sandbox base environment
func (s *Sandbox) GetGlobal(name string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Get the global from the base environment
return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`)
}
// Run executes Lua code in the sandbox
func (s *Sandbox) Run(code string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Simplified wrapper for multiple return values
wrappedCode := `
local function _execfunc()
` + code + `
end
-- Process results to match expected format
local function _wrapresults(...)
local n = select('#', ...)
if n == 0 then
return nil
elseif n == 1 then
return select(1, ...)
else
local results = {}
for i = 1, n do
results[i] = select(i, ...)
end
return results
end
end
return _wrapresults(_execfunc())
`
// Compile the code
if err := s.state.LoadString(wrappedCode); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push the function as argument
s.state.PushCopy(-2)
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
if err != nil {
return nil, err
}
return s.processResult(result), nil
}
// RunFile executes a Lua file in the sandbox
func (s *Sandbox) RunFile(filename string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.DoFile(filename)
}
// Compile compiles Lua code to bytecode
func (s *Sandbox) Compile(code string) ([]byte, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
return s.state.CompileBytecode(code, "sandbox")
}
// RunBytecode executes precompiled Lua bytecode
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Load the bytecode
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push bytecode function
s.state.PushCopy(-2)
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
if err != nil {
return nil, err
}
return s.processResult(result), nil
}
// LoadModule loads a Lua module
func (s *Sandbox) LoadModule(name string) error {
code := fmt.Sprintf("require('%s')", name)
_, err := s.Run(code)
return err
}
// SetPackagePath sets the sandbox package.path
func (s *Sandbox) SetPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.SetPackagePath(path); err != nil {
return err
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
}
// AddPackagePath adds a path to the sandbox package.path
func (s *Sandbox) AddPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.AddPackagePath(path); err != nil {
return err
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
s.modules[name] = module
return nil
}
// AddPermanentLua adds Lua code to the environment permanently
// This code becomes part of the base environment
func (s *Sandbox) AddPermanentLua(code string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
@ -506,25 +559,33 @@ func (s *Sandbox) AddPermanentLua(code string) error {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Simplified approach to add code to base environment
// Add code to base environment
return s.state.DoString(`
-- First compile the code
local f, err = loadstring([=[` + code + `]=], "permanent")
if not f then error(err, 0) end
local env = setmetatable({}, {__index = __env_system.base_env})
setfenv(f, env)
if not f then
error(err, 0)
end
-- Create a temporary environment based on base env
local temp_env = setmetatable({}, {__index = __env_system.base_env})
setfenv(f, temp_env)
-- Run the code in the temporary environment
local ok, err = pcall(f)
if not ok then error(err, 0) end
for k, v in pairs(env) do
if not ok then
error(err, 0)
end
-- Copy new values to base environment
for k, v in pairs(temp_env) do
__env_system.base_env[k] = v
end
`)
@ -539,71 +600,7 @@ func (s *Sandbox) ResetEnvironment() error {
return fmt.Errorf("sandbox is closed")
}
// Clear the environment system
s.state.DoString(`__env_system = nil`)
// Reinitialize
// Reinitialize the environment system
s.initialized = false
if err := s.initializeUnlocked(); err != nil {
return err
}
// Re-register all functions
for name, fn := range s.functions {
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil {
return err
}
}
return nil
}
// unwrapResult processes results from Lua executions
func (s *Sandbox) processResult(result any) any {
// Handle []float64 (common LuaJIT return type)
if floats, ok := result.([]float64); ok {
if len(floats) == 1 {
// Single number - return as float64
return floats[0]
}
// Multiple numbers - MUST convert to []any for tests to pass
anySlice := make([]any, len(floats))
for i, v := range floats {
anySlice[i] = v
}
return anySlice
}
// Handle maps with numeric keys (Lua tables)
if m, ok := result.(map[string]any); ok {
// Handle return tables with special structure
if vals, ok := m[""]; ok {
// This is a special case used by some Lua returns
if arr, ok := vals.([]float64); ok {
// Convert to []any for consistency
anySlice := make([]any, len(arr))
for i, v := range arr {
anySlice[i] = v
}
return anySlice
}
return vals
}
if len(m) == 1 {
// Check for single value map
for k, v := range m {
if k == "1" {
return v
}
}
}
}
// Other array types should be preserved
return result
return s.Initialize()
}

121
table.go
View File

@ -6,8 +6,7 @@ package luajit
#include <lauxlib.h>
#include <stdlib.h>
// Simple direct length check
size_t get_table_length(lua_State *L, int index) {
static size_t get_table_length(lua_State *L, int index) {
return lua_objlen(L, index);
}
*/
@ -15,53 +14,70 @@ import "C"
import (
"fmt"
"strconv"
"sync"
)
// Use a pool to reduce GC pressure when handling many tables
var tablePool = sync.Pool{
New: func() any {
return make(map[string]any)
},
}
// GetTableLength returns the length of a table at the given index
func (s *State) GetTableLength(index int) int {
return int(C.get_table_length(s.L, C.int(index)))
}
// getTableFromPool gets a map from the pool and ensures it's empty
func getTableFromPool() map[string]any {
table := tablePool.Get().(map[string]any)
// Clear any existing entries
for k := range table {
delete(table, k)
}
return table
}
// putTableToPool returns a map to the pool
func putTableToPool(table map[string]any) {
tablePool.Put(table)
}
// PushTable pushes a Go map onto the Lua stack as a table
func (s *State) PushTable(table map[string]any) error {
// Fast path for array tables
if arr, ok := table[""]; ok {
if floatArr, ok := arr.([]float64); ok {
s.CreateTable(len(floatArr), 0)
for i, v := range floatArr {
s.PushNumber(float64(i + 1))
s.PushNumber(v)
s.SetTable(-3)
}
return nil
} else if anyArr, ok := arr.([]any); ok {
s.CreateTable(len(anyArr), 0)
for i, v := range anyArr {
s.PushNumber(float64(i + 1))
if err := s.PushValue(v); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
}
// Regular table case - optimize capacity hint
// Create table with appropriate capacity hints
s.CreateTable(0, len(table))
// Add each key-value pair directly
// Add each key-value pair
for k, v := range table {
// Push key
s.PushString(k)
// Push value
if err := s.PushValue(v); err != nil {
return err
}
// t[k] = v
s.SetTable(-3)
}
// Return pooled tables to the pool
if isPooledTable(table) {
putTableToPool(table)
}
return nil
}
// isPooledTable detects if a table came from our pool
func isPooledTable(table map[string]any) bool {
// Check for our special marker - used for array tables in the pool
_, hasEmptyKey := table[""]
return len(table) == 1 && hasEmptyKey
}
// ToTable converts a Lua table at the given index to a Go map
func (s *State) ToTable(index int) (map[string]any, error) {
absIdx := s.absIndex(index)
@ -72,41 +88,34 @@ func (s *State) ToTable(index int) (map[string]any, error) {
// Try to detect array-like tables first
length := s.GetTableLength(absIdx)
if length > 0 {
// Fast path for common array case
allNumbers := true
// Check if this is an array-like table
isArray := true
array := make([]float64, length)
// Sample first few values to check if it's likely an array of numbers
for i := 1; i <= min(length, 5); i++ {
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(absIdx)
if !s.IsNumber(-1) {
allNumbers = false
isArray = false
s.Pop(1)
break
}
array[i-1] = s.ToNumber(-1)
s.Pop(1)
}
if allNumbers {
// Efficiently extract array values
array := make([]float64, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(absIdx)
array[i-1] = s.ToNumber(-1)
s.Pop(1)
}
// Return array as a special table with empty key
result := make(map[string]any, 1)
if isArray {
// Return array as a special pooled table with empty key
result := getTableFromPool()
result[""] = array
return result, nil
}
}
// Handle regular table with pre-allocated capacity
table := make(map[string]any, max(length, 8))
// Handle regular table
table := getTableFromPool()
// Iterate through all key-value pairs
s.PushNil() // Start iteration with nil key
@ -130,11 +139,12 @@ func (s *State) ToTable(index int) (map[string]any, error) {
// Convert and store the value
value, err := s.ToValue(-1)
if err != nil {
s.Pop(2) // Pop both key and value
s.Pop(2) // Pop both key and value
putTableToPool(table) // Return the table to the pool on error
return nil, err
}
// Unwrap nested array tables
// Handle nested array tables
if m, ok := value.(map[string]any); ok {
if arr, ok := m[""]; ok {
value = arr
@ -147,18 +157,3 @@ func (s *State) ToTable(index int) (map[string]any, error) {
return table, nil
}
// Helper functions for min/max operations
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
}

View File

@ -11,7 +11,7 @@ package luajit
#include <stdlib.h>
#include <string.h>
// Optimized helpers for common operations
// Helper to simplify some common operations
static int get_abs_index(lua_State *L, int idx) {
if (idx > 0 || idx <= LUA_REGISTRYINDEX) return idx;
return lua_gettop(L) + idx + 1;
@ -39,17 +39,9 @@ import "C"
import (
"fmt"
"strings"
"sync"
"unsafe"
)
// Type pool for common objects to reduce GC pressure
var stringBufferPool = sync.Pool{
New: func() any {
return new(strings.Builder)
},
}
// State represents a Lua state
type State struct {
L *C.lua_State
@ -86,7 +78,7 @@ func (s *State) SetTop(index int) {
C.lua_settop(s.L, C.int(index))
}
// PushCopy pushes a copy of the value at the given index onto the stack
// PushValue pushes a copy of the value at the given index onto the stack
func (s *State) PushCopy(index int) {
C.lua_pushvalue(s.L, C.int(index))
}
@ -191,22 +183,9 @@ func (s *State) PushNumber(n float64) {
// PushString pushes a string value onto the stack
func (s *State) PushString(str string) {
// Use direct C string for short strings (avoid allocations)
if len(str) < 128 {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
C.lua_pushlstring(s.L, cstr, C.size_t(len(str)))
return
}
// For longer strings, avoid double copy by using unsafe pointer
header := (*struct {
p unsafe.Pointer
len int
cap int
})(unsafe.Pointer(&str))
C.lua_pushlstring(s.L, (*C.char)(header.p), C.size_t(len(str)))
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
C.lua_pushlstring(s.L, cstr, C.size_t(len(str)))
}
// Table operations
@ -427,15 +406,13 @@ func (s *State) ExecuteWithResult(code string) (any, error) {
// SetPackagePath sets the Lua package.path
func (s *State) SetPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
code := fmt.Sprintf(`package.path = %q`, path)
return s.DoString(code)
return s.DoString(fmt.Sprintf(`package.path = %q`, path))
}
// AddPackagePath adds a path to package.path
func (s *State) AddPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
code := fmt.Sprintf(`package.path = package.path .. ";%s"`, path)
return s.DoString(code)
return s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path))
}
// Helper functions