610 lines
13 KiB
Go
610 lines
13 KiB
Go
package luajit
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
)
|
|
|
|
// LUA_MULTRET is the constant for multiple return values
|
|
const LUA_MULTRET = -1
|
|
|
|
// Sandbox provides a persistent Lua environment for executing scripts
|
|
type Sandbox struct {
|
|
state *State
|
|
mutex sync.Mutex
|
|
initialized bool
|
|
modules map[string]any
|
|
functions map[string]GoFunction
|
|
}
|
|
|
|
// NewSandbox creates a new sandbox with standard libraries loaded
|
|
func NewSandbox() *Sandbox {
|
|
return &Sandbox{
|
|
state: New(),
|
|
initialized: false,
|
|
modules: make(map[string]any),
|
|
functions: make(map[string]GoFunction),
|
|
}
|
|
}
|
|
|
|
// Close releases all resources used by the sandbox
|
|
func (s *Sandbox) Close() {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state != nil {
|
|
s.state.Close()
|
|
s.state = nil
|
|
}
|
|
}
|
|
|
|
// Initialize sets up the environment system
|
|
func (s *Sandbox) Initialize() error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
return s.initializeUnlocked()
|
|
}
|
|
|
|
// initializeUnlocked sets up the environment system without locking
|
|
func (s *Sandbox) initializeUnlocked() error {
|
|
if s.state == nil {
|
|
return fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
if s.initialized {
|
|
return nil
|
|
}
|
|
|
|
// Register modules
|
|
s.state.GetGlobal("__sandbox_modules")
|
|
if s.state.IsNil(-1) {
|
|
s.state.Pop(1)
|
|
s.state.NewTable()
|
|
s.state.SetGlobal("__sandbox_modules")
|
|
s.state.GetGlobal("__sandbox_modules")
|
|
}
|
|
|
|
// Add modules
|
|
for name, module := range s.modules {
|
|
s.state.PushString(name)
|
|
if err := s.state.PushValue(module); err != nil {
|
|
s.state.Pop(2)
|
|
return err
|
|
}
|
|
s.state.SetTable(-3)
|
|
}
|
|
s.state.Pop(1)
|
|
|
|
// Create simplified environment system
|
|
err := s.state.DoString(`
|
|
-- Global shared environment
|
|
__env_system = {
|
|
base_env = {}, -- Template environment
|
|
env_pool = {}, -- Pre-allocated environment pool
|
|
pool_size = 0, -- Current pool size
|
|
max_pool_size = 8 -- Maximum pool size
|
|
}
|
|
|
|
-- Create base environment with standard libraries
|
|
local base = __env_system.base_env
|
|
|
|
-- Safe standard libraries
|
|
base.string = string
|
|
base.table = table
|
|
base.math = math
|
|
base.os = {
|
|
time = os.time,
|
|
date = os.date,
|
|
difftime = os.difftime,
|
|
clock = os.clock
|
|
}
|
|
|
|
-- Basic functions
|
|
base.print = print
|
|
base.tonumber = tonumber
|
|
base.tostring = tostring
|
|
base.type = type
|
|
base.pairs = pairs
|
|
base.ipairs = ipairs
|
|
base.next = next
|
|
base.select = select
|
|
base.pcall = pcall
|
|
base.xpcall = xpcall
|
|
base.error = error
|
|
base.assert = assert
|
|
base.collectgarbage = collectgarbage
|
|
base.unpack = unpack or table.unpack
|
|
|
|
-- Package system
|
|
base.package = {
|
|
loaded = {},
|
|
path = package.path,
|
|
preload = {}
|
|
}
|
|
|
|
base.require = function(modname)
|
|
if base.package.loaded[modname] then
|
|
return base.package.loaded[modname]
|
|
end
|
|
|
|
local loader = base.package.preload[modname]
|
|
if type(loader) == "function" then
|
|
local result = loader(modname)
|
|
base.package.loaded[modname] = result or true
|
|
return result
|
|
end
|
|
|
|
error("module '" .. modname .. "' not found", 2)
|
|
end
|
|
|
|
-- Add registered custom modules
|
|
if __sandbox_modules then
|
|
for name, mod in pairs(__sandbox_modules) do
|
|
base[name] = mod
|
|
end
|
|
end
|
|
|
|
-- Get an environment for execution
|
|
function __get_sandbox_env()
|
|
local env
|
|
|
|
-- Try to reuse from pool
|
|
if __env_system.pool_size > 0 then
|
|
env = table.remove(__env_system.env_pool)
|
|
__env_system.pool_size = __env_system.pool_size - 1
|
|
else
|
|
-- Create new environment with metatable inheritance
|
|
env = setmetatable({}, {
|
|
__index = __env_system.base_env
|
|
})
|
|
end
|
|
|
|
return env
|
|
end
|
|
|
|
-- Return environment to pool for reuse
|
|
function __recycle_env(env)
|
|
if __env_system.pool_size < __env_system.max_pool_size then
|
|
-- Clear all fields except metatable
|
|
for k in pairs(env) do
|
|
env[k] = nil
|
|
end
|
|
|
|
-- Add to pool
|
|
table.insert(__env_system.env_pool, env)
|
|
__env_system.pool_size = __env_system.pool_size + 1
|
|
end
|
|
end
|
|
|
|
-- Execute code in sandbox
|
|
function __execute_sandbox(f)
|
|
-- Get environment
|
|
local env = __get_sandbox_env()
|
|
|
|
-- Set environment for function
|
|
setfenv(f, env)
|
|
|
|
-- Execute with protected call
|
|
local success, result = pcall(f)
|
|
|
|
-- Update base environment with new globals
|
|
for k, v in pairs(env) do
|
|
if k ~= "_G" and type(k) == "string" then
|
|
__env_system.base_env[k] = v
|
|
end
|
|
end
|
|
|
|
-- Recycle environment
|
|
__recycle_env(env)
|
|
|
|
-- Process result
|
|
if not success then
|
|
error(result, 0)
|
|
end
|
|
|
|
return result
|
|
end
|
|
`)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.initialized = true
|
|
return nil
|
|
}
|
|
|
|
// RegisterFunction registers a Go function in the sandbox
|
|
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
// Initialize if needed
|
|
if !s.initialized {
|
|
if err := s.initializeUnlocked(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Register function globally
|
|
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Store function for re-registration
|
|
s.functions[name] = fn
|
|
|
|
// Add to base environment
|
|
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
|
|
}
|
|
|
|
// SetGlobal sets a global variable in the sandbox base environment
|
|
func (s *Sandbox) SetGlobal(name string, value any) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
// Initialize if needed
|
|
if !s.initialized {
|
|
if err := s.initializeUnlocked(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Push the value onto the stack
|
|
if err := s.state.PushValue(value); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set the global with the pushed value
|
|
s.state.SetGlobal(name)
|
|
|
|
// Add to base environment
|
|
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
|
|
}
|
|
|
|
// GetGlobal retrieves a global variable from the sandbox base environment
|
|
func (s *Sandbox) GetGlobal(name string) (any, error) {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return nil, fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
// Initialize if needed
|
|
if !s.initialized {
|
|
if err := s.initializeUnlocked(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Get the global from the base environment
|
|
return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`)
|
|
}
|
|
|
|
// Run executes Lua code in the sandbox
|
|
func (s *Sandbox) Run(code string) (any, error) {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return nil, fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
// Initialize if needed
|
|
if !s.initialized {
|
|
if err := s.initializeUnlocked(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Simplified wrapper for multiple return values
|
|
wrappedCode := `
|
|
local function _execfunc()
|
|
` + code + `
|
|
end
|
|
|
|
-- Process results to match expected format
|
|
local function _wrapresults(...)
|
|
local n = select('#', ...)
|
|
if n == 0 then
|
|
return nil
|
|
elseif n == 1 then
|
|
return select(1, ...)
|
|
else
|
|
local results = {}
|
|
for i = 1, n do
|
|
results[i] = select(i, ...)
|
|
end
|
|
return results
|
|
end
|
|
end
|
|
|
|
return _wrapresults(_execfunc())
|
|
`
|
|
|
|
// Compile the code
|
|
if err := s.state.LoadString(wrappedCode); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get the sandbox executor
|
|
s.state.GetGlobal("__execute_sandbox")
|
|
|
|
// Push the function as argument
|
|
s.state.PushCopy(-2)
|
|
s.state.Remove(-3)
|
|
|
|
// Execute in sandbox
|
|
if err := s.state.Call(1, 1); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get result
|
|
result, err := s.state.ToValue(-1)
|
|
s.state.Pop(1)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s.processResult(result), nil
|
|
}
|
|
|
|
// RunFile executes a Lua file in the sandbox
|
|
func (s *Sandbox) RunFile(filename string) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
return s.state.DoFile(filename)
|
|
}
|
|
|
|
// Compile compiles Lua code to bytecode
|
|
func (s *Sandbox) Compile(code string) ([]byte, error) {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return nil, fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
return s.state.CompileBytecode(code, "sandbox")
|
|
}
|
|
|
|
// RunBytecode executes precompiled Lua bytecode
|
|
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return nil, fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
// Initialize if needed
|
|
if !s.initialized {
|
|
if err := s.initializeUnlocked(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Load the bytecode
|
|
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get the sandbox executor
|
|
s.state.GetGlobal("__execute_sandbox")
|
|
|
|
// Push bytecode function
|
|
s.state.PushCopy(-2)
|
|
s.state.Remove(-3)
|
|
|
|
// Execute in sandbox
|
|
if err := s.state.Call(1, 1); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get result
|
|
result, err := s.state.ToValue(-1)
|
|
s.state.Pop(1)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s.processResult(result), nil
|
|
}
|
|
|
|
// LoadModule loads a Lua module
|
|
func (s *Sandbox) LoadModule(name string) error {
|
|
code := fmt.Sprintf("require('%s')", name)
|
|
_, err := s.Run(code)
|
|
return err
|
|
}
|
|
|
|
// SetPackagePath sets the sandbox package.path
|
|
func (s *Sandbox) SetPackagePath(path string) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
// Update global package.path
|
|
if err := s.state.SetPackagePath(path); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Initialize if needed
|
|
if !s.initialized {
|
|
if err := s.initializeUnlocked(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Update base environment's package.path
|
|
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
|
|
}
|
|
|
|
// AddPackagePath adds a path to the sandbox package.path
|
|
func (s *Sandbox) AddPackagePath(path string) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
// Update global package.path
|
|
if err := s.state.AddPackagePath(path); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Initialize if needed
|
|
if !s.initialized {
|
|
if err := s.initializeUnlocked(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Update base environment's package.path
|
|
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
|
|
}
|
|
|
|
// AddModule adds a module to the sandbox environment
|
|
func (s *Sandbox) AddModule(name string, module any) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
s.modules[name] = module
|
|
return nil
|
|
}
|
|
|
|
// AddPermanentLua adds Lua code to the environment permanently
|
|
func (s *Sandbox) AddPermanentLua(code string) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
// Initialize if needed
|
|
if !s.initialized {
|
|
if err := s.initializeUnlocked(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Simplified approach to add code to base environment
|
|
return s.state.DoString(`
|
|
local f, err = loadstring([=[` + code + `]=], "permanent")
|
|
if not f then error(err, 0) end
|
|
|
|
local env = setmetatable({}, {__index = __env_system.base_env})
|
|
setfenv(f, env)
|
|
|
|
local ok, err = pcall(f)
|
|
if not ok then error(err, 0) end
|
|
|
|
for k, v in pairs(env) do
|
|
__env_system.base_env[k] = v
|
|
end
|
|
`)
|
|
}
|
|
|
|
// ResetEnvironment resets the sandbox to its initial state
|
|
func (s *Sandbox) ResetEnvironment() error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.state == nil {
|
|
return fmt.Errorf("sandbox is closed")
|
|
}
|
|
|
|
// Clear the environment system
|
|
s.state.DoString(`__env_system = nil`)
|
|
|
|
// Reinitialize
|
|
s.initialized = false
|
|
if err := s.initializeUnlocked(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Re-register all functions
|
|
for name, fn := range s.functions {
|
|
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// unwrapResult processes results from Lua executions
|
|
func (s *Sandbox) processResult(result any) any {
|
|
// Handle []float64 (common LuaJIT return type)
|
|
if floats, ok := result.([]float64); ok {
|
|
if len(floats) == 1 {
|
|
// Single number - return as float64
|
|
return floats[0]
|
|
}
|
|
// Multiple numbers - MUST convert to []any for tests to pass
|
|
anySlice := make([]any, len(floats))
|
|
for i, v := range floats {
|
|
anySlice[i] = v
|
|
}
|
|
return anySlice
|
|
}
|
|
|
|
// Handle maps with numeric keys (Lua tables)
|
|
if m, ok := result.(map[string]any); ok {
|
|
// Handle return tables with special structure
|
|
if vals, ok := m[""]; ok {
|
|
// This is a special case used by some Lua returns
|
|
if arr, ok := vals.([]float64); ok {
|
|
// Convert to []any for consistency
|
|
anySlice := make([]any, len(arr))
|
|
for i, v := range arr {
|
|
anySlice[i] = v
|
|
}
|
|
return anySlice
|
|
}
|
|
return vals
|
|
}
|
|
|
|
if len(m) == 1 {
|
|
// Check for single value map
|
|
for k, v := range m {
|
|
if k == "1" {
|
|
return v
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Other array types should be preserved
|
|
return result
|
|
}
|