LuaJIT-to-Go/tests/sandbox_test.go
2025-03-27 21:31:41 -05:00

777 lines
18 KiB
Go

package luajit_test
import (
"os"
"sync"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// TestSandboxLifecycle tests sandbox creation and closing
func TestSandboxLifecycle(t *testing.T) {
// Create a new sandbox
sandbox := luajit.NewSandbox()
if sandbox == nil {
t.Fatal("Failed to create sandbox")
}
// Close the sandbox
sandbox.Close()
// Test idempotent close (should not panic)
sandbox.Close()
}
// TestSandboxFunctionRegistration tests registering Go functions in the sandbox
func TestSandboxFunctionRegistration(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Register a simple addition function
add := func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1
}
err := sandbox.RegisterFunction("add", add)
if err != nil {
t.Fatalf("Failed to register function: %v", err)
}
// Test the function
result, err := sandbox.Run("return add(3, 4)")
if err != nil {
t.Fatalf("Failed to execute function: %v", err)
}
if result != float64(7) {
t.Fatalf("Expected 7, got %v", result)
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.RegisterFunction("test", add)
if err == nil {
t.Fatal("Expected error when registering function on closed sandbox")
}
}
// TestSandboxGlobalVariables tests setting and getting global variables
func TestSandboxGlobalVariables(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set a global variable
err := sandbox.SetGlobal("answer", 42)
if err != nil {
t.Fatalf("Failed to set global: %v", err)
}
// Get the global variable
value, err := sandbox.GetGlobal("answer")
if err != nil {
t.Fatalf("Failed to get global: %v", err)
}
if value != float64(42) {
t.Fatalf("Expected 42, got %v", value)
}
// Test different types
testCases := []struct {
name string
value any
}{
{"nil_value", nil},
{"bool_value", true},
{"string_value", "hello"},
{"table_value", map[string]any{"key": "value"}},
{"array_value", []float64{1, 2, 3}},
}
for _, tc := range testCases {
err := sandbox.SetGlobal(tc.name, tc.value)
if err != nil {
t.Fatalf("Failed to set global %s: %v", tc.name, err)
}
value, err := sandbox.GetGlobal(tc.name)
if err != nil {
t.Fatalf("Failed to get global %s: %v", tc.name, err)
}
// For tables/arrays, just check they're not nil
switch tc.value.(type) {
case map[string]any, []float64:
if value == nil {
t.Fatalf("Expected non-nil for %s, got nil", tc.name)
}
default:
if value != tc.value && !(tc.value == nil && value == nil) {
t.Fatalf("For %s: expected %v, got %v", tc.name, tc.value, value)
}
}
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.SetGlobal("test", 123)
if err == nil {
t.Fatal("Expected error when setting global on closed sandbox")
}
_, err = sandbox.GetGlobal("test")
if err == nil {
t.Fatal("Expected error when getting global from closed sandbox")
}
}
// TestSandboxCodeExecution tests running Lua code in the sandbox
func TestSandboxCodeExecution(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Run simple code
result, err := sandbox.Run("return 'hello'")
if err != nil {
t.Fatalf("Failed to run code: %v", err)
}
if result != "hello" {
t.Fatalf("Expected 'hello', got %v", result)
}
// Run code with multiple return values
result, err = sandbox.Run("return 1, 2, 3")
if err != nil {
t.Fatalf("Failed to run code: %v", err)
}
// Should return array for multiple values
results, ok := result.([]any)
if !ok {
t.Fatalf("Expected array for multiple returns, got %T", result)
}
if len(results) != 3 || results[0] != float64(1) || results[1] != float64(2) || results[2] != float64(3) {
t.Fatalf("Expected [1, 2, 3], got %v", results)
}
// Run code that sets a global
_, err = sandbox.Run("global_var = 'set from Lua'")
if err != nil {
t.Fatalf("Failed to run code: %v", err)
}
value, err := sandbox.GetGlobal("global_var")
if err != nil {
t.Fatalf("Failed to get global: %v", err)
}
if value != "set from Lua" {
t.Fatalf("Expected 'set from Lua', got %v", value)
}
// Run invalid code
_, err = sandbox.Run("this is not valid Lua")
if err == nil {
t.Fatal("Expected error for invalid code")
}
// Test after sandbox is closed
sandbox.Close()
_, err = sandbox.Run("return true")
if err == nil {
t.Fatal("Expected error when running code on closed sandbox")
}
}
// TestSandboxBytecodeExecution tests bytecode compilation and execution
func TestSandboxBytecodeExecution(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Compile code to bytecode
code := `
local function greet(name)
return "Hello, " .. name
end
return greet("World")
`
bytecode, err := sandbox.Compile(code)
if err != nil {
t.Fatalf("Failed to compile bytecode: %v", err)
}
if len(bytecode) == 0 {
t.Fatal("Expected non-empty bytecode")
}
// Run the bytecode
result, err := sandbox.RunBytecode(bytecode)
if err != nil {
t.Fatalf("Failed to run bytecode: %v", err)
}
if result != "Hello, World" {
t.Fatalf("Expected 'Hello, World', got %v", result)
}
// Test bytecode that sets a global
bytecode, err = sandbox.Compile("bytecode_var = 42")
if err != nil {
t.Fatalf("Failed to compile bytecode: %v", err)
}
_, err = sandbox.RunBytecode(bytecode)
if err != nil {
t.Fatalf("Failed to run bytecode: %v", err)
}
value, err := sandbox.GetGlobal("bytecode_var")
if err != nil {
t.Fatalf("Failed to get global: %v", err)
}
if value != float64(42) {
t.Fatalf("Expected 42, got %v", value)
}
// Test invalid bytecode
_, err = sandbox.RunBytecode([]byte("not valid bytecode"))
if err == nil {
t.Fatal("Expected error for invalid bytecode")
}
// Test after sandbox is closed
sandbox.Close()
_, err = sandbox.Compile("return true")
if err == nil {
t.Fatal("Expected error when compiling on closed sandbox")
}
_, err = sandbox.RunBytecode(bytecode)
if err == nil {
t.Fatal("Expected error when running bytecode on closed sandbox")
}
}
// TestSandboxPersistence tests state persistence across executions
func TestSandboxPersistence(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set up initial state
_, err := sandbox.Run(`
counter = 0
function increment()
counter = counter + 1
return counter
end
`)
if err != nil {
t.Fatalf("Failed to set up state: %v", err)
}
// Run multiple executions
for i := 1; i <= 3; i++ {
result, err := sandbox.Run("return increment()")
if err != nil {
t.Fatalf("Failed to run code: %v", err)
}
if result != float64(i) {
t.Fatalf("Expected %d, got %v", i, result)
}
}
// Check final counter value
value, err := sandbox.GetGlobal("counter")
if err != nil {
t.Fatalf("Failed to get global: %v", err)
}
if value != float64(3) {
t.Fatalf("Expected final counter to be 3, got %v", value)
}
// Test persistence with bytecode
bytecode, err := sandbox.Compile("return counter + 1")
if err != nil {
t.Fatalf("Failed to compile bytecode: %v", err)
}
result, err := sandbox.RunBytecode(bytecode)
if err != nil {
t.Fatalf("Failed to run bytecode: %v", err)
}
if result != float64(4) {
t.Fatalf("Expected 4, got %v", result)
}
}
// TestSandboxConcurrency tests concurrent access to the sandbox
func TestSandboxConcurrency(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set up a counter
_, err := sandbox.Run("counter = 0")
if err != nil {
t.Fatalf("Failed to set up counter: %v", err)
}
// Run concurrent increments
const numGoroutines = 10
const incrementsPerGoroutine = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
for j := 0; j < incrementsPerGoroutine; j++ {
_, err := sandbox.Run("counter = counter + 1")
if err != nil {
t.Errorf("Failed to increment counter: %v", err)
return
}
}
}()
}
wg.Wait()
// Check the final counter value
value, err := sandbox.GetGlobal("counter")
if err != nil {
t.Fatalf("Failed to get counter: %v", err)
}
expected := float64(numGoroutines * incrementsPerGoroutine)
if value != expected {
t.Fatalf("Expected counter to be %v, got %v", expected, value)
}
}
// TestPermanentLua tests the AddPermanentLua method
func TestPermanentLua(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Add permanent Lua environment
err := sandbox.AddPermanentLua(`
-- Create utility functions
function double(x)
return x * 2
end
function square(x)
return x * x
end
-- Create a protected environment
env = {
add = function(a, b) return a + b end,
sub = function(a, b) return a - b end
}
`)
if err != nil {
t.Fatalf("Failed to add permanent Lua: %v", err)
}
// Test using the permanent functions
testCases := []struct {
code string
expected float64
}{
{"return double(5)", 10},
{"return square(4)", 16},
{"return env.add(10, 20)", 30},
{"return env.sub(50, 30)", 20},
}
for _, tc := range testCases {
result, err := sandbox.Run(tc.code)
if err != nil {
t.Fatalf("Failed to run code '%s': %v", tc.code, err)
}
if result != tc.expected {
t.Fatalf("For '%s': expected %v, got %v", tc.code, tc.expected, result)
}
}
// Test persistence of permanent code across executions
_, err = sandbox.Run("counter = 0")
if err != nil {
t.Fatalf("Failed to set counter: %v", err)
}
result, err := sandbox.Run("counter = counter + 1; return double(counter)")
if err != nil {
t.Fatalf("Failed to run code: %v", err)
}
if result != float64(2) {
t.Fatalf("Expected 2, got %v", result)
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.AddPermanentLua("function test() end")
if err == nil {
t.Fatal("Expected error when adding permanent Lua to closed sandbox")
}
}
// TestResetEnvironment tests the ResetEnvironment method
func TestResetEnvironment(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set up some Go functions and Lua code
sandbox.RegisterFunction("timeNow", func(s *luajit.State) int {
s.PushString("test")
return 1
})
sandbox.AddPermanentLua(`
function permanent()
return "permanent function"
end
`)
_, err := sandbox.Run(`
temp_var = "will be reset"
function temp_func()
return "temp function"
end
`)
if err != nil {
t.Fatalf("Failed to run setup code: %v", err)
}
// Verify everything is set up correctly
result, err := sandbox.Run("return timeNow()")
if err != nil || result != "test" {
t.Fatalf("Go function not working: %v, %v", result, err)
}
result, err = sandbox.Run("return permanent()")
if err != nil || result != "permanent function" {
t.Fatalf("Permanent function not working: %v, %v", result, err)
}
result, err = sandbox.Run("return temp_func()")
if err != nil || result != "temp function" {
t.Fatalf("Temp function not working: %v, %v", result, err)
}
value, err := sandbox.GetGlobal("temp_var")
if err != nil || value != "will be reset" {
t.Fatalf("Temp var not set correctly: %v, %v", value, err)
}
// Reset the environment
err = sandbox.ResetEnvironment()
if err != nil {
t.Fatalf("Failed to reset environment: %v", err)
}
// Check Go function survives reset
result, err = sandbox.Run("return timeNow()")
if err != nil || result != "test" {
t.Fatalf("Go function should survive reset: %v, %v", result, err)
}
// Check permanent function is gone (it was added with AddPermanentLua but reset removes it)
_, err = sandbox.Run("return permanent()")
if err == nil {
t.Fatal("Permanent function should be gone after reset")
}
// Check temp function is gone
_, err = sandbox.Run("return temp_func()")
if err == nil {
t.Fatal("Temp function should be gone after reset")
}
// Check temp var is gone
value, err = sandbox.GetGlobal("temp_var")
if err != nil || value != nil {
t.Fatalf("Temp var should be nil after reset: %v", value)
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.ResetEnvironment()
if err == nil {
t.Fatal("Expected error when resetting closed sandbox")
}
}
// TestRunFile tests the RunFile method
func TestRunFile(t *testing.T) {
// Create a temporary Lua file
tmpfile, err := createTempLuaFile("return 42")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer removeTempFile(tmpfile)
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Run the file
err = sandbox.RunFile(tmpfile)
if err != nil {
t.Fatalf("Failed to run file: %v", err)
}
// Test non-existent file
err = sandbox.RunFile("does_not_exist.lua")
if err == nil {
t.Fatal("Expected error for non-existent file")
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.RunFile(tmpfile)
if err == nil {
t.Fatal("Expected error when running file on closed sandbox")
}
}
// TestSandboxPackagePath tests the SetPackagePath and AddPackagePath methods
func TestSandboxPackagePath(t *testing.T) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set package path
testPath := "/test/path/?.lua"
err := sandbox.SetPackagePath(testPath)
if err != nil {
t.Fatalf("Failed to set package path: %v", err)
}
// Check path was set
result, err := sandbox.Run("return package.path")
if err != nil {
t.Fatalf("Failed to get package.path: %v", err)
}
if result != testPath {
t.Fatalf("Expected package.path to be %q, got %q", testPath, result)
}
// Add to package path
addPath := "/another/path/?.lua"
err = sandbox.AddPackagePath(addPath)
if err != nil {
t.Fatalf("Failed to add package path: %v", err)
}
// Check path was updated
result, err = sandbox.Run("return package.path")
if err != nil {
t.Fatalf("Failed to get updated package.path: %v", err)
}
expected := testPath + ";" + addPath
if result != expected {
t.Fatalf("Expected package.path to be %q, got %q", expected, result)
}
// Test after sandbox is closed
sandbox.Close()
err = sandbox.SetPackagePath(testPath)
if err == nil {
t.Fatal("Expected error when setting package path on closed sandbox")
}
err = sandbox.AddPackagePath(addPath)
if err == nil {
t.Fatal("Expected error when adding package path to closed sandbox")
}
}
// TestSandboxLoadModule tests loading modules
func TestSandboxLoadModule(t *testing.T) {
// Skip for now since we don't have actual modules to load in the test environment
t.Skip("Skipping module loading test as it requires actual modules")
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Set package path to include current directory
err := sandbox.SetPackagePath("./?.lua")
if err != nil {
t.Fatalf("Failed to set package path: %v", err)
}
// Try to load a non-existent module
err = sandbox.LoadModule("nonexistent_module")
if err == nil {
t.Fatal("Expected error when loading non-existent module")
}
}
// Helper functions
// createTempLuaFile creates a temporary Lua file with the given content
func createTempLuaFile(content string) (string, error) {
tmpfile, err := os.CreateTemp("", "test-*.lua")
if err != nil {
return "", err
}
if _, err := tmpfile.WriteString(content); err != nil {
os.Remove(tmpfile.Name())
return "", err
}
if err := tmpfile.Close(); err != nil {
os.Remove(tmpfile.Name())
return "", err
}
return tmpfile.Name(), nil
}
// removeTempFile removes a temporary file
func removeTempFile(path string) {
os.Remove(path)
}
// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code
func BenchmarkSandboxLuaExecution(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Simple Lua code that does some computation
code := `
local sum = 0
for i = 1, 100 do
sum = sum + i
end
return sum
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode
func BenchmarkSandboxBytecodeExecution(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Same code as above, but precompiled
code := `
local sum = 0
for i = 1, 100 do
sum = sum + i
end
return sum
`
// Compile the bytecode once
bytecode, err := sandbox.Compile(code)
if err != nil {
b.Fatalf("Failed to compile bytecode: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.RunBytecode(bytecode)
if err != nil {
b.Fatalf("Failed to run bytecode: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxComplexComputation measures performance with more complex computation
func BenchmarkSandboxComplexComputation(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// More complex Lua code that calculates Fibonacci numbers
code := `
function fibonacci(n)
if n <= 1 then
return n
end
return fibonacci(n-1) + fibonacci(n-2)
end
return fibonacci(15) -- Not too high to avoid excessive runtime
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(610) {
b.Fatalf("Incorrect result: %v", result)
}
}
}
// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function
func BenchmarkSandboxFunctionCall(b *testing.B) {
sandbox := luajit.NewSandbox()
defer sandbox.Close()
// Register a simple addition function
add := func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1
}
err := sandbox.RegisterFunction("add", add)
if err != nil {
b.Fatalf("Failed to register function: %v", err)
}
// Lua code that calls the Go function in a loop
code := `
local sum = 0
for i = 1, 100 do
sum = add(sum, i)
end
return sum
`
b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := sandbox.Run(code)
if err != nil {
b.Fatalf("Failed to run code: %v", err)
}
if result != float64(5050) {
b.Fatalf("Incorrect result: %v", result)
}
}
}