From 875abee366fc7766c010a0ee9eec918216f57543 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 27 Mar 2025 21:58:56 -0500 Subject: [PATCH] optimize sandbox --- sandbox.go | 862 ++++++++++++++++++++++++----------------------------- 1 file changed, 393 insertions(+), 469 deletions(-) diff --git a/sandbox.go b/sandbox.go index ce093de..3dd352c 100644 --- a/sandbox.go +++ b/sandbox.go @@ -38,331 +38,21 @@ func (s *Sandbox) Close() { } } -// RegisterFunction registers a Go function in the sandbox -func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - // Make sure sandbox is initialized - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return err - } - } - - // Register function globally - if err := s.state.RegisterGoFunction(name, fn); err != nil { - return err - } - - // Store function for re-registration - s.functions[name] = fn - - // Add to base environment - return s.state.DoString(` - -- Add the function to base environment - __env_system.base_env["` + name + `"] = ` + name + ` - `) -} - -// SetGlobal sets a global variable in the sandbox base environment -func (s *Sandbox) SetGlobal(name string, value any) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - // Make sure sandbox is initialized - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return err - } - } - - // Push the value onto the stack - if err := s.state.PushValue(value); err != nil { - return err - } - - // Set the global with the pushed value - s.state.SetGlobal(name) - - // Add to base environment - return s.state.DoString(` - -- Add the global to base environment - __env_system.base_env["` + name + `"] = ` + name + ` - `) -} - -// GetGlobal retrieves a global variable from the sandbox base environment -func (s *Sandbox) GetGlobal(name string) (any, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return nil, fmt.Errorf("sandbox is closed") - } - - // Make sure sandbox is initialized - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return nil, err - } - } - - // Get the global from the base environment - code := `return __env_system.base_env["` + name + `"]` - return s.state.ExecuteWithResult(code) -} - -// Run executes Lua code in the sandbox and returns the result -func (s *Sandbox) Run(code string) (any, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return nil, fmt.Errorf("sandbox is closed") - } - - // Make sure sandbox is initialized - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return nil, err - } - } - - // Add wrapper for multiple return values - wrappedCode := ` - local function _execfunc() - ` + code + ` - end - - local function _wrapresults(...) - local results = {n = select('#', ...)} - for i = 1, results.n do - results[i] = select(i, ...) - end - return results - end - - return _wrapresults(_execfunc()) - ` - - // Compile the code - if err := s.state.LoadString(wrappedCode); err != nil { - return nil, err - } - - // Get the sandbox executor - s.state.GetGlobal("__execute_sandbox") - - // Setup call with correct argument order - s.state.PushCopy(-2) // Copy the function - - // Remove the original function - s.state.Remove(-3) - - // Execute in sandbox - if err := s.state.Call(1, 1); err != nil { - return nil, err - } - - // Get result - result, err := s.state.ToValue(-1) - s.state.Pop(1) - - return s.unwrapResult(result, err) -} - -// RunFile executes a Lua file in the sandbox -func (s *Sandbox) RunFile(filename string) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - return s.state.DoFile(filename) -} - -// Compile compiles Lua code to bytecode without executing it -func (s *Sandbox) Compile(code string) ([]byte, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return nil, fmt.Errorf("sandbox is closed") - } - - return s.state.CompileBytecode(code, "sandbox") -} - -// RunBytecode executes precompiled Lua bytecode in the sandbox -func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return nil, fmt.Errorf("sandbox is closed") - } - - // Make sure sandbox is initialized - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return nil, err - } - } - - // Load the bytecode - if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil { - return nil, err - } - - // Get the sandbox executor - s.state.GetGlobal("__execute_sandbox") - - // Push bytecode function - s.state.PushCopy(-2) - - // Remove original bytecode function - s.state.Remove(-3) - - // Execute in sandbox - if err := s.state.Call(1, 1); err != nil { - return nil, err - } - - // Get result - result, err := s.state.ToValue(-1) - s.state.Pop(1) - - return s.unwrapResult(result, err) -} - -// getResults collects results from the stack (must be called with mutex locked) -func (s *Sandbox) getResults() (any, error) { - numResults := s.state.GetTop() - if numResults == 0 { - return nil, nil - } else if numResults == 1 { - // Return single result directly - value, err := s.state.ToValue(-1) - s.state.Pop(1) - return value, err - } - - // Return multiple results as slice - results := make([]any, numResults) - for i := 0; i < numResults; i++ { - value, err := s.state.ToValue(i - numResults) - if err != nil { - s.state.Pop(numResults) - return nil, err - } - results[i] = value - } - s.state.Pop(numResults) - return results, nil -} - -// LoadModule loads a Lua module in the sandbox -func (s *Sandbox) LoadModule(name string) error { - code := fmt.Sprintf("require('%s')", name) - _, err := s.Run(code) - return err -} - -// SetPackagePath sets the sandbox package.path -func (s *Sandbox) SetPackagePath(path string) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - // Update global package.path - if err := s.state.SetPackagePath(path); err != nil { - return err - } - - // Make sure sandbox is initialized - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return err - } - } - - // Update base environment's package.path - return s.state.DoString(` - __env_system.base_env.package.path = package.path - `) -} - -// AddPackagePath adds a path to the sandbox package.path -func (s *Sandbox) AddPackagePath(path string) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - // Update global package.path - if err := s.state.AddPackagePath(path); err != nil { - return err - } - - // Make sure sandbox is initialized - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return err - } - } - - // Update base environment's package.path - return s.state.DoString(` - __env_system.base_env.package.path = package.path - `) -} - -// AddModule adds a module to the sandbox environment -func (s *Sandbox) AddModule(name string, module any) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - s.modules[name] = module - return nil -} - // Initialize sets up the environment system func (s *Sandbox) Initialize() error { s.mutex.Lock() defer s.mutex.Unlock() - return s.initializeUnlocked() } // initializeUnlocked sets up the environment system without locking -// It should only be called when the mutex is already locked func (s *Sandbox) initializeUnlocked() error { if s.state == nil { return fmt.Errorf("sandbox is closed") } if s.initialized { - return nil // Already initialized + return nil } // Register modules @@ -385,85 +75,74 @@ func (s *Sandbox) initializeUnlocked() error { } s.state.Pop(1) - // Create the environment system + // Create simplified environment system err := s.state.DoString(` - -- Global shared environment (created once) - __env_system = __env_system or { - base_env = nil, -- Template environment - initialized = false, -- Initialization flag - env_pool = {}, -- Pre-allocated environment pool - pool_size = 0, -- Current pool size - max_pool_size = 8 -- Maximum pool size + -- Global shared environment + __env_system = { + base_env = {}, -- Template environment + env_pool = {}, -- Pre-allocated environment pool + pool_size = 0, -- Current pool size + max_pool_size = 8 -- Maximum pool size } - -- Initialize base environment once - if not __env_system.initialized then - -- Create base environment with all standard libraries - local base = {} + -- Create base environment with standard libraries + local base = __env_system.base_env - -- Safe standard libraries - base.string = string - base.table = table - base.math = math - base.os = { - time = os.time, - date = os.date, - difftime = os.difftime, - clock = os.clock - } + -- Safe standard libraries + base.string = string + base.table = table + base.math = math + base.os = { + time = os.time, + date = os.date, + difftime = os.difftime, + clock = os.clock + } - -- Basic functions - base.print = print - base.tonumber = tonumber - base.tostring = tostring - base.type = type - base.pairs = pairs - base.ipairs = ipairs - base.next = next - base.select = select - base.pcall = pcall - base.xpcall = xpcall - base.error = error - base.assert = assert - base.collectgarbage = collectgarbage - base.unpack = unpack or table.unpack + -- Basic functions + base.print = print + base.tonumber = tonumber + base.tostring = tostring + base.type = type + base.pairs = pairs + base.ipairs = ipairs + base.next = next + base.select = select + base.pcall = pcall + base.xpcall = xpcall + base.error = error + base.assert = assert + base.collectgarbage = collectgarbage + base.unpack = unpack or table.unpack - -- Package system - base.package = { - loaded = {}, - path = package.path, - preload = {} - } + -- Package system + base.package = { + loaded = {}, + path = package.path, + preload = {} + } - base.require = function(modname) - if base.package.loaded[modname] then - return base.package.loaded[modname] - end - - local loader = base.package.preload[modname] - if type(loader) == "function" then - local result = loader(modname) - base.package.loaded[modname] = result or true - return result - end - - error("module '" .. modname .. "' not found", 2) + base.require = function(modname) + if base.package.loaded[modname] then + return base.package.loaded[modname] end - -- Add registered custom modules - if __sandbox_modules then - for name, mod in pairs(__sandbox_modules) do - base[name] = mod - end + local loader = base.package.preload[modname] + if type(loader) == "function" then + local result = loader(modname) + base.package.loaded[modname] = result or true + return result end - -- Store base environment - __env_system.base_env = base - __env_system.initialized = true + error("module '" .. modname .. "' not found", 2) end - -- Global variable for tracking current environment - __last_env = nil + -- Add registered custom modules + if __sandbox_modules then + for name, mod in pairs(__sandbox_modules) do + base[name] = mod + end + end -- Get an environment for execution function __get_sandbox_env() @@ -476,19 +155,15 @@ func (s *Sandbox) initializeUnlocked() error { else -- Create new environment with metatable inheritance env = setmetatable({}, { - __index = __env_system.base_env -- Use base env instead of _G + __index = __env_system.base_env }) end - -- Store reference to current environment - __last_env = env - return env end -- Return environment to pool for reuse function __recycle_env(env) - -- Only recycle if pool isn't full if __env_system.pool_size < __env_system.max_pool_size then -- Clear all fields except metatable for k in pairs(env) do @@ -512,7 +187,7 @@ func (s *Sandbox) initializeUnlocked() error { -- Execute with protected call local success, result = pcall(f) - -- Copy all globals to base environment + -- Update base environment with new globals for k, v in pairs(env) do if k ~= "_G" and type(k) == "string" then __env_system.base_env[k] = v @@ -526,16 +201,7 @@ func (s *Sandbox) initializeUnlocked() error { if not success then error(result, 0) end - - -- Handle multiple return values - if type(result) == "table" and result.n ~= nil then - local returnValues = {} - for i=1, result.n do - returnValues[i] = result[i] - end - return returnValues - end - + return result end `) @@ -548,8 +214,290 @@ func (s *Sandbox) initializeUnlocked() error { return nil } +// RegisterFunction registers a Go function in the sandbox +func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + // Initialize if needed + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return err + } + } + + // Register function globally + if err := s.state.RegisterGoFunction(name, fn); err != nil { + return err + } + + // Store function for re-registration + s.functions[name] = fn + + // Add to base environment + return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name) +} + +// SetGlobal sets a global variable in the sandbox base environment +func (s *Sandbox) SetGlobal(name string, value any) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + // Initialize if needed + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return err + } + } + + // Push the value onto the stack + if err := s.state.PushValue(value); err != nil { + return err + } + + // Set the global with the pushed value + s.state.SetGlobal(name) + + // Add to base environment + return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name) +} + +// GetGlobal retrieves a global variable from the sandbox base environment +func (s *Sandbox) GetGlobal(name string) (any, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return nil, fmt.Errorf("sandbox is closed") + } + + // Initialize if needed + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return nil, err + } + } + + // Get the global from the base environment + return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`) +} + +// Run executes Lua code in the sandbox +func (s *Sandbox) Run(code string) (any, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return nil, fmt.Errorf("sandbox is closed") + } + + // Initialize if needed + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return nil, err + } + } + + // Simplified wrapper for multiple return values + wrappedCode := ` + local function _execfunc() + ` + code + ` + end + + -- Process results to match expected format + local function _wrapresults(...) + local n = select('#', ...) + if n == 0 then + return nil + elseif n == 1 then + return select(1, ...) + else + local results = {} + for i = 1, n do + results[i] = select(i, ...) + end + return results + end + end + + return _wrapresults(_execfunc()) + ` + + // Compile the code + if err := s.state.LoadString(wrappedCode); err != nil { + return nil, err + } + + // Get the sandbox executor + s.state.GetGlobal("__execute_sandbox") + + // Push the function as argument + s.state.PushCopy(-2) + s.state.Remove(-3) + + // Execute in sandbox + if err := s.state.Call(1, 1); err != nil { + return nil, err + } + + // Get result + result, err := s.state.ToValue(-1) + s.state.Pop(1) + + if err != nil { + return nil, err + } + + return s.processResult(result), nil +} + +// RunFile executes a Lua file in the sandbox +func (s *Sandbox) RunFile(filename string) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + return s.state.DoFile(filename) +} + +// Compile compiles Lua code to bytecode +func (s *Sandbox) Compile(code string) ([]byte, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return nil, fmt.Errorf("sandbox is closed") + } + + return s.state.CompileBytecode(code, "sandbox") +} + +// RunBytecode executes precompiled Lua bytecode +func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return nil, fmt.Errorf("sandbox is closed") + } + + // Initialize if needed + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return nil, err + } + } + + // Load the bytecode + if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil { + return nil, err + } + + // Get the sandbox executor + s.state.GetGlobal("__execute_sandbox") + + // Push bytecode function + s.state.PushCopy(-2) + s.state.Remove(-3) + + // Execute in sandbox + if err := s.state.Call(1, 1); err != nil { + return nil, err + } + + // Get result + result, err := s.state.ToValue(-1) + s.state.Pop(1) + + if err != nil { + return nil, err + } + + return s.processResult(result), nil +} + +// LoadModule loads a Lua module +func (s *Sandbox) LoadModule(name string) error { + code := fmt.Sprintf("require('%s')", name) + _, err := s.Run(code) + return err +} + +// SetPackagePath sets the sandbox package.path +func (s *Sandbox) SetPackagePath(path string) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + // Update global package.path + if err := s.state.SetPackagePath(path); err != nil { + return err + } + + // Initialize if needed + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return err + } + } + + // Update base environment's package.path + return s.state.DoString(`__env_system.base_env.package.path = package.path`) +} + +// AddPackagePath adds a path to the sandbox package.path +func (s *Sandbox) AddPackagePath(path string) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + // Update global package.path + if err := s.state.AddPackagePath(path); err != nil { + return err + } + + // Initialize if needed + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return err + } + } + + // Update base environment's package.path + return s.state.DoString(`__env_system.base_env.package.path = package.path`) +} + +// AddModule adds a module to the sandbox environment +func (s *Sandbox) AddModule(name string, module any) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + s.modules[name] = module + return nil +} + // AddPermanentLua adds Lua code to the environment permanently -// This code becomes part of the base environment func (s *Sandbox) AddPermanentLua(code string) error { s.mutex.Lock() defer s.mutex.Unlock() @@ -558,33 +506,25 @@ func (s *Sandbox) AddPermanentLua(code string) error { return fmt.Errorf("sandbox is closed") } - // Make sure sandbox is initialized + // Initialize if needed if !s.initialized { if err := s.initializeUnlocked(); err != nil { return err } } - // Add code to base environment + // Simplified approach to add code to base environment return s.state.DoString(` - -- First compile the code local f, err = loadstring([=[` + code + `]=], "permanent") - if not f then - error(err, 0) - end - - -- Create a temporary environment based on base env - local temp_env = setmetatable({}, {__index = __env_system.base_env}) - setfenv(f, temp_env) - - -- Run the code in the temporary environment + if not f then error(err, 0) end + + local env = setmetatable({}, {__index = __env_system.base_env}) + setfenv(f, env) + local ok, err = pcall(f) - if not ok then - error(err, 0) - end - - -- Copy new values to base environment - for k, v in pairs(temp_env) do + if not ok then error(err, 0) end + + for k, v in pairs(env) do __env_system.base_env[k] = v end `) @@ -599,18 +539,10 @@ func (s *Sandbox) ResetEnvironment() error { return fmt.Errorf("sandbox is closed") } - // Clear the environment system completely - err := s.state.DoString(` - -- Reset environment system - __env_system = nil - __wrap_bytecode = nil - __last_env = nil - `) - if err != nil { - return err - } + // Clear the environment system + s.state.DoString(`__env_system = nil`) - // Reinitialize the environment system + // Reinitialize s.initialized = false if err := s.initializeUnlocked(); err != nil { return err @@ -622,10 +554,7 @@ func (s *Sandbox) ResetEnvironment() error { return err } - if err := s.state.DoString(` - -- Add the function to base environment - __env_system.base_env["` + name + `"] = ` + name + ` - `); err != nil { + if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil { return err } } @@ -633,53 +562,48 @@ func (s *Sandbox) ResetEnvironment() error { return nil } -// unwrapResult processes the raw result value from Lua -// and unwraps single values from special map structures -func (s *Sandbox) unwrapResult(result any, err error) (any, error) { - // Unwrap array stored in map with empty key - if m, ok := result.(map[string]any); ok { - // Check for special array format - if arr, ok := m[""]; ok { - // If the array has only one element, return that element - if slice, ok := arr.([]float64); ok { - if len(slice) == 1 { - return slice[0], err - } - // Convert []float64 to []any for consistency - anySlice := make([]any, len(slice)) - for i, v := range slice { - anySlice[i] = v - } - return anySlice, err - } - if slice, ok := arr.([]any); ok && len(slice) == 1 { - return slice[0], err - } - result = arr - } else if len(m) == 1 { - // When there's exactly one item, return its value directly - for _, v := range m { - return v, err - } +// unwrapResult processes results from Lua executions +func (s *Sandbox) processResult(result any) any { + // Handle []float64 (common LuaJIT return type) + if floats, ok := result.([]float64); ok { + if len(floats) == 1 { + // Single number - return as float64 + return floats[0] } - } - - // Convert []float64 to []any for consistency with multiple returns - if slice, ok := result.([]float64); ok { - if len(slice) == 1 { - return slice[0], err - } - anySlice := make([]any, len(slice)) - for i, v := range slice { + // Multiple numbers - MUST convert to []any for tests to pass + anySlice := make([]any, len(floats)) + for i, v := range floats { anySlice[i] = v } - return anySlice, err + return anySlice } - // Handle multiple return values - if results, ok := result.([]any); ok && len(results) == 1 { - return results[0], err + // Handle maps with numeric keys (Lua tables) + if m, ok := result.(map[string]any); ok { + // Handle return tables with special structure + if vals, ok := m[""]; ok { + // This is a special case used by some Lua returns + if arr, ok := vals.([]float64); ok { + // Convert to []any for consistency + anySlice := make([]any, len(arr)) + for i, v := range arr { + anySlice[i] = v + } + return anySlice + } + return vals + } + + if len(m) == 1 { + // Check for single value map + for k, v := range m { + if k == "1" { + return v + } + } + } } - return result, err + // Other array types should be preserved + return result }