LuaJIT-to-Go/sandbox.go
2025-03-27 21:31:41 -05:00

686 lines
15 KiB
Go

package luajit
import (
"fmt"
"sync"
)
// LUA_MULTRET is the constant for multiple return values
const LUA_MULTRET = -1
// Sandbox provides a persistent Lua environment for executing scripts
type Sandbox struct {
state *State
mutex sync.Mutex
initialized bool
modules map[string]any
functions map[string]GoFunction
}
// NewSandbox creates a new sandbox with standard libraries loaded
func NewSandbox() *Sandbox {
return &Sandbox{
state: New(),
initialized: false,
modules: make(map[string]any),
functions: make(map[string]GoFunction),
}
}
// Close releases all resources used by the sandbox
func (s *Sandbox) Close() {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state != nil {
s.state.Close()
s.state = nil
}
}
// RegisterFunction registers a Go function in the sandbox
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Register function globally
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
// Store function for re-registration
s.functions[name] = fn
// Add to base environment
return s.state.DoString(`
-- Add the function to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`)
}
// SetGlobal sets a global variable in the sandbox base environment
func (s *Sandbox) SetGlobal(name string, value any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Push the value onto the stack
if err := s.state.PushValue(value); err != nil {
return err
}
// Set the global with the pushed value
s.state.SetGlobal(name)
// Add to base environment
return s.state.DoString(`
-- Add the global to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`)
}
// GetGlobal retrieves a global variable from the sandbox base environment
func (s *Sandbox) GetGlobal(name string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Get the global from the base environment
code := `return __env_system.base_env["` + name + `"]`
return s.state.ExecuteWithResult(code)
}
// Run executes Lua code in the sandbox and returns the result
func (s *Sandbox) Run(code string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Add wrapper for multiple return values
wrappedCode := `
local function _execfunc()
` + code + `
end
local function _wrapresults(...)
local results = {n = select('#', ...)}
for i = 1, results.n do
results[i] = select(i, ...)
end
return results
end
return _wrapresults(_execfunc())
`
// Compile the code
if err := s.state.LoadString(wrappedCode); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Setup call with correct argument order
s.state.PushCopy(-2) // Copy the function
// Remove the original function
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
return s.unwrapResult(result, err)
}
// RunFile executes a Lua file in the sandbox
func (s *Sandbox) RunFile(filename string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.DoFile(filename)
}
// Compile compiles Lua code to bytecode without executing it
func (s *Sandbox) Compile(code string) ([]byte, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
return s.state.CompileBytecode(code, "sandbox")
}
// RunBytecode executes precompiled Lua bytecode in the sandbox
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Load the bytecode
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push bytecode function
s.state.PushCopy(-2)
// Remove original bytecode function
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
return s.unwrapResult(result, err)
}
// getResults collects results from the stack (must be called with mutex locked)
func (s *Sandbox) getResults() (any, error) {
numResults := s.state.GetTop()
if numResults == 0 {
return nil, nil
} else if numResults == 1 {
// Return single result directly
value, err := s.state.ToValue(-1)
s.state.Pop(1)
return value, err
}
// Return multiple results as slice
results := make([]any, numResults)
for i := 0; i < numResults; i++ {
value, err := s.state.ToValue(i - numResults)
if err != nil {
s.state.Pop(numResults)
return nil, err
}
results[i] = value
}
s.state.Pop(numResults)
return results, nil
}
// LoadModule loads a Lua module in the sandbox
func (s *Sandbox) LoadModule(name string) error {
code := fmt.Sprintf("require('%s')", name)
_, err := s.Run(code)
return err
}
// SetPackagePath sets the sandbox package.path
func (s *Sandbox) SetPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.SetPackagePath(path); err != nil {
return err
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`
__env_system.base_env.package.path = package.path
`)
}
// AddPackagePath adds a path to the sandbox package.path
func (s *Sandbox) AddPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.AddPackagePath(path); err != nil {
return err
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`
__env_system.base_env.package.path = package.path
`)
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
s.modules[name] = module
return nil
}
// Initialize sets up the environment system
func (s *Sandbox) Initialize() error {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.initializeUnlocked()
}
// initializeUnlocked sets up the environment system without locking
// It should only be called when the mutex is already locked
func (s *Sandbox) initializeUnlocked() error {
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
if s.initialized {
return nil // Already initialized
}
// Register modules
s.state.GetGlobal("__sandbox_modules")
if s.state.IsNil(-1) {
s.state.Pop(1)
s.state.NewTable()
s.state.SetGlobal("__sandbox_modules")
s.state.GetGlobal("__sandbox_modules")
}
// Add modules
for name, module := range s.modules {
s.state.PushString(name)
if err := s.state.PushValue(module); err != nil {
s.state.Pop(2)
return err
}
s.state.SetTable(-3)
}
s.state.Pop(1)
// Create the environment system
err := s.state.DoString(`
-- Global shared environment (created once)
__env_system = __env_system or {
base_env = nil, -- Template environment
initialized = false, -- Initialization flag
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
}
-- Initialize base environment once
if not __env_system.initialized then
-- Create base environment with all standard libraries
local base = {}
-- Safe standard libraries
base.string = string
base.table = table
base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
-- Basic functions
base.print = print
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
base.collectgarbage = collectgarbage
base.unpack = unpack or table.unpack
-- Package system
base.package = {
loaded = {},
path = package.path,
preload = {}
}
base.require = function(modname)
if base.package.loaded[modname] then
return base.package.loaded[modname]
end
local loader = base.package.preload[modname]
if type(loader) == "function" then
local result = loader(modname)
base.package.loaded[modname] = result or true
return result
end
error("module '" .. modname .. "' not found", 2)
end
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
-- Store base environment
__env_system.base_env = base
__env_system.initialized = true
end
-- Global variable for tracking current environment
__last_env = nil
-- Get an environment for execution
function __get_sandbox_env()
local env
-- Try to reuse from pool
if __env_system.pool_size > 0 then
env = table.remove(__env_system.env_pool)
__env_system.pool_size = __env_system.pool_size - 1
else
-- Create new environment with metatable inheritance
env = setmetatable({}, {
__index = __env_system.base_env -- Use base env instead of _G
})
end
-- Store reference to current environment
__last_env = env
return env
end
-- Return environment to pool for reuse
function __recycle_env(env)
-- Only recycle if pool isn't full
if __env_system.pool_size < __env_system.max_pool_size then
-- Clear all fields except metatable
for k in pairs(env) do
env[k] = nil
end
-- Add to pool
table.insert(__env_system.env_pool, env)
__env_system.pool_size = __env_system.pool_size + 1
end
end
-- Execute code in sandbox
function __execute_sandbox(f)
-- Get environment
local env = __get_sandbox_env()
-- Set environment for function
setfenv(f, env)
-- Execute with protected call
local success, result = pcall(f)
-- Copy all globals to base environment
for k, v in pairs(env) do
if k ~= "_G" and type(k) == "string" then
__env_system.base_env[k] = v
end
end
-- Recycle environment
__recycle_env(env)
-- Process result
if not success then
error(result, 0)
end
-- Handle multiple return values
if type(result) == "table" and result.n ~= nil then
local returnValues = {}
for i=1, result.n do
returnValues[i] = result[i]
end
return returnValues
end
return result
end
`)
if err != nil {
return err
}
s.initialized = true
return nil
}
// AddPermanentLua adds Lua code to the environment permanently
// This code becomes part of the base environment
func (s *Sandbox) AddPermanentLua(code string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Add code to base environment
return s.state.DoString(`
-- First compile the code
local f, err = loadstring([=[` + code + `]=], "permanent")
if not f then
error(err, 0)
end
-- Create a temporary environment based on base env
local temp_env = setmetatable({}, {__index = __env_system.base_env})
setfenv(f, temp_env)
-- Run the code in the temporary environment
local ok, err = pcall(f)
if not ok then
error(err, 0)
end
-- Copy new values to base environment
for k, v in pairs(temp_env) do
__env_system.base_env[k] = v
end
`)
}
// ResetEnvironment resets the sandbox to its initial state
func (s *Sandbox) ResetEnvironment() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Clear the environment system completely
err := s.state.DoString(`
-- Reset environment system
__env_system = nil
__wrap_bytecode = nil
__last_env = nil
`)
if err != nil {
return err
}
// Reinitialize the environment system
s.initialized = false
if err := s.initializeUnlocked(); err != nil {
return err
}
// Re-register all functions
for name, fn := range s.functions {
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
if err := s.state.DoString(`
-- Add the function to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`); err != nil {
return err
}
}
return nil
}
// unwrapResult processes the raw result value from Lua
// and unwraps single values from special map structures
func (s *Sandbox) unwrapResult(result any, err error) (any, error) {
// Unwrap array stored in map with empty key
if m, ok := result.(map[string]any); ok {
// Check for special array format
if arr, ok := m[""]; ok {
// If the array has only one element, return that element
if slice, ok := arr.([]float64); ok {
if len(slice) == 1 {
return slice[0], err
}
// Convert []float64 to []any for consistency
anySlice := make([]any, len(slice))
for i, v := range slice {
anySlice[i] = v
}
return anySlice, err
}
if slice, ok := arr.([]any); ok && len(slice) == 1 {
return slice[0], err
}
result = arr
} else if len(m) == 1 {
// When there's exactly one item, return its value directly
for _, v := range m {
return v, err
}
}
}
// Convert []float64 to []any for consistency with multiple returns
if slice, ok := result.([]float64); ok {
if len(slice) == 1 {
return slice[0], err
}
anySlice := make([]any, len(slice))
for i, v := range slice {
anySlice[i] = v
}
return anySlice, err
}
// Handle multiple return values
if results, ok := result.([]any); ok && len(results) == 1 {
return results[0], err
}
return result, err
}