package luajit import ( "fmt" "sync" ) // LUA_MULTRET is the constant for multiple return values const LUA_MULTRET = -1 // Sandbox provides a persistent Lua environment for executing scripts type Sandbox struct { state *State mutex sync.Mutex initialized bool modules map[string]any functions map[string]GoFunction } // NewSandbox creates a new sandbox with standard libraries loaded func NewSandbox() *Sandbox { return &Sandbox{ state: New(), initialized: false, modules: make(map[string]any), functions: make(map[string]GoFunction), } } // Close releases all resources used by the sandbox func (s *Sandbox) Close() { s.mutex.Lock() defer s.mutex.Unlock() if s.state != nil { s.state.Close() s.state = nil } } // Initialize sets up the environment system func (s *Sandbox) Initialize() error { s.mutex.Lock() defer s.mutex.Unlock() return s.initializeUnlocked() } // initializeUnlocked sets up the environment system without locking func (s *Sandbox) initializeUnlocked() error { if s.state == nil { return fmt.Errorf("sandbox is closed") } if s.initialized { return nil } // Register modules s.state.GetGlobal("__sandbox_modules") if s.state.IsNil(-1) { s.state.Pop(1) s.state.NewTable() s.state.SetGlobal("__sandbox_modules") s.state.GetGlobal("__sandbox_modules") } // Add modules for name, module := range s.modules { s.state.PushString(name) if err := s.state.PushValue(module); err != nil { s.state.Pop(2) return err } s.state.SetTable(-3) } s.state.Pop(1) // Create simplified environment system err := s.state.DoString(` -- Global shared environment __env_system = { base_env = {}, -- Template environment env_pool = {}, -- Pre-allocated environment pool pool_size = 0, -- Current pool size max_pool_size = 8 -- Maximum pool size } -- Create base environment with standard libraries local base = __env_system.base_env -- Safe standard libraries base.string = string base.table = table base.math = math base.os = { time = os.time, date = os.date, difftime = os.difftime, clock = os.clock } -- Basic functions base.print = print base.tonumber = tonumber base.tostring = tostring base.type = type base.pairs = pairs base.ipairs = ipairs base.next = next base.select = select base.pcall = pcall base.xpcall = xpcall base.error = error base.assert = assert base.collectgarbage = collectgarbage base.unpack = unpack or table.unpack -- Package system base.package = { loaded = {}, path = package.path, preload = {} } base.require = function(modname) if base.package.loaded[modname] then return base.package.loaded[modname] end local loader = base.package.preload[modname] if type(loader) == "function" then local result = loader(modname) base.package.loaded[modname] = result or true return result end error("module '" .. modname .. "' not found", 2) end -- Add registered custom modules if __sandbox_modules then for name, mod in pairs(__sandbox_modules) do base[name] = mod end end -- Get an environment for execution function __get_sandbox_env() local env -- Try to reuse from pool if __env_system.pool_size > 0 then env = table.remove(__env_system.env_pool) __env_system.pool_size = __env_system.pool_size - 1 else -- Create new environment with metatable inheritance env = setmetatable({}, { __index = __env_system.base_env }) end return env end -- Return environment to pool for reuse function __recycle_env(env) if __env_system.pool_size < __env_system.max_pool_size then -- Clear all fields except metatable for k in pairs(env) do env[k] = nil end -- Add to pool table.insert(__env_system.env_pool, env) __env_system.pool_size = __env_system.pool_size + 1 end end -- Execute code in sandbox function __execute_sandbox(f) -- Get environment local env = __get_sandbox_env() -- Set environment for function setfenv(f, env) -- Execute with protected call local success, result = pcall(f) -- Update base environment with new globals for k, v in pairs(env) do if k ~= "_G" and type(k) == "string" then __env_system.base_env[k] = v end end -- Recycle environment __recycle_env(env) -- Process result if not success then error(result, 0) end return result end `) if err != nil { return err } s.initialized = true return nil } // RegisterFunction registers a Go function in the sandbox func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return fmt.Errorf("sandbox is closed") } // Initialize if needed if !s.initialized { if err := s.initializeUnlocked(); err != nil { return err } } // Register function globally if err := s.state.RegisterGoFunction(name, fn); err != nil { return err } // Store function for re-registration s.functions[name] = fn // Add to base environment return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name) } // SetGlobal sets a global variable in the sandbox base environment func (s *Sandbox) SetGlobal(name string, value any) error { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return fmt.Errorf("sandbox is closed") } // Initialize if needed if !s.initialized { if err := s.initializeUnlocked(); err != nil { return err } } // Push the value onto the stack if err := s.state.PushValue(value); err != nil { return err } // Set the global with the pushed value s.state.SetGlobal(name) // Add to base environment return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name) } // GetGlobal retrieves a global variable from the sandbox base environment func (s *Sandbox) GetGlobal(name string) (any, error) { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return nil, fmt.Errorf("sandbox is closed") } // Initialize if needed if !s.initialized { if err := s.initializeUnlocked(); err != nil { return nil, err } } // Get the global from the base environment return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`) } // Run executes Lua code in the sandbox func (s *Sandbox) Run(code string) (any, error) { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return nil, fmt.Errorf("sandbox is closed") } // Initialize if needed if !s.initialized { if err := s.initializeUnlocked(); err != nil { return nil, err } } // Simplified wrapper for multiple return values wrappedCode := ` local function _execfunc() ` + code + ` end -- Process results to match expected format local function _wrapresults(...) local n = select('#', ...) if n == 0 then return nil elseif n == 1 then return select(1, ...) else local results = {} for i = 1, n do results[i] = select(i, ...) end return results end end return _wrapresults(_execfunc()) ` // Compile the code if err := s.state.LoadString(wrappedCode); err != nil { return nil, err } // Get the sandbox executor s.state.GetGlobal("__execute_sandbox") // Push the function as argument s.state.PushCopy(-2) s.state.Remove(-3) // Execute in sandbox if err := s.state.Call(1, 1); err != nil { return nil, err } // Get result result, err := s.state.ToValue(-1) s.state.Pop(1) if err != nil { return nil, err } return s.processResult(result), nil } // RunFile executes a Lua file in the sandbox func (s *Sandbox) RunFile(filename string) error { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return fmt.Errorf("sandbox is closed") } return s.state.DoFile(filename) } // Compile compiles Lua code to bytecode func (s *Sandbox) Compile(code string) ([]byte, error) { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return nil, fmt.Errorf("sandbox is closed") } return s.state.CompileBytecode(code, "sandbox") } // RunBytecode executes precompiled Lua bytecode func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return nil, fmt.Errorf("sandbox is closed") } // Initialize if needed if !s.initialized { if err := s.initializeUnlocked(); err != nil { return nil, err } } // Load the bytecode if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil { return nil, err } // Get the sandbox executor s.state.GetGlobal("__execute_sandbox") // Push bytecode function s.state.PushCopy(-2) s.state.Remove(-3) // Execute in sandbox if err := s.state.Call(1, 1); err != nil { return nil, err } // Get result result, err := s.state.ToValue(-1) s.state.Pop(1) if err != nil { return nil, err } return s.processResult(result), nil } // LoadModule loads a Lua module func (s *Sandbox) LoadModule(name string) error { code := fmt.Sprintf("require('%s')", name) _, err := s.Run(code) return err } // SetPackagePath sets the sandbox package.path func (s *Sandbox) SetPackagePath(path string) error { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return fmt.Errorf("sandbox is closed") } // Update global package.path if err := s.state.SetPackagePath(path); err != nil { return err } // Initialize if needed if !s.initialized { if err := s.initializeUnlocked(); err != nil { return err } } // Update base environment's package.path return s.state.DoString(`__env_system.base_env.package.path = package.path`) } // AddPackagePath adds a path to the sandbox package.path func (s *Sandbox) AddPackagePath(path string) error { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return fmt.Errorf("sandbox is closed") } // Update global package.path if err := s.state.AddPackagePath(path); err != nil { return err } // Initialize if needed if !s.initialized { if err := s.initializeUnlocked(); err != nil { return err } } // Update base environment's package.path return s.state.DoString(`__env_system.base_env.package.path = package.path`) } // AddModule adds a module to the sandbox environment func (s *Sandbox) AddModule(name string, module any) error { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return fmt.Errorf("sandbox is closed") } s.modules[name] = module return nil } // AddPermanentLua adds Lua code to the environment permanently func (s *Sandbox) AddPermanentLua(code string) error { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return fmt.Errorf("sandbox is closed") } // Initialize if needed if !s.initialized { if err := s.initializeUnlocked(); err != nil { return err } } // Simplified approach to add code to base environment return s.state.DoString(` local f, err = loadstring([=[` + code + `]=], "permanent") if not f then error(err, 0) end local env = setmetatable({}, {__index = __env_system.base_env}) setfenv(f, env) local ok, err = pcall(f) if not ok then error(err, 0) end for k, v in pairs(env) do __env_system.base_env[k] = v end `) } // ResetEnvironment resets the sandbox to its initial state func (s *Sandbox) ResetEnvironment() error { s.mutex.Lock() defer s.mutex.Unlock() if s.state == nil { return fmt.Errorf("sandbox is closed") } // Clear the environment system s.state.DoString(`__env_system = nil`) // Reinitialize s.initialized = false if err := s.initializeUnlocked(); err != nil { return err } // Re-register all functions for name, fn := range s.functions { if err := s.state.RegisterGoFunction(name, fn); err != nil { return err } if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil { return err } } return nil } // unwrapResult processes results from Lua executions func (s *Sandbox) processResult(result any) any { // Handle []float64 (common LuaJIT return type) if floats, ok := result.([]float64); ok { if len(floats) == 1 { // Single number - return as float64 return floats[0] } // Multiple numbers - MUST convert to []any for tests to pass anySlice := make([]any, len(floats)) for i, v := range floats { anySlice[i] = v } return anySlice } // Handle maps with numeric keys (Lua tables) if m, ok := result.(map[string]any); ok { // Handle return tables with special structure if vals, ok := m[""]; ok { // This is a special case used by some Lua returns if arr, ok := vals.([]float64); ok { // Convert to []any for consistency anySlice := make([]any, len(arr)) for i, v := range arr { anySlice[i] = v } return anySlice } return vals } if len(m) == 1 { // Check for single value map for k, v := range m { if k == "1" { return v } } } } // Other array types should be preserved return result }