From faab0a2d085eca0284df4e14429ff33fe47149f4 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 27 Mar 2025 21:31:41 -0500 Subject: [PATCH] sandbox 2 --- sandbox.go | 181 ++++++++++++++++++++++++++++++------------ tests/sandbox_test.go | 126 +++++++++++++++++++++++++++++ 2 files changed, 256 insertions(+), 51 deletions(-) diff --git a/sandbox.go b/sandbox.go index 04fb667..ce093de 100644 --- a/sandbox.go +++ b/sandbox.go @@ -14,6 +14,7 @@ type Sandbox struct { mutex sync.Mutex initialized bool modules map[string]any + functions map[string]GoFunction } // NewSandbox creates a new sandbox with standard libraries loaded @@ -22,6 +23,7 @@ func NewSandbox() *Sandbox { state: New(), initialized: false, modules: make(map[string]any), + functions: make(map[string]GoFunction), } } @@ -57,6 +59,9 @@ func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error { return err } + // Store function for re-registration + s.functions[name] = fn + // Add to base environment return s.state.DoString(` -- Add the function to base environment @@ -172,12 +177,7 @@ func (s *Sandbox) Run(code string) (any, error) { result, err := s.state.ToValue(-1) s.state.Pop(1) - // Handle multiple return values - if results, ok := result.([]any); ok && len(results) == 1 { - return results[0], err - } - - return result, err + return s.unwrapResult(result, err) } // RunFile executes a Lua file in the sandbox @@ -225,46 +225,13 @@ func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) { return nil, err } - // Add wrapper for multiple return values - if err := s.state.DoString(` - __wrap_bytecode = function(f) - local function _wrapresults(...) - local results = {n = select('#', ...)} - for i = 1, results.n do - results[i] = select(i, ...) - end - return results - end - - return function() - return _wrapresults(f()) - end - end - `); err != nil { - return nil, err - } - - // Get wrapper function - s.state.GetGlobal("__wrap_bytecode") + // Get the sandbox executor + s.state.GetGlobal("__execute_sandbox") // Push bytecode function s.state.PushCopy(-2) - // Call wrapper to create wrapped function - if err := s.state.Call(1, 1); err != nil { - return nil, err - } - // Remove original bytecode function - s.state.Remove(-2) - - // Get the sandbox executor - s.state.GetGlobal("__execute_sandbox") - - // Push wrapped function - s.state.PushCopy(-2) - - // Remove the wrapped function s.state.Remove(-3) // Execute in sandbox @@ -276,12 +243,7 @@ func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) { result, err := s.state.ToValue(-1) s.state.Pop(1) - // Handle multiple return values - if results, ok := result.([]any); ok && len(results) == 1 { - return results[0], err - } - - return result, err + return s.unwrapResult(result, err) } // getResults collects results from the stack (must be called with mutex locked) @@ -326,7 +288,22 @@ func (s *Sandbox) SetPackagePath(path string) error { return fmt.Errorf("sandbox is closed") } - return s.state.SetPackagePath(path) + // Update global package.path + if err := s.state.SetPackagePath(path); err != nil { + return err + } + + // Make sure sandbox is initialized + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return err + } + } + + // Update base environment's package.path + return s.state.DoString(` + __env_system.base_env.package.path = package.path + `) } // AddPackagePath adds a path to the sandbox package.path @@ -338,7 +315,22 @@ func (s *Sandbox) AddPackagePath(path string) error { return fmt.Errorf("sandbox is closed") } - return s.state.AddPackagePath(path) + // Update global package.path + if err := s.state.AddPackagePath(path); err != nil { + return err + } + + // Make sure sandbox is initialized + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return err + } + } + + // Update base environment's package.path + return s.state.DoString(` + __env_system.base_env.package.path = package.path + `) } // AddModule adds a module to the sandbox environment @@ -484,7 +476,7 @@ func (s *Sandbox) initializeUnlocked() error { else -- Create new environment with metatable inheritance env = setmetatable({}, { - __index = _G -- Use global environment as fallback + __index = __env_system.base_env -- Use base env instead of _G }) end @@ -520,6 +512,13 @@ func (s *Sandbox) initializeUnlocked() error { -- Execute with protected call local success, result = pcall(f) + -- Copy all globals to base environment + for k, v in pairs(env) do + if k ~= "_G" and type(k) == "string" then + __env_system.base_env[k] = v + end + end + -- Recycle environment __recycle_env(env) @@ -600,7 +599,87 @@ func (s *Sandbox) ResetEnvironment() error { return fmt.Errorf("sandbox is closed") } + // Clear the environment system completely + err := s.state.DoString(` + -- Reset environment system + __env_system = nil + __wrap_bytecode = nil + __last_env = nil + `) + if err != nil { + return err + } + // Reinitialize the environment system s.initialized = false - return s.Initialize() + if err := s.initializeUnlocked(); err != nil { + return err + } + + // Re-register all functions + for name, fn := range s.functions { + if err := s.state.RegisterGoFunction(name, fn); err != nil { + return err + } + + if err := s.state.DoString(` + -- Add the function to base environment + __env_system.base_env["` + name + `"] = ` + name + ` + `); err != nil { + return err + } + } + + return nil +} + +// unwrapResult processes the raw result value from Lua +// and unwraps single values from special map structures +func (s *Sandbox) unwrapResult(result any, err error) (any, error) { + // Unwrap array stored in map with empty key + if m, ok := result.(map[string]any); ok { + // Check for special array format + if arr, ok := m[""]; ok { + // If the array has only one element, return that element + if slice, ok := arr.([]float64); ok { + if len(slice) == 1 { + return slice[0], err + } + // Convert []float64 to []any for consistency + anySlice := make([]any, len(slice)) + for i, v := range slice { + anySlice[i] = v + } + return anySlice, err + } + if slice, ok := arr.([]any); ok && len(slice) == 1 { + return slice[0], err + } + result = arr + } else if len(m) == 1 { + // When there's exactly one item, return its value directly + for _, v := range m { + return v, err + } + } + } + + // Convert []float64 to []any for consistency with multiple returns + if slice, ok := result.([]float64); ok { + if len(slice) == 1 { + return slice[0], err + } + anySlice := make([]any, len(slice)) + for i, v := range slice { + anySlice[i] = v + } + return anySlice, err + } + + // Handle multiple return values + if results, ok := result.([]any); ok && len(results) == 1 { + return results[0], err + } + + return result, err } diff --git a/tests/sandbox_test.go b/tests/sandbox_test.go index 2784b54..fa1ba35 100644 --- a/tests/sandbox_test.go +++ b/tests/sandbox_test.go @@ -648,3 +648,129 @@ func createTempLuaFile(content string) (string, error) { func removeTempFile(path string) { os.Remove(path) } + +// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code +func BenchmarkSandboxLuaExecution(b *testing.B) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Simple Lua code that does some computation + code := ` + local sum = 0 + for i = 1, 100 do + sum = sum + i + end + return sum + ` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + result, err := sandbox.Run(code) + if err != nil { + b.Fatalf("Failed to run code: %v", err) + } + if result != float64(5050) { + b.Fatalf("Incorrect result: %v", result) + } + } +} + +// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode +func BenchmarkSandboxBytecodeExecution(b *testing.B) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Same code as above, but precompiled + code := ` + local sum = 0 + for i = 1, 100 do + sum = sum + i + end + return sum + ` + + // Compile the bytecode once + bytecode, err := sandbox.Compile(code) + if err != nil { + b.Fatalf("Failed to compile bytecode: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + result, err := sandbox.RunBytecode(bytecode) + if err != nil { + b.Fatalf("Failed to run bytecode: %v", err) + } + if result != float64(5050) { + b.Fatalf("Incorrect result: %v", result) + } + } +} + +// BenchmarkSandboxComplexComputation measures performance with more complex computation +func BenchmarkSandboxComplexComputation(b *testing.B) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // More complex Lua code that calculates Fibonacci numbers + code := ` + function fibonacci(n) + if n <= 1 then + return n + end + return fibonacci(n-1) + fibonacci(n-2) + end + + return fibonacci(15) -- Not too high to avoid excessive runtime + ` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + result, err := sandbox.Run(code) + if err != nil { + b.Fatalf("Failed to run code: %v", err) + } + if result != float64(610) { + b.Fatalf("Incorrect result: %v", result) + } + } +} + +// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function +func BenchmarkSandboxFunctionCall(b *testing.B) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Register a simple addition function + add := func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a + b) + return 1 + } + + err := sandbox.RegisterFunction("add", add) + if err != nil { + b.Fatalf("Failed to register function: %v", err) + } + + // Lua code that calls the Go function in a loop + code := ` + local sum = 0 + for i = 1, 100 do + sum = add(sum, i) + end + return sum + ` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + result, err := sandbox.Run(code) + if err != nil { + b.Fatalf("Failed to run code: %v", err) + } + if result != float64(5050) { + b.Fatalf("Incorrect result: %v", result) + } + } +}