diff --git a/core/workers/eventloop.go b/core/workers/eventloop.go deleted file mode 100644 index e262f08..0000000 --- a/core/workers/eventloop.go +++ /dev/null @@ -1,372 +0,0 @@ -package workers - -import ( - "context" - "errors" - "sync" - "sync/atomic" - "time" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// Common errors -var ( - ErrLoopClosed = errors.New("event loop is closed") - ErrExecutionTimeout = errors.New("script execution timed out") -) - -// StateInitFunc is a function that initializes a Lua state -type StateInitFunc func(*luajit.State) error - -// EventLoop represents a single-threaded Lua execution environment -type EventLoop struct { - state *luajit.State // Single Lua state for all executions - jobQueue chan job // Channel for receiving jobs - quit chan struct{} // Channel for shutdown signaling - wg sync.WaitGroup // WaitGroup for clean shutdown - isRunning atomic.Bool // Flag to track if loop is running - timeout time.Duration // Default timeout for script execution - stateInit StateInitFunc // Optional function to initialize Lua state - bufferSize int // Size of job queue buffer -} - -// EventLoopConfig contains configuration options for creating an EventLoop -type EventLoopConfig struct { - // StateInit is a function to initialize the Lua state with custom modules and functions - StateInit StateInitFunc - - // BufferSize is the size of the job queue buffer (default: 100) - BufferSize int - - // Timeout is the default execution timeout (default: 30s, 0 means no timeout) - Timeout time.Duration -} - -// NewEventLoop creates a new event loop with default configuration -func NewEventLoop() (*EventLoop, error) { - return NewEventLoopWithConfig(EventLoopConfig{}) -} - -// NewEventLoopWithInit creates a new event loop with a state initialization function -func NewEventLoopWithInit(init StateInitFunc) (*EventLoop, error) { - return NewEventLoopWithConfig(EventLoopConfig{ - StateInit: init, - }) -} - -// NewEventLoopWithConfig creates a new event loop with custom configuration -func NewEventLoopWithConfig(config EventLoopConfig) (*EventLoop, error) { - // Set default values - bufferSize := config.BufferSize - if bufferSize <= 0 { - bufferSize = 100 // Default buffer size - } - - timeout := config.Timeout - if timeout == 0 { - timeout = 30 * time.Second // Default timeout - } - - // Initialize the Lua state - state := luajit.New() - if state == nil { - return nil, errors.New("failed to create Lua state") - } - - // Create the event loop instance - el := &EventLoop{ - state: state, - jobQueue: make(chan job, bufferSize), - quit: make(chan struct{}), - timeout: timeout, - stateInit: config.StateInit, - bufferSize: bufferSize, - } - el.isRunning.Store(true) - - // Set up the sandbox environment - if err := setupSandbox(el.state); err != nil { - state.Close() - return nil, err - } - - // Initialize the state if needed - if el.stateInit != nil { - if err := el.stateInit(el.state); err != nil { - state.Close() - return nil, err - } - } - - // Start the event loop - el.wg.Add(1) - go el.run() - - return el, nil -} - -// run is the main event loop goroutine -func (el *EventLoop) run() { - defer el.wg.Done() - defer el.state.Close() - - for { - select { - case job, ok := <-el.jobQueue: - if !ok { - // Job queue closed, exit - return - } - - // Execute job with timeout if configured - if el.timeout > 0 { - el.executeJobWithTimeout(job) - } else { - // Execute without timeout - result := executeJobSandboxed(el.state, job) - job.Result <- result - } - - case <-el.quit: - // Quit signal received, exit - return - } - } -} - -// executeJobWithTimeout executes a job with a timeout -func (el *EventLoop) executeJobWithTimeout(j job) { - // Create a context with timeout - ctx, cancel := context.WithTimeout(context.Background(), el.timeout) - defer cancel() - - // Create a channel for the result - resultCh := make(chan JobResult, 1) - - // Execute the job in a separate goroutine - go func() { - result := executeJobSandboxed(el.state, j) - select { - case resultCh <- result: - // Result sent successfully - case <-ctx.Done(): - // Context canceled, result no longer needed - } - }() - - // Wait for result or timeout - select { - case result := <-resultCh: - // Send result to the original channel - j.Result <- result - case <-ctx.Done(): - // Timeout occurred - j.Result <- JobResult{nil, ErrExecutionTimeout} - // NOTE: The Lua execution continues in the background until it completes, - // but the result is discarded. This is a compromise to avoid forcibly - // terminating Lua code which could corrupt the state. - } -} - -// Submit sends a job to the event loop -func (el *EventLoop) Submit(bytecode []byte, execCtx *Context) (any, error) { - return el.SubmitWithContext(context.Background(), bytecode, execCtx) -} - -// SubmitWithTimeout sends a job to the event loop with a specific timeout -func (el *EventLoop) SubmitWithTimeout(bytecode []byte, execCtx *Context, timeout time.Duration) (any, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return el.SubmitWithContext(ctx, bytecode, execCtx) -} - -// SubmitWithContext sends a job to the event loop with context for cancellation -func (el *EventLoop) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) { - if !el.isRunning.Load() { - return nil, ErrLoopClosed - } - - resultChan := make(chan JobResult, 1) - j := job{ - Bytecode: bytecode, - Context: execCtx, - Result: resultChan, - } - - // Submit job with context - select { - case el.jobQueue <- j: - // Job submitted - case <-ctx.Done(): - return nil, ctx.Err() - } - - // Wait for result with context - select { - case result := <-resultChan: - return result.Value, result.Error - case <-ctx.Done(): - // Context canceled, but the job might still be processed - return nil, ctx.Err() - } -} - -// SetTimeout updates the default timeout for script execution -func (el *EventLoop) SetTimeout(timeout time.Duration) { - el.timeout = timeout -} - -// Shutdown gracefully shuts down the event loop -func (el *EventLoop) Shutdown() error { - if !el.isRunning.Load() { - return ErrLoopClosed - } - el.isRunning.Store(false) - - // Signal event loop to quit - close(el.quit) - - // Wait for event loop to finish - el.wg.Wait() - - // Close job queue - close(el.jobQueue) - - return nil -} - -// setupSandbox initializes the sandbox environment in the Lua state -func setupSandbox(state *luajit.State) error { - // This is the Lua script that creates our sandbox function - setupScript := ` - -- Create a function to run code in a sandbox environment - function __create_sandbox() - -- Create new environment table - local env = {} - - -- Add standard library modules (can be restricted as needed) - env.string = string - env.table = table - env.math = math - env.os = { - time = os.time, - date = os.date, - difftime = os.difftime, - clock = os.clock - } - env.tonumber = tonumber - env.tostring = tostring - env.type = type - env.pairs = pairs - env.ipairs = ipairs - env.next = next - env.select = select - env.unpack = unpack - env.pcall = pcall - env.xpcall = xpcall - env.error = error - env.assert = assert - - -- Allow access to package.loaded for modules - env.require = function(name) - return package.loaded[name] - end - - -- Create metatable to restrict access to _G - local mt = { - __index = function(t, k) - -- First check in env table - local v = rawget(env, k) - if v ~= nil then return v end - - -- If not found, check for registered modules/functions - local moduleValue = _G[k] - if type(moduleValue) == "table" or - type(moduleValue) == "function" then - return moduleValue - end - - return nil - end, - __newindex = function(t, k, v) - rawset(env, k, v) - end - } - - setmetatable(env, mt) - return env - end - - -- Create function to execute code with a sandbox - function __run_sandboxed(f, ctx) - local env = __create_sandbox() - - -- Add context to the environment if provided - if ctx then - env.ctx = ctx - end - - -- Set the environment and run the function - setfenv(f, env) - return f() - end - ` - - return state.DoString(setupScript) -} - -// executeJobSandboxed runs a script in a sandbox environment -func executeJobSandboxed(state *luajit.State, j job) JobResult { - // Set up context if provided - if j.Context != nil { - // Push context table - state.NewTable() - - // Add values to context table - for key, value := range j.Context.Values { - // Push key - state.PushString(key) - - // Push value - if err := state.PushValue(value); err != nil { - state.Pop(1) // Pop table - return JobResult{nil, err} - } - - // Set table[key] = value - state.SetTable(-3) - } - } else { - // Push nil if no context - state.PushNil() - } - - // Load bytecode - if err := state.LoadBytecode(j.Bytecode, "script"); err != nil { - state.Pop(1) // Pop context - return JobResult{nil, err} - } - - // Get the sandbox runner function - state.GetGlobal("__run_sandboxed") - - // Push loaded function and context as arguments - state.PushCopy(-2) // Copy the loaded function - state.PushCopy(-4) // Copy the context table or nil - - // Remove the original function and context - state.Remove(-5) // Remove original context - state.Remove(-4) // Remove original function - - // Call the sandbox runner with 2 args (function and context), expecting 1 result - if err := state.Call(2, 1); err != nil { - return JobResult{nil, err} - } - - // Get result - value, err := state.ToValue(-1) - state.Pop(1) // Pop result - - return JobResult{value, err} -} diff --git a/core/workers/eventloop_test.go b/core/workers/eventloop_test.go deleted file mode 100644 index ef1dac1..0000000 --- a/core/workers/eventloop_test.go +++ /dev/null @@ -1,741 +0,0 @@ -package workers - -import ( - "context" - "sync" - "testing" - "time" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// Helper function to create bytecode for testing -func createTestBytecode(t *testing.T, code string) []byte { - state := luajit.New() - if state == nil { - t.Fatal("Failed to create Lua state") - } - defer state.Close() - - bytecode, err := state.CompileBytecode(code, "test") - if err != nil { - t.Fatalf("Failed to compile test bytecode: %v", err) - } - - return bytecode -} - -// Test creating a new event loop with default and custom configs -func TestNewEventLoop(t *testing.T) { - tests := []struct { - name string - config EventLoopConfig - expectError bool - }{ - { - name: "Default config", - config: EventLoopConfig{}, - expectError: false, - }, - { - name: "Custom buffer size", - config: EventLoopConfig{ - BufferSize: 200, - }, - expectError: false, - }, - { - name: "Custom timeout", - config: EventLoopConfig{ - Timeout: 5 * time.Second, - }, - expectError: false, - }, - { - name: "With init function", - config: EventLoopConfig{ - StateInit: func(state *luajit.State) error { - return nil - }, - }, - expectError: false, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - el, err := NewEventLoopWithConfig(tc.config) - if tc.expectError { - if err == nil { - t.Errorf("Expected error but got nil") - } - } else { - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if el == nil { - t.Errorf("Expected non-nil event loop") - } else { - el.Shutdown() - } - } - }) - } -} - -// Test basic job submission and execution -func TestEventLoopBasicSubmission(t *testing.T) { - el, err := NewEventLoop() - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - // Simple return a value - bytecode := createTestBytecode(t, "return 42") - - result, err := el.Submit(bytecode, nil) - if err != nil { - t.Fatalf("Failed to submit job: %v", err) - } - - num, ok := result.(float64) - if !ok { - t.Fatalf("Expected float64 result, got %T", result) - } - - if num != 42 { - t.Errorf("Expected 42, got %f", num) - } - - // Test more complex Lua code - bytecode = createTestBytecode(t, ` - local result = 0 - for i = 1, 10 do - result = result + i - end - return result - `) - - result, err = el.Submit(bytecode, nil) - if err != nil { - t.Fatalf("Failed to submit job: %v", err) - } - - num, ok = result.(float64) - if !ok { - t.Fatalf("Expected float64 result, got %T", result) - } - - if num != 55 { - t.Errorf("Expected 55, got %f", num) - } -} - -// Test context passing between Go and Lua -func TestEventLoopContext(t *testing.T) { - el, err := NewEventLoop() - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - bytecode := createTestBytecode(t, ` - return { - num = ctx.number, - str = ctx.text, - flag = ctx.enabled, - list = {ctx.items[1], ctx.items[2], ctx.items[3]}, - } - `) - - execCtx := NewContext() - execCtx.Set("number", 42.5) - execCtx.Set("text", "hello") - execCtx.Set("enabled", true) - execCtx.Set("items", []float64{10, 20, 30}) - - result, err := el.Submit(bytecode, execCtx) - if err != nil { - t.Fatalf("Failed to submit job: %v", err) - } - - // Result should be a map - resultMap, ok := result.(map[string]any) - if !ok { - t.Fatalf("Expected map result, got %T", result) - } - - // Check values - if resultMap["num"] != 42.5 { - t.Errorf("Expected num=42.5, got %v", resultMap["num"]) - } - if resultMap["str"] != "hello" { - t.Errorf("Expected str=hello, got %v", resultMap["str"]) - } - if resultMap["flag"] != true { - t.Errorf("Expected flag=true, got %v", resultMap["flag"]) - } - - arr, ok := resultMap["list"].([]float64) - if !ok { - t.Fatalf("Expected []float64, got %T", resultMap["list"]) - } - - expected := []float64{10, 20, 30} - for i, v := range expected { - if arr[i] != v { - t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i]) - } - } - - // Test complex nested context - nestedCtx := NewContext() - nestedCtx.Set("user", map[string]any{ - "id": 123, - "name": "test user", - "roles": []any{ - "admin", - "editor", - }, - }) - - bytecode = createTestBytecode(t, ` - return { - id = ctx.user.id, - name = ctx.user.name, - role1 = ctx.user.roles[1], - role2 = ctx.user.roles[2], - } - `) - - result, err = el.Submit(bytecode, nestedCtx) - if err != nil { - t.Fatalf("Failed to submit job with nested context: %v", err) - } - - resultMap, ok = result.(map[string]any) - if !ok { - t.Fatalf("Expected map result, got %T", result) - } - - if resultMap["id"] != float64(123) { - t.Errorf("Expected id=123, got %v", resultMap["id"]) - } - if resultMap["name"] != "test user" { - t.Errorf("Expected name='test user', got %v", resultMap["name"]) - } - if resultMap["role1"] != "admin" { - t.Errorf("Expected role1='admin', got %v", resultMap["role1"]) - } - if resultMap["role2"] != "editor" { - t.Errorf("Expected role2='editor', got %v", resultMap["role2"]) - } -} - -// Test execution timeout -func TestEventLoopTimeout(t *testing.T) { - // Create event loop with short timeout - el, err := NewEventLoopWithConfig(EventLoopConfig{ - Timeout: 100 * time.Millisecond, - }) - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - // Create bytecode that runs for longer than the timeout - bytecode := createTestBytecode(t, ` - -- Loop for 500ms - local start = os.time() - while os.difftime(os.time(), start) < 0.5 do end - return "done" - `) - - // This should time out - _, err = el.Submit(bytecode, nil) - if err != ErrExecutionTimeout { - t.Errorf("Expected timeout error, got: %v", err) - } - - // Now set a longer timeout and try again - el.SetTimeout(1 * time.Second) - - // This should succeed - result, err := el.Submit(bytecode, nil) - if err != nil { - t.Fatalf("Expected success with longer timeout, got: %v", err) - } - if result != "done" { - t.Errorf("Expected 'done', got %v", result) - } - - // Test per-call timeout with SubmitWithTimeout - bytecode = createTestBytecode(t, ` - -- Loop for 300ms - local start = os.time() - while os.difftime(os.time(), start) < 0.3 do end - return "done again" - `) - - // This should time out with a custom timeout - _, err = el.SubmitWithTimeout(bytecode, nil, 50*time.Millisecond) - if err == nil { - t.Errorf("Expected timeout error, got success") - } - - // This should succeed with a longer custom timeout - result, err = el.SubmitWithTimeout(bytecode, nil, 500*time.Millisecond) - if err != nil { - t.Fatalf("Expected success with custom timeout, got: %v", err) - } - if result != "done again" { - t.Errorf("Expected 'done again', got %v", result) - } -} - -// Test module registration and execution -func TestEventLoopModules(t *testing.T) { - // Define an init function that registers a simple "math" module - mathInit := func(state *luajit.State) error { - // Register the "add" function directly - err := state.RegisterGoFunction("add", func(s *luajit.State) int { - a := s.ToNumber(1) - b := s.ToNumber(2) - s.PushNumber(a + b) - return 1 // Return one result - }) - - if err != nil { - return err - } - - // Register a math module with multiple functions - mathFuncs := map[string]luajit.GoFunction{ - "multiply": func(s *luajit.State) int { - a := s.ToNumber(1) - b := s.ToNumber(2) - s.PushNumber(a * b) - return 1 - }, - "subtract": func(s *luajit.State) int { - a := s.ToNumber(1) - b := s.ToNumber(2) - s.PushNumber(a - b) - return 1 - }, - } - - return RegisterModule(state, "math2", mathFuncs) - } - - // Create an event loop with our init function - el, err := NewEventLoopWithInit(mathInit) - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - // Test the add function - bytecode1 := createTestBytecode(t, "return add(5, 7)") - result1, err := el.Submit(bytecode1, nil) - if err != nil { - t.Fatalf("Failed to call add function: %v", err) - } - - num1, ok := result1.(float64) - if !ok || num1 != 12 { - t.Errorf("Expected add(5, 7) = 12, got %v", result1) - } - - // Test the math2 module - bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)") - result2, err := el.Submit(bytecode2, nil) - if err != nil { - t.Fatalf("Failed to call math2.multiply: %v", err) - } - - num2, ok := result2.(float64) - if !ok || num2 != 48 { - t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2) - } - - // Test multiple operations - bytecode3 := createTestBytecode(t, ` - local a = add(10, 20) - local b = math2.subtract(a, 5) - return math2.multiply(b, 2) - `) - - result3, err := el.Submit(bytecode3, nil) - if err != nil { - t.Fatalf("Failed to execute combined operations: %v", err) - } - - num3, ok := result3.(float64) - if !ok || num3 != 50 { - t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3) - } -} - -// Test combined module init functions -func TestEventLoopCombinedModules(t *testing.T) { - // First init function adds a function to get a constant value - init1 := func(state *luajit.State) error { - return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int { - s.PushNumber(42) - return 1 - }) - } - - // Second init function registers a function that multiplies a number by 2 - init2 := func(state *luajit.State) error { - return state.RegisterGoFunction("double", func(s *luajit.State) int { - n := s.ToNumber(1) - s.PushNumber(n * 2) - return 1 - }) - } - - // Combine the init functions - combinedInit := CombineInitFuncs(init1, init2) - - // Create an event loop with the combined init function - el, err := NewEventLoopWithInit(combinedInit) - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - // Test using both functions together in a single script - bytecode := createTestBytecode(t, "return double(getAnswer())") - result, err := el.Submit(bytecode, nil) - if err != nil { - t.Fatalf("Failed to execute: %v", err) - } - - num, ok := result.(float64) - if !ok || num != 84 { - t.Errorf("Expected double(getAnswer()) = 84, got %v", result) - } -} - -// Test sandbox isolation between executions -func TestEventLoopSandboxIsolation(t *testing.T) { - el, err := NewEventLoop() - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - // Create a script that tries to modify a global variable - bytecode1 := createTestBytecode(t, ` - -- Set a "global" variable - my_global = "test value" - return true - `) - - _, err = el.Submit(bytecode1, nil) - if err != nil { - t.Fatalf("Failed to execute first script: %v", err) - } - - // Now try to access that variable from another script - bytecode2 := createTestBytecode(t, ` - -- Try to access the previously set global - return my_global ~= nil - `) - - result, err := el.Submit(bytecode2, nil) - if err != nil { - t.Fatalf("Failed to execute second script: %v", err) - } - - // The variable should not be accessible (sandbox isolation) - if result.(bool) { - t.Errorf("Expected sandbox isolation, but global variable was accessible") - } -} - -// Test error handling -func TestEventLoopErrorHandling(t *testing.T) { - el, err := NewEventLoop() - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - // Test invalid bytecode - _, err = el.Submit([]byte("not valid bytecode"), nil) - if err == nil { - t.Errorf("Expected error for invalid bytecode, got nil") - } - - // Test Lua runtime error - bytecode := createTestBytecode(t, ` - error("intentional error") - return true - `) - - _, err = el.Submit(bytecode, nil) - if err == nil { - t.Errorf("Expected error from Lua error() call, got nil") - } - - // Test with nil context (should work fine) - bytecode = createTestBytecode(t, "return ctx == nil") - result, err := el.Submit(bytecode, nil) - if err != nil { - t.Errorf("Unexpected error with nil context: %v", err) - } - if result.(bool) != true { - t.Errorf("Expected ctx to be nil in Lua, but it wasn't") - } - - // Test access to restricted library - bytecode = createTestBytecode(t, ` - -- Try to access io library directly - return io ~= nil - `) - - result, err = el.Submit(bytecode, nil) - if err != nil { - t.Fatalf("Failed to execute sandbox test: %v", err) - } - - // io should not be directly accessible - if result.(bool) { - t.Errorf("Expected io library to be restricted, but it was accessible") - } -} - -// Test concurrent job submission -func TestEventLoopConcurrency(t *testing.T) { - el, err := NewEventLoopWithConfig(EventLoopConfig{ - BufferSize: 100, // Buffer for concurrent submissions - Timeout: 5 * time.Second, - }) - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - // Create bytecode that returns its input value - bytecode := createTestBytecode(t, "return ctx.n") - - // Submit multiple jobs concurrently - const jobCount = 50 - var wg sync.WaitGroup - results := make([]int, jobCount) - - wg.Add(jobCount) - for i := 0; i < jobCount; i++ { - i := i // Capture loop variable - go func() { - defer wg.Done() - - // Create context with job number - ctx := NewContext() - ctx.Set("n", float64(i)) - - // Submit job - result, err := el.Submit(bytecode, ctx) - if err != nil { - t.Errorf("Job %d failed: %v", i, err) - return - } - - // Verify result matches job number - num, ok := result.(float64) - if !ok { - t.Errorf("Job %d: expected float64, got %T", i, result) - return - } - - results[i] = int(num) - }() - } - - wg.Wait() - - // Verify all results - for i, res := range results { - if res != i && res != 0 { // 0 means error already logged - t.Errorf("Expected result[%d] = %d, got %d", i, i, res) - } - } -} - -// Test state consistency across multiple calls -func TestEventLoopStateConsistency(t *testing.T) { - // Create an event loop with a module that maintains count between calls - initFunc := func(state *luajit.State) error { - // Create a closure that increments a counter in upvalue - code := ` - -- Create a counter with initial value 0 - local counter = 0 - - -- Create a function that returns and increments the counter - function get_next_count() - local current = counter - counter = counter + 1 - return current - end - ` - return state.DoString(code) - } - - el, err := NewEventLoopWithInit(initFunc) - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - // Now run multiple scripts that call the counter function - bytecode := createTestBytecode(t, "return get_next_count()") - - // Each call should return an incremented value - for i := 0; i < 5; i++ { - result, err := el.Submit(bytecode, nil) - if err != nil { - t.Fatalf("Call %d failed: %v", i, err) - } - - num, ok := result.(float64) - if !ok { - t.Fatalf("Expected float64 result, got %T", result) - } - - if int(num) != i { - t.Errorf("Expected count %d, got %d", i, int(num)) - } - } -} - -// Test shutdown and cleanup -func TestEventLoopShutdown(t *testing.T) { - el, err := NewEventLoop() - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - - // Submit a job to verify it works - bytecode := createTestBytecode(t, "return 42") - _, err = el.Submit(bytecode, nil) - if err != nil { - t.Fatalf("Failed to submit job: %v", err) - } - - // Shutdown - if err := el.Shutdown(); err != nil { - t.Errorf("Shutdown failed: %v", err) - } - - // Submit after shutdown should fail - _, err = el.Submit(bytecode, nil) - if err != ErrLoopClosed { - t.Errorf("Expected ErrLoopClosed, got %v", err) - } - - // Second shutdown should return error - if err := el.Shutdown(); err != ErrLoopClosed { - t.Errorf("Expected ErrLoopClosed on second shutdown, got %v", err) - } -} - -// Test high load with multiple sequential and concurrent jobs -func TestEventLoopHighLoad(t *testing.T) { - el, err := NewEventLoopWithConfig(EventLoopConfig{ - BufferSize: 1000, // Large buffer for high load - Timeout: 5 * time.Second, - }) - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - // Sequential load test - bytecode := createTestBytecode(t, ` - -- Do some work - local result = 0 - for i = 1, 1000 do - result = result + i - end - return result - `) - - start := time.Now() - for i := 0; i < 100; i++ { - _, err := el.Submit(bytecode, nil) - if err != nil { - t.Fatalf("Sequential job %d failed: %v", i, err) - } - } - seqDuration := time.Since(start) - t.Logf("Sequential load test: 100 jobs in %v", seqDuration) - - // Concurrent load test - start = time.Now() - var wg sync.WaitGroup - wg.Add(100) - for i := 0; i < 100; i++ { - go func() { - defer wg.Done() - _, err := el.Submit(bytecode, nil) - if err != nil { - t.Errorf("Concurrent job failed: %v", err) - } - }() - } - wg.Wait() - concDuration := time.Since(start) - t.Logf("Concurrent load test: 100 jobs in %v", concDuration) -} - -// Test context cancellation -func TestEventLoopCancel(t *testing.T) { - el, err := NewEventLoop() - if err != nil { - t.Fatalf("Failed to create event loop: %v", err) - } - defer el.Shutdown() - - // Create a long-running script - bytecode := createTestBytecode(t, ` - -- Sleep for 500ms - local start = os.time() - while os.difftime(os.time(), start) < 0.5 do end - return "done" - `) - - // Create a context that we can cancel - ctx, cancel := context.WithCancel(context.Background()) - - // Start execution in a goroutine - resultCh := make(chan any, 1) - errCh := make(chan error, 1) - go func() { - res, err := el.SubmitWithContext(ctx, bytecode, nil) - if err != nil { - errCh <- err - } else { - resultCh <- res - } - }() - - // Cancel quickly - time.Sleep(50 * time.Millisecond) - cancel() - - // Should get cancellation error - select { - case err := <-errCh: - if ctx.Err() == nil || err == nil { - t.Errorf("Expected context cancellation error") - } - case res := <-resultCh: - t.Errorf("Expected cancellation, got result: %v", res) - case <-time.After(1 * time.Second): - t.Errorf("Timed out waiting for cancellation") - } -} diff --git a/core/workers/init_test.go b/core/workers/init_test.go new file mode 100644 index 0000000..ea0557d --- /dev/null +++ b/core/workers/init_test.go @@ -0,0 +1,346 @@ +package workers + +import ( + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +func TestModuleRegistration(t *testing.T) { + // Define an init function that registers a simple "math" module + mathInit := func(state *luajit.State) error { + // Register the "add" function + err := state.RegisterGoFunction("add", func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a + b) + return 1 // Return one result + }) + + if err != nil { + return err + } + + // Register a whole module + mathFuncs := map[string]luajit.GoFunction{ + "multiply": func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a * b) + return 1 + }, + "subtract": func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a - b) + return 1 + }, + } + + return RegisterModule(state, "math2", mathFuncs) + } + + // Create a pool with our init function + pool, err := NewPoolWithInit(2, mathInit) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test the add function + bytecode1 := createTestBytecode(t, "return add(5, 7)") + result1, err := pool.Submit(bytecode1, nil) + if err != nil { + t.Fatalf("Failed to call add function: %v", err) + } + + num1, ok := result1.(float64) + if !ok || num1 != 12 { + t.Errorf("Expected add(5, 7) = 12, got %v", result1) + } + + // Test the math2 module + bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)") + result2, err := pool.Submit(bytecode2, nil) + if err != nil { + t.Fatalf("Failed to call math2.multiply: %v", err) + } + + num2, ok := result2.(float64) + if !ok || num2 != 48 { + t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2) + } + + // Test multiple operations + bytecode3 := createTestBytecode(t, ` + local a = add(10, 20) + local b = math2.subtract(a, 5) + return math2.multiply(b, 2) + `) + + result3, err := pool.Submit(bytecode3, nil) + if err != nil { + t.Fatalf("Failed to execute combined operations: %v", err) + } + + num3, ok := result3.(float64) + if !ok || num3 != 50 { + t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3) + } +} + +func TestModuleInitFunc(t *testing.T) { + // Define math module functions + mathModule := func() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "add": func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a + b) + return 1 + }, + "multiply": func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a * b) + return 1 + }, + } + } + + // Define string module functions + strModule := func() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "concat": func(s *luajit.State) int { + a := s.ToString(1) + b := s.ToString(2) + s.PushString(a + b) + return 1 + }, + } + } + + // Create module map + modules := map[string]ModuleFunc{ + "math2": mathModule, + "str": strModule, + } + + // Create pool with module init + pool, err := NewPoolWithInit(2, ModuleInitFunc(modules)) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test math module + bytecode1 := createTestBytecode(t, "return math2.add(5, 7)") + result1, err := pool.Submit(bytecode1, nil) + if err != nil { + t.Fatalf("Failed to call math2.add: %v", err) + } + + num1, ok := result1.(float64) + if !ok || num1 != 12 { + t.Errorf("Expected math2.add(5, 7) = 12, got %v", result1) + } + + // Test string module + bytecode2 := createTestBytecode(t, "return str.concat('hello', 'world')") + result2, err := pool.Submit(bytecode2, nil) + if err != nil { + t.Fatalf("Failed to call str.concat: %v", err) + } + + str2, ok := result2.(string) + if !ok || str2 != "helloworld" { + t.Errorf("Expected str.concat('hello', 'world') = 'helloworld', got %v", result2) + } +} + +func TestCombineInitFuncs(t *testing.T) { + // First init function adds a function to get a constant value + init1 := func(state *luajit.State) error { + return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int { + s.PushNumber(42) + return 1 + }) + } + + // Second init function registers a function that multiplies a number by 2 + init2 := func(state *luajit.State) error { + return state.RegisterGoFunction("double", func(s *luajit.State) int { + n := s.ToNumber(1) + s.PushNumber(n * 2) + return 1 + }) + } + + // Combine the init functions + combinedInit := CombineInitFuncs(init1, init2) + + // Create a pool with the combined init function + pool, err := NewPoolWithInit(1, combinedInit) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test using both functions together in a single script + bytecode := createTestBytecode(t, "return double(getAnswer())") + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to execute: %v", err) + } + + num, ok := result.(float64) + if !ok || num != 84 { + t.Errorf("Expected double(getAnswer()) = 84, got %v", result) + } +} + +func TestSandboxIsolation(t *testing.T) { + // Create a pool + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Create a script that tries to modify a global variable + bytecode1 := createTestBytecode(t, ` + -- Set a "global" variable + my_global = "test value" + return true + `) + + _, err = pool.Submit(bytecode1, nil) + if err != nil { + t.Fatalf("Failed to execute first script: %v", err) + } + + // Now try to access that variable from another script + bytecode2 := createTestBytecode(t, ` + -- Try to access the previously set global + return my_global ~= nil + `) + + result, err := pool.Submit(bytecode2, nil) + if err != nil { + t.Fatalf("Failed to execute second script: %v", err) + } + + // The variable should not be accessible (sandbox isolation) + if result.(bool) { + t.Errorf("Expected sandbox isolation, but global variable was accessible") + } +} + +func TestContextInSandbox(t *testing.T) { + // Create a pool + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Create a context with test data + ctx := NewContext() + ctx.Set("name", "test") + ctx.Set("value", 42.5) + ctx.Set("items", []float64{1, 2, 3}) + + bytecode := createTestBytecode(t, ` + -- Access and manipulate context values + local sum = 0 + for i, v in ipairs(ctx.items) do + sum = sum + v + end + + return { + name_length = string.len(ctx.name), + value_doubled = ctx.value * 2, + items_sum = sum + } + `) + + result, err := pool.Submit(bytecode, ctx) + if err != nil { + t.Fatalf("Failed to execute script with context: %v", err) + } + + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + // Check context values were correctly accessible + if resultMap["name_length"].(float64) != 4 { + t.Errorf("Expected name_length = 4, got %v", resultMap["name_length"]) + } + + if resultMap["value_doubled"].(float64) != 85 { + t.Errorf("Expected value_doubled = 85, got %v", resultMap["value_doubled"]) + } + + if resultMap["items_sum"].(float64) != 6 { + t.Errorf("Expected items_sum = 6, got %v", resultMap["items_sum"]) + } +} + +func TestStandardLibsInSandbox(t *testing.T) { + // Create a pool + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test access to standard libraries + bytecode := createTestBytecode(t, ` + local results = {} + + -- Test string library + results.string_upper = string.upper("test") + + -- Test math library + results.math_sqrt = math.sqrt(16) + + -- Test table library + local tbl = {10, 20, 30} + table.insert(tbl, 40) + results.table_length = #tbl + + -- Test os library (limited functions) + results.has_os_time = type(os.time) == "function" + + return results + `) + + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to execute script: %v", err) + } + + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + // Check standard library functions worked + if resultMap["string_upper"] != "TEST" { + t.Errorf("Expected string_upper = 'TEST', got %v", resultMap["string_upper"]) + } + + if resultMap["math_sqrt"].(float64) != 4 { + t.Errorf("Expected math_sqrt = 4, got %v", resultMap["math_sqrt"]) + } + + if resultMap["table_length"].(float64) != 4 { + t.Errorf("Expected table_length = 4, got %v", resultMap["table_length"]) + } + + if resultMap["has_os_time"] != true { + t.Errorf("Expected has_os_time = true, got %v", resultMap["has_os_time"]) + } +} diff --git a/core/workers/pool.go b/core/workers/pool.go new file mode 100644 index 0000000..e6b85f5 --- /dev/null +++ b/core/workers/pool.go @@ -0,0 +1,123 @@ +package workers + +import ( + "context" + "sync" + "sync/atomic" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// StateInitFunc is a function that initializes a Lua state +// It can be used to register custom functions and modules +type StateInitFunc func(*luajit.State) error + +// Pool manages a pool of Lua worker goroutines +type Pool struct { + workers uint32 // Number of workers + jobs chan job // Channel to send jobs to workers + wg sync.WaitGroup // WaitGroup to track active workers + quit chan struct{} // Channel to signal shutdown + isRunning atomic.Bool // Flag to track if pool is running + stateInit StateInitFunc // Optional function to initialize Lua state +} + +// NewPool creates a new worker pool with the specified number of workers +func NewPool(numWorkers int) (*Pool, error) { + return NewPoolWithInit(numWorkers, nil) +} + +// NewPoolWithInit creates a new worker pool with the specified number of workers +// and a function to initialize each worker's Lua state +func NewPoolWithInit(numWorkers int, initFunc StateInitFunc) (*Pool, error) { + if numWorkers <= 0 { + return nil, ErrNoWorkers + } + + p := &Pool{ + workers: uint32(numWorkers), + jobs: make(chan job, numWorkers), // Buffer equal to worker count + quit: make(chan struct{}), + stateInit: initFunc, + } + p.isRunning.Store(true) + + // Start workers + p.wg.Add(numWorkers) + for i := 0; i < numWorkers; i++ { + w := &worker{ + pool: p, + id: uint32(i), + } + go w.run() + } + + return p, nil +} + +// RegisterGlobal is no longer needed with the sandbox approach +// but kept as a no-op for backward compatibility +func (p *Pool) RegisterGlobal(name string) { + // No-op in sandbox mode +} + +// SubmitWithContext sends a job to the worker pool with context +func (p *Pool) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) { + if !p.isRunning.Load() { + return nil, ErrPoolClosed + } + + resultChan := make(chan JobResult, 1) + j := job{ + Bytecode: bytecode, + Context: execCtx, + Result: resultChan, + } + + // Submit job with context + select { + case p.jobs <- j: + // Job submitted + case <-ctx.Done(): + return nil, ctx.Err() + } + + // Wait for result with context + select { + case result := <-resultChan: + return result.Value, result.Error + case <-ctx.Done(): + // Note: The job will still be processed by a worker, + // but the result will be discarded + return nil, ctx.Err() + } +} + +// Submit sends a job to the worker pool +func (p *Pool) Submit(bytecode []byte, execCtx *Context) (any, error) { + return p.SubmitWithContext(context.Background(), bytecode, execCtx) +} + +// Shutdown gracefully shuts down the worker pool +func (p *Pool) Shutdown() error { + if !p.isRunning.Load() { + return ErrPoolClosed + } + p.isRunning.Store(false) + + // Signal workers to quit + close(p.quit) + + // Wait for workers to finish + p.wg.Wait() + + // Close jobs channel + close(p.jobs) + + return nil +} + +// ActiveWorkers returns the number of active workers +func (p *Pool) ActiveWorkers() uint32 { + return atomic.LoadUint32(&p.workers) +} diff --git a/core/workers/sandbox.go b/core/workers/sandbox.go new file mode 100644 index 0000000..ddf9930 --- /dev/null +++ b/core/workers/sandbox.go @@ -0,0 +1,144 @@ +package workers + +// setupSandbox initializes the sandbox environment creation function +func (w *worker) setupSandbox() error { + // This is the Lua script that creates our sandbox function + setupScript := ` + -- Create a function to run code in a sandbox environment + function __create_sandbox() + -- Create new environment table + local env = {} + + -- Add standard library modules (can be restricted as needed) + env.string = string + env.table = table + env.math = math + env.os = { + time = os.time, + date = os.date, + difftime = os.difftime, + clock = os.clock + } + env.tonumber = tonumber + env.tostring = tostring + env.type = type + env.pairs = pairs + env.ipairs = ipairs + env.next = next + env.select = select + env.unpack = unpack + env.pcall = pcall + env.xpcall = xpcall + env.error = error + env.assert = assert + + -- Allow access to package.loaded for modules + env.require = function(name) + return package.loaded[name] + end + + -- Create metatable to restrict access to _G + local mt = { + __index = function(t, k) + -- First check in env table + local v = rawget(env, k) + if v ~= nil then return v end + + -- If not found, check for registered modules/functions + local moduleValue = _G[k] + if type(moduleValue) == "table" or + type(moduleValue) == "function" then + return moduleValue + end + + return nil + end, + __newindex = function(t, k, v) + rawset(env, k, v) + end + } + + setmetatable(env, mt) + return env + end + + -- Create function to execute code with a sandbox + function __run_sandboxed(f, ctx) + local env = __create_sandbox() + + -- Add context to the environment if provided + if ctx then + env.ctx = ctx + end + + -- Set the environment and run the function + setfenv(f, env) + return f() + end + ` + + return w.state.DoString(setupScript) +} + +// executeJobSandboxed runs a script in a sandbox environment +func (w *worker) executeJobSandboxed(j job) JobResult { + // No need to reset the state for each execution, since we're using a sandbox + + // Re-run init function to register functions and modules if needed + if w.pool.stateInit != nil { + if err := w.pool.stateInit(w.state); err != nil { + return JobResult{nil, err} + } + } + + // Set up context if provided + if j.Context != nil { + // Push context table + w.state.NewTable() + + // Add values to context table + for key, value := range j.Context.Values { + // Push key + w.state.PushString(key) + + // Push value + if err := w.state.PushValue(value); err != nil { + return JobResult{nil, err} + } + + // Set table[key] = value + w.state.SetTable(-3) + } + } else { + // Push nil if no context + w.state.PushNil() + } + + // Load bytecode + if err := w.state.LoadBytecode(j.Bytecode, "script"); err != nil { + w.state.Pop(1) // Pop context + return JobResult{nil, err} + } + + // Get the sandbox runner function + w.state.GetGlobal("__run_sandboxed") + + // Push loaded function and context as arguments + w.state.PushCopy(-2) // Copy the loaded function + w.state.PushCopy(-4) // Copy the context table or nil + + // Remove the original function and context + w.state.Remove(-5) // Remove original context + w.state.Remove(-4) // Remove original function + + // Call the sandbox runner with 2 args (function and context), expecting 1 result + if err := w.state.Call(2, 1); err != nil { + return JobResult{nil, err} + } + + // Get result + value, err := w.state.ToValue(-1) + w.state.Pop(1) // Pop result + + return JobResult{value, err} +} diff --git a/core/workers/worker.go b/core/workers/worker.go new file mode 100644 index 0000000..e83ce0c --- /dev/null +++ b/core/workers/worker.go @@ -0,0 +1,71 @@ +package workers + +import ( + "errors" + "sync/atomic" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// Common errors +var ( + ErrPoolClosed = errors.New("worker pool is closed") + ErrNoWorkers = errors.New("no workers available") + ErrInitFailed = errors.New("worker initialization failed") +) + +// worker represents a single Lua execution worker +type worker struct { + pool *Pool // Reference to the pool + state *luajit.State // Lua state + id uint32 // Worker ID +} + +// run is the main worker function that processes jobs +func (w *worker) run() { + defer w.pool.wg.Done() + + // Initialize Lua state + w.state = luajit.New() + if w.state == nil { + // Worker failed to initialize, decrement counter + atomic.AddUint32(&w.pool.workers, ^uint32(0)) + return + } + defer w.state.Close() + + // Set up sandbox environment + if err := w.setupSandbox(); err != nil { + // Worker failed to initialize sandbox, decrement counter + atomic.AddUint32(&w.pool.workers, ^uint32(0)) + return + } + + // Run init function if provided + if w.pool.stateInit != nil { + if err := w.pool.stateInit(w.state); err != nil { + // Worker failed to initialize with custom init function + atomic.AddUint32(&w.pool.workers, ^uint32(0)) + return + } + } + + // Main worker loop + for { + select { + case job, ok := <-w.pool.jobs: + if !ok { + // Jobs channel closed, exit + return + } + + // Execute job + result := w.executeJobSandboxed(job) + job.Result <- result + + case <-w.pool.quit: + // Quit signal received, exit + return + } + } +} diff --git a/core/workers/workers_test.go b/core/workers/workers_test.go new file mode 100644 index 0000000..23dd712 --- /dev/null +++ b/core/workers/workers_test.go @@ -0,0 +1,445 @@ +package workers + +import ( + "context" + "testing" + "time" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// This helper function creates real LuaJIT bytecode for our tests. Instead of using +// mocks, we compile actual Lua code into bytecode just like we would in production. +func createTestBytecode(t *testing.T, code string) []byte { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + bytecode, err := state.CompileBytecode(code, "test") + if err != nil { + t.Fatalf("Failed to compile test bytecode: %v", err) + } + + return bytecode +} + +// This test makes sure we can create a worker pool with a valid number of workers, +// and that we properly reject attempts to create a pool with zero or negative workers. +func TestNewPool(t *testing.T) { + tests := []struct { + name string + workers int + expectErr bool + }{ + {"valid workers", 4, false}, + {"zero workers", 0, true}, + {"negative workers", -1, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool, err := NewPool(tt.workers) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error for %d workers, got nil", tt.workers) + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if pool == nil { + t.Errorf("Expected non-nil pool") + } else { + pool.Shutdown() + } + } + }) + } +} + +// Here we're testing the basic job submission flow. We run a simple Lua script +// that returns the number 42 and make sure we get that same value back from the worker pool. +func TestPoolSubmit(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + bytecode := createTestBytecode(t, "return 42") + + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + num, ok := result.(float64) + if !ok { + t.Fatalf("Expected float64 result, got %T", result) + } + + if num != 42 { + t.Errorf("Expected 42, got %f", num) + } +} + +// This test checks how our worker pool handles timeouts. We run a script that takes +// some time to complete and verify two scenarios: one where the timeout is long enough +// for successful completion, and another where we expect the operation to be canceled +// due to a short timeout. +func TestPoolSubmitWithContext(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Create bytecode that sleeps + bytecode := createTestBytecode(t, ` + -- Sleep for 500ms + local start = os.time() + while os.difftime(os.time(), start) < 0.5 do end + return "done" + `) + + // Test with timeout that should succeed + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + result, err := pool.SubmitWithContext(ctx, bytecode, nil) + if err != nil { + t.Fatalf("Unexpected error with sufficient timeout: %v", err) + } + if result != "done" { + t.Errorf("Expected 'done', got %v", result) + } + + // Test with timeout that should fail + ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pool.SubmitWithContext(ctx, bytecode, nil) + if err == nil { + t.Errorf("Expected timeout error, got nil") + } +} + +// We need to make sure we can pass different types of context values from Go to Lua and +// get them back properly. This test sends numbers, strings, booleans, and arrays to +// a Lua script and verifies they're all handled correctly in both directions. +func TestContextValues(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + bytecode := createTestBytecode(t, ` + return { + num = ctx.number, + str = ctx.text, + flag = ctx.enabled, + list = {ctx.table[1], ctx.table[2], ctx.table[3]}, + } + `) + + execCtx := NewContext() + execCtx.Set("number", 42.5) + execCtx.Set("text", "hello") + execCtx.Set("enabled", true) + execCtx.Set("table", []float64{10, 20, 30}) + + result, err := pool.Submit(bytecode, execCtx) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + // Result should be a map + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + // Check values + if resultMap["num"] != 42.5 { + t.Errorf("Expected num=42.5, got %v", resultMap["num"]) + } + if resultMap["str"] != "hello" { + t.Errorf("Expected str=hello, got %v", resultMap["str"]) + } + if resultMap["flag"] != true { + t.Errorf("Expected flag=true, got %v", resultMap["flag"]) + } + + arr, ok := resultMap["list"].([]float64) + if !ok { + t.Fatalf("Expected []float64, got %T", resultMap["list"]) + } + + expected := []float64{10, 20, 30} + for i, v := range expected { + if arr[i] != v { + t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i]) + } + } +} + +// Test context with nested data structures +func TestNestedContext(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + bytecode := createTestBytecode(t, ` + return { + id = ctx.params.id, + name = ctx.params.name, + method = ctx.request.method, + path = ctx.request.path + } + `) + + execCtx := NewContext() + + // Set nested params + params := map[string]any{ + "id": "123", + "name": "test", + } + execCtx.Set("params", params) + + // Set nested request info + request := map[string]any{ + "method": "GET", + "path": "/api/test", + } + execCtx.Set("request", request) + + result, err := pool.Submit(bytecode, execCtx) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + // Result should be a map + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + if resultMap["id"] != "123" { + t.Errorf("Expected id=123, got %v", resultMap["id"]) + } + if resultMap["name"] != "test" { + t.Errorf("Expected name=test, got %v", resultMap["name"]) + } + if resultMap["method"] != "GET" { + t.Errorf("Expected method=GET, got %v", resultMap["method"]) + } + if resultMap["path"] != "/api/test" { + t.Errorf("Expected path=/api/test, got %v", resultMap["path"]) + } +} + +// A key requirement for our worker pool is that we don't leak state between executions. +// This test confirms that by setting a global variable in one job and then checking +// that it's been cleared before the next job runs on the same worker. +func TestStateReset(t *testing.T) { + pool, err := NewPool(1) // Use 1 worker to ensure same state is reused + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // First job sets a global + bytecode1 := createTestBytecode(t, ` + global_var = "should be cleared" + return true + `) + + // Second job checks if global exists + bytecode2 := createTestBytecode(t, ` + return global_var ~= nil + `) + + // Run first job + _, err = pool.Submit(bytecode1, nil) + if err != nil { + t.Fatalf("Failed to submit first job: %v", err) + } + + // Run second job + result, err := pool.Submit(bytecode2, nil) + if err != nil { + t.Fatalf("Failed to submit second job: %v", err) + } + + // Global should be cleared + if result.(bool) { + t.Errorf("Expected global_var to be cleared, but it still exists") + } +} + +// Let's make sure our pool shuts down cleanly. This test confirms that jobs work +// before shutdown, that we get the right error when trying to submit after shutdown, +// and that we properly handle attempts to shut down an already closed pool. +func TestPoolShutdown(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + + // Submit a job to verify pool works + bytecode := createTestBytecode(t, "return 42") + _, err = pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + // Shutdown + if err := pool.Shutdown(); err != nil { + t.Errorf("Shutdown failed: %v", err) + } + + // Submit after shutdown should fail + _, err = pool.Submit(bytecode, nil) + if err != ErrPoolClosed { + t.Errorf("Expected ErrPoolClosed, got %v", err) + } + + // Second shutdown should return error + if err := pool.Shutdown(); err != ErrPoolClosed { + t.Errorf("Expected ErrPoolClosed on second shutdown, got %v", err) + } +} + +// A robust worker pool needs to handle errors gracefully. This test checks various +// error scenarios: invalid bytecode, Lua runtime errors, nil context (which +// should work fine), and unsupported parameter types (which should properly error out). +func TestErrorHandling(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test invalid bytecode + _, err = pool.Submit([]byte("not valid bytecode"), nil) + if err == nil { + t.Errorf("Expected error for invalid bytecode, got nil") + } + + // Test Lua runtime error + bytecode := createTestBytecode(t, ` + error("intentional error") + return true + `) + + _, err = pool.Submit(bytecode, nil) + if err == nil { + t.Errorf("Expected error from Lua error() call, got nil") + } + + // Test with nil context + bytecode = createTestBytecode(t, "return ctx == nil") + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Errorf("Unexpected error with nil context: %v", err) + } + if result.(bool) != true { + t.Errorf("Expected ctx to be nil in Lua, but it wasn't") + } + + // Test invalid context value + execCtx := NewContext() + execCtx.Set("param", complex128(1+2i)) // Unsupported type + + bytecode = createTestBytecode(t, "return ctx.param") + _, err = pool.Submit(bytecode, execCtx) + if err == nil { + t.Errorf("Expected error for unsupported context value type, got nil") + } +} + +// The whole point of a worker pool is concurrent processing, so we need to verify +// it works under load. This test submits multiple jobs simultaneously and makes sure +// they all complete correctly with their own unique results. +func TestConcurrentExecution(t *testing.T) { + const workers = 4 + const jobs = 20 + + pool, err := NewPool(workers) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Create bytecode that returns its input + bytecode := createTestBytecode(t, "return ctx.n") + + // Run multiple jobs concurrently + results := make(chan int, jobs) + for i := 0; i < jobs; i++ { + i := i // Capture loop variable + go func() { + execCtx := NewContext() + execCtx.Set("n", float64(i)) + + result, err := pool.Submit(bytecode, execCtx) + if err != nil { + t.Errorf("Job %d failed: %v", i, err) + results <- -1 + return + } + + num, ok := result.(float64) + if !ok { + t.Errorf("Job %d: expected float64, got %T", i, result) + results <- -1 + return + } + + results <- int(num) + }() + } + + // Collect results + counts := make(map[int]bool) + for i := 0; i < jobs; i++ { + result := <-results + if result != -1 { + counts[result] = true + } + } + + // Verify all jobs were processed + if len(counts) != jobs { + t.Errorf("Expected %d unique results, got %d", jobs, len(counts)) + } +} + +// Test context operations +func TestContext(t *testing.T) { + ctx := NewContext() + + // Test Set and Get + ctx.Set("key", "value") + if ctx.Get("key") != "value" { + t.Errorf("Expected value, got %v", ctx.Get("key")) + } + + // Test overwriting + ctx.Set("key", 123) + if ctx.Get("key") != 123 { + t.Errorf("Expected 123, got %v", ctx.Get("key")) + } + + // Test missing key + if ctx.Get("missing") != nil { + t.Errorf("Expected nil for missing key, got %v", ctx.Get("missing")) + } +}