diff --git a/core/workers/job.go b/core/workers/job.go new file mode 100644 index 0000000..b5db4fb --- /dev/null +++ b/core/workers/job.go @@ -0,0 +1,14 @@ +package workers + +// JobResult represents the result of a Lua script execution +type JobResult struct { + Value interface{} // Return value from Lua + Error error // Error if any +} + +// job represents a Lua script execution request +type job struct { + Bytecode []byte // Compiled LuaJIT bytecode + Params map[string]interface{} // Parameters to pass to the script + Result chan<- JobResult // Channel to send result back +} diff --git a/core/workers/pool.go b/core/workers/pool.go new file mode 100644 index 0000000..1c5c1fb --- /dev/null +++ b/core/workers/pool.go @@ -0,0 +1,98 @@ +package workers + +import ( + "context" + "sync" + "sync/atomic" +) + +// Pool manages a pool of Lua worker goroutines +type Pool struct { + workers uint32 // Number of workers + jobs chan job // Channel to send jobs to workers + wg sync.WaitGroup // WaitGroup to track active workers + quit chan struct{} // Channel to signal shutdown + isRunning atomic.Bool // Flag to track if pool is running +} + +// NewPool creates a new worker pool with the specified number of workers +func NewPool(numWorkers int) (*Pool, error) { + if numWorkers <= 0 { + return nil, ErrNoWorkers + } + + p := &Pool{ + workers: uint32(numWorkers), + jobs: make(chan job, numWorkers), // Buffer equal to worker count + quit: make(chan struct{}), + } + p.isRunning.Store(true) + + // Start workers + p.wg.Add(numWorkers) + for i := 0; i < numWorkers; i++ { + w := &worker{ + pool: p, + id: uint32(i), + } + go w.run() + } + + return p, nil +} + +// SubmitWithContext sends a job to the worker pool with context +func (p *Pool) SubmitWithContext(ctx context.Context, bytecode []byte, params map[string]interface{}) (interface{}, error) { + if !p.isRunning.Load() { + return nil, ErrPoolClosed + } + + resultChan := make(chan JobResult, 1) + j := job{ + Bytecode: bytecode, + Params: params, + Result: resultChan, + } + + // Submit job with context + select { + case p.jobs <- j: + // Job submitted + case <-ctx.Done(): + return nil, ctx.Err() + } + + // Wait for result with context + select { + case result := <-resultChan: + return result.Value, result.Error + case <-ctx.Done(): + // Note: The job will still be processed by a worker, + // but the result will be discarded + return nil, ctx.Err() + } +} + +// Shutdown gracefully shuts down the worker pool +func (p *Pool) Shutdown() error { + if !p.isRunning.Load() { + return ErrPoolClosed + } + p.isRunning.Store(false) + + // Signal workers to quit + close(p.quit) + + // Wait for workers to finish + p.wg.Wait() + + // Close jobs channel + close(p.jobs) + + return nil +} + +// ActiveWorkers returns the number of active workers +func (p *Pool) ActiveWorkers() uint32 { + return atomic.LoadUint32(&p.workers) +} diff --git a/core/workers/worker.go b/core/workers/worker.go new file mode 100644 index 0000000..c481cbd --- /dev/null +++ b/core/workers/worker.go @@ -0,0 +1,162 @@ +package workers + +import ( + "context" + "errors" + "sync/atomic" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// Common errors +var ( + ErrPoolClosed = errors.New("worker pool is closed") + ErrNoWorkers = errors.New("no workers available") +) + +// worker represents a single Lua execution worker +type worker struct { + pool *Pool // Reference to the pool + state *luajit.State // Lua state + id uint32 // Worker ID +} + +// run is the main worker function that processes jobs +func (w *worker) run() { + defer w.pool.wg.Done() + + // Initialize Lua state + w.state = luajit.New() + if w.state == nil { + // Worker failed to initialize, decrement counter + atomic.AddUint32(&w.pool.workers, ^uint32(0)) + return + } + defer w.state.Close() + + // Set up reset function for clearing state between requests + if err := w.setupResetFunction(); err != nil { + // Worker failed to initialize reset function, decrement counter + atomic.AddUint32(&w.pool.workers, ^uint32(0)) + return + } + + // Main worker loop + for { + select { + case job, ok := <-w.pool.jobs: + if !ok { + // Jobs channel closed, exit + return + } + + // Execute job + result := w.executeJob(job) + job.Result <- result + + case <-w.pool.quit: + // Quit signal received, exit + return + } + } +} + +// setupResetFunction initializes the reset function for clearing globals +func (w *worker) setupResetFunction() error { + resetScript := ` + -- Create reset function to efficiently clear globals after each request + function __reset_globals() + -- Only keep builtin globals, remove all user-defined globals + local preserve = { + ["_G"] = true, ["_VERSION"] = true, ["__reset_globals"] = true, + ["assert"] = true, ["collectgarbage"] = true, ["coroutine"] = true, + ["debug"] = true, ["dofile"] = true, ["error"] = true, + ["getmetatable"] = true, ["io"] = true, ["ipairs"] = true, + ["load"] = true, ["loadfile"] = true, ["loadstring"] = true, + ["math"] = true, ["next"] = true, ["os"] = true, + ["package"] = true, ["pairs"] = true, ["pcall"] = true, + ["print"] = true, ["rawequal"] = true, ["rawget"] = true, + ["rawset"] = true, ["require"] = true, ["select"] = true, + ["setmetatable"] = true, ["string"] = true, ["table"] = true, + ["tonumber"] = true, ["tostring"] = true, ["type"] = true, + ["unpack"] = true, ["xpcall"] = true + } + + -- Clear all non-standard globals + for name in pairs(_G) do + if not preserve[name] then + _G[name] = nil + end + end + + -- Run garbage collection to release memory + collectgarbage('collect') + end + ` + + return w.state.DoString(resetScript) +} + +// resetState prepares the Lua state for a new job +func (w *worker) resetState() { + w.state.DoString("__reset_globals()") +} + +// setParams sets job parameters as a global 'params' table +func (w *worker) setParams(params map[string]interface{}) error { + // Create new table for params + w.state.NewTable() + + // Add each parameter to the table + for key, value := range params { + // Push key + w.state.PushString(key) + + // Push value + if err := w.state.PushValue(value); err != nil { + return err + } + + // Set table[key] = value + w.state.SetTable(-3) + } + + // Set the table as global 'params' + w.state.SetGlobal("params") + + return nil +} + +// executeJob executes a Lua job in the worker's state +func (w *worker) executeJob(j job) JobResult { + // Reset state before execution + w.resetState() + + // Set parameters + if j.Params != nil { + if err := w.setParams(j.Params); err != nil { + return JobResult{nil, err} + } + } + + // Load bytecode + if err := w.state.LoadBytecode(j.Bytecode, "script"); err != nil { + return JobResult{nil, err} + } + + // Execute script with one result + if err := w.state.RunBytecodeWithResults(1); err != nil { + return JobResult{nil, err} + } + + // Get result + value, err := w.state.ToValue(-1) + w.state.Pop(1) // Pop result + + return JobResult{value, err} +} + +// Submit sends a job to the worker pool +func (p *Pool) Submit(bytecode []byte, params map[string]interface{}) (interface{}, error) { + return p.SubmitWithContext(context.Background(), bytecode, params) +} diff --git a/core/workers/workers_test.go b/core/workers/workers_test.go new file mode 100644 index 0000000..63517f8 --- /dev/null +++ b/core/workers/workers_test.go @@ -0,0 +1,365 @@ +package workers + +import ( + "context" + "testing" + "time" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// This helper function creates real LuaJIT bytecode for our tests. Instead of using +// mocks, we compile actual Lua code into bytecode just like we would in production. +func createTestBytecode(t *testing.T, code string) []byte { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + bytecode, err := state.CompileBytecode(code, "test") + if err != nil { + t.Fatalf("Failed to compile test bytecode: %v", err) + } + + return bytecode +} + +// This test makes sure we can create a worker pool with a valid number of workers, +// and that we properly reject attempts to create a pool with zero or negative workers. +func TestNewPool(t *testing.T) { + tests := []struct { + name string + workers int + expectErr bool + }{ + {"valid workers", 4, false}, + {"zero workers", 0, true}, + {"negative workers", -1, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool, err := NewPool(tt.workers) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error for %d workers, got nil", tt.workers) + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if pool == nil { + t.Errorf("Expected non-nil pool") + } else { + pool.Shutdown() + } + } + }) + } +} + +// Here we're testing the basic job submission flow. We run a simple Lua script +// that returns the number 42 and make sure we get that same value back from the worker pool. +func TestPoolSubmit(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + bytecode := createTestBytecode(t, "return 42") + + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + num, ok := result.(float64) + if !ok { + t.Fatalf("Expected float64 result, got %T", result) + } + + if num != 42 { + t.Errorf("Expected 42, got %f", num) + } +} + +// This test checks how our worker pool handles timeouts. We run a script that takes +// some time to complete and verify two scenarios: one where the timeout is long enough +// for successful completion, and another where we expect the operation to be canceled +// due to a short timeout. +func TestPoolSubmitWithContext(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Create bytecode that sleeps + bytecode := createTestBytecode(t, ` + -- Sleep for 500ms + local start = os.time() + while os.difftime(os.time(), start) < 0.5 do end + return "done" + `) + + // Test with timeout that should succeed + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + result, err := pool.SubmitWithContext(ctx, bytecode, nil) + if err != nil { + t.Fatalf("Unexpected error with sufficient timeout: %v", err) + } + if result != "done" { + t.Errorf("Expected 'done', got %v", result) + } + + // Test with timeout that should fail + ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pool.SubmitWithContext(ctx, bytecode, nil) + if err == nil { + t.Errorf("Expected timeout error, got nil") + } +} + +// We need to make sure we can pass different types of parameters from Go to Lua and +// get them back properly. This test sends numbers, strings, booleans, and arrays to +// a Lua script and verifies they're all handled correctly in both directions. +func TestJobParameters(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + bytecode := createTestBytecode(t, ` + return { + num = params.number, + str = params.text, + flag = params.enabled, + list = {params.table[1], params.table[2], params.table[3]}, + } + `) + + params := map[string]any{ + "number": 42.5, + "text": "hello", + "enabled": true, + "table": []float64{10, 20, 30}, + } + + result, err := pool.Submit(bytecode, params) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + // Result should be a map + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + // Check values + if resultMap["num"] != 42.5 { + t.Errorf("Expected num=42.5, got %v", resultMap["num"]) + } + if resultMap["str"] != "hello" { + t.Errorf("Expected str=hello, got %v", resultMap["str"]) + } + if resultMap["flag"] != true { + t.Errorf("Expected flag=true, got %v", resultMap["flag"]) + } + + arr, ok := resultMap["list"].([]float64) + if !ok { + t.Fatalf("Expected []float64, got %T", resultMap["list"]) + } + + expected := []float64{10, 20, 30} + for i, v := range expected { + if arr[i] != v { + t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i]) + } + } +} + +// A key requirement for our worker pool is that we don't leak state between executions. +// This test confirms that by setting a global variable in one job and then checking +// that it's been cleared before the next job runs on the same worker. +func TestStateReset(t *testing.T) { + pool, err := NewPool(1) // Use 1 worker to ensure same state is reused + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // First job sets a global + bytecode1 := createTestBytecode(t, ` + global_var = "should be cleared" + return true + `) + + // Second job checks if global exists + bytecode2 := createTestBytecode(t, ` + return global_var ~= nil + `) + + // Run first job + _, err = pool.Submit(bytecode1, nil) + if err != nil { + t.Fatalf("Failed to submit first job: %v", err) + } + + // Run second job + result, err := pool.Submit(bytecode2, nil) + if err != nil { + t.Fatalf("Failed to submit second job: %v", err) + } + + // Global should be cleared + if result.(bool) { + t.Errorf("Expected global_var to be cleared, but it still exists") + } +} + +// Let's make sure our pool shuts down cleanly. This test confirms that jobs work +// before shutdown, that we get the right error when trying to submit after shutdown, +// and that we properly handle attempts to shut down an already closed pool. +func TestPoolShutdown(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + + // Submit a job to verify pool works + bytecode := createTestBytecode(t, "return 42") + _, err = pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + // Shutdown + if err := pool.Shutdown(); err != nil { + t.Errorf("Shutdown failed: %v", err) + } + + // Submit after shutdown should fail + _, err = pool.Submit(bytecode, nil) + if err != ErrPoolClosed { + t.Errorf("Expected ErrPoolClosed, got %v", err) + } + + // Second shutdown should return error + if err := pool.Shutdown(); err != ErrPoolClosed { + t.Errorf("Expected ErrPoolClosed on second shutdown, got %v", err) + } +} + +// A robust worker pool needs to handle errors gracefully. This test checks various +// error scenarios: invalid bytecode, Lua runtime errors, nil parameters (which +// should work fine), and unsupported parameter types (which should properly error out). +func TestErrorHandling(t *testing.T) { + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test invalid bytecode + _, err = pool.Submit([]byte("not valid bytecode"), nil) + if err == nil { + t.Errorf("Expected error for invalid bytecode, got nil") + } + + // Test Lua runtime error + bytecode := createTestBytecode(t, ` + error("intentional error") + return true + `) + + _, err = pool.Submit(bytecode, nil) + if err == nil { + t.Errorf("Expected error from Lua error() call, got nil") + } + + // Test invalid parameter + bytecode = createTestBytecode(t, "return param") + + // This should work with nil value + _, err = pool.Submit(bytecode, map[string]any{ + "param": nil, + }) + if err != nil { + t.Errorf("Unexpected error with nil param: %v", err) + } + + // Complex type that can't be converted + complex := map[string]any{ + "param": complex128(1 + 2i), // Unsupported type + } + + _, err = pool.Submit(bytecode, complex) + if err == nil { + t.Errorf("Expected error for unsupported parameter type, got nil") + } +} + +// The whole point of a worker pool is concurrent processing, so we need to verify +// it works under load. This test submits multiple jobs simultaneously and makes sure +// they all complete correctly with their own unique results. +func TestConcurrentExecution(t *testing.T) { + const workers = 4 + const jobs = 20 + + pool, err := NewPool(workers) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Create bytecode that returns its input + bytecode := createTestBytecode(t, "return params.n") + + // Run multiple jobs concurrently + results := make(chan int, jobs) + for i := 0; i < jobs; i++ { + i := i // Capture loop variable + go func() { + params := map[string]any{"n": float64(i)} + result, err := pool.Submit(bytecode, params) + if err != nil { + t.Errorf("Job %d failed: %v", i, err) + results <- -1 + return + } + + num, ok := result.(float64) + if !ok { + t.Errorf("Job %d: expected float64, got %T", i, result) + results <- -1 + return + } + + results <- int(num) + }() + } + + // Collect results + counts := make(map[int]bool) + for i := 0; i < jobs; i++ { + result := <-results + if result != -1 { + counts[result] = true + } + } + + // Verify all jobs were processed + if len(counts) != jobs { + t.Errorf("Expected %d unique results, got %d", jobs, len(counts)) + } +}