LuaJIT-to-Go/sandbox.go
2025-03-27 21:58:56 -05:00

610 lines
13 KiB
Go

package luajit
import (
"fmt"
"sync"
)
// LUA_MULTRET is the constant for multiple return values
const LUA_MULTRET = -1
// Sandbox provides a persistent Lua environment for executing scripts
type Sandbox struct {
state *State
mutex sync.Mutex
initialized bool
modules map[string]any
functions map[string]GoFunction
}
// NewSandbox creates a new sandbox with standard libraries loaded
func NewSandbox() *Sandbox {
return &Sandbox{
state: New(),
initialized: false,
modules: make(map[string]any),
functions: make(map[string]GoFunction),
}
}
// Close releases all resources used by the sandbox
func (s *Sandbox) Close() {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state != nil {
s.state.Close()
s.state = nil
}
}
// Initialize sets up the environment system
func (s *Sandbox) Initialize() error {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.initializeUnlocked()
}
// initializeUnlocked sets up the environment system without locking
func (s *Sandbox) initializeUnlocked() error {
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
if s.initialized {
return nil
}
// Register modules
s.state.GetGlobal("__sandbox_modules")
if s.state.IsNil(-1) {
s.state.Pop(1)
s.state.NewTable()
s.state.SetGlobal("__sandbox_modules")
s.state.GetGlobal("__sandbox_modules")
}
// Add modules
for name, module := range s.modules {
s.state.PushString(name)
if err := s.state.PushValue(module); err != nil {
s.state.Pop(2)
return err
}
s.state.SetTable(-3)
}
s.state.Pop(1)
// Create simplified environment system
err := s.state.DoString(`
-- Global shared environment
__env_system = {
base_env = {}, -- Template environment
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
}
-- Create base environment with standard libraries
local base = __env_system.base_env
-- Safe standard libraries
base.string = string
base.table = table
base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
-- Basic functions
base.print = print
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
base.collectgarbage = collectgarbage
base.unpack = unpack or table.unpack
-- Package system
base.package = {
loaded = {},
path = package.path,
preload = {}
}
base.require = function(modname)
if base.package.loaded[modname] then
return base.package.loaded[modname]
end
local loader = base.package.preload[modname]
if type(loader) == "function" then
local result = loader(modname)
base.package.loaded[modname] = result or true
return result
end
error("module '" .. modname .. "' not found", 2)
end
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
-- Get an environment for execution
function __get_sandbox_env()
local env
-- Try to reuse from pool
if __env_system.pool_size > 0 then
env = table.remove(__env_system.env_pool)
__env_system.pool_size = __env_system.pool_size - 1
else
-- Create new environment with metatable inheritance
env = setmetatable({}, {
__index = __env_system.base_env
})
end
return env
end
-- Return environment to pool for reuse
function __recycle_env(env)
if __env_system.pool_size < __env_system.max_pool_size then
-- Clear all fields except metatable
for k in pairs(env) do
env[k] = nil
end
-- Add to pool
table.insert(__env_system.env_pool, env)
__env_system.pool_size = __env_system.pool_size + 1
end
end
-- Execute code in sandbox
function __execute_sandbox(f)
-- Get environment
local env = __get_sandbox_env()
-- Set environment for function
setfenv(f, env)
-- Execute with protected call
local success, result = pcall(f)
-- Update base environment with new globals
for k, v in pairs(env) do
if k ~= "_G" and type(k) == "string" then
__env_system.base_env[k] = v
end
end
-- Recycle environment
__recycle_env(env)
-- Process result
if not success then
error(result, 0)
end
return result
end
`)
if err != nil {
return err
}
s.initialized = true
return nil
}
// RegisterFunction registers a Go function in the sandbox
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Register function globally
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
// Store function for re-registration
s.functions[name] = fn
// Add to base environment
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
}
// SetGlobal sets a global variable in the sandbox base environment
func (s *Sandbox) SetGlobal(name string, value any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Push the value onto the stack
if err := s.state.PushValue(value); err != nil {
return err
}
// Set the global with the pushed value
s.state.SetGlobal(name)
// Add to base environment
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
}
// GetGlobal retrieves a global variable from the sandbox base environment
func (s *Sandbox) GetGlobal(name string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Get the global from the base environment
return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`)
}
// Run executes Lua code in the sandbox
func (s *Sandbox) Run(code string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Simplified wrapper for multiple return values
wrappedCode := `
local function _execfunc()
` + code + `
end
-- Process results to match expected format
local function _wrapresults(...)
local n = select('#', ...)
if n == 0 then
return nil
elseif n == 1 then
return select(1, ...)
else
local results = {}
for i = 1, n do
results[i] = select(i, ...)
end
return results
end
end
return _wrapresults(_execfunc())
`
// Compile the code
if err := s.state.LoadString(wrappedCode); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push the function as argument
s.state.PushCopy(-2)
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
if err != nil {
return nil, err
}
return s.processResult(result), nil
}
// RunFile executes a Lua file in the sandbox
func (s *Sandbox) RunFile(filename string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.DoFile(filename)
}
// Compile compiles Lua code to bytecode
func (s *Sandbox) Compile(code string) ([]byte, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
return s.state.CompileBytecode(code, "sandbox")
}
// RunBytecode executes precompiled Lua bytecode
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Load the bytecode
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push bytecode function
s.state.PushCopy(-2)
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
if err != nil {
return nil, err
}
return s.processResult(result), nil
}
// LoadModule loads a Lua module
func (s *Sandbox) LoadModule(name string) error {
code := fmt.Sprintf("require('%s')", name)
_, err := s.Run(code)
return err
}
// SetPackagePath sets the sandbox package.path
func (s *Sandbox) SetPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.SetPackagePath(path); err != nil {
return err
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
}
// AddPackagePath adds a path to the sandbox package.path
func (s *Sandbox) AddPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.AddPackagePath(path); err != nil {
return err
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
s.modules[name] = module
return nil
}
// AddPermanentLua adds Lua code to the environment permanently
func (s *Sandbox) AddPermanentLua(code string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Simplified approach to add code to base environment
return s.state.DoString(`
local f, err = loadstring([=[` + code + `]=], "permanent")
if not f then error(err, 0) end
local env = setmetatable({}, {__index = __env_system.base_env})
setfenv(f, env)
local ok, err = pcall(f)
if not ok then error(err, 0) end
for k, v in pairs(env) do
__env_system.base_env[k] = v
end
`)
}
// ResetEnvironment resets the sandbox to its initial state
func (s *Sandbox) ResetEnvironment() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Clear the environment system
s.state.DoString(`__env_system = nil`)
// Reinitialize
s.initialized = false
if err := s.initializeUnlocked(); err != nil {
return err
}
// Re-register all functions
for name, fn := range s.functions {
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil {
return err
}
}
return nil
}
// unwrapResult processes results from Lua executions
func (s *Sandbox) processResult(result any) any {
// Handle []float64 (common LuaJIT return type)
if floats, ok := result.([]float64); ok {
if len(floats) == 1 {
// Single number - return as float64
return floats[0]
}
// Multiple numbers - MUST convert to []any for tests to pass
anySlice := make([]any, len(floats))
for i, v := range floats {
anySlice[i] = v
}
return anySlice
}
// Handle maps with numeric keys (Lua tables)
if m, ok := result.(map[string]any); ok {
// Handle return tables with special structure
if vals, ok := m[""]; ok {
// This is a special case used by some Lua returns
if arr, ok := vals.([]float64); ok {
// Convert to []any for consistency
anySlice := make([]any, len(arr))
for i, v := range arr {
anySlice[i] = v
}
return anySlice
}
return vals
}
if len(m) == 1 {
// Check for single value map
for k, v := range m {
if k == "1" {
return v
}
}
}
}
// Other array types should be preserved
return result
}