remove sandbox

This commit is contained in:
Sky Johnson 2025-03-28 20:27:51 -05:00
parent 5774808064
commit 656ac1a703
3 changed files with 0 additions and 1392 deletions

View File

@ -1,133 +0,0 @@
package luajit_bench
import (
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code
func BenchmarkSandboxLuaExecution(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Simple Lua code that does some computation
code := `
local sum = 0
for i = 1, 100 do
sum = sum + i
end
return sum
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode
func BenchmarkSandboxBytecodeExecution(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Same code as above, but precompiled
code := `
local sum = 0
for i = 1, 100 do
sum = sum + i
end
return sum
`
// Compile the bytecode once
bytecode, err := sandbox.Compile(code)
if err != nil {
b.Fatalf("Failed to compile bytecode: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.RunBytecode(bytecode)
if err != nil {
b.Fatalf("Failed to run bytecode: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxComplexComputation measures performance with more complex computation
func BenchmarkSandboxComplexComputation(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// More complex Lua code that calculates Fibonacci numbers
code := `
function fibonacci(n)
if n <= 1 then
return n
end
return fibonacci(n-1) + fibonacci(n-2)
end
return fibonacci(15) -- Not too high to avoid excessive runtime
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(610) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function
func BenchmarkSandboxFunctionCall(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Register a simple addition function
add := func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1
}
err := sandbox.RegisterFunction("add", add)
if err != nil {
b.Fatalf("Failed to register function: %v", err)
}
// Lua code that calls the Go function in a loop
code := `
local sum = 0
for i = 1, 100 do
sum = add(sum, i)
end
return sum
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}

View File

@ -1,609 +0,0 @@
package luajit
import (
"fmt"
"sync"
)
// LUA_MULTRET is the constant for multiple return values
const LUA_MULTRET = -1
// Sandbox provides a persistent Lua environment for executing scripts
type Sandbox struct {
state *State
mutex sync.Mutex
initialized bool
modules map[string]any
functions map[string]GoFunction
}
// NewSandbox creates a new sandbox with standard libraries loaded
func NewSandbox() *Sandbox {
return &Sandbox{
state: New(),
initialized: false,
modules: make(map[string]any),
functions: make(map[string]GoFunction),
}
}
// Close releases all resources used by the sandbox
func (s *Sandbox) Close() {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state != nil {
s.state.Close()
s.state = nil
}
}
// Initialize sets up the environment system
func (s *Sandbox) Initialize() error {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.initializeUnlocked()
}
// initializeUnlocked sets up the environment system without locking
func (s *Sandbox) initializeUnlocked() error {
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
if s.initialized {
return nil
}
// Register modules
s.state.GetGlobal("__sandbox_modules")
if s.state.IsNil(-1) {
s.state.Pop(1)
s.state.NewTable()
s.state.SetGlobal("__sandbox_modules")
s.state.GetGlobal("__sandbox_modules")
}
// Add modules
for name, module := range s.modules {
s.state.PushString(name)
if err := s.state.PushValue(module); err != nil {
s.state.Pop(2)
return err
}
s.state.SetTable(-3)
}
s.state.Pop(1)
// Create simplified environment system
err := s.state.DoString(`
-- Global shared environment
__env_system = {
base_env = {}, -- Template environment
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
}
-- Create base environment with standard libraries
local base = __env_system.base_env
-- Safe standard libraries
base.string = string
base.table = table
base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
-- Basic functions
base.print = print
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
base.collectgarbage = collectgarbage
base.unpack = unpack or table.unpack
-- Package system
base.package = {
loaded = {},
path = package.path,
preload = {}
}
base.require = function(modname)
if base.package.loaded[modname] then
return base.package.loaded[modname]
end
local loader = base.package.preload[modname]
if type(loader) == "function" then
local result = loader(modname)
base.package.loaded[modname] = result or true
return result
end
error("module '" .. modname .. "' not found", 2)
end
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
-- Get an environment for execution
function __get_sandbox_env()
local env
-- Try to reuse from pool
if __env_system.pool_size > 0 then
env = table.remove(__env_system.env_pool)
__env_system.pool_size = __env_system.pool_size - 1
else
-- Create new environment with metatable inheritance
env = setmetatable({}, {
__index = __env_system.base_env
})
end
return env
end
-- Return environment to pool for reuse
function __recycle_env(env)
if __env_system.pool_size < __env_system.max_pool_size then
-- Clear all fields except metatable
for k in pairs(env) do
env[k] = nil
end
-- Add to pool
table.insert(__env_system.env_pool, env)
__env_system.pool_size = __env_system.pool_size + 1
end
end
-- Execute code in sandbox
function __execute_sandbox(f)
-- Get environment
local env = __get_sandbox_env()
-- Set environment for function
setfenv(f, env)
-- Execute with protected call
local success, result = pcall(f)
-- Update base environment with new globals
for k, v in pairs(env) do
if k ~= "_G" and type(k) == "string" then
__env_system.base_env[k] = v
end
end
-- Recycle environment
__recycle_env(env)
-- Process result
if not success then
error(result, 0)
end
return result
end
`)
if err != nil {
return err
}
s.initialized = true
return nil
}
// RegisterFunction registers a Go function in the sandbox
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Register function globally
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
// Store function for re-registration
s.functions[name] = fn
// Add to base environment
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
}
// SetGlobal sets a global variable in the sandbox base environment
func (s *Sandbox) SetGlobal(name string, value any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Push the value onto the stack
if err := s.state.PushValue(value); err != nil {
return err
}
// Set the global with the pushed value
s.state.SetGlobal(name)
// Add to base environment
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
}
// GetGlobal retrieves a global variable from the sandbox base environment
func (s *Sandbox) GetGlobal(name string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Get the global from the base environment
return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`)
}
// Run executes Lua code in the sandbox
func (s *Sandbox) Run(code string) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Simplified wrapper for multiple return values
wrappedCode := `
local function _execfunc()
` + code + `
end
-- Process results to match expected format
local function _wrapresults(...)
local n = select('#', ...)
if n == 0 then
return nil
elseif n == 1 then
return select(1, ...)
else
local results = {}
for i = 1, n do
results[i] = select(i, ...)
end
return results
end
end
return _wrapresults(_execfunc())
`
// Compile the code
if err := s.state.LoadString(wrappedCode); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push the function as argument
s.state.PushCopy(-2)
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
if err != nil {
return nil, err
}
return s.processResult(result), nil
}
// RunFile executes a Lua file in the sandbox
func (s *Sandbox) RunFile(filename string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
return s.state.DoFile(filename)
}
// Compile compiles Lua code to bytecode
func (s *Sandbox) Compile(code string) ([]byte, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
return s.state.CompileBytecode(code, "sandbox")
}
// RunBytecode executes precompiled Lua bytecode
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return nil, fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return nil, err
}
}
// Load the bytecode
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
return nil, err
}
// Get the sandbox executor
s.state.GetGlobal("__execute_sandbox")
// Push bytecode function
s.state.PushCopy(-2)
s.state.Remove(-3)
// Execute in sandbox
if err := s.state.Call(1, 1); err != nil {
return nil, err
}
// Get result
result, err := s.state.ToValue(-1)
s.state.Pop(1)
if err != nil {
return nil, err
}
return s.processResult(result), nil
}
// LoadModule loads a Lua module
func (s *Sandbox) LoadModule(name string) error {
code := fmt.Sprintf("require('%s')", name)
_, err := s.Run(code)
return err
}
// SetPackagePath sets the sandbox package.path
func (s *Sandbox) SetPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.SetPackagePath(path); err != nil {
return err
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
}
// AddPackagePath adds a path to the sandbox package.path
func (s *Sandbox) AddPackagePath(path string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Update global package.path
if err := s.state.AddPackagePath(path); err != nil {
return err
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Update base environment's package.path
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
s.modules[name] = module
return nil
}
// AddPermanentLua adds Lua code to the environment permanently
func (s *Sandbox) AddPermanentLua(code string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Initialize if needed
if !s.initialized {
if err := s.initializeUnlocked(); err != nil {
return err
}
}
// Simplified approach to add code to base environment
return s.state.DoString(`
local f, err = loadstring([=[` + code + `]=], "permanent")
if not f then error(err, 0) end
local env = setmetatable({}, {__index = __env_system.base_env})
setfenv(f, env)
local ok, err = pcall(f)
if not ok then error(err, 0) end
for k, v in pairs(env) do
__env_system.base_env[k] = v
end
`)
}
// ResetEnvironment resets the sandbox to its initial state
func (s *Sandbox) ResetEnvironment() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.state == nil {
return fmt.Errorf("sandbox is closed")
}
// Clear the environment system
s.state.DoString(`__env_system = nil`)
// Reinitialize
s.initialized = false
if err := s.initializeUnlocked(); err != nil {
return err
}
// Re-register all functions
for name, fn := range s.functions {
if err := s.state.RegisterGoFunction(name, fn); err != nil {
return err
}
if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil {
return err
}
}
return nil
}
// unwrapResult processes results from Lua executions
func (s *Sandbox) processResult(result any) any {
// Handle []float64 (common LuaJIT return type)
if floats, ok := result.([]float64); ok {
if len(floats) == 1 {
// Single number - return as float64
return floats[0]
}
// Multiple numbers - MUST convert to []any for tests to pass
anySlice := make([]any, len(floats))
for i, v := range floats {
anySlice[i] = v
}
return anySlice
}
// Handle maps with numeric keys (Lua tables)
if m, ok := result.(map[string]any); ok {
// Handle return tables with special structure
if vals, ok := m[""]; ok {
// This is a special case used by some Lua returns
if arr, ok := vals.([]float64); ok {
// Convert to []any for consistency
anySlice := make([]any, len(arr))
for i, v := range arr {
anySlice[i] = v
}
return anySlice
}
return vals
}
if len(m) == 1 {
// Check for single value map
for k, v := range m {
if k == "1" {
return v
}
}
}
}
// Other array types should be preserved
return result
}

View File

@ -1,650 +0,0 @@
package luajit_test
import (
"os"
"sync"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// TestSandboxLifecycle tests sandbox creation and closing
func TestSandboxLifecycle(t *testing.T) {
// Create a new sandbox
sandbox := luajit.NewSandbox()
if sandbox == nil {
t.Fatal("Failed to create sandbox")
}
// Close the sandbox
sandbox.Close()
// Test idempotent close (should not panic)
sandbox.Close()
}
// TestSandboxFunctionRegistration tests registering Go functions in the sandbox
func TestSandboxFunctionRegistration(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Register a simple addition function
add := func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1
}
err := sandbox.RegisterFunction("add", add)
if err != nil {
t.Fatalf("Failed to register function: %v", err)
}
// Test the function
result, err := sandbox.Run("return add(3, 4)")
if err != nil {
t.Fatalf("Failed to execute function: %v", err)
}
if result != float64(7) {
t.Fatalf("Expected 7, got %v", result)
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.RegisterFunction("test", add)
if err == nil {
t.Fatal("Expected error when registering function on closed sandbox")
}
}
// TestSandboxGlobalVariables tests setting and getting global variables
func TestSandboxGlobalVariables(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set a global variable
err := sandbox.SetGlobal("answer", 42)
if err != nil {
t.Fatalf("Failed to set global: %v", err)
}
// Get the global variable
value, err := sandbox.GetGlobal("answer")
if err != nil {
t.Fatalf("Failed to get global: %v", err)
}
if value != float64(42) {
t.Fatalf("Expected 42, got %v", value)
}
// Test different types
testCases := []struct {
name string
value any
}{
{"nil_value", nil},
{"bool_value", true},
{"string_value", "hello"},
{"table_value", map[string]any{"key": "value"}},
{"array_value", []float64{1, 2, 3}},
}
for _, tc := range testCases {
err := sandbox.SetGlobal(tc.name, tc.value)
if err != nil {
t.Fatalf("Failed to set global %s: %v", tc.name, err)
}
value, err := sandbox.GetGlobal(tc.name)
if err != nil {
t.Fatalf("Failed to get global %s: %v", tc.name, err)
}
// For tables/arrays, just check they're not nil
switch tc.value.(type) {
case map[string]any, []float64:
if value == nil {
t.Fatalf("Expected non-nil for %s, got nil", tc.name)
}
default:
if value != tc.value && !(tc.value == nil && value == nil) {
t.Fatalf("For %s: expected %v, got %v", tc.name, tc.value, value)
}
}
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.SetGlobal("test", 123)
if err == nil {
t.Fatal("Expected error when setting global on closed sandbox")
}
_, err = sandbox.GetGlobal("test")
if err == nil {
t.Fatal("Expected error when getting global from closed sandbox")
}
}
// TestSandboxCodeExecution tests running Lua code in the sandbox
func TestSandboxCodeExecution(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Run simple code
result, err := sandbox.Run("return 'hello'")
if err != nil {
t.Fatalf("Failed to run code: %v", err)
}
if result != "hello" {
t.Fatalf("Expected 'hello', got %v", result)
}
// Run code with multiple return values
result, err = sandbox.Run("return 1, 2, 3")
if err != nil {
t.Fatalf("Failed to run code: %v", err)
}
// Should return array for multiple values
results, ok := result.([]any)
if !ok {
t.Fatalf("Expected array for multiple returns, got %T", result)
}
if len(results) != 3 || results[0] != float64(1) || results[1] != float64(2) || results[2] != float64(3) {
t.Fatalf("Expected [1, 2, 3], got %v", results)
}
// Run code that sets a global
_, err = sandbox.Run("global_var = 'set from Lua'")
if err != nil {
t.Fatalf("Failed to run code: %v", err)
}
value, err := sandbox.GetGlobal("global_var")
if err != nil {
t.Fatalf("Failed to get global: %v", err)
}
if value != "set from Lua" {
t.Fatalf("Expected 'set from Lua', got %v", value)
}
// Run invalid code
_, err = sandbox.Run("this is not valid Lua")
if err == nil {
t.Fatal("Expected error for invalid code")
}
// Test after sandbox is closed
sandbox.Close()
_, err = sandbox.Run("return true")
if err == nil {
t.Fatal("Expected error when running code on closed sandbox")
}
}
// TestSandboxBytecodeExecution tests bytecode compilation and execution
func TestSandboxBytecodeExecution(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Compile code to bytecode
code := `
local function greet(name)
return "Hello, " .. name
end
return greet("World")
`
bytecode, err := sandbox.Compile(code)
if err != nil {
t.Fatalf("Failed to compile bytecode: %v", err)
}
if len(bytecode) == 0 {
t.Fatal("Expected non-empty bytecode")
}
// Run the bytecode
result, err := sandbox.RunBytecode(bytecode)
if err != nil {
t.Fatalf("Failed to run bytecode: %v", err)
}
if result != "Hello, World" {
t.Fatalf("Expected 'Hello, World', got %v", result)
}
// Test bytecode that sets a global
bytecode, err = sandbox.Compile("bytecode_var = 42")
if err != nil {
t.Fatalf("Failed to compile bytecode: %v", err)
}
_, err = sandbox.RunBytecode(bytecode)
if err != nil {
t.Fatalf("Failed to run bytecode: %v", err)
}
value, err := sandbox.GetGlobal("bytecode_var")
if err != nil {
t.Fatalf("Failed to get global: %v", err)
}
if value != float64(42) {
t.Fatalf("Expected 42, got %v", value)
}
// Test invalid bytecode
_, err = sandbox.RunBytecode([]byte("not valid bytecode"))
if err == nil {
t.Fatal("Expected error for invalid bytecode")
}
// Test after sandbox is closed
sandbox.Close()
_, err = sandbox.Compile("return true")
if err == nil {
t.Fatal("Expected error when compiling on closed sandbox")
}
_, err = sandbox.RunBytecode(bytecode)
if err == nil {
t.Fatal("Expected error when running bytecode on closed sandbox")
}
}
// TestSandboxPersistence tests state persistence across executions
func TestSandboxPersistence(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set up initial state
_, err := sandbox.Run(`
counter = 0
function increment()
counter = counter + 1
return counter
end
`)
if err != nil {
t.Fatalf("Failed to set up state: %v", err)
}
// Run multiple executions
for i := 1; i <= 3; i++ {
result, err := sandbox.Run("return increment()")
if err != nil {
t.Fatalf("Failed to run code: %v", err)
}
if result != float64(i) {
t.Fatalf("Expected %d, got %v", i, result)
}
}
// Check final counter value
value, err := sandbox.GetGlobal("counter")
if err != nil {
t.Fatalf("Failed to get global: %v", err)
}
if value != float64(3) {
t.Fatalf("Expected final counter to be 3, got %v", value)
}
// Test persistence with bytecode
bytecode, err := sandbox.Compile("return counter + 1")
if err != nil {
t.Fatalf("Failed to compile bytecode: %v", err)
}
result, err := sandbox.RunBytecode(bytecode)
if err != nil {
t.Fatalf("Failed to run bytecode: %v", err)
}
if result != float64(4) {
t.Fatalf("Expected 4, got %v", result)
}
}
// TestSandboxConcurrency tests concurrent access to the sandbox
func TestSandboxConcurrency(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set up a counter
_, err := sandbox.Run("counter = 0")
if err != nil {
t.Fatalf("Failed to set up counter: %v", err)
}
// Run concurrent increments
const numGoroutines = 10
const incrementsPerGoroutine = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
for j := 0; j < incrementsPerGoroutine; j++ {
_, err := sandbox.Run("counter = counter + 1")
if err != nil {
t.Errorf("Failed to increment counter: %v", err)
return
}
}
}()
}
wg.Wait()
// Check the final counter value
value, err := sandbox.GetGlobal("counter")
if err != nil {
t.Fatalf("Failed to get counter: %v", err)
}
expected := float64(numGoroutines * incrementsPerGoroutine)
if value != expected {
t.Fatalf("Expected counter to be %v, got %v", expected, value)
}
}
// TestPermanentLua tests the AddPermanentLua method
func TestPermanentLua(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Add permanent Lua environment
err := sandbox.AddPermanentLua(`
-- Create utility functions
function double(x)
return x * 2
end
function square(x)
return x * x
end
-- Create a protected environment
env = {
add = function(a, b) return a + b end,
sub = function(a, b) return a - b end
}
`)
if err != nil {
t.Fatalf("Failed to add permanent Lua: %v", err)
}
// Test using the permanent functions
testCases := []struct {
code string
expected float64
}{
{"return double(5)", 10},
{"return square(4)", 16},
{"return env.add(10, 20)", 30},
{"return env.sub(50, 30)", 20},
}
for _, tc := range testCases {
result, err := sandbox.Run(tc.code)
if err != nil {
t.Fatalf("Failed to run code '%s': %v", tc.code, err)
}
if result != tc.expected {
t.Fatalf("For '%s': expected %v, got %v", tc.code, tc.expected, result)
}
}
// Test persistence of permanent code across executions
_, err = sandbox.Run("counter = 0")
if err != nil {
t.Fatalf("Failed to set counter: %v", err)
}
result, err := sandbox.Run("counter = counter + 1; return double(counter)")
if err != nil {
t.Fatalf("Failed to run code: %v", err)
}
if result != float64(2) {
t.Fatalf("Expected 2, got %v", result)
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.AddPermanentLua("function test() end")
if err == nil {
t.Fatal("Expected error when adding permanent Lua to closed sandbox")
}
}
// TestResetEnvironment tests the ResetEnvironment method
func TestResetEnvironment(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set up some Go functions and Lua code
sandbox.RegisterFunction("timeNow", func(s *luajit.State) int {
s.PushString("test")
return 1
})
sandbox.AddPermanentLua(`
function permanent()
return "permanent function"
end
`)
_, err := sandbox.Run(`
temp_var = "will be reset"
function temp_func()
return "temp function"
end
`)
if err != nil {
t.Fatalf("Failed to run setup code: %v", err)
}
// Verify everything is set up correctly
result, err := sandbox.Run("return timeNow()")
if err != nil || result != "test" {
t.Fatalf("Go function not working: %v, %v", result, err)
}
result, err = sandbox.Run("return permanent()")
if err != nil || result != "permanent function" {
t.Fatalf("Permanent function not working: %v, %v", result, err)
}
result, err = sandbox.Run("return temp_func()")
if err != nil || result != "temp function" {
t.Fatalf("Temp function not working: %v, %v", result, err)
}
value, err := sandbox.GetGlobal("temp_var")
if err != nil || value != "will be reset" {
t.Fatalf("Temp var not set correctly: %v, %v", value, err)
}
// Reset the environment
err = sandbox.ResetEnvironment()
if err != nil {
t.Fatalf("Failed to reset environment: %v", err)
}
// Check Go function survives reset
result, err = sandbox.Run("return timeNow()")
if err != nil || result != "test" {
t.Fatalf("Go function should survive reset: %v, %v", result, err)
}
// Check permanent function is gone (it was added with AddPermanentLua but reset removes it)
_, err = sandbox.Run("return permanent()")
if err == nil {
t.Fatal("Permanent function should be gone after reset")
}
// Check temp function is gone
_, err = sandbox.Run("return temp_func()")
if err == nil {
t.Fatal("Temp function should be gone after reset")
}
// Check temp var is gone
value, err = sandbox.GetGlobal("temp_var")
if err != nil || value != nil {
t.Fatalf("Temp var should be nil after reset: %v", value)
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.ResetEnvironment()
if err == nil {
t.Fatal("Expected error when resetting closed sandbox")
}
}
// TestRunFile tests the RunFile method
func TestRunFile(t *testing.T) {
// Create a temporary Lua file
tmpfile, err := createTempLuaFile("return 42")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer removeTempFile(tmpfile)
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Run the file
err = sandbox.RunFile(tmpfile)
if err != nil {
t.Fatalf("Failed to run file: %v", err)
}
// Test non-existent file
err = sandbox.RunFile("does_not_exist.lua")
if err == nil {
t.Fatal("Expected error for non-existent file")
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.RunFile(tmpfile)
if err == nil {
t.Fatal("Expected error when running file on closed sandbox")
}
}
// TestSandboxPackagePath tests the SetPackagePath and AddPackagePath methods
func TestSandboxPackagePath(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set package path
testPath := "/test/path/?.lua"
err := sandbox.SetPackagePath(testPath)
if err != nil {
t.Fatalf("Failed to set package path: %v", err)
}
// Check path was set
result, err := sandbox.Run("return package.path")
if err != nil {
t.Fatalf("Failed to get package.path: %v", err)
}
if result != testPath {
t.Fatalf("Expected package.path to be %q, got %q", testPath, result)
}
// Add to package path
addPath := "/another/path/?.lua"
err = sandbox.AddPackagePath(addPath)
if err != nil {
t.Fatalf("Failed to add package path: %v", err)
}
// Check path was updated
result, err = sandbox.Run("return package.path")
if err != nil {
t.Fatalf("Failed to get updated package.path: %v", err)
}
expected := testPath + ";" + addPath
if result != expected {
t.Fatalf("Expected package.path to be %q, got %q", expected, result)
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.SetPackagePath(testPath)
if err == nil {
t.Fatal("Expected error when setting package path on closed sandbox")
}
err = sandbox.AddPackagePath(addPath)
if err == nil {
t.Fatal("Expected error when adding package path to closed sandbox")
}
}
// TestSandboxLoadModule tests loading modules
func TestSandboxLoadModule(t *testing.T) {
// Skip for now since we don't have actual modules to load in the test environment
t.Skip("Skipping module loading test as it requires actual modules")
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set package path to include current directory
err := sandbox.SetPackagePath("./?.lua")
if err != nil {
t.Fatalf("Failed to set package path: %v", err)
}
// Try to load a non-existent module
err = sandbox.LoadModule("nonexistent_module")
if err == nil {
t.Fatal("Expected error when loading non-existent module")
}
}
// Helper functions
// createTempLuaFile creates a temporary Lua file with the given content
func createTempLuaFile(content string) (string, error) {
tmpfile, err := os.CreateTemp("", "test-*.lua")
if err != nil {
return "", err
}
if _, err := tmpfile.WriteString(content); err != nil {
os.Remove(tmpfile.Name())
return "", err
}
if err := tmpfile.Close(); err != nil {
os.Remove(tmpfile.Name())
return "", err
}
return tmpfile.Name(), nil
}
// removeTempFile removes a temporary file
func removeTempFile(path string) {
os.Remove(path)
}