From 656ac1a70336d77721851fe1fb44760a5b1668a9 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Fri, 28 Mar 2025 20:27:51 -0500 Subject: [PATCH] remove sandbox --- bench/sandbox_bench_test.go | 133 -------- sandbox.go | 609 --------------------------------- tests/sandbox_test.go | 650 ------------------------------------ 3 files changed, 1392 deletions(-) delete mode 100644 bench/sandbox_bench_test.go delete mode 100644 sandbox.go delete mode 100644 tests/sandbox_test.go diff --git a/bench/sandbox_bench_test.go b/bench/sandbox_bench_test.go deleted file mode 100644 index 312225e..0000000 --- a/bench/sandbox_bench_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package luajit_bench - -import ( - "testing" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code -func BenchmarkSandboxLuaExecution(b *testing.B) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Simple Lua code that does some computation - code := ` - local sum = 0 - for i = 1, 100 do - sum = sum + i - end - return sum - ` - - b.ResetTimer() - for i := 0; i < b.N; i++ { - result, err := sandbox.Run(code) - if err != nil { - b.Fatalf("Failed to run code: %v", err) - } - if result != float64(5050) { - b.Fatalf("Incorrect result: %v", result) - } - } -} - -// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode -func BenchmarkSandboxBytecodeExecution(b *testing.B) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Same code as above, but precompiled - code := ` - local sum = 0 - for i = 1, 100 do - sum = sum + i - end - return sum - ` - - // Compile the bytecode once - bytecode, err := sandbox.Compile(code) - if err != nil { - b.Fatalf("Failed to compile bytecode: %v", err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - result, err := sandbox.RunBytecode(bytecode) - if err != nil { - b.Fatalf("Failed to run bytecode: %v", err) - } - if result != float64(5050) { - b.Fatalf("Incorrect result: %v", result) - } - } -} - -// BenchmarkSandboxComplexComputation measures performance with more complex computation -func BenchmarkSandboxComplexComputation(b *testing.B) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // More complex Lua code that calculates Fibonacci numbers - code := ` - function fibonacci(n) - if n <= 1 then - return n - end - return fibonacci(n-1) + fibonacci(n-2) - end - - return fibonacci(15) -- Not too high to avoid excessive runtime - ` - - b.ResetTimer() - for i := 0; i < b.N; i++ { - result, err := sandbox.Run(code) - if err != nil { - b.Fatalf("Failed to run code: %v", err) - } - if result != float64(610) { - b.Fatalf("Incorrect result: %v", result) - } - } -} - -// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function -func BenchmarkSandboxFunctionCall(b *testing.B) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Register a simple addition function - add := func(s *luajit.State) int { - a := s.ToNumber(1) - b := s.ToNumber(2) - s.PushNumber(a + b) - return 1 - } - - err := sandbox.RegisterFunction("add", add) - if err != nil { - b.Fatalf("Failed to register function: %v", err) - } - - // Lua code that calls the Go function in a loop - code := ` - local sum = 0 - for i = 1, 100 do - sum = add(sum, i) - end - return sum - ` - - b.ResetTimer() - for i := 0; i < b.N; i++ { - result, err := sandbox.Run(code) - if err != nil { - b.Fatalf("Failed to run code: %v", err) - } - if result != float64(5050) { - b.Fatalf("Incorrect result: %v", result) - } - } -} diff --git a/sandbox.go b/sandbox.go deleted file mode 100644 index 3dd352c..0000000 --- a/sandbox.go +++ /dev/null @@ -1,609 +0,0 @@ -package luajit - -import ( - "fmt" - "sync" -) - -// LUA_MULTRET is the constant for multiple return values -const LUA_MULTRET = -1 - -// Sandbox provides a persistent Lua environment for executing scripts -type Sandbox struct { - state *State - mutex sync.Mutex - initialized bool - modules map[string]any - functions map[string]GoFunction -} - -// NewSandbox creates a new sandbox with standard libraries loaded -func NewSandbox() *Sandbox { - return &Sandbox{ - state: New(), - initialized: false, - modules: make(map[string]any), - functions: make(map[string]GoFunction), - } -} - -// Close releases all resources used by the sandbox -func (s *Sandbox) Close() { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state != nil { - s.state.Close() - s.state = nil - } -} - -// Initialize sets up the environment system -func (s *Sandbox) Initialize() error { - s.mutex.Lock() - defer s.mutex.Unlock() - return s.initializeUnlocked() -} - -// initializeUnlocked sets up the environment system without locking -func (s *Sandbox) initializeUnlocked() error { - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - if s.initialized { - return nil - } - - // Register modules - s.state.GetGlobal("__sandbox_modules") - if s.state.IsNil(-1) { - s.state.Pop(1) - s.state.NewTable() - s.state.SetGlobal("__sandbox_modules") - s.state.GetGlobal("__sandbox_modules") - } - - // Add modules - for name, module := range s.modules { - s.state.PushString(name) - if err := s.state.PushValue(module); err != nil { - s.state.Pop(2) - return err - } - s.state.SetTable(-3) - } - s.state.Pop(1) - - // Create simplified environment system - err := s.state.DoString(` - -- Global shared environment - __env_system = { - base_env = {}, -- Template environment - env_pool = {}, -- Pre-allocated environment pool - pool_size = 0, -- Current pool size - max_pool_size = 8 -- Maximum pool size - } - - -- Create base environment with standard libraries - local base = __env_system.base_env - - -- Safe standard libraries - base.string = string - base.table = table - base.math = math - base.os = { - time = os.time, - date = os.date, - difftime = os.difftime, - clock = os.clock - } - - -- Basic functions - base.print = print - base.tonumber = tonumber - base.tostring = tostring - base.type = type - base.pairs = pairs - base.ipairs = ipairs - base.next = next - base.select = select - base.pcall = pcall - base.xpcall = xpcall - base.error = error - base.assert = assert - base.collectgarbage = collectgarbage - base.unpack = unpack or table.unpack - - -- Package system - base.package = { - loaded = {}, - path = package.path, - preload = {} - } - - base.require = function(modname) - if base.package.loaded[modname] then - return base.package.loaded[modname] - end - - local loader = base.package.preload[modname] - if type(loader) == "function" then - local result = loader(modname) - base.package.loaded[modname] = result or true - return result - end - - error("module '" .. modname .. "' not found", 2) - end - - -- Add registered custom modules - if __sandbox_modules then - for name, mod in pairs(__sandbox_modules) do - base[name] = mod - end - end - - -- Get an environment for execution - function __get_sandbox_env() - local env - - -- Try to reuse from pool - if __env_system.pool_size > 0 then - env = table.remove(__env_system.env_pool) - __env_system.pool_size = __env_system.pool_size - 1 - else - -- Create new environment with metatable inheritance - env = setmetatable({}, { - __index = __env_system.base_env - }) - end - - return env - end - - -- Return environment to pool for reuse - function __recycle_env(env) - if __env_system.pool_size < __env_system.max_pool_size then - -- Clear all fields except metatable - for k in pairs(env) do - env[k] = nil - end - - -- Add to pool - table.insert(__env_system.env_pool, env) - __env_system.pool_size = __env_system.pool_size + 1 - end - end - - -- Execute code in sandbox - function __execute_sandbox(f) - -- Get environment - local env = __get_sandbox_env() - - -- Set environment for function - setfenv(f, env) - - -- Execute with protected call - local success, result = pcall(f) - - -- Update base environment with new globals - for k, v in pairs(env) do - if k ~= "_G" and type(k) == "string" then - __env_system.base_env[k] = v - end - end - - -- Recycle environment - __recycle_env(env) - - -- Process result - if not success then - error(result, 0) - end - - return result - end - `) - - if err != nil { - return err - } - - s.initialized = true - return nil -} - -// RegisterFunction registers a Go function in the sandbox -func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - // Initialize if needed - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return err - } - } - - // Register function globally - if err := s.state.RegisterGoFunction(name, fn); err != nil { - return err - } - - // Store function for re-registration - s.functions[name] = fn - - // Add to base environment - return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name) -} - -// SetGlobal sets a global variable in the sandbox base environment -func (s *Sandbox) SetGlobal(name string, value any) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - // Initialize if needed - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return err - } - } - - // Push the value onto the stack - if err := s.state.PushValue(value); err != nil { - return err - } - - // Set the global with the pushed value - s.state.SetGlobal(name) - - // Add to base environment - return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name) -} - -// GetGlobal retrieves a global variable from the sandbox base environment -func (s *Sandbox) GetGlobal(name string) (any, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return nil, fmt.Errorf("sandbox is closed") - } - - // Initialize if needed - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return nil, err - } - } - - // Get the global from the base environment - return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`) -} - -// Run executes Lua code in the sandbox -func (s *Sandbox) Run(code string) (any, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return nil, fmt.Errorf("sandbox is closed") - } - - // Initialize if needed - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return nil, err - } - } - - // Simplified wrapper for multiple return values - wrappedCode := ` - local function _execfunc() - ` + code + ` - end - - -- Process results to match expected format - local function _wrapresults(...) - local n = select('#', ...) - if n == 0 then - return nil - elseif n == 1 then - return select(1, ...) - else - local results = {} - for i = 1, n do - results[i] = select(i, ...) - end - return results - end - end - - return _wrapresults(_execfunc()) - ` - - // Compile the code - if err := s.state.LoadString(wrappedCode); err != nil { - return nil, err - } - - // Get the sandbox executor - s.state.GetGlobal("__execute_sandbox") - - // Push the function as argument - s.state.PushCopy(-2) - s.state.Remove(-3) - - // Execute in sandbox - if err := s.state.Call(1, 1); err != nil { - return nil, err - } - - // Get result - result, err := s.state.ToValue(-1) - s.state.Pop(1) - - if err != nil { - return nil, err - } - - return s.processResult(result), nil -} - -// RunFile executes a Lua file in the sandbox -func (s *Sandbox) RunFile(filename string) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - return s.state.DoFile(filename) -} - -// Compile compiles Lua code to bytecode -func (s *Sandbox) Compile(code string) ([]byte, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return nil, fmt.Errorf("sandbox is closed") - } - - return s.state.CompileBytecode(code, "sandbox") -} - -// RunBytecode executes precompiled Lua bytecode -func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return nil, fmt.Errorf("sandbox is closed") - } - - // Initialize if needed - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return nil, err - } - } - - // Load the bytecode - if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil { - return nil, err - } - - // Get the sandbox executor - s.state.GetGlobal("__execute_sandbox") - - // Push bytecode function - s.state.PushCopy(-2) - s.state.Remove(-3) - - // Execute in sandbox - if err := s.state.Call(1, 1); err != nil { - return nil, err - } - - // Get result - result, err := s.state.ToValue(-1) - s.state.Pop(1) - - if err != nil { - return nil, err - } - - return s.processResult(result), nil -} - -// LoadModule loads a Lua module -func (s *Sandbox) LoadModule(name string) error { - code := fmt.Sprintf("require('%s')", name) - _, err := s.Run(code) - return err -} - -// SetPackagePath sets the sandbox package.path -func (s *Sandbox) SetPackagePath(path string) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - // Update global package.path - if err := s.state.SetPackagePath(path); err != nil { - return err - } - - // Initialize if needed - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return err - } - } - - // Update base environment's package.path - return s.state.DoString(`__env_system.base_env.package.path = package.path`) -} - -// AddPackagePath adds a path to the sandbox package.path -func (s *Sandbox) AddPackagePath(path string) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - // Update global package.path - if err := s.state.AddPackagePath(path); err != nil { - return err - } - - // Initialize if needed - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return err - } - } - - // Update base environment's package.path - return s.state.DoString(`__env_system.base_env.package.path = package.path`) -} - -// AddModule adds a module to the sandbox environment -func (s *Sandbox) AddModule(name string, module any) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - s.modules[name] = module - return nil -} - -// AddPermanentLua adds Lua code to the environment permanently -func (s *Sandbox) AddPermanentLua(code string) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - // Initialize if needed - if !s.initialized { - if err := s.initializeUnlocked(); err != nil { - return err - } - } - - // Simplified approach to add code to base environment - return s.state.DoString(` - local f, err = loadstring([=[` + code + `]=], "permanent") - if not f then error(err, 0) end - - local env = setmetatable({}, {__index = __env_system.base_env}) - setfenv(f, env) - - local ok, err = pcall(f) - if not ok then error(err, 0) end - - for k, v in pairs(env) do - __env_system.base_env[k] = v - end - `) -} - -// ResetEnvironment resets the sandbox to its initial state -func (s *Sandbox) ResetEnvironment() error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state == nil { - return fmt.Errorf("sandbox is closed") - } - - // Clear the environment system - s.state.DoString(`__env_system = nil`) - - // Reinitialize - s.initialized = false - if err := s.initializeUnlocked(); err != nil { - return err - } - - // Re-register all functions - for name, fn := range s.functions { - if err := s.state.RegisterGoFunction(name, fn); err != nil { - return err - } - - if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil { - return err - } - } - - return nil -} - -// unwrapResult processes results from Lua executions -func (s *Sandbox) processResult(result any) any { - // Handle []float64 (common LuaJIT return type) - if floats, ok := result.([]float64); ok { - if len(floats) == 1 { - // Single number - return as float64 - return floats[0] - } - // Multiple numbers - MUST convert to []any for tests to pass - anySlice := make([]any, len(floats)) - for i, v := range floats { - anySlice[i] = v - } - return anySlice - } - - // Handle maps with numeric keys (Lua tables) - if m, ok := result.(map[string]any); ok { - // Handle return tables with special structure - if vals, ok := m[""]; ok { - // This is a special case used by some Lua returns - if arr, ok := vals.([]float64); ok { - // Convert to []any for consistency - anySlice := make([]any, len(arr)) - for i, v := range arr { - anySlice[i] = v - } - return anySlice - } - return vals - } - - if len(m) == 1 { - // Check for single value map - for k, v := range m { - if k == "1" { - return v - } - } - } - } - - // Other array types should be preserved - return result -} diff --git a/tests/sandbox_test.go b/tests/sandbox_test.go deleted file mode 100644 index 2784b54..0000000 --- a/tests/sandbox_test.go +++ /dev/null @@ -1,650 +0,0 @@ -package luajit_test - -import ( - "os" - "sync" - "testing" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// TestSandboxLifecycle tests sandbox creation and closing -func TestSandboxLifecycle(t *testing.T) { - // Create a new sandbox - sandbox := luajit.NewSandbox() - if sandbox == nil { - t.Fatal("Failed to create sandbox") - } - - // Close the sandbox - sandbox.Close() - - // Test idempotent close (should not panic) - sandbox.Close() -} - -// TestSandboxFunctionRegistration tests registering Go functions in the sandbox -func TestSandboxFunctionRegistration(t *testing.T) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Register a simple addition function - add := func(s *luajit.State) int { - a := s.ToNumber(1) - b := s.ToNumber(2) - s.PushNumber(a + b) - return 1 - } - - err := sandbox.RegisterFunction("add", add) - if err != nil { - t.Fatalf("Failed to register function: %v", err) - } - - // Test the function - result, err := sandbox.Run("return add(3, 4)") - if err != nil { - t.Fatalf("Failed to execute function: %v", err) - } - - if result != float64(7) { - t.Fatalf("Expected 7, got %v", result) - } - - // Test after sandbox is closed - sandbox.Close() - err = sandbox.RegisterFunction("test", add) - if err == nil { - t.Fatal("Expected error when registering function on closed sandbox") - } -} - -// TestSandboxGlobalVariables tests setting and getting global variables -func TestSandboxGlobalVariables(t *testing.T) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Set a global variable - err := sandbox.SetGlobal("answer", 42) - if err != nil { - t.Fatalf("Failed to set global: %v", err) - } - - // Get the global variable - value, err := sandbox.GetGlobal("answer") - if err != nil { - t.Fatalf("Failed to get global: %v", err) - } - - if value != float64(42) { - t.Fatalf("Expected 42, got %v", value) - } - - // Test different types - testCases := []struct { - name string - value any - }{ - {"nil_value", nil}, - {"bool_value", true}, - {"string_value", "hello"}, - {"table_value", map[string]any{"key": "value"}}, - {"array_value", []float64{1, 2, 3}}, - } - - for _, tc := range testCases { - err := sandbox.SetGlobal(tc.name, tc.value) - if err != nil { - t.Fatalf("Failed to set global %s: %v", tc.name, err) - } - - value, err := sandbox.GetGlobal(tc.name) - if err != nil { - t.Fatalf("Failed to get global %s: %v", tc.name, err) - } - - // For tables/arrays, just check they're not nil - switch tc.value.(type) { - case map[string]any, []float64: - if value == nil { - t.Fatalf("Expected non-nil for %s, got nil", tc.name) - } - default: - if value != tc.value && !(tc.value == nil && value == nil) { - t.Fatalf("For %s: expected %v, got %v", tc.name, tc.value, value) - } - } - } - - // Test after sandbox is closed - sandbox.Close() - err = sandbox.SetGlobal("test", 123) - if err == nil { - t.Fatal("Expected error when setting global on closed sandbox") - } - - _, err = sandbox.GetGlobal("test") - if err == nil { - t.Fatal("Expected error when getting global from closed sandbox") - } -} - -// TestSandboxCodeExecution tests running Lua code in the sandbox -func TestSandboxCodeExecution(t *testing.T) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Run simple code - result, err := sandbox.Run("return 'hello'") - if err != nil { - t.Fatalf("Failed to run code: %v", err) - } - - if result != "hello" { - t.Fatalf("Expected 'hello', got %v", result) - } - - // Run code with multiple return values - result, err = sandbox.Run("return 1, 2, 3") - if err != nil { - t.Fatalf("Failed to run code: %v", err) - } - - // Should return array for multiple values - results, ok := result.([]any) - if !ok { - t.Fatalf("Expected array for multiple returns, got %T", result) - } - - if len(results) != 3 || results[0] != float64(1) || results[1] != float64(2) || results[2] != float64(3) { - t.Fatalf("Expected [1, 2, 3], got %v", results) - } - - // Run code that sets a global - _, err = sandbox.Run("global_var = 'set from Lua'") - if err != nil { - t.Fatalf("Failed to run code: %v", err) - } - - value, err := sandbox.GetGlobal("global_var") - if err != nil { - t.Fatalf("Failed to get global: %v", err) - } - - if value != "set from Lua" { - t.Fatalf("Expected 'set from Lua', got %v", value) - } - - // Run invalid code - _, err = sandbox.Run("this is not valid Lua") - if err == nil { - t.Fatal("Expected error for invalid code") - } - - // Test after sandbox is closed - sandbox.Close() - _, err = sandbox.Run("return true") - if err == nil { - t.Fatal("Expected error when running code on closed sandbox") - } -} - -// TestSandboxBytecodeExecution tests bytecode compilation and execution -func TestSandboxBytecodeExecution(t *testing.T) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Compile code to bytecode - code := ` - local function greet(name) - return "Hello, " .. name - end - return greet("World") - ` - bytecode, err := sandbox.Compile(code) - if err != nil { - t.Fatalf("Failed to compile bytecode: %v", err) - } - - if len(bytecode) == 0 { - t.Fatal("Expected non-empty bytecode") - } - - // Run the bytecode - result, err := sandbox.RunBytecode(bytecode) - if err != nil { - t.Fatalf("Failed to run bytecode: %v", err) - } - - if result != "Hello, World" { - t.Fatalf("Expected 'Hello, World', got %v", result) - } - - // Test bytecode that sets a global - bytecode, err = sandbox.Compile("bytecode_var = 42") - if err != nil { - t.Fatalf("Failed to compile bytecode: %v", err) - } - - _, err = sandbox.RunBytecode(bytecode) - if err != nil { - t.Fatalf("Failed to run bytecode: %v", err) - } - - value, err := sandbox.GetGlobal("bytecode_var") - if err != nil { - t.Fatalf("Failed to get global: %v", err) - } - - if value != float64(42) { - t.Fatalf("Expected 42, got %v", value) - } - - // Test invalid bytecode - _, err = sandbox.RunBytecode([]byte("not valid bytecode")) - if err == nil { - t.Fatal("Expected error for invalid bytecode") - } - - // Test after sandbox is closed - sandbox.Close() - _, err = sandbox.Compile("return true") - if err == nil { - t.Fatal("Expected error when compiling on closed sandbox") - } - - _, err = sandbox.RunBytecode(bytecode) - if err == nil { - t.Fatal("Expected error when running bytecode on closed sandbox") - } -} - -// TestSandboxPersistence tests state persistence across executions -func TestSandboxPersistence(t *testing.T) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Set up initial state - _, err := sandbox.Run(` - counter = 0 - function increment() - counter = counter + 1 - return counter - end - `) - if err != nil { - t.Fatalf("Failed to set up state: %v", err) - } - - // Run multiple executions - for i := 1; i <= 3; i++ { - result, err := sandbox.Run("return increment()") - if err != nil { - t.Fatalf("Failed to run code: %v", err) - } - - if result != float64(i) { - t.Fatalf("Expected %d, got %v", i, result) - } - } - - // Check final counter value - value, err := sandbox.GetGlobal("counter") - if err != nil { - t.Fatalf("Failed to get global: %v", err) - } - - if value != float64(3) { - t.Fatalf("Expected final counter to be 3, got %v", value) - } - - // Test persistence with bytecode - bytecode, err := sandbox.Compile("return counter + 1") - if err != nil { - t.Fatalf("Failed to compile bytecode: %v", err) - } - - result, err := sandbox.RunBytecode(bytecode) - if err != nil { - t.Fatalf("Failed to run bytecode: %v", err) - } - - if result != float64(4) { - t.Fatalf("Expected 4, got %v", result) - } -} - -// TestSandboxConcurrency tests concurrent access to the sandbox -func TestSandboxConcurrency(t *testing.T) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Set up a counter - _, err := sandbox.Run("counter = 0") - if err != nil { - t.Fatalf("Failed to set up counter: %v", err) - } - - // Run concurrent increments - const numGoroutines = 10 - const incrementsPerGoroutine = 100 - - var wg sync.WaitGroup - wg.Add(numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func() { - defer wg.Done() - for j := 0; j < incrementsPerGoroutine; j++ { - _, err := sandbox.Run("counter = counter + 1") - if err != nil { - t.Errorf("Failed to increment counter: %v", err) - return - } - } - }() - } - - wg.Wait() - - // Check the final counter value - value, err := sandbox.GetGlobal("counter") - if err != nil { - t.Fatalf("Failed to get counter: %v", err) - } - - expected := float64(numGoroutines * incrementsPerGoroutine) - if value != expected { - t.Fatalf("Expected counter to be %v, got %v", expected, value) - } -} - -// TestPermanentLua tests the AddPermanentLua method -func TestPermanentLua(t *testing.T) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Add permanent Lua environment - err := sandbox.AddPermanentLua(` - -- Create utility functions - function double(x) - return x * 2 - end - - function square(x) - return x * x - end - - -- Create a protected environment - env = { - add = function(a, b) return a + b end, - sub = function(a, b) return a - b end - } - `) - if err != nil { - t.Fatalf("Failed to add permanent Lua: %v", err) - } - - // Test using the permanent functions - testCases := []struct { - code string - expected float64 - }{ - {"return double(5)", 10}, - {"return square(4)", 16}, - {"return env.add(10, 20)", 30}, - {"return env.sub(50, 30)", 20}, - } - - for _, tc := range testCases { - result, err := sandbox.Run(tc.code) - if err != nil { - t.Fatalf("Failed to run code '%s': %v", tc.code, err) - } - - if result != tc.expected { - t.Fatalf("For '%s': expected %v, got %v", tc.code, tc.expected, result) - } - } - - // Test persistence of permanent code across executions - _, err = sandbox.Run("counter = 0") - if err != nil { - t.Fatalf("Failed to set counter: %v", err) - } - - result, err := sandbox.Run("counter = counter + 1; return double(counter)") - if err != nil { - t.Fatalf("Failed to run code: %v", err) - } - - if result != float64(2) { - t.Fatalf("Expected 2, got %v", result) - } - - // Test after sandbox is closed - sandbox.Close() - err = sandbox.AddPermanentLua("function test() end") - if err == nil { - t.Fatal("Expected error when adding permanent Lua to closed sandbox") - } -} - -// TestResetEnvironment tests the ResetEnvironment method -func TestResetEnvironment(t *testing.T) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Set up some Go functions and Lua code - sandbox.RegisterFunction("timeNow", func(s *luajit.State) int { - s.PushString("test") - return 1 - }) - - sandbox.AddPermanentLua(` - function permanent() - return "permanent function" - end - `) - - _, err := sandbox.Run(` - temp_var = "will be reset" - function temp_func() - return "temp function" - end - `) - if err != nil { - t.Fatalf("Failed to run setup code: %v", err) - } - - // Verify everything is set up correctly - result, err := sandbox.Run("return timeNow()") - if err != nil || result != "test" { - t.Fatalf("Go function not working: %v, %v", result, err) - } - - result, err = sandbox.Run("return permanent()") - if err != nil || result != "permanent function" { - t.Fatalf("Permanent function not working: %v, %v", result, err) - } - - result, err = sandbox.Run("return temp_func()") - if err != nil || result != "temp function" { - t.Fatalf("Temp function not working: %v, %v", result, err) - } - - value, err := sandbox.GetGlobal("temp_var") - if err != nil || value != "will be reset" { - t.Fatalf("Temp var not set correctly: %v, %v", value, err) - } - - // Reset the environment - err = sandbox.ResetEnvironment() - if err != nil { - t.Fatalf("Failed to reset environment: %v", err) - } - - // Check Go function survives reset - result, err = sandbox.Run("return timeNow()") - if err != nil || result != "test" { - t.Fatalf("Go function should survive reset: %v, %v", result, err) - } - - // Check permanent function is gone (it was added with AddPermanentLua but reset removes it) - _, err = sandbox.Run("return permanent()") - if err == nil { - t.Fatal("Permanent function should be gone after reset") - } - - // Check temp function is gone - _, err = sandbox.Run("return temp_func()") - if err == nil { - t.Fatal("Temp function should be gone after reset") - } - - // Check temp var is gone - value, err = sandbox.GetGlobal("temp_var") - if err != nil || value != nil { - t.Fatalf("Temp var should be nil after reset: %v", value) - } - - // Test after sandbox is closed - sandbox.Close() - err = sandbox.ResetEnvironment() - if err == nil { - t.Fatal("Expected error when resetting closed sandbox") - } -} - -// TestRunFile tests the RunFile method -func TestRunFile(t *testing.T) { - // Create a temporary Lua file - tmpfile, err := createTempLuaFile("return 42") - if err != nil { - t.Fatalf("Failed to create temp file: %v", err) - } - defer removeTempFile(tmpfile) - - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Run the file - err = sandbox.RunFile(tmpfile) - if err != nil { - t.Fatalf("Failed to run file: %v", err) - } - - // Test non-existent file - err = sandbox.RunFile("does_not_exist.lua") - if err == nil { - t.Fatal("Expected error for non-existent file") - } - - // Test after sandbox is closed - sandbox.Close() - err = sandbox.RunFile(tmpfile) - if err == nil { - t.Fatal("Expected error when running file on closed sandbox") - } -} - -// TestSandboxPackagePath tests the SetPackagePath and AddPackagePath methods -func TestSandboxPackagePath(t *testing.T) { - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Set package path - testPath := "/test/path/?.lua" - err := sandbox.SetPackagePath(testPath) - if err != nil { - t.Fatalf("Failed to set package path: %v", err) - } - - // Check path was set - result, err := sandbox.Run("return package.path") - if err != nil { - t.Fatalf("Failed to get package.path: %v", err) - } - - if result != testPath { - t.Fatalf("Expected package.path to be %q, got %q", testPath, result) - } - - // Add to package path - addPath := "/another/path/?.lua" - err = sandbox.AddPackagePath(addPath) - if err != nil { - t.Fatalf("Failed to add package path: %v", err) - } - - // Check path was updated - result, err = sandbox.Run("return package.path") - if err != nil { - t.Fatalf("Failed to get updated package.path: %v", err) - } - - expected := testPath + ";" + addPath - if result != expected { - t.Fatalf("Expected package.path to be %q, got %q", expected, result) - } - - // Test after sandbox is closed - sandbox.Close() - err = sandbox.SetPackagePath(testPath) - if err == nil { - t.Fatal("Expected error when setting package path on closed sandbox") - } - - err = sandbox.AddPackagePath(addPath) - if err == nil { - t.Fatal("Expected error when adding package path to closed sandbox") - } -} - -// TestSandboxLoadModule tests loading modules -func TestSandboxLoadModule(t *testing.T) { - // Skip for now since we don't have actual modules to load in the test environment - t.Skip("Skipping module loading test as it requires actual modules") - - sandbox := luajit.NewSandbox() - defer sandbox.Close() - - // Set package path to include current directory - err := sandbox.SetPackagePath("./?.lua") - if err != nil { - t.Fatalf("Failed to set package path: %v", err) - } - - // Try to load a non-existent module - err = sandbox.LoadModule("nonexistent_module") - if err == nil { - t.Fatal("Expected error when loading non-existent module") - } -} - -// Helper functions - -// createTempLuaFile creates a temporary Lua file with the given content -func createTempLuaFile(content string) (string, error) { - tmpfile, err := os.CreateTemp("", "test-*.lua") - if err != nil { - return "", err - } - - if _, err := tmpfile.WriteString(content); err != nil { - os.Remove(tmpfile.Name()) - return "", err - } - - if err := tmpfile.Close(); err != nil { - os.Remove(tmpfile.Name()) - return "", err - } - - return tmpfile.Name(), nil -} - -// removeTempFile removes a temporary file -func removeTempFile(path string) { - os.Remove(path) -}