sandbox 2

This commit is contained in:
Sky Johnson 2025-03-27 21:31:41 -05:00
parent f106dfd9ea
commit faab0a2d08
2 changed files with 256 additions and 51 deletions

View File

@ -14,6 +14,7 @@ type Sandbox struct {
mutex sync.Mutex
initialized bool
modules map[string]any
functions map[string]GoFunction
}
// NewSandbox creates a new sandbox with standard libraries loaded
@ -22,6 +23,7 @@ func NewSandbox() *Sandbox {
state: New(),
initialized: false,
modules: make(map[string]any),
functions: make(map[string]GoFunction),
}
}
@ -57,6 +59,9 @@ func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
return err
}
// Store function for re-registration
s.functions[name] = fn
// Add to base environment
return s.state.DoString(`
-- Add the function to base environment
@ -172,12 +177,7 @@ func (s *Sandbox) Run(code string) (any, error) {
result, err := s.state.ToValue(-1)
s.state.Pop(1)
// Handle multiple return values
if results, ok := result.([]any); ok && len(results) == 1 {
return results[0], err
}
return result, err
return s.unwrapResult(result, err)
}
// RunFile executes a Lua file in the sandbox
@ -225,46 +225,13 @@ func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
return nil, err
}
// Add wrapper for multiple return values
if err := s.state.DoString(`
__wrap_bytecode = function(f)
local function _wrapresults(...)
local results = {n = select('#', ...)}
for i = 1, results.n do
results[i] = select(i, ...)
end
return results
end
return function()
return _wrapresults(f())
end
end
`); err != nil {
return nil, err
}
// Get wrapper function
s.state.GetGlobal("__wrap_bytecode")
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push bytecode function
s.state.PushCopy(-2)
// Call wrapper to create wrapped function
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Remove original bytecode function
s.state.Remove(-2)
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push wrapped function
s.state.PushCopy(-2)
// Remove the wrapped function
s.state.Remove(-3)
// Execute in sandbox
@ -276,12 +243,7 @@ func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
result, err := s.state.ToValue(-1)
s.state.Pop(1)
// Handle multiple return values
if results, ok := result.([]any); ok && len(results) == 1 {
return results[0], err
}
return result, err
return s.unwrapResult(result, err)
}
// getResults collects results from the stack (must be called with mutex locked)
@ -326,7 +288,22 @@ func (s *Sandbox) SetPackagePath(path string) error {
return fmt.Errorf("sandbox is closed")
}
return s.state.SetPackagePath(path)
// Update global package.path
if err := s.state.SetPackagePath(path); err != nil {
return err
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`
__env_system.base_env.package.path = package.path
`)
}
// AddPackagePath adds a path to the sandbox package.path
@ -338,7 +315,22 @@ func (s *Sandbox) AddPackagePath(path string) error {
return fmt.Errorf("sandbox is closed")
}
return s.state.AddPackagePath(path)
// Update global package.path
if err := s.state.AddPackagePath(path); err != nil {
return err
}
// Make sure sandbox is initialized
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`
__env_system.base_env.package.path = package.path
`)
}
// AddModule adds a module to the sandbox environment
@ -484,7 +476,7 @@ func (s *Sandbox) initializeUnlocked() error {
else
-- Create new environment with metatable inheritance
env = setmetatable({}, {
__index = _G -- Use global environment as fallback
__index = __env_system.base_env -- Use base env instead of _G
})
end
@ -520,6 +512,13 @@ func (s *Sandbox) initializeUnlocked() error {
-- Execute with protected call
local success, result = pcall(f)
-- Copy all globals to base environment
for k, v in pairs(env) do
if k ~= "_G" and type(k) == "string" then
__env_system.base_env[k] = v
end
end
-- Recycle environment
__recycle_env(env)
@ -600,7 +599,87 @@ func (s *Sandbox) ResetEnvironment() error {
return fmt.Errorf("sandbox is closed")
}
// Clear the environment system completely
err := s.state.DoString(`
-- Reset environment system
__env_system = nil
__wrap_bytecode = nil
__last_env = nil
`)
if err != nil {
return err
}
// Reinitialize the environment system
s.initialized = false
return s.Initialize()
if err := s.initializeUnlocked(); err != nil {
return err
}
// Re-register all functions
for name, fn := range s.functions {
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
if err := s.state.DoString(`
-- Add the function to base environment
__env_system.base_env["` + name + `"] = ` + name + `
`); err != nil {
return err
}
}
return nil
}
// unwrapResult processes the raw result value from Lua
// and unwraps single values from special map structures
func (s *Sandbox) unwrapResult(result any, err error) (any, error) {
// Unwrap array stored in map with empty key
if m, ok := result.(map[string]any); ok {
// Check for special array format
if arr, ok := m[""]; ok {
// If the array has only one element, return that element
if slice, ok := arr.([]float64); ok {
if len(slice) == 1 {
return slice[0], err
}
// Convert []float64 to []any for consistency
anySlice := make([]any, len(slice))
for i, v := range slice {
anySlice[i] = v
}
return anySlice, err
}
if slice, ok := arr.([]any); ok && len(slice) == 1 {
return slice[0], err
}
result = arr
} else if len(m) == 1 {
// When there's exactly one item, return its value directly
for _, v := range m {
return v, err
}
}
}
// Convert []float64 to []any for consistency with multiple returns
if slice, ok := result.([]float64); ok {
if len(slice) == 1 {
return slice[0], err
}
anySlice := make([]any, len(slice))
for i, v := range slice {
anySlice[i] = v
}
return anySlice, err
}
// Handle multiple return values
if results, ok := result.([]any); ok && len(results) == 1 {
return results[0], err
}
return result, err
}

View File

@ -648,3 +648,129 @@ func createTempLuaFile(content string) (string, error) {
func removeTempFile(path string) {
os.Remove(path)
}
// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code
func BenchmarkSandboxLuaExecution(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Simple Lua code that does some computation
code := `
local sum = 0
for i = 1, 100 do
sum = sum + i
end
return sum
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode
func BenchmarkSandboxBytecodeExecution(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Same code as above, but precompiled
code := `
local sum = 0
for i = 1, 100 do
sum = sum + i
end
return sum
`
// Compile the bytecode once
bytecode, err := sandbox.Compile(code)
if err != nil {
b.Fatalf("Failed to compile bytecode: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.RunBytecode(bytecode)
if err != nil {
b.Fatalf("Failed to run bytecode: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxComplexComputation measures performance with more complex computation
func BenchmarkSandboxComplexComputation(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// More complex Lua code that calculates Fibonacci numbers
code := `
function fibonacci(n)
if n <= 1 then
return n
end
return fibonacci(n-1) + fibonacci(n-2)
end
return fibonacci(15) -- Not too high to avoid excessive runtime
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(610) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function
func BenchmarkSandboxFunctionCall(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Register a simple addition function
add := func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1
}
err := sandbox.RegisterFunction("add", add)
if err != nil {
b.Fatalf("Failed to register function: %v", err)
}
// Lua code that calls the Go function in a loop
code := `
local sum = 0
for i = 1, 100 do
sum = add(sum, i)
end
return sum
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}