optimize sandbox
This commit is contained in:
parent
4ad87f81f3
commit
875abee366
862
sandbox.go
862
sandbox.go
|
@ -38,331 +38,21 @@ func (s *Sandbox) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
// RegisterFunction registers a Go function in the sandbox
|
||||
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Make sure sandbox is initialized
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Register function globally
|
||||
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store function for re-registration
|
||||
s.functions[name] = fn
|
||||
|
||||
// Add to base environment
|
||||
return s.state.DoString(`
|
||||
-- Add the function to base environment
|
||||
__env_system.base_env["` + name + `"] = ` + name + `
|
||||
`)
|
||||
}
|
||||
|
||||
// SetGlobal sets a global variable in the sandbox base environment
|
||||
func (s *Sandbox) SetGlobal(name string, value any) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Make sure sandbox is initialized
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Push the value onto the stack
|
||||
if err := s.state.PushValue(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the global with the pushed value
|
||||
s.state.SetGlobal(name)
|
||||
|
||||
// Add to base environment
|
||||
return s.state.DoString(`
|
||||
-- Add the global to base environment
|
||||
__env_system.base_env["` + name + `"] = ` + name + `
|
||||
`)
|
||||
}
|
||||
|
||||
// GetGlobal retrieves a global variable from the sandbox base environment
|
||||
func (s *Sandbox) GetGlobal(name string) (any, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Make sure sandbox is initialized
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get the global from the base environment
|
||||
code := `return __env_system.base_env["` + name + `"]`
|
||||
return s.state.ExecuteWithResult(code)
|
||||
}
|
||||
|
||||
// Run executes Lua code in the sandbox and returns the result
|
||||
func (s *Sandbox) Run(code string) (any, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Make sure sandbox is initialized
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Add wrapper for multiple return values
|
||||
wrappedCode := `
|
||||
local function _execfunc()
|
||||
` + code + `
|
||||
end
|
||||
|
||||
local function _wrapresults(...)
|
||||
local results = {n = select('#', ...)}
|
||||
for i = 1, results.n do
|
||||
results[i] = select(i, ...)
|
||||
end
|
||||
return results
|
||||
end
|
||||
|
||||
return _wrapresults(_execfunc())
|
||||
`
|
||||
|
||||
// Compile the code
|
||||
if err := s.state.LoadString(wrappedCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the sandbox executor
|
||||
s.state.GetGlobal("__execute_sandbox")
|
||||
|
||||
// Setup call with correct argument order
|
||||
s.state.PushCopy(-2) // Copy the function
|
||||
|
||||
// Remove the original function
|
||||
s.state.Remove(-3)
|
||||
|
||||
// Execute in sandbox
|
||||
if err := s.state.Call(1, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get result
|
||||
result, err := s.state.ToValue(-1)
|
||||
s.state.Pop(1)
|
||||
|
||||
return s.unwrapResult(result, err)
|
||||
}
|
||||
|
||||
// RunFile executes a Lua file in the sandbox
|
||||
func (s *Sandbox) RunFile(filename string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
return s.state.DoFile(filename)
|
||||
}
|
||||
|
||||
// Compile compiles Lua code to bytecode without executing it
|
||||
func (s *Sandbox) Compile(code string) ([]byte, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
return s.state.CompileBytecode(code, "sandbox")
|
||||
}
|
||||
|
||||
// RunBytecode executes precompiled Lua bytecode in the sandbox
|
||||
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Make sure sandbox is initialized
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Load the bytecode
|
||||
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the sandbox executor
|
||||
s.state.GetGlobal("__execute_sandbox")
|
||||
|
||||
// Push bytecode function
|
||||
s.state.PushCopy(-2)
|
||||
|
||||
// Remove original bytecode function
|
||||
s.state.Remove(-3)
|
||||
|
||||
// Execute in sandbox
|
||||
if err := s.state.Call(1, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get result
|
||||
result, err := s.state.ToValue(-1)
|
||||
s.state.Pop(1)
|
||||
|
||||
return s.unwrapResult(result, err)
|
||||
}
|
||||
|
||||
// getResults collects results from the stack (must be called with mutex locked)
|
||||
func (s *Sandbox) getResults() (any, error) {
|
||||
numResults := s.state.GetTop()
|
||||
if numResults == 0 {
|
||||
return nil, nil
|
||||
} else if numResults == 1 {
|
||||
// Return single result directly
|
||||
value, err := s.state.ToValue(-1)
|
||||
s.state.Pop(1)
|
||||
return value, err
|
||||
}
|
||||
|
||||
// Return multiple results as slice
|
||||
results := make([]any, numResults)
|
||||
for i := 0; i < numResults; i++ {
|
||||
value, err := s.state.ToValue(i - numResults)
|
||||
if err != nil {
|
||||
s.state.Pop(numResults)
|
||||
return nil, err
|
||||
}
|
||||
results[i] = value
|
||||
}
|
||||
s.state.Pop(numResults)
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// LoadModule loads a Lua module in the sandbox
|
||||
func (s *Sandbox) LoadModule(name string) error {
|
||||
code := fmt.Sprintf("require('%s')", name)
|
||||
_, err := s.Run(code)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetPackagePath sets the sandbox package.path
|
||||
func (s *Sandbox) SetPackagePath(path string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Update global package.path
|
||||
if err := s.state.SetPackagePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure sandbox is initialized
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update base environment's package.path
|
||||
return s.state.DoString(`
|
||||
__env_system.base_env.package.path = package.path
|
||||
`)
|
||||
}
|
||||
|
||||
// AddPackagePath adds a path to the sandbox package.path
|
||||
func (s *Sandbox) AddPackagePath(path string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Update global package.path
|
||||
if err := s.state.AddPackagePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure sandbox is initialized
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update base environment's package.path
|
||||
return s.state.DoString(`
|
||||
__env_system.base_env.package.path = package.path
|
||||
`)
|
||||
}
|
||||
|
||||
// AddModule adds a module to the sandbox environment
|
||||
func (s *Sandbox) AddModule(name string, module any) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
s.modules[name] = module
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize sets up the environment system
|
||||
func (s *Sandbox) Initialize() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
return s.initializeUnlocked()
|
||||
}
|
||||
|
||||
// initializeUnlocked sets up the environment system without locking
|
||||
// It should only be called when the mutex is already locked
|
||||
func (s *Sandbox) initializeUnlocked() error {
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
if s.initialized {
|
||||
return nil // Already initialized
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register modules
|
||||
|
@ -385,85 +75,74 @@ func (s *Sandbox) initializeUnlocked() error {
|
|||
}
|
||||
s.state.Pop(1)
|
||||
|
||||
// Create the environment system
|
||||
// Create simplified environment system
|
||||
err := s.state.DoString(`
|
||||
-- Global shared environment (created once)
|
||||
__env_system = __env_system or {
|
||||
base_env = nil, -- Template environment
|
||||
initialized = false, -- Initialization flag
|
||||
env_pool = {}, -- Pre-allocated environment pool
|
||||
pool_size = 0, -- Current pool size
|
||||
max_pool_size = 8 -- Maximum pool size
|
||||
-- Global shared environment
|
||||
__env_system = {
|
||||
base_env = {}, -- Template environment
|
||||
env_pool = {}, -- Pre-allocated environment pool
|
||||
pool_size = 0, -- Current pool size
|
||||
max_pool_size = 8 -- Maximum pool size
|
||||
}
|
||||
|
||||
-- Initialize base environment once
|
||||
if not __env_system.initialized then
|
||||
-- Create base environment with all standard libraries
|
||||
local base = {}
|
||||
-- Create base environment with standard libraries
|
||||
local base = __env_system.base_env
|
||||
|
||||
-- Safe standard libraries
|
||||
base.string = string
|
||||
base.table = table
|
||||
base.math = math
|
||||
base.os = {
|
||||
time = os.time,
|
||||
date = os.date,
|
||||
difftime = os.difftime,
|
||||
clock = os.clock
|
||||
}
|
||||
-- Safe standard libraries
|
||||
base.string = string
|
||||
base.table = table
|
||||
base.math = math
|
||||
base.os = {
|
||||
time = os.time,
|
||||
date = os.date,
|
||||
difftime = os.difftime,
|
||||
clock = os.clock
|
||||
}
|
||||
|
||||
-- Basic functions
|
||||
base.print = print
|
||||
base.tonumber = tonumber
|
||||
base.tostring = tostring
|
||||
base.type = type
|
||||
base.pairs = pairs
|
||||
base.ipairs = ipairs
|
||||
base.next = next
|
||||
base.select = select
|
||||
base.pcall = pcall
|
||||
base.xpcall = xpcall
|
||||
base.error = error
|
||||
base.assert = assert
|
||||
base.collectgarbage = collectgarbage
|
||||
base.unpack = unpack or table.unpack
|
||||
-- Basic functions
|
||||
base.print = print
|
||||
base.tonumber = tonumber
|
||||
base.tostring = tostring
|
||||
base.type = type
|
||||
base.pairs = pairs
|
||||
base.ipairs = ipairs
|
||||
base.next = next
|
||||
base.select = select
|
||||
base.pcall = pcall
|
||||
base.xpcall = xpcall
|
||||
base.error = error
|
||||
base.assert = assert
|
||||
base.collectgarbage = collectgarbage
|
||||
base.unpack = unpack or table.unpack
|
||||
|
||||
-- Package system
|
||||
base.package = {
|
||||
loaded = {},
|
||||
path = package.path,
|
||||
preload = {}
|
||||
}
|
||||
-- Package system
|
||||
base.package = {
|
||||
loaded = {},
|
||||
path = package.path,
|
||||
preload = {}
|
||||
}
|
||||
|
||||
base.require = function(modname)
|
||||
if base.package.loaded[modname] then
|
||||
return base.package.loaded[modname]
|
||||
end
|
||||
|
||||
local loader = base.package.preload[modname]
|
||||
if type(loader) == "function" then
|
||||
local result = loader(modname)
|
||||
base.package.loaded[modname] = result or true
|
||||
return result
|
||||
end
|
||||
|
||||
error("module '" .. modname .. "' not found", 2)
|
||||
base.require = function(modname)
|
||||
if base.package.loaded[modname] then
|
||||
return base.package.loaded[modname]
|
||||
end
|
||||
|
||||
-- Add registered custom modules
|
||||
if __sandbox_modules then
|
||||
for name, mod in pairs(__sandbox_modules) do
|
||||
base[name] = mod
|
||||
end
|
||||
local loader = base.package.preload[modname]
|
||||
if type(loader) == "function" then
|
||||
local result = loader(modname)
|
||||
base.package.loaded[modname] = result or true
|
||||
return result
|
||||
end
|
||||
|
||||
-- Store base environment
|
||||
__env_system.base_env = base
|
||||
__env_system.initialized = true
|
||||
error("module '" .. modname .. "' not found", 2)
|
||||
end
|
||||
|
||||
-- Global variable for tracking current environment
|
||||
__last_env = nil
|
||||
-- Add registered custom modules
|
||||
if __sandbox_modules then
|
||||
for name, mod in pairs(__sandbox_modules) do
|
||||
base[name] = mod
|
||||
end
|
||||
end
|
||||
|
||||
-- Get an environment for execution
|
||||
function __get_sandbox_env()
|
||||
|
@ -476,19 +155,15 @@ func (s *Sandbox) initializeUnlocked() error {
|
|||
else
|
||||
-- Create new environment with metatable inheritance
|
||||
env = setmetatable({}, {
|
||||
__index = __env_system.base_env -- Use base env instead of _G
|
||||
__index = __env_system.base_env
|
||||
})
|
||||
end
|
||||
|
||||
-- Store reference to current environment
|
||||
__last_env = env
|
||||
|
||||
return env
|
||||
end
|
||||
|
||||
-- Return environment to pool for reuse
|
||||
function __recycle_env(env)
|
||||
-- Only recycle if pool isn't full
|
||||
if __env_system.pool_size < __env_system.max_pool_size then
|
||||
-- Clear all fields except metatable
|
||||
for k in pairs(env) do
|
||||
|
@ -512,7 +187,7 @@ func (s *Sandbox) initializeUnlocked() error {
|
|||
-- Execute with protected call
|
||||
local success, result = pcall(f)
|
||||
|
||||
-- Copy all globals to base environment
|
||||
-- Update base environment with new globals
|
||||
for k, v in pairs(env) do
|
||||
if k ~= "_G" and type(k) == "string" then
|
||||
__env_system.base_env[k] = v
|
||||
|
@ -526,16 +201,7 @@ func (s *Sandbox) initializeUnlocked() error {
|
|||
if not success then
|
||||
error(result, 0)
|
||||
end
|
||||
|
||||
-- Handle multiple return values
|
||||
if type(result) == "table" and result.n ~= nil then
|
||||
local returnValues = {}
|
||||
for i=1, result.n do
|
||||
returnValues[i] = result[i]
|
||||
end
|
||||
return returnValues
|
||||
end
|
||||
|
||||
|
||||
return result
|
||||
end
|
||||
`)
|
||||
|
@ -548,8 +214,290 @@ func (s *Sandbox) initializeUnlocked() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// RegisterFunction registers a Go function in the sandbox
|
||||
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Register function globally
|
||||
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store function for re-registration
|
||||
s.functions[name] = fn
|
||||
|
||||
// Add to base environment
|
||||
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
|
||||
}
|
||||
|
||||
// SetGlobal sets a global variable in the sandbox base environment
|
||||
func (s *Sandbox) SetGlobal(name string, value any) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Push the value onto the stack
|
||||
if err := s.state.PushValue(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the global with the pushed value
|
||||
s.state.SetGlobal(name)
|
||||
|
||||
// Add to base environment
|
||||
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
|
||||
}
|
||||
|
||||
// GetGlobal retrieves a global variable from the sandbox base environment
|
||||
func (s *Sandbox) GetGlobal(name string) (any, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get the global from the base environment
|
||||
return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`)
|
||||
}
|
||||
|
||||
// Run executes Lua code in the sandbox
|
||||
func (s *Sandbox) Run(code string) (any, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Simplified wrapper for multiple return values
|
||||
wrappedCode := `
|
||||
local function _execfunc()
|
||||
` + code + `
|
||||
end
|
||||
|
||||
-- Process results to match expected format
|
||||
local function _wrapresults(...)
|
||||
local n = select('#', ...)
|
||||
if n == 0 then
|
||||
return nil
|
||||
elseif n == 1 then
|
||||
return select(1, ...)
|
||||
else
|
||||
local results = {}
|
||||
for i = 1, n do
|
||||
results[i] = select(i, ...)
|
||||
end
|
||||
return results
|
||||
end
|
||||
end
|
||||
|
||||
return _wrapresults(_execfunc())
|
||||
`
|
||||
|
||||
// Compile the code
|
||||
if err := s.state.LoadString(wrappedCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the sandbox executor
|
||||
s.state.GetGlobal("__execute_sandbox")
|
||||
|
||||
// Push the function as argument
|
||||
s.state.PushCopy(-2)
|
||||
s.state.Remove(-3)
|
||||
|
||||
// Execute in sandbox
|
||||
if err := s.state.Call(1, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get result
|
||||
result, err := s.state.ToValue(-1)
|
||||
s.state.Pop(1)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.processResult(result), nil
|
||||
}
|
||||
|
||||
// RunFile executes a Lua file in the sandbox
|
||||
func (s *Sandbox) RunFile(filename string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
return s.state.DoFile(filename)
|
||||
}
|
||||
|
||||
// Compile compiles Lua code to bytecode
|
||||
func (s *Sandbox) Compile(code string) ([]byte, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
return s.state.CompileBytecode(code, "sandbox")
|
||||
}
|
||||
|
||||
// RunBytecode executes precompiled Lua bytecode
|
||||
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Load the bytecode
|
||||
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the sandbox executor
|
||||
s.state.GetGlobal("__execute_sandbox")
|
||||
|
||||
// Push bytecode function
|
||||
s.state.PushCopy(-2)
|
||||
s.state.Remove(-3)
|
||||
|
||||
// Execute in sandbox
|
||||
if err := s.state.Call(1, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get result
|
||||
result, err := s.state.ToValue(-1)
|
||||
s.state.Pop(1)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.processResult(result), nil
|
||||
}
|
||||
|
||||
// LoadModule loads a Lua module
|
||||
func (s *Sandbox) LoadModule(name string) error {
|
||||
code := fmt.Sprintf("require('%s')", name)
|
||||
_, err := s.Run(code)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetPackagePath sets the sandbox package.path
|
||||
func (s *Sandbox) SetPackagePath(path string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Update global package.path
|
||||
if err := s.state.SetPackagePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update base environment's package.path
|
||||
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
|
||||
}
|
||||
|
||||
// AddPackagePath adds a path to the sandbox package.path
|
||||
func (s *Sandbox) AddPackagePath(path string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Update global package.path
|
||||
if err := s.state.AddPackagePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update base environment's package.path
|
||||
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
|
||||
}
|
||||
|
||||
// AddModule adds a module to the sandbox environment
|
||||
func (s *Sandbox) AddModule(name string, module any) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
s.modules[name] = module
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPermanentLua adds Lua code to the environment permanently
|
||||
// This code becomes part of the base environment
|
||||
func (s *Sandbox) AddPermanentLua(code string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
@ -558,33 +506,25 @@ func (s *Sandbox) AddPermanentLua(code string) error {
|
|||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Make sure sandbox is initialized
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Add code to base environment
|
||||
// Simplified approach to add code to base environment
|
||||
return s.state.DoString(`
|
||||
-- First compile the code
|
||||
local f, err = loadstring([=[` + code + `]=], "permanent")
|
||||
if not f then
|
||||
error(err, 0)
|
||||
end
|
||||
|
||||
-- Create a temporary environment based on base env
|
||||
local temp_env = setmetatable({}, {__index = __env_system.base_env})
|
||||
setfenv(f, temp_env)
|
||||
|
||||
-- Run the code in the temporary environment
|
||||
if not f then error(err, 0) end
|
||||
|
||||
local env = setmetatable({}, {__index = __env_system.base_env})
|
||||
setfenv(f, env)
|
||||
|
||||
local ok, err = pcall(f)
|
||||
if not ok then
|
||||
error(err, 0)
|
||||
end
|
||||
|
||||
-- Copy new values to base environment
|
||||
for k, v in pairs(temp_env) do
|
||||
if not ok then error(err, 0) end
|
||||
|
||||
for k, v in pairs(env) do
|
||||
__env_system.base_env[k] = v
|
||||
end
|
||||
`)
|
||||
|
@ -599,18 +539,10 @@ func (s *Sandbox) ResetEnvironment() error {
|
|||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Clear the environment system completely
|
||||
err := s.state.DoString(`
|
||||
-- Reset environment system
|
||||
__env_system = nil
|
||||
__wrap_bytecode = nil
|
||||
__last_env = nil
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Clear the environment system
|
||||
s.state.DoString(`__env_system = nil`)
|
||||
|
||||
// Reinitialize the environment system
|
||||
// Reinitialize
|
||||
s.initialized = false
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
|
@ -622,10 +554,7 @@ func (s *Sandbox) ResetEnvironment() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if err := s.state.DoString(`
|
||||
-- Add the function to base environment
|
||||
__env_system.base_env["` + name + `"] = ` + name + `
|
||||
`); err != nil {
|
||||
if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -633,53 +562,48 @@ func (s *Sandbox) ResetEnvironment() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// unwrapResult processes the raw result value from Lua
|
||||
// and unwraps single values from special map structures
|
||||
func (s *Sandbox) unwrapResult(result any, err error) (any, error) {
|
||||
// Unwrap array stored in map with empty key
|
||||
if m, ok := result.(map[string]any); ok {
|
||||
// Check for special array format
|
||||
if arr, ok := m[""]; ok {
|
||||
// If the array has only one element, return that element
|
||||
if slice, ok := arr.([]float64); ok {
|
||||
if len(slice) == 1 {
|
||||
return slice[0], err
|
||||
}
|
||||
// Convert []float64 to []any for consistency
|
||||
anySlice := make([]any, len(slice))
|
||||
for i, v := range slice {
|
||||
anySlice[i] = v
|
||||
}
|
||||
return anySlice, err
|
||||
}
|
||||
if slice, ok := arr.([]any); ok && len(slice) == 1 {
|
||||
return slice[0], err
|
||||
}
|
||||
result = arr
|
||||
} else if len(m) == 1 {
|
||||
// When there's exactly one item, return its value directly
|
||||
for _, v := range m {
|
||||
return v, err
|
||||
}
|
||||
// unwrapResult processes results from Lua executions
|
||||
func (s *Sandbox) processResult(result any) any {
|
||||
// Handle []float64 (common LuaJIT return type)
|
||||
if floats, ok := result.([]float64); ok {
|
||||
if len(floats) == 1 {
|
||||
// Single number - return as float64
|
||||
return floats[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Convert []float64 to []any for consistency with multiple returns
|
||||
if slice, ok := result.([]float64); ok {
|
||||
if len(slice) == 1 {
|
||||
return slice[0], err
|
||||
}
|
||||
anySlice := make([]any, len(slice))
|
||||
for i, v := range slice {
|
||||
// Multiple numbers - MUST convert to []any for tests to pass
|
||||
anySlice := make([]any, len(floats))
|
||||
for i, v := range floats {
|
||||
anySlice[i] = v
|
||||
}
|
||||
return anySlice, err
|
||||
return anySlice
|
||||
}
|
||||
|
||||
// Handle multiple return values
|
||||
if results, ok := result.([]any); ok && len(results) == 1 {
|
||||
return results[0], err
|
||||
// Handle maps with numeric keys (Lua tables)
|
||||
if m, ok := result.(map[string]any); ok {
|
||||
// Handle return tables with special structure
|
||||
if vals, ok := m[""]; ok {
|
||||
// This is a special case used by some Lua returns
|
||||
if arr, ok := vals.([]float64); ok {
|
||||
// Convert to []any for consistency
|
||||
anySlice := make([]any, len(arr))
|
||||
for i, v := range arr {
|
||||
anySlice[i] = v
|
||||
}
|
||||
return anySlice
|
||||
}
|
||||
return vals
|
||||
}
|
||||
|
||||
if len(m) == 1 {
|
||||
// Check for single value map
|
||||
for k, v := range m {
|
||||
if k == "1" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, err
|
||||
// Other array types should be preserved
|
||||
return result
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user