optimize sandbox

This commit is contained in:
Sky Johnson 2025-03-27 21:58:56 -05:00
parent 4ad87f81f3
commit 875abee366

View File

@ -38,331 +38,21 @@ func (s *Sandbox) Close() {
}
}
// RegisterFunction registers a Go function in the sandbox
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Register function globally
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
// Store function for re-registration
s.functions[name] = fn
// Add to base environment
return s.state.DoString(`
-- Add the function to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`)
}
// SetGlobal sets a global variable in the sandbox base environment
func (s *Sandbox) SetGlobal(name string, value any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Push the value onto the stack
if err := s.state.PushValue(value); err != nil {
return err
}
// Set the global with the pushed value
s.state.SetGlobal(name)
// Add to base environment
return s.state.DoString(`
-- Add the global to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`)
}
// GetGlobal retrieves a global variable from the sandbox base environment
func (s *Sandbox) GetGlobal(name string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Get the global from the base environment
code := `return __env_system.base_env["` + name + `"]`
return s.state.ExecuteWithResult(code)
}
// Run executes Lua code in the sandbox and returns the result
func (s *Sandbox) Run(code string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Add wrapper for multiple return values
wrappedCode := `
local function _execfunc()
` + code + `
end
local function _wrapresults(...)
local results = {n = select('#', ...)}
for i = 1, results.n do
results[i] = select(i, ...)
end
return results
end
return _wrapresults(_execfunc())
`
// Compile the code
if err := s.state.LoadString(wrappedCode); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Setup call with correct argument order
s.state.PushCopy(-2) // Copy the function
// Remove the original function
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
return s.unwrapResult(result, err)
}
// RunFile executes a Lua file in the sandbox
func (s *Sandbox) RunFile(filename string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.DoFile(filename)
}
// Compile compiles Lua code to bytecode without executing it
func (s *Sandbox) Compile(code string) ([]byte, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
return s.state.CompileBytecode(code, "sandbox")
}
// RunBytecode executes precompiled Lua bytecode in the sandbox
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Load the bytecode
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push bytecode function
s.state.PushCopy(-2)
// Remove original bytecode function
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
return s.unwrapResult(result, err)
}
// getResults collects results from the stack (must be called with mutex locked)
func (s *Sandbox) getResults() (any, error) {
numResults := s.state.GetTop()
if numResults == 0 {
return nil, nil
} else if numResults == 1 {
// Return single result directly
value, err := s.state.ToValue(-1)
s.state.Pop(1)
return value, err
}
// Return multiple results as slice
results := make([]any, numResults)
for i := 0; i < numResults; i++ {
value, err := s.state.ToValue(i - numResults)
if err != nil {
s.state.Pop(numResults)
return nil, err
}
results[i] = value
}
s.state.Pop(numResults)
return results, nil
}
// LoadModule loads a Lua module in the sandbox
func (s *Sandbox) LoadModule(name string) error {
code := fmt.Sprintf("require('%s')", name)
_, err := s.Run(code)
return err
}
// SetPackagePath sets the sandbox package.path
func (s *Sandbox) SetPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.SetPackagePath(path); err != nil {
return err
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`
__env_system.base_env.package.path = package.path
`)
}
// AddPackagePath adds a path to the sandbox package.path
func (s *Sandbox) AddPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.AddPackagePath(path); err != nil {
return err
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`
__env_system.base_env.package.path = package.path
`)
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
s.modules[name] = module
return nil
}
// Initialize sets up the environment system
func (s *Sandbox) Initialize() error {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.initializeUnlocked()
}
// initializeUnlocked sets up the environment system without locking
// It should only be called when the mutex is already locked
func (s *Sandbox) initializeUnlocked() error {
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
if s.initialized {
return nil // Already initialized
return nil
}
// Register modules
@ -385,85 +75,74 @@ func (s *Sandbox) initializeUnlocked() error {
}
s.state.Pop(1)
// Create the environment system
// Create simplified environment system
err := s.state.DoString(`
-- Global shared environment (created once)
__env_system = __env_system or {
base_env = nil, -- Template environment
initialized = false, -- Initialization flag
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
-- Global shared environment
__env_system = {
base_env = {}, -- Template environment
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
}
-- Initialize base environment once
if not __env_system.initialized then
-- Create base environment with all standard libraries
local base = {}
-- Create base environment with standard libraries
local base = __env_system.base_env
-- Safe standard libraries
base.string = string
base.table = table
base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
-- Safe standard libraries
base.string = string
base.table = table
base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
-- Basic functions
base.print = print
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
base.collectgarbage = collectgarbage
base.unpack = unpack or table.unpack
-- Basic functions
base.print = print
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
base.collectgarbage = collectgarbage
base.unpack = unpack or table.unpack
-- Package system
base.package = {
loaded = {},
path = package.path,
preload = {}
}
-- Package system
base.package = {
loaded = {},
path = package.path,
preload = {}
}
base.require = function(modname)
if base.package.loaded[modname] then
return base.package.loaded[modname]
end
local loader = base.package.preload[modname]
if type(loader) == "function" then
local result = loader(modname)
base.package.loaded[modname] = result or true
return result
end
error("module '" .. modname .. "' not found", 2)
base.require = function(modname)
if base.package.loaded[modname] then
return base.package.loaded[modname]
end
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
local loader = base.package.preload[modname]
if type(loader) == "function" then
local result = loader(modname)
base.package.loaded[modname] = result or true
return result
end
-- Store base environment
__env_system.base_env = base
__env_system.initialized = true
error("module '" .. modname .. "' not found", 2)
end
-- Global variable for tracking current environment
__last_env = nil
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
-- Get an environment for execution
function __get_sandbox_env()
@ -476,19 +155,15 @@ func (s *Sandbox) initializeUnlocked() error {
else
-- Create new environment with metatable inheritance
env = setmetatable({}, {
__index = __env_system.base_env -- Use base env instead of _G
__index = __env_system.base_env
})
end
-- Store reference to current environment
__last_env = env
return env
end
-- Return environment to pool for reuse
function __recycle_env(env)
-- Only recycle if pool isn't full
if __env_system.pool_size < __env_system.max_pool_size then
-- Clear all fields except metatable
for k in pairs(env) do
@ -512,7 +187,7 @@ func (s *Sandbox) initializeUnlocked() error {
-- Execute with protected call
local success, result = pcall(f)
-- Copy all globals to base environment
-- Update base environment with new globals
for k, v in pairs(env) do
if k ~= "_G" and type(k) == "string" then
__env_system.base_env[k] = v
@ -526,16 +201,7 @@ func (s *Sandbox) initializeUnlocked() error {
if not success then
error(result, 0)
end
-- Handle multiple return values
if type(result) == "table" and result.n ~= nil then
local returnValues = {}
for i=1, result.n do
returnValues[i] = result[i]
end
return returnValues
end
return result
end
`)
@ -548,8 +214,290 @@ func (s *Sandbox) initializeUnlocked() error {
return nil
}
// RegisterFunction registers a Go function in the sandbox
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Register function globally
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
// Store function for re-registration
s.functions[name] = fn
// Add to base environment
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
}
// SetGlobal sets a global variable in the sandbox base environment
func (s *Sandbox) SetGlobal(name string, value any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Push the value onto the stack
if err := s.state.PushValue(value); err != nil {
return err
}
// Set the global with the pushed value
s.state.SetGlobal(name)
// Add to base environment
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
}
// GetGlobal retrieves a global variable from the sandbox base environment
func (s *Sandbox) GetGlobal(name string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Get the global from the base environment
return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`)
}
// Run executes Lua code in the sandbox
func (s *Sandbox) Run(code string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Simplified wrapper for multiple return values
wrappedCode := `
local function _execfunc()
` + code + `
end
-- Process results to match expected format
local function _wrapresults(...)
local n = select('#', ...)
if n == 0 then
return nil
elseif n == 1 then
return select(1, ...)
else
local results = {}
for i = 1, n do
results[i] = select(i, ...)
end
return results
end
end
return _wrapresults(_execfunc())
`
// Compile the code
if err := s.state.LoadString(wrappedCode); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push the function as argument
s.state.PushCopy(-2)
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
if err != nil {
return nil, err
}
return s.processResult(result), nil
}
// RunFile executes a Lua file in the sandbox
func (s *Sandbox) RunFile(filename string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.DoFile(filename)
}
// Compile compiles Lua code to bytecode
func (s *Sandbox) Compile(code string) ([]byte, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
return s.state.CompileBytecode(code, "sandbox")
}
// RunBytecode executes precompiled Lua bytecode
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Load the bytecode
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push bytecode function
s.state.PushCopy(-2)
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
if err != nil {
return nil, err
}
return s.processResult(result), nil
}
// LoadModule loads a Lua module
func (s *Sandbox) LoadModule(name string) error {
code := fmt.Sprintf("require('%s')", name)
_, err := s.Run(code)
return err
}
// SetPackagePath sets the sandbox package.path
func (s *Sandbox) SetPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.SetPackagePath(path); err != nil {
return err
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
}
// AddPackagePath adds a path to the sandbox package.path
func (s *Sandbox) AddPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.AddPackagePath(path); err != nil {
return err
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
s.modules[name] = module
return nil
}
// AddPermanentLua adds Lua code to the environment permanently
// This code becomes part of the base environment
func (s *Sandbox) AddPermanentLua(code string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
@ -558,33 +506,25 @@ func (s *Sandbox) AddPermanentLua(code string) error {
return fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Add code to base environment
// Simplified approach to add code to base environment
return s.state.DoString(`
-- First compile the code
local f, err = loadstring([=[` + code + `]=], "permanent")
if not f then
error(err, 0)
end
-- Create a temporary environment based on base env
local temp_env = setmetatable({}, {__index = __env_system.base_env})
setfenv(f, temp_env)
-- Run the code in the temporary environment
if not f then error(err, 0) end
local env = setmetatable({}, {__index = __env_system.base_env})
setfenv(f, env)
local ok, err = pcall(f)
if not ok then
error(err, 0)
end
-- Copy new values to base environment
for k, v in pairs(temp_env) do
if not ok then error(err, 0) end
for k, v in pairs(env) do
__env_system.base_env[k] = v
end
`)
@ -599,18 +539,10 @@ func (s *Sandbox) ResetEnvironment() error {
return fmt.Errorf("sandbox is closed")
}
// Clear the environment system completely
err := s.state.DoString(`
-- Reset environment system
__env_system = nil
__wrap_bytecode = nil
__last_env = nil
`)
if err != nil {
return err
}
// Clear the environment system
s.state.DoString(`__env_system = nil`)
// Reinitialize the environment system
// Reinitialize
s.initialized = false
if err := s.initializeUnlocked(); err != nil {
return err
@ -622,10 +554,7 @@ func (s *Sandbox) ResetEnvironment() error {
return err
}
if err := s.state.DoString(`
-- Add the function to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`); err != nil {
if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil {
return err
}
}
@ -633,53 +562,48 @@ func (s *Sandbox) ResetEnvironment() error {
return nil
}
// unwrapResult processes the raw result value from Lua
// and unwraps single values from special map structures
func (s *Sandbox) unwrapResult(result any, err error) (any, error) {
// Unwrap array stored in map with empty key
if m, ok := result.(map[string]any); ok {
// Check for special array format
if arr, ok := m[""]; ok {
// If the array has only one element, return that element
if slice, ok := arr.([]float64); ok {
if len(slice) == 1 {
return slice[0], err
}
// Convert []float64 to []any for consistency
anySlice := make([]any, len(slice))
for i, v := range slice {
anySlice[i] = v
}
return anySlice, err
}
if slice, ok := arr.([]any); ok && len(slice) == 1 {
return slice[0], err
}
result = arr
} else if len(m) == 1 {
// When there's exactly one item, return its value directly
for _, v := range m {
return v, err
}
// unwrapResult processes results from Lua executions
func (s *Sandbox) processResult(result any) any {
// Handle []float64 (common LuaJIT return type)
if floats, ok := result.([]float64); ok {
if len(floats) == 1 {
// Single number - return as float64
return floats[0]
}
}
// Convert []float64 to []any for consistency with multiple returns
if slice, ok := result.([]float64); ok {
if len(slice) == 1 {
return slice[0], err
}
anySlice := make([]any, len(slice))
for i, v := range slice {
// Multiple numbers - MUST convert to []any for tests to pass
anySlice := make([]any, len(floats))
for i, v := range floats {
anySlice[i] = v
}
return anySlice, err
return anySlice
}
// Handle multiple return values
if results, ok := result.([]any); ok && len(results) == 1 {
return results[0], err
// Handle maps with numeric keys (Lua tables)
if m, ok := result.(map[string]any); ok {
// Handle return tables with special structure
if vals, ok := m[""]; ok {
// This is a special case used by some Lua returns
if arr, ok := vals.([]float64); ok {
// Convert to []any for consistency
anySlice := make([]any, len(arr))
for i, v := range arr {
anySlice[i] = v
}
return anySlice
}
return vals
}
if len(m) == 1 {
// Check for single value map
for k, v := range m {
if k == "1" {
return v
}
}
}
}
return result, err
// Other array types should be preserved
return result
}