optimize sandbox

This commit is contained in:
Sky Johnson 2025-03-27 21:58:56 -05:00
parent 4ad87f81f3
commit 875abee366

View File

@ -38,331 +38,21 @@ func (s *Sandbox) Close() {
} }
} }
// RegisterFunction registers a Go function in the sandbox
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Register function globally
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
// Store function for re-registration
s.functions[name] = fn
// Add to base environment
return s.state.DoString(`
-- Add the function to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`)
}
// SetGlobal sets a global variable in the sandbox base environment
func (s *Sandbox) SetGlobal(name string, value any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Push the value onto the stack
if err := s.state.PushValue(value); err != nil {
return err
}
// Set the global with the pushed value
s.state.SetGlobal(name)
// Add to base environment
return s.state.DoString(`
-- Add the global to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`)
}
// GetGlobal retrieves a global variable from the sandbox base environment
func (s *Sandbox) GetGlobal(name string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Get the global from the base environment
code := `return __env_system.base_env["` + name + `"]`
return s.state.ExecuteWithResult(code)
}
// Run executes Lua code in the sandbox and returns the result
func (s *Sandbox) Run(code string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Add wrapper for multiple return values
wrappedCode := `
local function _execfunc()
` + code + `
end
local function _wrapresults(...)
local results = {n = select('#', ...)}
for i = 1, results.n do
results[i] = select(i, ...)
end
return results
end
return _wrapresults(_execfunc())
`
// Compile the code
if err := s.state.LoadString(wrappedCode); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Setup call with correct argument order
s.state.PushCopy(-2) // Copy the function
// Remove the original function
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
return s.unwrapResult(result, err)
}
// RunFile executes a Lua file in the sandbox
func (s *Sandbox) RunFile(filename string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.DoFile(filename)
}
// Compile compiles Lua code to bytecode without executing it
func (s *Sandbox) Compile(code string) ([]byte, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
return s.state.CompileBytecode(code, "sandbox")
}
// RunBytecode executes precompiled Lua bytecode in the sandbox
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Load the bytecode
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push bytecode function
s.state.PushCopy(-2)
// Remove original bytecode function
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
return s.unwrapResult(result, err)
}
// getResults collects results from the stack (must be called with mutex locked)
func (s *Sandbox) getResults() (any, error) {
numResults := s.state.GetTop()
if numResults == 0 {
return nil, nil
} else if numResults == 1 {
// Return single result directly
value, err := s.state.ToValue(-1)
s.state.Pop(1)
return value, err
}
// Return multiple results as slice
results := make([]any, numResults)
for i := 0; i < numResults; i++ {
value, err := s.state.ToValue(i - numResults)
if err != nil {
s.state.Pop(numResults)
return nil, err
}
results[i] = value
}
s.state.Pop(numResults)
return results, nil
}
// LoadModule loads a Lua module in the sandbox
func (s *Sandbox) LoadModule(name string) error {
code := fmt.Sprintf("require('%s')", name)
_, err := s.Run(code)
return err
}
// SetPackagePath sets the sandbox package.path
func (s *Sandbox) SetPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.SetPackagePath(path); err != nil {
return err
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`
__env_system.base_env.package.path = package.path
`)
}
// AddPackagePath adds a path to the sandbox package.path
func (s *Sandbox) AddPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.AddPackagePath(path); err != nil {
return err
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`
__env_system.base_env.package.path = package.path
`)
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
s.modules[name] = module
return nil
}
// Initialize sets up the environment system // Initialize sets up the environment system
func (s *Sandbox) Initialize() error { func (s *Sandbox) Initialize() error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
return s.initializeUnlocked() return s.initializeUnlocked()
} }
// initializeUnlocked sets up the environment system without locking // initializeUnlocked sets up the environment system without locking
// It should only be called when the mutex is already locked
func (s *Sandbox) initializeUnlocked() error { func (s *Sandbox) initializeUnlocked() error {
if s.state == nil { if s.state == nil {
return fmt.Errorf("sandbox is closed") return fmt.Errorf("sandbox is closed")
} }
if s.initialized { if s.initialized {
return nil // Already initialized return nil
} }
// Register modules // Register modules
@ -385,21 +75,18 @@ func (s *Sandbox) initializeUnlocked() error {
} }
s.state.Pop(1) s.state.Pop(1)
// Create the environment system // Create simplified environment system
err := s.state.DoString(` err := s.state.DoString(`
-- Global shared environment (created once) -- Global shared environment
__env_system = __env_system or { __env_system = {
base_env = nil, -- Template environment base_env = {}, -- Template environment
initialized = false, -- Initialization flag
env_pool = {}, -- Pre-allocated environment pool env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size max_pool_size = 8 -- Maximum pool size
} }
-- Initialize base environment once -- Create base environment with standard libraries
if not __env_system.initialized then local base = __env_system.base_env
-- Create base environment with all standard libraries
local base = {}
-- Safe standard libraries -- Safe standard libraries
base.string = string base.string = string
@ -457,14 +144,6 @@ func (s *Sandbox) initializeUnlocked() error {
end end
end end
-- Store base environment
__env_system.base_env = base
__env_system.initialized = true
end
-- Global variable for tracking current environment
__last_env = nil
-- Get an environment for execution -- Get an environment for execution
function __get_sandbox_env() function __get_sandbox_env()
local env local env
@ -476,19 +155,15 @@ func (s *Sandbox) initializeUnlocked() error {
else else
-- Create new environment with metatable inheritance -- Create new environment with metatable inheritance
env = setmetatable({}, { env = setmetatable({}, {
__index = __env_system.base_env -- Use base env instead of _G __index = __env_system.base_env
}) })
end end
-- Store reference to current environment
__last_env = env
return env return env
end end
-- Return environment to pool for reuse -- Return environment to pool for reuse
function __recycle_env(env) function __recycle_env(env)
-- Only recycle if pool isn't full
if __env_system.pool_size < __env_system.max_pool_size then if __env_system.pool_size < __env_system.max_pool_size then
-- Clear all fields except metatable -- Clear all fields except metatable
for k in pairs(env) do for k in pairs(env) do
@ -512,7 +187,7 @@ func (s *Sandbox) initializeUnlocked() error {
-- Execute with protected call -- Execute with protected call
local success, result = pcall(f) local success, result = pcall(f)
-- Copy all globals to base environment -- Update base environment with new globals
for k, v in pairs(env) do for k, v in pairs(env) do
if k ~= "_G" and type(k) == "string" then if k ~= "_G" and type(k) == "string" then
__env_system.base_env[k] = v __env_system.base_env[k] = v
@ -527,15 +202,6 @@ func (s *Sandbox) initializeUnlocked() error {
error(result, 0) error(result, 0)
end end
-- Handle multiple return values
if type(result) == "table" and result.n ~= nil then
local returnValues = {}
for i=1, result.n do
returnValues[i] = result[i]
end
return returnValues
end
return result return result
end end
`) `)
@ -548,8 +214,290 @@ func (s *Sandbox) initializeUnlocked() error {
return nil return nil
} }
// RegisterFunction registers a Go function in the sandbox
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Register function globally
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
// Store function for re-registration
s.functions[name] = fn
// Add to base environment
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
}
// SetGlobal sets a global variable in the sandbox base environment
func (s *Sandbox) SetGlobal(name string, value any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Push the value onto the stack
if err := s.state.PushValue(value); err != nil {
return err
}
// Set the global with the pushed value
s.state.SetGlobal(name)
// Add to base environment
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
}
// GetGlobal retrieves a global variable from the sandbox base environment
func (s *Sandbox) GetGlobal(name string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Get the global from the base environment
return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`)
}
// Run executes Lua code in the sandbox
func (s *Sandbox) Run(code string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Simplified wrapper for multiple return values
wrappedCode := `
local function _execfunc()
` + code + `
end
-- Process results to match expected format
local function _wrapresults(...)
local n = select('#', ...)
if n == 0 then
return nil
elseif n == 1 then
return select(1, ...)
else
local results = {}
for i = 1, n do
results[i] = select(i, ...)
end
return results
end
end
return _wrapresults(_execfunc())
`
// Compile the code
if err := s.state.LoadString(wrappedCode); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push the function as argument
s.state.PushCopy(-2)
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
if err != nil {
return nil, err
}
return s.processResult(result), nil
}
// RunFile executes a Lua file in the sandbox
func (s *Sandbox) RunFile(filename string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.DoFile(filename)
}
// Compile compiles Lua code to bytecode
func (s *Sandbox) Compile(code string) ([]byte, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
return s.state.CompileBytecode(code, "sandbox")
}
// RunBytecode executes precompiled Lua bytecode
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Load the bytecode
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push bytecode function
s.state.PushCopy(-2)
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
if err != nil {
return nil, err
}
return s.processResult(result), nil
}
// LoadModule loads a Lua module
func (s *Sandbox) LoadModule(name string) error {
code := fmt.Sprintf("require('%s')", name)
_, err := s.Run(code)
return err
}
// SetPackagePath sets the sandbox package.path
func (s *Sandbox) SetPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.SetPackagePath(path); err != nil {
return err
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
}
// AddPackagePath adds a path to the sandbox package.path
func (s *Sandbox) AddPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.AddPackagePath(path); err != nil {
return err
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
s.modules[name] = module
return nil
}
// AddPermanentLua adds Lua code to the environment permanently // AddPermanentLua adds Lua code to the environment permanently
// This code becomes part of the base environment
func (s *Sandbox) AddPermanentLua(code string) error { func (s *Sandbox) AddPermanentLua(code string) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
@ -558,33 +506,25 @@ func (s *Sandbox) AddPermanentLua(code string) error {
return fmt.Errorf("sandbox is closed") return fmt.Errorf("sandbox is closed")
} }
// Make sure sandbox is initialized // Initialize if needed
if !s.initialized { if !s.initialized {
if err := s.initializeUnlocked(); err != nil { if err := s.initializeUnlocked(); err != nil {
return err return err
} }
} }
// Add code to base environment // Simplified approach to add code to base environment
return s.state.DoString(` return s.state.DoString(`
-- First compile the code
local f, err = loadstring([=[` + code + `]=], "permanent") local f, err = loadstring([=[` + code + `]=], "permanent")
if not f then if not f then error(err, 0) end
error(err, 0)
end
-- Create a temporary environment based on base env local env = setmetatable({}, {__index = __env_system.base_env})
local temp_env = setmetatable({}, {__index = __env_system.base_env}) setfenv(f, env)
setfenv(f, temp_env)
-- Run the code in the temporary environment
local ok, err = pcall(f) local ok, err = pcall(f)
if not ok then if not ok then error(err, 0) end
error(err, 0)
end
-- Copy new values to base environment for k, v in pairs(env) do
for k, v in pairs(temp_env) do
__env_system.base_env[k] = v __env_system.base_env[k] = v
end end
`) `)
@ -599,18 +539,10 @@ func (s *Sandbox) ResetEnvironment() error {
return fmt.Errorf("sandbox is closed") return fmt.Errorf("sandbox is closed")
} }
// Clear the environment system completely // Clear the environment system
err := s.state.DoString(` s.state.DoString(`__env_system = nil`)
-- Reset environment system
__env_system = nil
__wrap_bytecode = nil
__last_env = nil
`)
if err != nil {
return err
}
// Reinitialize the environment system // Reinitialize
s.initialized = false s.initialized = false
if err := s.initializeUnlocked(); err != nil { if err := s.initializeUnlocked(); err != nil {
return err return err
@ -622,10 +554,7 @@ func (s *Sandbox) ResetEnvironment() error {
return err return err
} }
if err := s.state.DoString(` if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil {
-- Add the function to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`); err != nil {
return err return err
} }
} }
@ -633,53 +562,48 @@ func (s *Sandbox) ResetEnvironment() error {
return nil return nil
} }
// unwrapResult processes the raw result value from Lua // unwrapResult processes results from Lua executions
// and unwraps single values from special map structures func (s *Sandbox) processResult(result any) any {
func (s *Sandbox) unwrapResult(result any, err error) (any, error) { // Handle []float64 (common LuaJIT return type)
// Unwrap array stored in map with empty key if floats, ok := result.([]float64); ok {
if len(floats) == 1 {
// Single number - return as float64
return floats[0]
}
// Multiple numbers - MUST convert to []any for tests to pass
anySlice := make([]any, len(floats))
for i, v := range floats {
anySlice[i] = v
}
return anySlice
}
// Handle maps with numeric keys (Lua tables)
if m, ok := result.(map[string]any); ok { if m, ok := result.(map[string]any); ok {
// Check for special array format // Handle return tables with special structure
if arr, ok := m[""]; ok { if vals, ok := m[""]; ok {
// If the array has only one element, return that element // This is a special case used by some Lua returns
if slice, ok := arr.([]float64); ok { if arr, ok := vals.([]float64); ok {
if len(slice) == 1 { // Convert to []any for consistency
return slice[0], err anySlice := make([]any, len(arr))
} for i, v := range arr {
// Convert []float64 to []any for consistency
anySlice := make([]any, len(slice))
for i, v := range slice {
anySlice[i] = v anySlice[i] = v
} }
return anySlice, err return anySlice
} }
if slice, ok := arr.([]any); ok && len(slice) == 1 { return vals
return slice[0], err }
if len(m) == 1 {
// Check for single value map
for k, v := range m {
if k == "1" {
return v
} }
result = arr
} else if len(m) == 1 {
// When there's exactly one item, return its value directly
for _, v := range m {
return v, err
} }
} }
} }
// Convert []float64 to []any for consistency with multiple returns // Other array types should be preserved
if slice, ok := result.([]float64); ok { return result
if len(slice) == 1 {
return slice[0], err
}
anySlice := make([]any, len(slice))
for i, v := range slice {
anySlice[i] = v
}
return anySlice, err
}
// Handle multiple return values
if results, ok := result.([]any); ok && len(results) == 1 {
return results[0], err
}
return result, err
} }