package luajit_test import ( "os" "sync" "testing" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // TestSandboxLifecycle tests sandbox creation and closing func TestSandboxLifecycle(t *testing.T) { // Create a new sandbox sandbox := luajit.NewSandbox() if sandbox == nil { t.Fatal("Failed to create sandbox") } // Close the sandbox sandbox.Close() // Test idempotent close (should not panic) sandbox.Close() } // TestSandboxFunctionRegistration tests registering Go functions in the sandbox func TestSandboxFunctionRegistration(t *testing.T) { sandbox := luajit.NewSandbox() defer sandbox.Close() // Register a simple addition function add := func(s *luajit.State) int { a := s.ToNumber(1) b := s.ToNumber(2) s.PushNumber(a + b) return 1 } err := sandbox.RegisterFunction("add", add) if err != nil { t.Fatalf("Failed to register function: %v", err) } // Test the function result, err := sandbox.Run("return add(3, 4)") if err != nil { t.Fatalf("Failed to execute function: %v", err) } if result != float64(7) { t.Fatalf("Expected 7, got %v", result) } // Test after sandbox is closed sandbox.Close() err = sandbox.RegisterFunction("test", add) if err == nil { t.Fatal("Expected error when registering function on closed sandbox") } } // TestSandboxGlobalVariables tests setting and getting global variables func TestSandboxGlobalVariables(t *testing.T) { sandbox := luajit.NewSandbox() defer sandbox.Close() // Set a global variable err := sandbox.SetGlobal("answer", 42) if err != nil { t.Fatalf("Failed to set global: %v", err) } // Get the global variable value, err := sandbox.GetGlobal("answer") if err != nil { t.Fatalf("Failed to get global: %v", err) } if value != float64(42) { t.Fatalf("Expected 42, got %v", value) } // Test different types testCases := []struct { name string value any }{ {"nil_value", nil}, {"bool_value", true}, {"string_value", "hello"}, {"table_value", map[string]any{"key": "value"}}, {"array_value", []float64{1, 2, 3}}, } for _, tc := range testCases { err := sandbox.SetGlobal(tc.name, tc.value) if err != nil { t.Fatalf("Failed to set global %s: %v", tc.name, err) } value, err := sandbox.GetGlobal(tc.name) if err != nil { t.Fatalf("Failed to get global %s: %v", tc.name, err) } // For tables/arrays, just check they're not nil switch tc.value.(type) { case map[string]any, []float64: if value == nil { t.Fatalf("Expected non-nil for %s, got nil", tc.name) } default: if value != tc.value && !(tc.value == nil && value == nil) { t.Fatalf("For %s: expected %v, got %v", tc.name, tc.value, value) } } } // Test after sandbox is closed sandbox.Close() err = sandbox.SetGlobal("test", 123) if err == nil { t.Fatal("Expected error when setting global on closed sandbox") } _, err = sandbox.GetGlobal("test") if err == nil { t.Fatal("Expected error when getting global from closed sandbox") } } // TestSandboxCodeExecution tests running Lua code in the sandbox func TestSandboxCodeExecution(t *testing.T) { sandbox := luajit.NewSandbox() defer sandbox.Close() // Run simple code result, err := sandbox.Run("return 'hello'") if err != nil { t.Fatalf("Failed to run code: %v", err) } if result != "hello" { t.Fatalf("Expected 'hello', got %v", result) } // Run code with multiple return values result, err = sandbox.Run("return 1, 2, 3") if err != nil { t.Fatalf("Failed to run code: %v", err) } // Should return array for multiple values results, ok := result.([]any) if !ok { t.Fatalf("Expected array for multiple returns, got %T", result) } if len(results) != 3 || results[0] != float64(1) || results[1] != float64(2) || results[2] != float64(3) { t.Fatalf("Expected [1, 2, 3], got %v", results) } // Run code that sets a global _, err = sandbox.Run("global_var = 'set from Lua'") if err != nil { t.Fatalf("Failed to run code: %v", err) } value, err := sandbox.GetGlobal("global_var") if err != nil { t.Fatalf("Failed to get global: %v", err) } if value != "set from Lua" { t.Fatalf("Expected 'set from Lua', got %v", value) } // Run invalid code _, err = sandbox.Run("this is not valid Lua") if err == nil { t.Fatal("Expected error for invalid code") } // Test after sandbox is closed sandbox.Close() _, err = sandbox.Run("return true") if err == nil { t.Fatal("Expected error when running code on closed sandbox") } } // TestSandboxBytecodeExecution tests bytecode compilation and execution func TestSandboxBytecodeExecution(t *testing.T) { sandbox := luajit.NewSandbox() defer sandbox.Close() // Compile code to bytecode code := ` local function greet(name) return "Hello, " .. name end return greet("World") ` bytecode, err := sandbox.Compile(code) if err != nil { t.Fatalf("Failed to compile bytecode: %v", err) } if len(bytecode) == 0 { t.Fatal("Expected non-empty bytecode") } // Run the bytecode result, err := sandbox.RunBytecode(bytecode) if err != nil { t.Fatalf("Failed to run bytecode: %v", err) } if result != "Hello, World" { t.Fatalf("Expected 'Hello, World', got %v", result) } // Test bytecode that sets a global bytecode, err = sandbox.Compile("bytecode_var = 42") if err != nil { t.Fatalf("Failed to compile bytecode: %v", err) } _, err = sandbox.RunBytecode(bytecode) if err != nil { t.Fatalf("Failed to run bytecode: %v", err) } value, err := sandbox.GetGlobal("bytecode_var") if err != nil { t.Fatalf("Failed to get global: %v", err) } if value != float64(42) { t.Fatalf("Expected 42, got %v", value) } // Test invalid bytecode _, err = sandbox.RunBytecode([]byte("not valid bytecode")) if err == nil { t.Fatal("Expected error for invalid bytecode") } // Test after sandbox is closed sandbox.Close() _, err = sandbox.Compile("return true") if err == nil { t.Fatal("Expected error when compiling on closed sandbox") } _, err = sandbox.RunBytecode(bytecode) if err == nil { t.Fatal("Expected error when running bytecode on closed sandbox") } } // TestSandboxPersistence tests state persistence across executions func TestSandboxPersistence(t *testing.T) { sandbox := luajit.NewSandbox() defer sandbox.Close() // Set up initial state _, err := sandbox.Run(` counter = 0 function increment() counter = counter + 1 return counter end `) if err != nil { t.Fatalf("Failed to set up state: %v", err) } // Run multiple executions for i := 1; i <= 3; i++ { result, err := sandbox.Run("return increment()") if err != nil { t.Fatalf("Failed to run code: %v", err) } if result != float64(i) { t.Fatalf("Expected %d, got %v", i, result) } } // Check final counter value value, err := sandbox.GetGlobal("counter") if err != nil { t.Fatalf("Failed to get global: %v", err) } if value != float64(3) { t.Fatalf("Expected final counter to be 3, got %v", value) } // Test persistence with bytecode bytecode, err := sandbox.Compile("return counter + 1") if err != nil { t.Fatalf("Failed to compile bytecode: %v", err) } result, err := sandbox.RunBytecode(bytecode) if err != nil { t.Fatalf("Failed to run bytecode: %v", err) } if result != float64(4) { t.Fatalf("Expected 4, got %v", result) } } // TestSandboxConcurrency tests concurrent access to the sandbox func TestSandboxConcurrency(t *testing.T) { sandbox := luajit.NewSandbox() defer sandbox.Close() // Set up a counter _, err := sandbox.Run("counter = 0") if err != nil { t.Fatalf("Failed to set up counter: %v", err) } // Run concurrent increments const numGoroutines = 10 const incrementsPerGoroutine = 100 var wg sync.WaitGroup wg.Add(numGoroutines) for i := 0; i < numGoroutines; i++ { go func() { defer wg.Done() for j := 0; j < incrementsPerGoroutine; j++ { _, err := sandbox.Run("counter = counter + 1") if err != nil { t.Errorf("Failed to increment counter: %v", err) return } } }() } wg.Wait() // Check the final counter value value, err := sandbox.GetGlobal("counter") if err != nil { t.Fatalf("Failed to get counter: %v", err) } expected := float64(numGoroutines * incrementsPerGoroutine) if value != expected { t.Fatalf("Expected counter to be %v, got %v", expected, value) } } // TestPermanentLua tests the AddPermanentLua method func TestPermanentLua(t *testing.T) { sandbox := luajit.NewSandbox() defer sandbox.Close() // Add permanent Lua environment err := sandbox.AddPermanentLua(` -- Create utility functions function double(x) return x * 2 end function square(x) return x * x end -- Create a protected environment env = { add = function(a, b) return a + b end, sub = function(a, b) return a - b end } `) if err != nil { t.Fatalf("Failed to add permanent Lua: %v", err) } // Test using the permanent functions testCases := []struct { code string expected float64 }{ {"return double(5)", 10}, {"return square(4)", 16}, {"return env.add(10, 20)", 30}, {"return env.sub(50, 30)", 20}, } for _, tc := range testCases { result, err := sandbox.Run(tc.code) if err != nil { t.Fatalf("Failed to run code '%s': %v", tc.code, err) } if result != tc.expected { t.Fatalf("For '%s': expected %v, got %v", tc.code, tc.expected, result) } } // Test persistence of permanent code across executions _, err = sandbox.Run("counter = 0") if err != nil { t.Fatalf("Failed to set counter: %v", err) } result, err := sandbox.Run("counter = counter + 1; return double(counter)") if err != nil { t.Fatalf("Failed to run code: %v", err) } if result != float64(2) { t.Fatalf("Expected 2, got %v", result) } // Test after sandbox is closed sandbox.Close() err = sandbox.AddPermanentLua("function test() end") if err == nil { t.Fatal("Expected error when adding permanent Lua to closed sandbox") } } // TestResetEnvironment tests the ResetEnvironment method func TestResetEnvironment(t *testing.T) { sandbox := luajit.NewSandbox() defer sandbox.Close() // Set up some Go functions and Lua code sandbox.RegisterFunction("timeNow", func(s *luajit.State) int { s.PushString("test") return 1 }) sandbox.AddPermanentLua(` function permanent() return "permanent function" end `) _, err := sandbox.Run(` temp_var = "will be reset" function temp_func() return "temp function" end `) if err != nil { t.Fatalf("Failed to run setup code: %v", err) } // Verify everything is set up correctly result, err := sandbox.Run("return timeNow()") if err != nil || result != "test" { t.Fatalf("Go function not working: %v, %v", result, err) } result, err = sandbox.Run("return permanent()") if err != nil || result != "permanent function" { t.Fatalf("Permanent function not working: %v, %v", result, err) } result, err = sandbox.Run("return temp_func()") if err != nil || result != "temp function" { t.Fatalf("Temp function not working: %v, %v", result, err) } value, err := sandbox.GetGlobal("temp_var") if err != nil || value != "will be reset" { t.Fatalf("Temp var not set correctly: %v, %v", value, err) } // Reset the environment err = sandbox.ResetEnvironment() if err != nil { t.Fatalf("Failed to reset environment: %v", err) } // Check Go function survives reset result, err = sandbox.Run("return timeNow()") if err != nil || result != "test" { t.Fatalf("Go function should survive reset: %v, %v", result, err) } // Check permanent function is gone (it was added with AddPermanentLua but reset removes it) _, err = sandbox.Run("return permanent()") if err == nil { t.Fatal("Permanent function should be gone after reset") } // Check temp function is gone _, err = sandbox.Run("return temp_func()") if err == nil { t.Fatal("Temp function should be gone after reset") } // Check temp var is gone value, err = sandbox.GetGlobal("temp_var") if err != nil || value != nil { t.Fatalf("Temp var should be nil after reset: %v", value) } // Test after sandbox is closed sandbox.Close() err = sandbox.ResetEnvironment() if err == nil { t.Fatal("Expected error when resetting closed sandbox") } } // TestRunFile tests the RunFile method func TestRunFile(t *testing.T) { // Create a temporary Lua file tmpfile, err := createTempLuaFile("return 42") if err != nil { t.Fatalf("Failed to create temp file: %v", err) } defer removeTempFile(tmpfile) sandbox := luajit.NewSandbox() defer sandbox.Close() // Run the file err = sandbox.RunFile(tmpfile) if err != nil { t.Fatalf("Failed to run file: %v", err) } // Test non-existent file err = sandbox.RunFile("does_not_exist.lua") if err == nil { t.Fatal("Expected error for non-existent file") } // Test after sandbox is closed sandbox.Close() err = sandbox.RunFile(tmpfile) if err == nil { t.Fatal("Expected error when running file on closed sandbox") } } // TestSandboxPackagePath tests the SetPackagePath and AddPackagePath methods func TestSandboxPackagePath(t *testing.T) { sandbox := luajit.NewSandbox() defer sandbox.Close() // Set package path testPath := "/test/path/?.lua" err := sandbox.SetPackagePath(testPath) if err != nil { t.Fatalf("Failed to set package path: %v", err) } // Check path was set result, err := sandbox.Run("return package.path") if err != nil { t.Fatalf("Failed to get package.path: %v", err) } if result != testPath { t.Fatalf("Expected package.path to be %q, got %q", testPath, result) } // Add to package path addPath := "/another/path/?.lua" err = sandbox.AddPackagePath(addPath) if err != nil { t.Fatalf("Failed to add package path: %v", err) } // Check path was updated result, err = sandbox.Run("return package.path") if err != nil { t.Fatalf("Failed to get updated package.path: %v", err) } expected := testPath + ";" + addPath if result != expected { t.Fatalf("Expected package.path to be %q, got %q", expected, result) } // Test after sandbox is closed sandbox.Close() err = sandbox.SetPackagePath(testPath) if err == nil { t.Fatal("Expected error when setting package path on closed sandbox") } err = sandbox.AddPackagePath(addPath) if err == nil { t.Fatal("Expected error when adding package path to closed sandbox") } } // TestSandboxLoadModule tests loading modules func TestSandboxLoadModule(t *testing.T) { // Skip for now since we don't have actual modules to load in the test environment t.Skip("Skipping module loading test as it requires actual modules") sandbox := luajit.NewSandbox() defer sandbox.Close() // Set package path to include current directory err := sandbox.SetPackagePath("./?.lua") if err != nil { t.Fatalf("Failed to set package path: %v", err) } // Try to load a non-existent module err = sandbox.LoadModule("nonexistent_module") if err == nil { t.Fatal("Expected error when loading non-existent module") } } // Helper functions // createTempLuaFile creates a temporary Lua file with the given content func createTempLuaFile(content string) (string, error) { tmpfile, err := os.CreateTemp("", "test-*.lua") if err != nil { return "", err } if _, err := tmpfile.WriteString(content); err != nil { os.Remove(tmpfile.Name()) return "", err } if err := tmpfile.Close(); err != nil { os.Remove(tmpfile.Name()) return "", err } return tmpfile.Name(), nil } // removeTempFile removes a temporary file func removeTempFile(path string) { os.Remove(path) }