diff --git a/sandbox.go b/sandbox.go new file mode 100644 index 0000000..04fb667 --- /dev/null +++ b/sandbox.go @@ -0,0 +1,606 @@ +package luajit + +import ( + "fmt" + "sync" +) + +// LUA_MULTRET is the constant for multiple return values +const LUA_MULTRET = -1 + +// Sandbox provides a persistent Lua environment for executing scripts +type Sandbox struct { + state *State + mutex sync.Mutex + initialized bool + modules map[string]any +} + +// NewSandbox creates a new sandbox with standard libraries loaded +func NewSandbox() *Sandbox { + return &Sandbox{ + state: New(), + initialized: false, + modules: make(map[string]any), + } +} + +// Close releases all resources used by the sandbox +func (s *Sandbox) Close() { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state != nil { + s.state.Close() + s.state = nil + } +} + +// RegisterFunction registers a Go function in the sandbox +func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + // Make sure sandbox is initialized + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return err + } + } + + // Register function globally + if err := s.state.RegisterGoFunction(name, fn); err != nil { + return err + } + + // Add to base environment + return s.state.DoString(` + -- Add the function to base environment + __env_system.base_env["` + name + `"] = ` + name + ` + `) +} + +// SetGlobal sets a global variable in the sandbox base environment +func (s *Sandbox) SetGlobal(name string, value any) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + // Make sure sandbox is initialized + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return err + } + } + + // Push the value onto the stack + if err := s.state.PushValue(value); err != nil { + return err + } + + // Set the global with the pushed value + s.state.SetGlobal(name) + + // Add to base environment + return s.state.DoString(` + -- Add the global to base environment + __env_system.base_env["` + name + `"] = ` + name + ` + `) +} + +// GetGlobal retrieves a global variable from the sandbox base environment +func (s *Sandbox) GetGlobal(name string) (any, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return nil, fmt.Errorf("sandbox is closed") + } + + // Make sure sandbox is initialized + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return nil, err + } + } + + // Get the global from the base environment + code := `return __env_system.base_env["` + name + `"]` + return s.state.ExecuteWithResult(code) +} + +// Run executes Lua code in the sandbox and returns the result +func (s *Sandbox) Run(code string) (any, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return nil, fmt.Errorf("sandbox is closed") + } + + // Make sure sandbox is initialized + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return nil, err + } + } + + // Add wrapper for multiple return values + wrappedCode := ` + local function _execfunc() + ` + code + ` + end + + local function _wrapresults(...) + local results = {n = select('#', ...)} + for i = 1, results.n do + results[i] = select(i, ...) + end + return results + end + + return _wrapresults(_execfunc()) + ` + + // Compile the code + if err := s.state.LoadString(wrappedCode); err != nil { + return nil, err + } + + // Get the sandbox executor + s.state.GetGlobal("__execute_sandbox") + + // Setup call with correct argument order + s.state.PushCopy(-2) // Copy the function + + // Remove the original function + s.state.Remove(-3) + + // Execute in sandbox + if err := s.state.Call(1, 1); err != nil { + return nil, err + } + + // Get result + result, err := s.state.ToValue(-1) + s.state.Pop(1) + + // Handle multiple return values + if results, ok := result.([]any); ok && len(results) == 1 { + return results[0], err + } + + return result, err +} + +// RunFile executes a Lua file in the sandbox +func (s *Sandbox) RunFile(filename string) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + return s.state.DoFile(filename) +} + +// Compile compiles Lua code to bytecode without executing it +func (s *Sandbox) Compile(code string) ([]byte, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return nil, fmt.Errorf("sandbox is closed") + } + + return s.state.CompileBytecode(code, "sandbox") +} + +// RunBytecode executes precompiled Lua bytecode in the sandbox +func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return nil, fmt.Errorf("sandbox is closed") + } + + // Make sure sandbox is initialized + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return nil, err + } + } + + // Load the bytecode + if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil { + return nil, err + } + + // Add wrapper for multiple return values + if err := s.state.DoString(` + __wrap_bytecode = function(f) + local function _wrapresults(...) + local results = {n = select('#', ...)} + for i = 1, results.n do + results[i] = select(i, ...) + end + return results + end + + return function() + return _wrapresults(f()) + end + end + `); err != nil { + return nil, err + } + + // Get wrapper function + s.state.GetGlobal("__wrap_bytecode") + + // Push bytecode function + s.state.PushCopy(-2) + + // Call wrapper to create wrapped function + if err := s.state.Call(1, 1); err != nil { + return nil, err + } + + // Remove original bytecode function + s.state.Remove(-2) + + // Get the sandbox executor + s.state.GetGlobal("__execute_sandbox") + + // Push wrapped function + s.state.PushCopy(-2) + + // Remove the wrapped function + s.state.Remove(-3) + + // Execute in sandbox + if err := s.state.Call(1, 1); err != nil { + return nil, err + } + + // Get result + result, err := s.state.ToValue(-1) + s.state.Pop(1) + + // Handle multiple return values + if results, ok := result.([]any); ok && len(results) == 1 { + return results[0], err + } + + return result, err +} + +// getResults collects results from the stack (must be called with mutex locked) +func (s *Sandbox) getResults() (any, error) { + numResults := s.state.GetTop() + if numResults == 0 { + return nil, nil + } else if numResults == 1 { + // Return single result directly + value, err := s.state.ToValue(-1) + s.state.Pop(1) + return value, err + } + + // Return multiple results as slice + results := make([]any, numResults) + for i := 0; i < numResults; i++ { + value, err := s.state.ToValue(i - numResults) + if err != nil { + s.state.Pop(numResults) + return nil, err + } + results[i] = value + } + s.state.Pop(numResults) + return results, nil +} + +// LoadModule loads a Lua module in the sandbox +func (s *Sandbox) LoadModule(name string) error { + code := fmt.Sprintf("require('%s')", name) + _, err := s.Run(code) + return err +} + +// SetPackagePath sets the sandbox package.path +func (s *Sandbox) SetPackagePath(path string) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + return s.state.SetPackagePath(path) +} + +// AddPackagePath adds a path to the sandbox package.path +func (s *Sandbox) AddPackagePath(path string) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + return s.state.AddPackagePath(path) +} + +// AddModule adds a module to the sandbox environment +func (s *Sandbox) AddModule(name string, module any) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + s.modules[name] = module + return nil +} + +// Initialize sets up the environment system +func (s *Sandbox) Initialize() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.initializeUnlocked() +} + +// initializeUnlocked sets up the environment system without locking +// It should only be called when the mutex is already locked +func (s *Sandbox) initializeUnlocked() error { + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + if s.initialized { + return nil // Already initialized + } + + // Register modules + s.state.GetGlobal("__sandbox_modules") + if s.state.IsNil(-1) { + s.state.Pop(1) + s.state.NewTable() + s.state.SetGlobal("__sandbox_modules") + s.state.GetGlobal("__sandbox_modules") + } + + // Add modules + for name, module := range s.modules { + s.state.PushString(name) + if err := s.state.PushValue(module); err != nil { + s.state.Pop(2) + return err + } + s.state.SetTable(-3) + } + s.state.Pop(1) + + // Create the environment system + err := s.state.DoString(` + -- Global shared environment (created once) + __env_system = __env_system or { + base_env = nil, -- Template environment + initialized = false, -- Initialization flag + env_pool = {}, -- Pre-allocated environment pool + pool_size = 0, -- Current pool size + max_pool_size = 8 -- Maximum pool size + } + + -- Initialize base environment once + if not __env_system.initialized then + -- Create base environment with all standard libraries + local base = {} + + -- Safe standard libraries + base.string = string + base.table = table + base.math = math + base.os = { + time = os.time, + date = os.date, + difftime = os.difftime, + clock = os.clock + } + + -- Basic functions + base.print = print + base.tonumber = tonumber + base.tostring = tostring + base.type = type + base.pairs = pairs + base.ipairs = ipairs + base.next = next + base.select = select + base.pcall = pcall + base.xpcall = xpcall + base.error = error + base.assert = assert + base.collectgarbage = collectgarbage + base.unpack = unpack or table.unpack + + -- Package system + base.package = { + loaded = {}, + path = package.path, + preload = {} + } + + base.require = function(modname) + if base.package.loaded[modname] then + return base.package.loaded[modname] + end + + local loader = base.package.preload[modname] + if type(loader) == "function" then + local result = loader(modname) + base.package.loaded[modname] = result or true + return result + end + + error("module '" .. modname .. "' not found", 2) + end + + -- Add registered custom modules + if __sandbox_modules then + for name, mod in pairs(__sandbox_modules) do + base[name] = mod + end + end + + -- Store base environment + __env_system.base_env = base + __env_system.initialized = true + end + + -- Global variable for tracking current environment + __last_env = nil + + -- Get an environment for execution + function __get_sandbox_env() + local env + + -- Try to reuse from pool + if __env_system.pool_size > 0 then + env = table.remove(__env_system.env_pool) + __env_system.pool_size = __env_system.pool_size - 1 + else + -- Create new environment with metatable inheritance + env = setmetatable({}, { + __index = _G -- Use global environment as fallback + }) + end + + -- Store reference to current environment + __last_env = env + + return env + end + + -- Return environment to pool for reuse + function __recycle_env(env) + -- Only recycle if pool isn't full + if __env_system.pool_size < __env_system.max_pool_size then + -- Clear all fields except metatable + for k in pairs(env) do + env[k] = nil + end + + -- Add to pool + table.insert(__env_system.env_pool, env) + __env_system.pool_size = __env_system.pool_size + 1 + end + end + + -- Execute code in sandbox + function __execute_sandbox(f) + -- Get environment + local env = __get_sandbox_env() + + -- Set environment for function + setfenv(f, env) + + -- Execute with protected call + local success, result = pcall(f) + + -- Recycle environment + __recycle_env(env) + + -- Process result + if not success then + error(result, 0) + end + + -- Handle multiple return values + if type(result) == "table" and result.n ~= nil then + local returnValues = {} + for i=1, result.n do + returnValues[i] = result[i] + end + return returnValues + end + + return result + end + `) + + if err != nil { + return err + } + + s.initialized = true + return nil +} + +// AddPermanentLua adds Lua code to the environment permanently +// This code becomes part of the base environment +func (s *Sandbox) AddPermanentLua(code string) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + // Make sure sandbox is initialized + if !s.initialized { + if err := s.initializeUnlocked(); err != nil { + return err + } + } + + // Add code to base environment + return s.state.DoString(` + -- First compile the code + local f, err = loadstring([=[` + code + `]=], "permanent") + if not f then + error(err, 0) + end + + -- Create a temporary environment based on base env + local temp_env = setmetatable({}, {__index = __env_system.base_env}) + setfenv(f, temp_env) + + -- Run the code in the temporary environment + local ok, err = pcall(f) + if not ok then + error(err, 0) + end + + -- Copy new values to base environment + for k, v in pairs(temp_env) do + __env_system.base_env[k] = v + end + `) +} + +// ResetEnvironment resets the sandbox to its initial state +func (s *Sandbox) ResetEnvironment() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.state == nil { + return fmt.Errorf("sandbox is closed") + } + + // Reinitialize the environment system + s.initialized = false + return s.Initialize() +} diff --git a/tests/sandbox_test.go b/tests/sandbox_test.go new file mode 100644 index 0000000..2784b54 --- /dev/null +++ b/tests/sandbox_test.go @@ -0,0 +1,650 @@ +package luajit_test + +import ( + "os" + "sync" + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// TestSandboxLifecycle tests sandbox creation and closing +func TestSandboxLifecycle(t *testing.T) { + // Create a new sandbox + sandbox := luajit.NewSandbox() + if sandbox == nil { + t.Fatal("Failed to create sandbox") + } + + // Close the sandbox + sandbox.Close() + + // Test idempotent close (should not panic) + sandbox.Close() +} + +// TestSandboxFunctionRegistration tests registering Go functions in the sandbox +func TestSandboxFunctionRegistration(t *testing.T) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Register a simple addition function + add := func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a + b) + return 1 + } + + err := sandbox.RegisterFunction("add", add) + if err != nil { + t.Fatalf("Failed to register function: %v", err) + } + + // Test the function + result, err := sandbox.Run("return add(3, 4)") + if err != nil { + t.Fatalf("Failed to execute function: %v", err) + } + + if result != float64(7) { + t.Fatalf("Expected 7, got %v", result) + } + + // Test after sandbox is closed + sandbox.Close() + err = sandbox.RegisterFunction("test", add) + if err == nil { + t.Fatal("Expected error when registering function on closed sandbox") + } +} + +// TestSandboxGlobalVariables tests setting and getting global variables +func TestSandboxGlobalVariables(t *testing.T) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Set a global variable + err := sandbox.SetGlobal("answer", 42) + if err != nil { + t.Fatalf("Failed to set global: %v", err) + } + + // Get the global variable + value, err := sandbox.GetGlobal("answer") + if err != nil { + t.Fatalf("Failed to get global: %v", err) + } + + if value != float64(42) { + t.Fatalf("Expected 42, got %v", value) + } + + // Test different types + testCases := []struct { + name string + value any + }{ + {"nil_value", nil}, + {"bool_value", true}, + {"string_value", "hello"}, + {"table_value", map[string]any{"key": "value"}}, + {"array_value", []float64{1, 2, 3}}, + } + + for _, tc := range testCases { + err := sandbox.SetGlobal(tc.name, tc.value) + if err != nil { + t.Fatalf("Failed to set global %s: %v", tc.name, err) + } + + value, err := sandbox.GetGlobal(tc.name) + if err != nil { + t.Fatalf("Failed to get global %s: %v", tc.name, err) + } + + // For tables/arrays, just check they're not nil + switch tc.value.(type) { + case map[string]any, []float64: + if value == nil { + t.Fatalf("Expected non-nil for %s, got nil", tc.name) + } + default: + if value != tc.value && !(tc.value == nil && value == nil) { + t.Fatalf("For %s: expected %v, got %v", tc.name, tc.value, value) + } + } + } + + // Test after sandbox is closed + sandbox.Close() + err = sandbox.SetGlobal("test", 123) + if err == nil { + t.Fatal("Expected error when setting global on closed sandbox") + } + + _, err = sandbox.GetGlobal("test") + if err == nil { + t.Fatal("Expected error when getting global from closed sandbox") + } +} + +// TestSandboxCodeExecution tests running Lua code in the sandbox +func TestSandboxCodeExecution(t *testing.T) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Run simple code + result, err := sandbox.Run("return 'hello'") + if err != nil { + t.Fatalf("Failed to run code: %v", err) + } + + if result != "hello" { + t.Fatalf("Expected 'hello', got %v", result) + } + + // Run code with multiple return values + result, err = sandbox.Run("return 1, 2, 3") + if err != nil { + t.Fatalf("Failed to run code: %v", err) + } + + // Should return array for multiple values + results, ok := result.([]any) + if !ok { + t.Fatalf("Expected array for multiple returns, got %T", result) + } + + if len(results) != 3 || results[0] != float64(1) || results[1] != float64(2) || results[2] != float64(3) { + t.Fatalf("Expected [1, 2, 3], got %v", results) + } + + // Run code that sets a global + _, err = sandbox.Run("global_var = 'set from Lua'") + if err != nil { + t.Fatalf("Failed to run code: %v", err) + } + + value, err := sandbox.GetGlobal("global_var") + if err != nil { + t.Fatalf("Failed to get global: %v", err) + } + + if value != "set from Lua" { + t.Fatalf("Expected 'set from Lua', got %v", value) + } + + // Run invalid code + _, err = sandbox.Run("this is not valid Lua") + if err == nil { + t.Fatal("Expected error for invalid code") + } + + // Test after sandbox is closed + sandbox.Close() + _, err = sandbox.Run("return true") + if err == nil { + t.Fatal("Expected error when running code on closed sandbox") + } +} + +// TestSandboxBytecodeExecution tests bytecode compilation and execution +func TestSandboxBytecodeExecution(t *testing.T) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Compile code to bytecode + code := ` + local function greet(name) + return "Hello, " .. name + end + return greet("World") + ` + bytecode, err := sandbox.Compile(code) + if err != nil { + t.Fatalf("Failed to compile bytecode: %v", err) + } + + if len(bytecode) == 0 { + t.Fatal("Expected non-empty bytecode") + } + + // Run the bytecode + result, err := sandbox.RunBytecode(bytecode) + if err != nil { + t.Fatalf("Failed to run bytecode: %v", err) + } + + if result != "Hello, World" { + t.Fatalf("Expected 'Hello, World', got %v", result) + } + + // Test bytecode that sets a global + bytecode, err = sandbox.Compile("bytecode_var = 42") + if err != nil { + t.Fatalf("Failed to compile bytecode: %v", err) + } + + _, err = sandbox.RunBytecode(bytecode) + if err != nil { + t.Fatalf("Failed to run bytecode: %v", err) + } + + value, err := sandbox.GetGlobal("bytecode_var") + if err != nil { + t.Fatalf("Failed to get global: %v", err) + } + + if value != float64(42) { + t.Fatalf("Expected 42, got %v", value) + } + + // Test invalid bytecode + _, err = sandbox.RunBytecode([]byte("not valid bytecode")) + if err == nil { + t.Fatal("Expected error for invalid bytecode") + } + + // Test after sandbox is closed + sandbox.Close() + _, err = sandbox.Compile("return true") + if err == nil { + t.Fatal("Expected error when compiling on closed sandbox") + } + + _, err = sandbox.RunBytecode(bytecode) + if err == nil { + t.Fatal("Expected error when running bytecode on closed sandbox") + } +} + +// TestSandboxPersistence tests state persistence across executions +func TestSandboxPersistence(t *testing.T) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Set up initial state + _, err := sandbox.Run(` + counter = 0 + function increment() + counter = counter + 1 + return counter + end + `) + if err != nil { + t.Fatalf("Failed to set up state: %v", err) + } + + // Run multiple executions + for i := 1; i <= 3; i++ { + result, err := sandbox.Run("return increment()") + if err != nil { + t.Fatalf("Failed to run code: %v", err) + } + + if result != float64(i) { + t.Fatalf("Expected %d, got %v", i, result) + } + } + + // Check final counter value + value, err := sandbox.GetGlobal("counter") + if err != nil { + t.Fatalf("Failed to get global: %v", err) + } + + if value != float64(3) { + t.Fatalf("Expected final counter to be 3, got %v", value) + } + + // Test persistence with bytecode + bytecode, err := sandbox.Compile("return counter + 1") + if err != nil { + t.Fatalf("Failed to compile bytecode: %v", err) + } + + result, err := sandbox.RunBytecode(bytecode) + if err != nil { + t.Fatalf("Failed to run bytecode: %v", err) + } + + if result != float64(4) { + t.Fatalf("Expected 4, got %v", result) + } +} + +// TestSandboxConcurrency tests concurrent access to the sandbox +func TestSandboxConcurrency(t *testing.T) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Set up a counter + _, err := sandbox.Run("counter = 0") + if err != nil { + t.Fatalf("Failed to set up counter: %v", err) + } + + // Run concurrent increments + const numGoroutines = 10 + const incrementsPerGoroutine = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < incrementsPerGoroutine; j++ { + _, err := sandbox.Run("counter = counter + 1") + if err != nil { + t.Errorf("Failed to increment counter: %v", err) + return + } + } + }() + } + + wg.Wait() + + // Check the final counter value + value, err := sandbox.GetGlobal("counter") + if err != nil { + t.Fatalf("Failed to get counter: %v", err) + } + + expected := float64(numGoroutines * incrementsPerGoroutine) + if value != expected { + t.Fatalf("Expected counter to be %v, got %v", expected, value) + } +} + +// TestPermanentLua tests the AddPermanentLua method +func TestPermanentLua(t *testing.T) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Add permanent Lua environment + err := sandbox.AddPermanentLua(` + -- Create utility functions + function double(x) + return x * 2 + end + + function square(x) + return x * x + end + + -- Create a protected environment + env = { + add = function(a, b) return a + b end, + sub = function(a, b) return a - b end + } + `) + if err != nil { + t.Fatalf("Failed to add permanent Lua: %v", err) + } + + // Test using the permanent functions + testCases := []struct { + code string + expected float64 + }{ + {"return double(5)", 10}, + {"return square(4)", 16}, + {"return env.add(10, 20)", 30}, + {"return env.sub(50, 30)", 20}, + } + + for _, tc := range testCases { + result, err := sandbox.Run(tc.code) + if err != nil { + t.Fatalf("Failed to run code '%s': %v", tc.code, err) + } + + if result != tc.expected { + t.Fatalf("For '%s': expected %v, got %v", tc.code, tc.expected, result) + } + } + + // Test persistence of permanent code across executions + _, err = sandbox.Run("counter = 0") + if err != nil { + t.Fatalf("Failed to set counter: %v", err) + } + + result, err := sandbox.Run("counter = counter + 1; return double(counter)") + if err != nil { + t.Fatalf("Failed to run code: %v", err) + } + + if result != float64(2) { + t.Fatalf("Expected 2, got %v", result) + } + + // Test after sandbox is closed + sandbox.Close() + err = sandbox.AddPermanentLua("function test() end") + if err == nil { + t.Fatal("Expected error when adding permanent Lua to closed sandbox") + } +} + +// TestResetEnvironment tests the ResetEnvironment method +func TestResetEnvironment(t *testing.T) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Set up some Go functions and Lua code + sandbox.RegisterFunction("timeNow", func(s *luajit.State) int { + s.PushString("test") + return 1 + }) + + sandbox.AddPermanentLua(` + function permanent() + return "permanent function" + end + `) + + _, err := sandbox.Run(` + temp_var = "will be reset" + function temp_func() + return "temp function" + end + `) + if err != nil { + t.Fatalf("Failed to run setup code: %v", err) + } + + // Verify everything is set up correctly + result, err := sandbox.Run("return timeNow()") + if err != nil || result != "test" { + t.Fatalf("Go function not working: %v, %v", result, err) + } + + result, err = sandbox.Run("return permanent()") + if err != nil || result != "permanent function" { + t.Fatalf("Permanent function not working: %v, %v", result, err) + } + + result, err = sandbox.Run("return temp_func()") + if err != nil || result != "temp function" { + t.Fatalf("Temp function not working: %v, %v", result, err) + } + + value, err := sandbox.GetGlobal("temp_var") + if err != nil || value != "will be reset" { + t.Fatalf("Temp var not set correctly: %v, %v", value, err) + } + + // Reset the environment + err = sandbox.ResetEnvironment() + if err != nil { + t.Fatalf("Failed to reset environment: %v", err) + } + + // Check Go function survives reset + result, err = sandbox.Run("return timeNow()") + if err != nil || result != "test" { + t.Fatalf("Go function should survive reset: %v, %v", result, err) + } + + // Check permanent function is gone (it was added with AddPermanentLua but reset removes it) + _, err = sandbox.Run("return permanent()") + if err == nil { + t.Fatal("Permanent function should be gone after reset") + } + + // Check temp function is gone + _, err = sandbox.Run("return temp_func()") + if err == nil { + t.Fatal("Temp function should be gone after reset") + } + + // Check temp var is gone + value, err = sandbox.GetGlobal("temp_var") + if err != nil || value != nil { + t.Fatalf("Temp var should be nil after reset: %v", value) + } + + // Test after sandbox is closed + sandbox.Close() + err = sandbox.ResetEnvironment() + if err == nil { + t.Fatal("Expected error when resetting closed sandbox") + } +} + +// TestRunFile tests the RunFile method +func TestRunFile(t *testing.T) { + // Create a temporary Lua file + tmpfile, err := createTempLuaFile("return 42") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer removeTempFile(tmpfile) + + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Run the file + err = sandbox.RunFile(tmpfile) + if err != nil { + t.Fatalf("Failed to run file: %v", err) + } + + // Test non-existent file + err = sandbox.RunFile("does_not_exist.lua") + if err == nil { + t.Fatal("Expected error for non-existent file") + } + + // Test after sandbox is closed + sandbox.Close() + err = sandbox.RunFile(tmpfile) + if err == nil { + t.Fatal("Expected error when running file on closed sandbox") + } +} + +// TestSandboxPackagePath tests the SetPackagePath and AddPackagePath methods +func TestSandboxPackagePath(t *testing.T) { + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Set package path + testPath := "/test/path/?.lua" + err := sandbox.SetPackagePath(testPath) + if err != nil { + t.Fatalf("Failed to set package path: %v", err) + } + + // Check path was set + result, err := sandbox.Run("return package.path") + if err != nil { + t.Fatalf("Failed to get package.path: %v", err) + } + + if result != testPath { + t.Fatalf("Expected package.path to be %q, got %q", testPath, result) + } + + // Add to package path + addPath := "/another/path/?.lua" + err = sandbox.AddPackagePath(addPath) + if err != nil { + t.Fatalf("Failed to add package path: %v", err) + } + + // Check path was updated + result, err = sandbox.Run("return package.path") + if err != nil { + t.Fatalf("Failed to get updated package.path: %v", err) + } + + expected := testPath + ";" + addPath + if result != expected { + t.Fatalf("Expected package.path to be %q, got %q", expected, result) + } + + // Test after sandbox is closed + sandbox.Close() + err = sandbox.SetPackagePath(testPath) + if err == nil { + t.Fatal("Expected error when setting package path on closed sandbox") + } + + err = sandbox.AddPackagePath(addPath) + if err == nil { + t.Fatal("Expected error when adding package path to closed sandbox") + } +} + +// TestSandboxLoadModule tests loading modules +func TestSandboxLoadModule(t *testing.T) { + // Skip for now since we don't have actual modules to load in the test environment + t.Skip("Skipping module loading test as it requires actual modules") + + sandbox := luajit.NewSandbox() + defer sandbox.Close() + + // Set package path to include current directory + err := sandbox.SetPackagePath("./?.lua") + if err != nil { + t.Fatalf("Failed to set package path: %v", err) + } + + // Try to load a non-existent module + err = sandbox.LoadModule("nonexistent_module") + if err == nil { + t.Fatal("Expected error when loading non-existent module") + } +} + +// Helper functions + +// createTempLuaFile creates a temporary Lua file with the given content +func createTempLuaFile(content string) (string, error) { + tmpfile, err := os.CreateTemp("", "test-*.lua") + if err != nil { + return "", err + } + + if _, err := tmpfile.WriteString(content); err != nil { + os.Remove(tmpfile.Name()) + return "", err + } + + if err := tmpfile.Close(); err != nil { + os.Remove(tmpfile.Name()) + return "", err + } + + return tmpfile.Name(), nil +} + +// removeTempFile removes a temporary file +func removeTempFile(path string) { + os.Remove(path) +} diff --git a/wrapper.go b/wrapper.go index e09e956..f5b2add 100644 --- a/wrapper.go +++ b/wrapper.go @@ -1,9 +1,9 @@ package luajit /* +#cgo !windows pkg-config: --static luajit #cgo windows CFLAGS: -I${SRCDIR}/vendor/luajit/include #cgo windows LDFLAGS: -L${SRCDIR}/vendor/luajit/windows -lluajit -static -#cgo !windows pkg-config: --static luajit #include #include