Compare commits
7 Commits
f106dfd9ea
...
875abee366
Author | SHA1 | Date | |
---|---|---|---|
875abee366 | |||
4ad87f81f3 | |||
9e5092acdb | |||
b83f77d7a6 | |||
29679349ef | |||
fed0c2ad34 | |||
faab0a2d08 |
133
bench/sandbox_bench_test.go
Normal file
133
bench/sandbox_bench_test.go
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
package luajit_bench
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code
|
||||||
|
func BenchmarkSandboxLuaExecution(b *testing.B) {
|
||||||
|
sandbox := luajit.NewSandbox()
|
||||||
|
defer sandbox.Close()
|
||||||
|
|
||||||
|
// Simple Lua code that does some computation
|
||||||
|
code := `
|
||||||
|
local sum = 0
|
||||||
|
for i = 1, 100 do
|
||||||
|
sum = sum + i
|
||||||
|
end
|
||||||
|
return sum
|
||||||
|
`
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result, err := sandbox.Run(code)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to run code: %v", err)
|
||||||
|
}
|
||||||
|
if result != float64(5050) {
|
||||||
|
b.Fatalf("Incorrect result: %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode
|
||||||
|
func BenchmarkSandboxBytecodeExecution(b *testing.B) {
|
||||||
|
sandbox := luajit.NewSandbox()
|
||||||
|
defer sandbox.Close()
|
||||||
|
|
||||||
|
// Same code as above, but precompiled
|
||||||
|
code := `
|
||||||
|
local sum = 0
|
||||||
|
for i = 1, 100 do
|
||||||
|
sum = sum + i
|
||||||
|
end
|
||||||
|
return sum
|
||||||
|
`
|
||||||
|
|
||||||
|
// Compile the bytecode once
|
||||||
|
bytecode, err := sandbox.Compile(code)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to compile bytecode: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result, err := sandbox.RunBytecode(bytecode)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to run bytecode: %v", err)
|
||||||
|
}
|
||||||
|
if result != float64(5050) {
|
||||||
|
b.Fatalf("Incorrect result: %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkSandboxComplexComputation measures performance with more complex computation
|
||||||
|
func BenchmarkSandboxComplexComputation(b *testing.B) {
|
||||||
|
sandbox := luajit.NewSandbox()
|
||||||
|
defer sandbox.Close()
|
||||||
|
|
||||||
|
// More complex Lua code that calculates Fibonacci numbers
|
||||||
|
code := `
|
||||||
|
function fibonacci(n)
|
||||||
|
if n <= 1 then
|
||||||
|
return n
|
||||||
|
end
|
||||||
|
return fibonacci(n-1) + fibonacci(n-2)
|
||||||
|
end
|
||||||
|
|
||||||
|
return fibonacci(15) -- Not too high to avoid excessive runtime
|
||||||
|
`
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result, err := sandbox.Run(code)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to run code: %v", err)
|
||||||
|
}
|
||||||
|
if result != float64(610) {
|
||||||
|
b.Fatalf("Incorrect result: %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function
|
||||||
|
func BenchmarkSandboxFunctionCall(b *testing.B) {
|
||||||
|
sandbox := luajit.NewSandbox()
|
||||||
|
defer sandbox.Close()
|
||||||
|
|
||||||
|
// Register a simple addition function
|
||||||
|
add := func(s *luajit.State) int {
|
||||||
|
a := s.ToNumber(1)
|
||||||
|
b := s.ToNumber(2)
|
||||||
|
s.PushNumber(a + b)
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sandbox.RegisterFunction("add", add)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to register function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lua code that calls the Go function in a loop
|
||||||
|
code := `
|
||||||
|
local sum = 0
|
||||||
|
for i = 1, 100 do
|
||||||
|
sum = add(sum, i)
|
||||||
|
end
|
||||||
|
return sum
|
||||||
|
`
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result, err := sandbox.Run(code)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to run code: %v", err)
|
||||||
|
}
|
||||||
|
if result != float64(5050) {
|
||||||
|
b.Fatalf("Incorrect result: %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
77
bytecode.go
77
bytecode.go
|
@ -12,7 +12,7 @@ typedef struct {
|
||||||
const char *name;
|
const char *name;
|
||||||
} BytecodeReader;
|
} BytecodeReader;
|
||||||
|
|
||||||
static const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
|
const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
|
||||||
BytecodeReader *r = (BytecodeReader *)ud;
|
BytecodeReader *r = (BytecodeReader *)ud;
|
||||||
(void)L; // unused
|
(void)L; // unused
|
||||||
if (r->size == 0) return NULL;
|
if (r->size == 0) return NULL;
|
||||||
|
@ -21,45 +21,23 @@ static const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
|
||||||
return (const char *)r->buf;
|
return (const char *)r->buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char *name) {
|
int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char *name) {
|
||||||
BytecodeReader reader = {buf, len, name};
|
BytecodeReader reader = {buf, len, name};
|
||||||
return lua_load(L, bytecode_reader, &reader, name);
|
return lua_load(L, bytecode_reader, &reader, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef struct {
|
// Direct bytecode dumping without intermediate buffer - more efficient
|
||||||
unsigned char *buf;
|
int direct_bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
|
||||||
size_t len;
|
void **data = (void **)ud;
|
||||||
size_t capacity;
|
size_t current_size = (size_t)data[1];
|
||||||
} BytecodeWriter;
|
void *newbuf = realloc(data[0], current_size + sz);
|
||||||
|
|
||||||
static int bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
|
|
||||||
BytecodeWriter *w = (BytecodeWriter *)ud;
|
|
||||||
unsigned char *newbuf;
|
|
||||||
(void)L; // unused
|
|
||||||
|
|
||||||
// Check if we need to reallocate
|
|
||||||
if (w->len + sz > w->capacity) {
|
|
||||||
size_t new_capacity = w->capacity * 2;
|
|
||||||
if (new_capacity < w->len + sz) {
|
|
||||||
new_capacity = w->len + sz;
|
|
||||||
}
|
|
||||||
|
|
||||||
newbuf = (unsigned char *)realloc(w->buf, new_capacity);
|
|
||||||
if (newbuf == NULL) return 1;
|
if (newbuf == NULL) return 1;
|
||||||
|
|
||||||
w->buf = newbuf;
|
memcpy((unsigned char*)newbuf + current_size, p, sz);
|
||||||
w->capacity = new_capacity;
|
data[0] = newbuf;
|
||||||
}
|
data[1] = (void*)(current_size + sz);
|
||||||
|
|
||||||
memcpy(w->buf + w->len, p, sz);
|
|
||||||
w->len += sz;
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrapper function that calls lua_dump with bytecode_writer
|
|
||||||
static int dump_lua_function(lua_State *L, BytecodeWriter *w) {
|
|
||||||
return lua_dump(L, bytecode_writer, w);
|
|
||||||
}
|
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
|
@ -73,31 +51,23 @@ func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
|
||||||
return nil, fmt.Errorf("failed to load string: %w", err)
|
return nil, fmt.Errorf("failed to load string: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up writer with initial capacity
|
// Use a simpler direct writer with just two pointers
|
||||||
var writer C.BytecodeWriter
|
data := [2]unsafe.Pointer{nil, nil}
|
||||||
writer.buf = nil
|
|
||||||
writer.len = 0
|
|
||||||
writer.capacity = 0
|
|
||||||
|
|
||||||
// Initial allocation with a reasonable size
|
|
||||||
const initialSize = 4096
|
|
||||||
writer.buf = (*C.uchar)(C.malloc(initialSize))
|
|
||||||
if writer.buf == nil {
|
|
||||||
s.Pop(1) // Remove the loaded function
|
|
||||||
return nil, fmt.Errorf("failed to allocate memory for bytecode")
|
|
||||||
}
|
|
||||||
writer.capacity = initialSize
|
|
||||||
|
|
||||||
// Dump the function to bytecode
|
// Dump the function to bytecode
|
||||||
err := s.safeCall(func() C.int {
|
err := s.safeCall(func() C.int {
|
||||||
return C.dump_lua_function(s.L, (*C.BytecodeWriter)(unsafe.Pointer(&writer)))
|
return C.lua_dump(s.L, (*[0]byte)(unsafe.Pointer(C.direct_bytecode_writer)), unsafe.Pointer(&data))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Copy bytecode to Go slice regardless of the result
|
// Get result
|
||||||
bytecode := C.GoBytes(unsafe.Pointer(writer.buf), C.int(writer.len))
|
var bytecode []byte
|
||||||
|
if data[0] != nil {
|
||||||
|
// Create Go slice that references the C memory
|
||||||
|
length := uintptr(data[1])
|
||||||
|
bytecode = C.GoBytes(data[0], C.int(length))
|
||||||
|
C.free(data[0])
|
||||||
|
}
|
||||||
|
|
||||||
// Clean up
|
|
||||||
C.free(unsafe.Pointer(writer.buf))
|
|
||||||
s.Pop(1) // Remove the function from stack
|
s.Pop(1) // Remove the function from stack
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -164,6 +134,11 @@ func (s *State) LoadAndRunBytecodeWithResults(bytecode []byte, name string, nres
|
||||||
|
|
||||||
// CompileAndRun compiles and immediately executes Lua code
|
// CompileAndRun compiles and immediately executes Lua code
|
||||||
func (s *State) CompileAndRun(code string, name string) error {
|
func (s *State) CompileAndRun(code string, name string) error {
|
||||||
|
// Skip bytecode step for small scripts - direct execution is faster
|
||||||
|
if len(code) < 1024 {
|
||||||
|
return s.DoString(code)
|
||||||
|
}
|
||||||
|
|
||||||
bytecode, err := s.CompileBytecode(code, name)
|
bytecode, err := s.CompileBytecode(code, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("compile error: %w", err)
|
return fmt.Errorf("compile error: %w", err)
|
||||||
|
|
|
@ -22,13 +22,17 @@ import (
|
||||||
// GoFunction defines the signature for Go functions callable from Lua
|
// GoFunction defines the signature for Go functions callable from Lua
|
||||||
type GoFunction func(*State) int
|
type GoFunction func(*State) int
|
||||||
|
|
||||||
|
// Static registry size reduces resizing operations
|
||||||
|
const initialRegistrySize = 64
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// functionRegistry stores all registered Go functions
|
// functionRegistry stores all registered Go functions
|
||||||
functionRegistry = struct {
|
functionRegistry = struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
funcs map[unsafe.Pointer]GoFunction
|
funcs map[unsafe.Pointer]GoFunction
|
||||||
|
initOnce sync.Once
|
||||||
}{
|
}{
|
||||||
funcs: make(map[unsafe.Pointer]GoFunction),
|
funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,6 +47,7 @@ func goFunctionWrapper(L *C.lua_State) C.int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use read-lock for better concurrency
|
||||||
functionRegistry.RLock()
|
functionRegistry.RLock()
|
||||||
fn, ok := functionRegistry.funcs[ptr]
|
fn, ok := functionRegistry.funcs[ptr]
|
||||||
functionRegistry.RUnlock()
|
functionRegistry.RUnlock()
|
||||||
|
|
743
sandbox.go
743
sandbox.go
|
@ -14,6 +14,7 @@ type Sandbox struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
initialized bool
|
initialized bool
|
||||||
modules map[string]any
|
modules map[string]any
|
||||||
|
functions map[string]GoFunction
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSandbox creates a new sandbox with standard libraries loaded
|
// NewSandbox creates a new sandbox with standard libraries loaded
|
||||||
|
@ -22,6 +23,7 @@ func NewSandbox() *Sandbox {
|
||||||
state: New(),
|
state: New(),
|
||||||
initialized: false,
|
initialized: false,
|
||||||
modules: make(map[string]any),
|
modules: make(map[string]any),
|
||||||
|
functions: make(map[string]GoFunction),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,341 +38,21 @@ func (s *Sandbox) Close() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterFunction registers a Go function in the sandbox
|
|
||||||
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure sandbox is initialized
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register function globally
|
|
||||||
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add to base environment
|
|
||||||
return s.state.DoString(`
|
|
||||||
-- Add the function to base environment
|
|
||||||
__env_system.base_env["` + name + `"] = ` + name + `
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetGlobal sets a global variable in the sandbox base environment
|
|
||||||
func (s *Sandbox) SetGlobal(name string, value any) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure sandbox is initialized
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Push the value onto the stack
|
|
||||||
if err := s.state.PushValue(value); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the global with the pushed value
|
|
||||||
s.state.SetGlobal(name)
|
|
||||||
|
|
||||||
// Add to base environment
|
|
||||||
return s.state.DoString(`
|
|
||||||
-- Add the global to base environment
|
|
||||||
__env_system.base_env["` + name + `"] = ` + name + `
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetGlobal retrieves a global variable from the sandbox base environment
|
|
||||||
func (s *Sandbox) GetGlobal(name string) (any, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return nil, fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure sandbox is initialized
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the global from the base environment
|
|
||||||
code := `return __env_system.base_env["` + name + `"]`
|
|
||||||
return s.state.ExecuteWithResult(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run executes Lua code in the sandbox and returns the result
|
|
||||||
func (s *Sandbox) Run(code string) (any, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return nil, fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure sandbox is initialized
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add wrapper for multiple return values
|
|
||||||
wrappedCode := `
|
|
||||||
local function _execfunc()
|
|
||||||
` + code + `
|
|
||||||
end
|
|
||||||
|
|
||||||
local function _wrapresults(...)
|
|
||||||
local results = {n = select('#', ...)}
|
|
||||||
for i = 1, results.n do
|
|
||||||
results[i] = select(i, ...)
|
|
||||||
end
|
|
||||||
return results
|
|
||||||
end
|
|
||||||
|
|
||||||
return _wrapresults(_execfunc())
|
|
||||||
`
|
|
||||||
|
|
||||||
// Compile the code
|
|
||||||
if err := s.state.LoadString(wrappedCode); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the sandbox executor
|
|
||||||
s.state.GetGlobal("__execute_sandbox")
|
|
||||||
|
|
||||||
// Setup call with correct argument order
|
|
||||||
s.state.PushCopy(-2) // Copy the function
|
|
||||||
|
|
||||||
// Remove the original function
|
|
||||||
s.state.Remove(-3)
|
|
||||||
|
|
||||||
// Execute in sandbox
|
|
||||||
if err := s.state.Call(1, 1); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get result
|
|
||||||
result, err := s.state.ToValue(-1)
|
|
||||||
s.state.Pop(1)
|
|
||||||
|
|
||||||
// Handle multiple return values
|
|
||||||
if results, ok := result.([]any); ok && len(results) == 1 {
|
|
||||||
return results[0], err
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunFile executes a Lua file in the sandbox
|
|
||||||
func (s *Sandbox) RunFile(filename string) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.state.DoFile(filename)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile compiles Lua code to bytecode without executing it
|
|
||||||
func (s *Sandbox) Compile(code string) ([]byte, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return nil, fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.state.CompileBytecode(code, "sandbox")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunBytecode executes precompiled Lua bytecode in the sandbox
|
|
||||||
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return nil, fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure sandbox is initialized
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the bytecode
|
|
||||||
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add wrapper for multiple return values
|
|
||||||
if err := s.state.DoString(`
|
|
||||||
__wrap_bytecode = function(f)
|
|
||||||
local function _wrapresults(...)
|
|
||||||
local results = {n = select('#', ...)}
|
|
||||||
for i = 1, results.n do
|
|
||||||
results[i] = select(i, ...)
|
|
||||||
end
|
|
||||||
return results
|
|
||||||
end
|
|
||||||
|
|
||||||
return function()
|
|
||||||
return _wrapresults(f())
|
|
||||||
end
|
|
||||||
end
|
|
||||||
`); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get wrapper function
|
|
||||||
s.state.GetGlobal("__wrap_bytecode")
|
|
||||||
|
|
||||||
// Push bytecode function
|
|
||||||
s.state.PushCopy(-2)
|
|
||||||
|
|
||||||
// Call wrapper to create wrapped function
|
|
||||||
if err := s.state.Call(1, 1); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove original bytecode function
|
|
||||||
s.state.Remove(-2)
|
|
||||||
|
|
||||||
// Get the sandbox executor
|
|
||||||
s.state.GetGlobal("__execute_sandbox")
|
|
||||||
|
|
||||||
// Push wrapped function
|
|
||||||
s.state.PushCopy(-2)
|
|
||||||
|
|
||||||
// Remove the wrapped function
|
|
||||||
s.state.Remove(-3)
|
|
||||||
|
|
||||||
// Execute in sandbox
|
|
||||||
if err := s.state.Call(1, 1); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get result
|
|
||||||
result, err := s.state.ToValue(-1)
|
|
||||||
s.state.Pop(1)
|
|
||||||
|
|
||||||
// Handle multiple return values
|
|
||||||
if results, ok := result.([]any); ok && len(results) == 1 {
|
|
||||||
return results[0], err
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// getResults collects results from the stack (must be called with mutex locked)
|
|
||||||
func (s *Sandbox) getResults() (any, error) {
|
|
||||||
numResults := s.state.GetTop()
|
|
||||||
if numResults == 0 {
|
|
||||||
return nil, nil
|
|
||||||
} else if numResults == 1 {
|
|
||||||
// Return single result directly
|
|
||||||
value, err := s.state.ToValue(-1)
|
|
||||||
s.state.Pop(1)
|
|
||||||
return value, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return multiple results as slice
|
|
||||||
results := make([]any, numResults)
|
|
||||||
for i := 0; i < numResults; i++ {
|
|
||||||
value, err := s.state.ToValue(i - numResults)
|
|
||||||
if err != nil {
|
|
||||||
s.state.Pop(numResults)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
results[i] = value
|
|
||||||
}
|
|
||||||
s.state.Pop(numResults)
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadModule loads a Lua module in the sandbox
|
|
||||||
func (s *Sandbox) LoadModule(name string) error {
|
|
||||||
code := fmt.Sprintf("require('%s')", name)
|
|
||||||
_, err := s.Run(code)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPackagePath sets the sandbox package.path
|
|
||||||
func (s *Sandbox) SetPackagePath(path string) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.state.SetPackagePath(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPackagePath adds a path to the sandbox package.path
|
|
||||||
func (s *Sandbox) AddPackagePath(path string) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.state.AddPackagePath(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddModule adds a module to the sandbox environment
|
|
||||||
func (s *Sandbox) AddModule(name string, module any) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.modules[name] = module
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize sets up the environment system
|
// Initialize sets up the environment system
|
||||||
func (s *Sandbox) Initialize() error {
|
func (s *Sandbox) Initialize() error {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
return s.initializeUnlocked()
|
return s.initializeUnlocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
// initializeUnlocked sets up the environment system without locking
|
// initializeUnlocked sets up the environment system without locking
|
||||||
// It should only be called when the mutex is already locked
|
|
||||||
func (s *Sandbox) initializeUnlocked() error {
|
func (s *Sandbox) initializeUnlocked() error {
|
||||||
if s.state == nil {
|
if s.state == nil {
|
||||||
return fmt.Errorf("sandbox is closed")
|
return fmt.Errorf("sandbox is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.initialized {
|
if s.initialized {
|
||||||
return nil // Already initialized
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register modules
|
// Register modules
|
||||||
|
@ -393,21 +75,18 @@ func (s *Sandbox) initializeUnlocked() error {
|
||||||
}
|
}
|
||||||
s.state.Pop(1)
|
s.state.Pop(1)
|
||||||
|
|
||||||
// Create the environment system
|
// Create simplified environment system
|
||||||
err := s.state.DoString(`
|
err := s.state.DoString(`
|
||||||
-- Global shared environment (created once)
|
-- Global shared environment
|
||||||
__env_system = __env_system or {
|
__env_system = {
|
||||||
base_env = nil, -- Template environment
|
base_env = {}, -- Template environment
|
||||||
initialized = false, -- Initialization flag
|
|
||||||
env_pool = {}, -- Pre-allocated environment pool
|
env_pool = {}, -- Pre-allocated environment pool
|
||||||
pool_size = 0, -- Current pool size
|
pool_size = 0, -- Current pool size
|
||||||
max_pool_size = 8 -- Maximum pool size
|
max_pool_size = 8 -- Maximum pool size
|
||||||
}
|
}
|
||||||
|
|
||||||
-- Initialize base environment once
|
-- Create base environment with standard libraries
|
||||||
if not __env_system.initialized then
|
local base = __env_system.base_env
|
||||||
-- Create base environment with all standard libraries
|
|
||||||
local base = {}
|
|
||||||
|
|
||||||
-- Safe standard libraries
|
-- Safe standard libraries
|
||||||
base.string = string
|
base.string = string
|
||||||
|
@ -465,14 +144,6 @@ func (s *Sandbox) initializeUnlocked() error {
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Store base environment
|
|
||||||
__env_system.base_env = base
|
|
||||||
__env_system.initialized = true
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Global variable for tracking current environment
|
|
||||||
__last_env = nil
|
|
||||||
|
|
||||||
-- Get an environment for execution
|
-- Get an environment for execution
|
||||||
function __get_sandbox_env()
|
function __get_sandbox_env()
|
||||||
local env
|
local env
|
||||||
|
@ -484,19 +155,15 @@ func (s *Sandbox) initializeUnlocked() error {
|
||||||
else
|
else
|
||||||
-- Create new environment with metatable inheritance
|
-- Create new environment with metatable inheritance
|
||||||
env = setmetatable({}, {
|
env = setmetatable({}, {
|
||||||
__index = _G -- Use global environment as fallback
|
__index = __env_system.base_env
|
||||||
})
|
})
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Store reference to current environment
|
|
||||||
__last_env = env
|
|
||||||
|
|
||||||
return env
|
return env
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Return environment to pool for reuse
|
-- Return environment to pool for reuse
|
||||||
function __recycle_env(env)
|
function __recycle_env(env)
|
||||||
-- Only recycle if pool isn't full
|
|
||||||
if __env_system.pool_size < __env_system.max_pool_size then
|
if __env_system.pool_size < __env_system.max_pool_size then
|
||||||
-- Clear all fields except metatable
|
-- Clear all fields except metatable
|
||||||
for k in pairs(env) do
|
for k in pairs(env) do
|
||||||
|
@ -520,6 +187,13 @@ func (s *Sandbox) initializeUnlocked() error {
|
||||||
-- Execute with protected call
|
-- Execute with protected call
|
||||||
local success, result = pcall(f)
|
local success, result = pcall(f)
|
||||||
|
|
||||||
|
-- Update base environment with new globals
|
||||||
|
for k, v in pairs(env) do
|
||||||
|
if k ~= "_G" and type(k) == "string" then
|
||||||
|
__env_system.base_env[k] = v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
-- Recycle environment
|
-- Recycle environment
|
||||||
__recycle_env(env)
|
__recycle_env(env)
|
||||||
|
|
||||||
|
@ -528,15 +202,6 @@ func (s *Sandbox) initializeUnlocked() error {
|
||||||
error(result, 0)
|
error(result, 0)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Handle multiple return values
|
|
||||||
if type(result) == "table" and result.n ~= nil then
|
|
||||||
local returnValues = {}
|
|
||||||
for i=1, result.n do
|
|
||||||
returnValues[i] = result[i]
|
|
||||||
end
|
|
||||||
return returnValues
|
|
||||||
end
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
`)
|
`)
|
||||||
|
@ -549,8 +214,290 @@ func (s *Sandbox) initializeUnlocked() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterFunction registers a Go function in the sandbox
|
||||||
|
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.state == nil {
|
||||||
|
return fmt.Errorf("sandbox is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize if needed
|
||||||
|
if !s.initialized {
|
||||||
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register function globally
|
||||||
|
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store function for re-registration
|
||||||
|
s.functions[name] = fn
|
||||||
|
|
||||||
|
// Add to base environment
|
||||||
|
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGlobal sets a global variable in the sandbox base environment
|
||||||
|
func (s *Sandbox) SetGlobal(name string, value any) error {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.state == nil {
|
||||||
|
return fmt.Errorf("sandbox is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize if needed
|
||||||
|
if !s.initialized {
|
||||||
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push the value onto the stack
|
||||||
|
if err := s.state.PushValue(value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the global with the pushed value
|
||||||
|
s.state.SetGlobal(name)
|
||||||
|
|
||||||
|
// Add to base environment
|
||||||
|
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobal retrieves a global variable from the sandbox base environment
|
||||||
|
func (s *Sandbox) GetGlobal(name string) (any, error) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.state == nil {
|
||||||
|
return nil, fmt.Errorf("sandbox is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize if needed
|
||||||
|
if !s.initialized {
|
||||||
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the global from the base environment
|
||||||
|
return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run executes Lua code in the sandbox
|
||||||
|
func (s *Sandbox) Run(code string) (any, error) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.state == nil {
|
||||||
|
return nil, fmt.Errorf("sandbox is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize if needed
|
||||||
|
if !s.initialized {
|
||||||
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simplified wrapper for multiple return values
|
||||||
|
wrappedCode := `
|
||||||
|
local function _execfunc()
|
||||||
|
` + code + `
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Process results to match expected format
|
||||||
|
local function _wrapresults(...)
|
||||||
|
local n = select('#', ...)
|
||||||
|
if n == 0 then
|
||||||
|
return nil
|
||||||
|
elseif n == 1 then
|
||||||
|
return select(1, ...)
|
||||||
|
else
|
||||||
|
local results = {}
|
||||||
|
for i = 1, n do
|
||||||
|
results[i] = select(i, ...)
|
||||||
|
end
|
||||||
|
return results
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return _wrapresults(_execfunc())
|
||||||
|
`
|
||||||
|
|
||||||
|
// Compile the code
|
||||||
|
if err := s.state.LoadString(wrappedCode); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the sandbox executor
|
||||||
|
s.state.GetGlobal("__execute_sandbox")
|
||||||
|
|
||||||
|
// Push the function as argument
|
||||||
|
s.state.PushCopy(-2)
|
||||||
|
s.state.Remove(-3)
|
||||||
|
|
||||||
|
// Execute in sandbox
|
||||||
|
if err := s.state.Call(1, 1); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get result
|
||||||
|
result, err := s.state.ToValue(-1)
|
||||||
|
s.state.Pop(1)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.processResult(result), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunFile executes a Lua file in the sandbox
|
||||||
|
func (s *Sandbox) RunFile(filename string) error {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.state == nil {
|
||||||
|
return fmt.Errorf("sandbox is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.state.DoFile(filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile compiles Lua code to bytecode
|
||||||
|
func (s *Sandbox) Compile(code string) ([]byte, error) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.state == nil {
|
||||||
|
return nil, fmt.Errorf("sandbox is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.state.CompileBytecode(code, "sandbox")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunBytecode executes precompiled Lua bytecode
|
||||||
|
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.state == nil {
|
||||||
|
return nil, fmt.Errorf("sandbox is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize if needed
|
||||||
|
if !s.initialized {
|
||||||
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the bytecode
|
||||||
|
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the sandbox executor
|
||||||
|
s.state.GetGlobal("__execute_sandbox")
|
||||||
|
|
||||||
|
// Push bytecode function
|
||||||
|
s.state.PushCopy(-2)
|
||||||
|
s.state.Remove(-3)
|
||||||
|
|
||||||
|
// Execute in sandbox
|
||||||
|
if err := s.state.Call(1, 1); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get result
|
||||||
|
result, err := s.state.ToValue(-1)
|
||||||
|
s.state.Pop(1)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.processResult(result), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadModule loads a Lua module
|
||||||
|
func (s *Sandbox) LoadModule(name string) error {
|
||||||
|
code := fmt.Sprintf("require('%s')", name)
|
||||||
|
_, err := s.Run(code)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPackagePath sets the sandbox package.path
|
||||||
|
func (s *Sandbox) SetPackagePath(path string) error {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.state == nil {
|
||||||
|
return fmt.Errorf("sandbox is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update global package.path
|
||||||
|
if err := s.state.SetPackagePath(path); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize if needed
|
||||||
|
if !s.initialized {
|
||||||
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update base environment's package.path
|
||||||
|
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPackagePath adds a path to the sandbox package.path
|
||||||
|
func (s *Sandbox) AddPackagePath(path string) error {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.state == nil {
|
||||||
|
return fmt.Errorf("sandbox is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update global package.path
|
||||||
|
if err := s.state.AddPackagePath(path); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize if needed
|
||||||
|
if !s.initialized {
|
||||||
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update base environment's package.path
|
||||||
|
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddModule adds a module to the sandbox environment
|
||||||
|
func (s *Sandbox) AddModule(name string, module any) error {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.state == nil {
|
||||||
|
return fmt.Errorf("sandbox is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.modules[name] = module
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddPermanentLua adds Lua code to the environment permanently
|
// AddPermanentLua adds Lua code to the environment permanently
|
||||||
// This code becomes part of the base environment
|
|
||||||
func (s *Sandbox) AddPermanentLua(code string) error {
|
func (s *Sandbox) AddPermanentLua(code string) error {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
@ -559,33 +506,25 @@ func (s *Sandbox) AddPermanentLua(code string) error {
|
||||||
return fmt.Errorf("sandbox is closed")
|
return fmt.Errorf("sandbox is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure sandbox is initialized
|
// Initialize if needed
|
||||||
if !s.initialized {
|
if !s.initialized {
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add code to base environment
|
// Simplified approach to add code to base environment
|
||||||
return s.state.DoString(`
|
return s.state.DoString(`
|
||||||
-- First compile the code
|
|
||||||
local f, err = loadstring([=[` + code + `]=], "permanent")
|
local f, err = loadstring([=[` + code + `]=], "permanent")
|
||||||
if not f then
|
if not f then error(err, 0) end
|
||||||
error(err, 0)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Create a temporary environment based on base env
|
local env = setmetatable({}, {__index = __env_system.base_env})
|
||||||
local temp_env = setmetatable({}, {__index = __env_system.base_env})
|
setfenv(f, env)
|
||||||
setfenv(f, temp_env)
|
|
||||||
|
|
||||||
-- Run the code in the temporary environment
|
|
||||||
local ok, err = pcall(f)
|
local ok, err = pcall(f)
|
||||||
if not ok then
|
if not ok then error(err, 0) end
|
||||||
error(err, 0)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Copy new values to base environment
|
for k, v in pairs(env) do
|
||||||
for k, v in pairs(temp_env) do
|
|
||||||
__env_system.base_env[k] = v
|
__env_system.base_env[k] = v
|
||||||
end
|
end
|
||||||
`)
|
`)
|
||||||
|
@ -600,7 +539,71 @@ func (s *Sandbox) ResetEnvironment() error {
|
||||||
return fmt.Errorf("sandbox is closed")
|
return fmt.Errorf("sandbox is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reinitialize the environment system
|
// Clear the environment system
|
||||||
|
s.state.DoString(`__env_system = nil`)
|
||||||
|
|
||||||
|
// Reinitialize
|
||||||
s.initialized = false
|
s.initialized = false
|
||||||
return s.Initialize()
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-register all functions
|
||||||
|
for name, fn := range s.functions {
|
||||||
|
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unwrapResult processes results from Lua executions
|
||||||
|
func (s *Sandbox) processResult(result any) any {
|
||||||
|
// Handle []float64 (common LuaJIT return type)
|
||||||
|
if floats, ok := result.([]float64); ok {
|
||||||
|
if len(floats) == 1 {
|
||||||
|
// Single number - return as float64
|
||||||
|
return floats[0]
|
||||||
|
}
|
||||||
|
// Multiple numbers - MUST convert to []any for tests to pass
|
||||||
|
anySlice := make([]any, len(floats))
|
||||||
|
for i, v := range floats {
|
||||||
|
anySlice[i] = v
|
||||||
|
}
|
||||||
|
return anySlice
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle maps with numeric keys (Lua tables)
|
||||||
|
if m, ok := result.(map[string]any); ok {
|
||||||
|
// Handle return tables with special structure
|
||||||
|
if vals, ok := m[""]; ok {
|
||||||
|
// This is a special case used by some Lua returns
|
||||||
|
if arr, ok := vals.([]float64); ok {
|
||||||
|
// Convert to []any for consistency
|
||||||
|
anySlice := make([]any, len(arr))
|
||||||
|
for i, v := range arr {
|
||||||
|
anySlice[i] = v
|
||||||
|
}
|
||||||
|
return anySlice
|
||||||
|
}
|
||||||
|
return vals
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m) == 1 {
|
||||||
|
// Check for single value map
|
||||||
|
for k, v := range m {
|
||||||
|
if k == "1" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other array types should be preserved
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
121
table.go
121
table.go
|
@ -6,7 +6,8 @@ package luajit
|
||||||
#include <lauxlib.h>
|
#include <lauxlib.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
|
||||||
static size_t get_table_length(lua_State *L, int index) {
|
// Simple direct length check
|
||||||
|
size_t get_table_length(lua_State *L, int index) {
|
||||||
return lua_objlen(L, index);
|
return lua_objlen(L, index);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
@ -14,70 +15,53 @@ import "C"
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Use a pool to reduce GC pressure when handling many tables
|
|
||||||
var tablePool = sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return make(map[string]any)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTableLength returns the length of a table at the given index
|
// GetTableLength returns the length of a table at the given index
|
||||||
func (s *State) GetTableLength(index int) int {
|
func (s *State) GetTableLength(index int) int {
|
||||||
return int(C.get_table_length(s.L, C.int(index)))
|
return int(C.get_table_length(s.L, C.int(index)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// getTableFromPool gets a map from the pool and ensures it's empty
|
|
||||||
func getTableFromPool() map[string]any {
|
|
||||||
table := tablePool.Get().(map[string]any)
|
|
||||||
// Clear any existing entries
|
|
||||||
for k := range table {
|
|
||||||
delete(table, k)
|
|
||||||
}
|
|
||||||
return table
|
|
||||||
}
|
|
||||||
|
|
||||||
// putTableToPool returns a map to the pool
|
|
||||||
func putTableToPool(table map[string]any) {
|
|
||||||
tablePool.Put(table)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PushTable pushes a Go map onto the Lua stack as a table
|
// PushTable pushes a Go map onto the Lua stack as a table
|
||||||
func (s *State) PushTable(table map[string]any) error {
|
func (s *State) PushTable(table map[string]any) error {
|
||||||
// Create table with appropriate capacity hints
|
// Fast path for array tables
|
||||||
s.CreateTable(0, len(table))
|
if arr, ok := table[""]; ok {
|
||||||
|
if floatArr, ok := arr.([]float64); ok {
|
||||||
// Add each key-value pair
|
s.CreateTable(len(floatArr), 0)
|
||||||
for k, v := range table {
|
for i, v := range floatArr {
|
||||||
// Push key
|
s.PushNumber(float64(i + 1))
|
||||||
s.PushString(k)
|
s.PushNumber(v)
|
||||||
|
s.SetTable(-3)
|
||||||
// Push value
|
}
|
||||||
|
return nil
|
||||||
|
} else if anyArr, ok := arr.([]any); ok {
|
||||||
|
s.CreateTable(len(anyArr), 0)
|
||||||
|
for i, v := range anyArr {
|
||||||
|
s.PushNumber(float64(i + 1))
|
||||||
if err := s.PushValue(v); err != nil {
|
if err := s.PushValue(v); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
s.SetTable(-3)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// t[k] = v
|
// Regular table case - optimize capacity hint
|
||||||
|
s.CreateTable(0, len(table))
|
||||||
|
|
||||||
|
// Add each key-value pair directly
|
||||||
|
for k, v := range table {
|
||||||
|
s.PushString(k)
|
||||||
|
if err := s.PushValue(v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
s.SetTable(-3)
|
s.SetTable(-3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return pooled tables to the pool
|
|
||||||
if isPooledTable(table) {
|
|
||||||
putTableToPool(table)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPooledTable detects if a table came from our pool
|
|
||||||
func isPooledTable(table map[string]any) bool {
|
|
||||||
// Check for our special marker - used for array tables in the pool
|
|
||||||
_, hasEmptyKey := table[""]
|
|
||||||
return len(table) == 1 && hasEmptyKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToTable converts a Lua table at the given index to a Go map
|
// ToTable converts a Lua table at the given index to a Go map
|
||||||
func (s *State) ToTable(index int) (map[string]any, error) {
|
func (s *State) ToTable(index int) (map[string]any, error) {
|
||||||
absIdx := s.absIndex(index)
|
absIdx := s.absIndex(index)
|
||||||
|
@ -88,34 +72,41 @@ func (s *State) ToTable(index int) (map[string]any, error) {
|
||||||
// Try to detect array-like tables first
|
// Try to detect array-like tables first
|
||||||
length := s.GetTableLength(absIdx)
|
length := s.GetTableLength(absIdx)
|
||||||
if length > 0 {
|
if length > 0 {
|
||||||
// Check if this is an array-like table
|
// Fast path for common array case
|
||||||
isArray := true
|
allNumbers := true
|
||||||
array := make([]float64, length)
|
|
||||||
|
|
||||||
for i := 1; i <= length; i++ {
|
// Sample first few values to check if it's likely an array of numbers
|
||||||
|
for i := 1; i <= min(length, 5); i++ {
|
||||||
s.PushNumber(float64(i))
|
s.PushNumber(float64(i))
|
||||||
s.GetTable(absIdx)
|
s.GetTable(absIdx)
|
||||||
|
|
||||||
if !s.IsNumber(-1) {
|
if !s.IsNumber(-1) {
|
||||||
isArray = false
|
allNumbers = false
|
||||||
s.Pop(1)
|
s.Pop(1)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
s.Pop(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if allNumbers {
|
||||||
|
// Efficiently extract array values
|
||||||
|
array := make([]float64, length)
|
||||||
|
for i := 1; i <= length; i++ {
|
||||||
|
s.PushNumber(float64(i))
|
||||||
|
s.GetTable(absIdx)
|
||||||
array[i-1] = s.ToNumber(-1)
|
array[i-1] = s.ToNumber(-1)
|
||||||
s.Pop(1)
|
s.Pop(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if isArray {
|
// Return array as a special table with empty key
|
||||||
// Return array as a special pooled table with empty key
|
result := make(map[string]any, 1)
|
||||||
result := getTableFromPool()
|
|
||||||
result[""] = array
|
result[""] = array
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle regular table
|
// Handle regular table with pre-allocated capacity
|
||||||
table := getTableFromPool()
|
table := make(map[string]any, max(length, 8))
|
||||||
|
|
||||||
// Iterate through all key-value pairs
|
// Iterate through all key-value pairs
|
||||||
s.PushNil() // Start iteration with nil key
|
s.PushNil() // Start iteration with nil key
|
||||||
|
@ -140,11 +131,10 @@ func (s *State) ToTable(index int) (map[string]any, error) {
|
||||||
value, err := s.ToValue(-1)
|
value, err := s.ToValue(-1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Pop(2) // Pop both key and value
|
s.Pop(2) // Pop both key and value
|
||||||
putTableToPool(table) // Return the table to the pool on error
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle nested array tables
|
// Unwrap nested array tables
|
||||||
if m, ok := value.(map[string]any); ok {
|
if m, ok := value.(map[string]any); ok {
|
||||||
if arr, ok := m[""]; ok {
|
if arr, ok := m[""]; ok {
|
||||||
value = arr
|
value = arr
|
||||||
|
@ -157,3 +147,18 @@ func (s *State) ToTable(index int) (map[string]any, error) {
|
||||||
|
|
||||||
return table, nil
|
return table, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper functions for min/max operations
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func max(a, b int) int {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
31
wrapper.go
31
wrapper.go
|
@ -11,7 +11,7 @@ package luajit
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
// Helper to simplify some common operations
|
// Optimized helpers for common operations
|
||||||
static int get_abs_index(lua_State *L, int idx) {
|
static int get_abs_index(lua_State *L, int idx) {
|
||||||
if (idx > 0 || idx <= LUA_REGISTRYINDEX) return idx;
|
if (idx > 0 || idx <= LUA_REGISTRYINDEX) return idx;
|
||||||
return lua_gettop(L) + idx + 1;
|
return lua_gettop(L) + idx + 1;
|
||||||
|
@ -39,9 +39,17 @@ import "C"
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Type pool for common objects to reduce GC pressure
|
||||||
|
var stringBufferPool = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return new(strings.Builder)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// State represents a Lua state
|
// State represents a Lua state
|
||||||
type State struct {
|
type State struct {
|
||||||
L *C.lua_State
|
L *C.lua_State
|
||||||
|
@ -78,7 +86,7 @@ func (s *State) SetTop(index int) {
|
||||||
C.lua_settop(s.L, C.int(index))
|
C.lua_settop(s.L, C.int(index))
|
||||||
}
|
}
|
||||||
|
|
||||||
// PushValue pushes a copy of the value at the given index onto the stack
|
// PushCopy pushes a copy of the value at the given index onto the stack
|
||||||
func (s *State) PushCopy(index int) {
|
func (s *State) PushCopy(index int) {
|
||||||
C.lua_pushvalue(s.L, C.int(index))
|
C.lua_pushvalue(s.L, C.int(index))
|
||||||
}
|
}
|
||||||
|
@ -183,9 +191,22 @@ func (s *State) PushNumber(n float64) {
|
||||||
|
|
||||||
// PushString pushes a string value onto the stack
|
// PushString pushes a string value onto the stack
|
||||||
func (s *State) PushString(str string) {
|
func (s *State) PushString(str string) {
|
||||||
|
// Use direct C string for short strings (avoid allocations)
|
||||||
|
if len(str) < 128 {
|
||||||
cstr := C.CString(str)
|
cstr := C.CString(str)
|
||||||
defer C.free(unsafe.Pointer(cstr))
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
C.lua_pushlstring(s.L, cstr, C.size_t(len(str)))
|
C.lua_pushlstring(s.L, cstr, C.size_t(len(str)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For longer strings, avoid double copy by using unsafe pointer
|
||||||
|
header := (*struct {
|
||||||
|
p unsafe.Pointer
|
||||||
|
len int
|
||||||
|
cap int
|
||||||
|
})(unsafe.Pointer(&str))
|
||||||
|
|
||||||
|
C.lua_pushlstring(s.L, (*C.char)(header.p), C.size_t(len(str)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Table operations
|
// Table operations
|
||||||
|
@ -406,13 +427,15 @@ func (s *State) ExecuteWithResult(code string) (any, error) {
|
||||||
// SetPackagePath sets the Lua package.path
|
// SetPackagePath sets the Lua package.path
|
||||||
func (s *State) SetPackagePath(path string) error {
|
func (s *State) SetPackagePath(path string) error {
|
||||||
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
|
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
|
||||||
return s.DoString(fmt.Sprintf(`package.path = %q`, path))
|
code := fmt.Sprintf(`package.path = %q`, path)
|
||||||
|
return s.DoString(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPackagePath adds a path to package.path
|
// AddPackagePath adds a path to package.path
|
||||||
func (s *State) AddPackagePath(path string) error {
|
func (s *State) AddPackagePath(path string) error {
|
||||||
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
|
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
|
||||||
return s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path))
|
code := fmt.Sprintf(`package.path = package.path .. ";%s"`, path)
|
||||||
|
return s.DoString(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
|
|
Loading…
Reference in New Issue
Block a user