From f60ce41ec179deec640c514d50f09b91a8b439b1 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 13 Mar 2025 22:06:40 -0500 Subject: [PATCH] add lua worker config 1 --- core/workers/config.go | 65 +++++ core/workers/pool.go | 4 +- core/workers/worker.go | 66 +++++ core/workers/workers_config_test.go | 414 ++++++++++++++++++++++++++++ core/workers/workers_test.go | 26 +- luajit | 2 +- moonshark.go | 4 +- 7 files changed, 565 insertions(+), 16 deletions(-) create mode 100644 core/workers/config.go create mode 100644 core/workers/workers_config_test.go diff --git a/core/workers/config.go b/core/workers/config.go new file mode 100644 index 0000000..f81968b --- /dev/null +++ b/core/workers/config.go @@ -0,0 +1,65 @@ +package workers + +import ( + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// InitFunc is a function that initializes a Lua state +type InitFunc func(*luajit.State) error + +// WorkerConfig contains configuration for worker initialization +type WorkerConfig struct { + // Functions maps Lua global names to Go functions + Functions map[string]luajit.GoFunction + + // Modules maps module names to Lua code strings + Modules map[string]string + + // ModulePaths maps module names to file paths + ModulePaths map[string]string + + // PackagePath custom package.path to set (optional) + PackagePath string + + // CustomInit allows for complex custom initialization + CustomInit InitFunc +} + +// NewWorkerConfig creates a new worker configuration +func NewWorkerConfig() *WorkerConfig { + return &WorkerConfig{ + Functions: make(map[string]luajit.GoFunction), + Modules: make(map[string]string), + ModulePaths: make(map[string]string), + } +} + +// AddFunction registers a Go function to be available in Lua +func (c *WorkerConfig) AddFunction(name string, fn luajit.GoFunction) *WorkerConfig { + c.Functions[name] = fn + return c +} + +// AddModule adds a Lua module from a string +func (c *WorkerConfig) AddModule(name, code string) *WorkerConfig { + c.Modules[name] = code + return c +} + +// AddModuleFile adds a Lua module from a file +func (c *WorkerConfig) AddModuleFile(name, path string) *WorkerConfig { + c.ModulePaths[name] = path + return c +} + +// SetPackagePath sets a custom package.path +func (c *WorkerConfig) SetPackagePath(path string) *WorkerConfig { + c.PackagePath = path + return c +} + +// SetCustomInit sets a custom initialization function +func (c *WorkerConfig) SetCustomInit(fn InitFunc) *WorkerConfig { + c.CustomInit = fn + return c +} diff --git a/core/workers/pool.go b/core/workers/pool.go index 2859caf..29b4611 100644 --- a/core/workers/pool.go +++ b/core/workers/pool.go @@ -13,10 +13,11 @@ type Pool struct { wg sync.WaitGroup // WaitGroup to track active workers quit chan struct{} // Channel to signal shutdown isRunning atomic.Bool // Flag to track if pool is running + config *WorkerConfig // Configuration for worker initialization } // NewPool creates a new worker pool with the specified number of workers -func NewPool(numWorkers int) (*Pool, error) { +func NewPool(numWorkers int, config *WorkerConfig) (*Pool, error) { if numWorkers <= 0 { return nil, ErrNoWorkers } @@ -25,6 +26,7 @@ func NewPool(numWorkers int) (*Pool, error) { workers: uint32(numWorkers), jobs: make(chan job, numWorkers), // Buffer equal to worker count quit: make(chan struct{}), + config: config, } p.isRunning.Store(true) diff --git a/core/workers/worker.go b/core/workers/worker.go index 958cfed..bfd2b53 100644 --- a/core/workers/worker.go +++ b/core/workers/worker.go @@ -40,6 +40,15 @@ func (w *worker) run() { return } + // Apply worker configuration if available + if w.pool.config != nil { + if err := w.applyConfig(); err != nil { + // Worker failed to initialize with configuration + atomic.AddUint32(&w.pool.workers, ^uint32(0)) + return + } + } + // Main worker loop for { select { @@ -60,6 +69,63 @@ func (w *worker) run() { } } +// applyConfig applies the worker configuration to the Lua state +func (w *worker) applyConfig() error { + config := w.pool.config + + // Set package path if specified + if config.PackagePath != "" { + if err := w.state.SetPackagePath(config.PackagePath); err != nil { + return err + } + } + + // Register Go functions + for name, fn := range config.Functions { + if err := w.state.RegisterGoFunction(name, fn); err != nil { + return err + } + } + + // Load modules from strings + for name, code := range config.Modules { + moduleLoader := ` +local module_code = [=====[` + code + `]=====] +package.preload["` + name + `"] = function() + local fn, err = loadstring(module_code, "` + name + `") + if not fn then error(err) end + return fn() +end +` + if err := w.state.DoString(moduleLoader); err != nil { + return err + } + } + + // Load modules from files + for name, path := range config.ModulePaths { + moduleLoader := ` +package.preload["` + name + `"] = function() + local fn, err = loadfile("` + path + `") + if not fn then error(err) end + return fn() +end +` + if err := w.state.DoString(moduleLoader); err != nil { + return err + } + } + + // Apply custom initialization if provided + if config.CustomInit != nil { + if err := config.CustomInit(w.state); err != nil { + return err + } + } + + return nil +} + // setupResetFunction initializes the reset function for clearing globals func (w *worker) setupResetFunction() error { resetScript := ` diff --git a/core/workers/workers_config_test.go b/core/workers/workers_config_test.go new file mode 100644 index 0000000..b05b2a0 --- /dev/null +++ b/core/workers/workers_config_test.go @@ -0,0 +1,414 @@ +package workers + +import ( + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// TestWorkerWithGoFunctions tests registering and calling Go functions from Lua +func TestWorkerWithGoFunctions(t *testing.T) { + // Create config with Go functions + config := NewWorkerConfig() + config.AddFunction("add", func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a + b) + return 1 + }) + config.AddFunction("concat", func(s *luajit.State) int { + a := s.ToString(1) + b := s.ToString(2) + s.PushString(a + b) + return 1 + }) + + // Create pool with config + pool, err := NewPool(2, config) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test using the Go functions from Lua + bytecode := createTestBytecode(t, ` + local sum = add(5, 7) + local str = concat("hello", "world") + return {sum = sum, str = str} + `) + + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + // Check results + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + if resultMap["sum"] != 12.0 { + t.Errorf("Expected sum=12, got %v", resultMap["sum"]) + } + + if resultMap["str"] != "helloworld" { + t.Errorf("Expected str=helloworld, got %v", resultMap["str"]) + } +} + +// TestWorkerWithModules tests loading and using Lua modules +func TestWorkerWithModules(t *testing.T) { + // Create config with modules + config := NewWorkerConfig() + + // Add a module as a string + mathModule := ` + local math_utils = {} + + function math_utils.add(a, b) + return a + b + end + + function math_utils.multiply(a, b) + return a * b + end + + return math_utils + ` + config.AddModule("math_utils", mathModule) + + // Add another module + stringModule := ` + local string_utils = {} + + function string_utils.concat(a, b) + return a .. b + end + + function string_utils.reverse(s) + return string.reverse(s) + end + + return string_utils + ` + config.AddModule("string_utils", stringModule) + + // Create pool with config + pool, err := NewPool(2, config) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test using the modules + bytecode := createTestBytecode(t, ` + local math = require("math_utils") + local str = require("string_utils") + + local sum = math.add(5, 7) + local product = math.multiply(3, 4) + local combined = str.concat("hello", "world") + local reversed = str.reverse("abcdef") + + return { + sum = sum, + product = product, + combined = combined, + reversed = reversed + } + `) + + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + // Check results + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + if resultMap["sum"] != 12.0 { + t.Errorf("Expected sum=12, got %v", resultMap["sum"]) + } + + if resultMap["product"] != 12.0 { + t.Errorf("Expected product=12, got %v", resultMap["product"]) + } + + if resultMap["combined"] != "helloworld" { + t.Errorf("Expected combined=helloworld, got %v", resultMap["combined"]) + } + + if resultMap["reversed"] != "fedcba" { + t.Errorf("Expected reversed=fedcba, got %v", resultMap["reversed"]) + } +} + +// TestCustomInitFunction tests using a custom initialization function +func TestCustomInitFunction(t *testing.T) { + // Create config with custom init function + config := NewWorkerConfig() + + // Set custom init function that sets a global variable + config.SetCustomInit(func(s *luajit.State) error { + return s.DoString(` + -- Create a helper utility in the global environment + -- We're adding this through custom init rather than as a module + function split_string(str, sep) + local result = {} + for part in string.gmatch(str .. sep, "(.-)" .. sep) do + table.insert(result, part) + end + return result + end + `) + }) + + // Create pool with config + pool, err := NewPool(2, config) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test using the custom initialized function + bytecode := createTestBytecode(t, ` + local parts = split_string("a,b,c,d", ",") + return parts + `) + + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + // Result should be an array + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + // Check if the empty key exists for array values + if _, hasArray := resultMap[""]; !hasArray { + t.Fatalf("Expected array to be present at empty key, got: %v", resultMap) + } + + // Check array elements + array, ok := resultMap[""].([]float64) + if !ok { + t.Fatalf("Expected []float64 under empty key, got %T", resultMap[""]) + } + + // Check array length (should have 4 elements) + if len(array) != 4 { + t.Errorf("Expected 4 elements, got %d", len(array)) + } +} + +// TestPersistedState tests maintaining state between executions using upvalues +func TestPersistedState(t *testing.T) { + // Create config with a stateful module + config := NewWorkerConfig() + + // Add a module that has internal state via upvalues + counterModule := ` + -- Module with internal counter + local counter = 0 + + local stateful = {} + + function stateful.increment() + counter = counter + 1 + return counter + end + + function stateful.get_value() + return counter + end + + function stateful.reset() + counter = 0 + return counter + end + + return stateful + ` + config.AddModule("counter", counterModule) + + // Create pool with single worker to ensure same Lua state is used + pool, err := NewPool(1, config) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // First job: increment counter + bytecode1 := createTestBytecode(t, ` + local counter = require("counter") + return counter.increment() + `) + + result1, err := pool.Submit(bytecode1, nil) + if err != nil { + t.Fatalf("Failed to submit first job: %v", err) + } + + if result1 != 1.0 { + t.Errorf("Expected counter value 1, got %v", result1) + } + + // Second job: increment counter again + bytecode2 := createTestBytecode(t, ` + local counter = require("counter") + return counter.increment() + `) + + result2, err := pool.Submit(bytecode2, nil) + if err != nil { + t.Fatalf("Failed to submit second job: %v", err) + } + + if result2 != 2.0 { + t.Errorf("Expected counter value 2, got %v", result2) + } + + // Third job: get current value without incrementing + bytecode3 := createTestBytecode(t, ` + local counter = require("counter") + return counter.get_value() + `) + + result3, err := pool.Submit(bytecode3, nil) + if err != nil { + t.Fatalf("Failed to submit third job: %v", err) + } + + if result3 != 2.0 { + t.Errorf("Expected counter value 2, got %v", result3) + } +} + +// TestComplexConfiguration tests a more complex setup with multiple features +func TestComplexConfiguration(t *testing.T) { + // Create config with multiple features + config := NewWorkerConfig() + + // Add Go functions + config.AddFunction("add", func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a + b) + return 1 + }) + + // Add a module + mathModule := ` + local math_ext = {} + + function math_ext.square(x) + return x * x + end + + function math_ext.cube(x) + return x * x * x + end + + return math_ext + ` + config.AddModule("math_ext", mathModule) + + // Set custom package path + config.SetPackagePath("./?.lua;./test/?.lua") + + // Set custom init function + config.SetCustomInit(func(s *luajit.State) error { + return s.DoString(` + -- Set a global helper + function capitalize(s) + return string.upper(string.sub(s, 1, 1)) .. string.sub(s, 2) + end + `) + }) + + // Create pool with config + pool, err := NewPool(2, config) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test using all the features together + bytecode := createTestBytecode(t, ` + local math_ext = require("math_ext") + + local sum = add(10, 5) + local squared = math_ext.square(sum) + local name = capitalize("test") + + -- Check package.path was set correctly + local has_custom_path = string.find(package.path, "./test/?.lua") ~= nil + + return { + sum = sum, + squared = squared, + name = name, + has_custom_path = has_custom_path + } + `) + + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to submit job: %v", err) + } + + // Check results + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + if resultMap["sum"] != 15.0 { + t.Errorf("Expected sum=15, got %v", resultMap["sum"]) + } + + if resultMap["squared"] != 225.0 { + t.Errorf("Expected squared=225, got %v", resultMap["squared"]) + } + + if resultMap["name"] != "Test" { + t.Errorf("Expected name=Test, got %v", resultMap["name"]) + } + + if resultMap["has_custom_path"] != true { + t.Errorf("Expected has_custom_path=true, got %v", resultMap["has_custom_path"]) + } +} + +// TestWorkerConfigInitFail tests error handling during worker initialization +func TestWorkerConfigInitFail(t *testing.T) { + // Create config with invalid Lua code that will fail + config := NewWorkerConfig() + + // Set custom init function that will definitely fail + config.SetCustomInit(func(s *luajit.State) error { + return s.DoString(` + -- This has a syntax error + function without_end( + `) + }) + + // Create pool - should still work even though module is invalid + pool, err := NewPool(1, config) + if err != nil { + t.Fatalf("Pool creation should succeed despite bad module: %v", err) + } + defer pool.Shutdown() + + // Worker count should be 0 as initialization failed + if pool.ActiveWorkers() != 0 { + t.Errorf("Expected 0 active workers, got %d", pool.ActiveWorkers()) + } +} diff --git a/core/workers/workers_test.go b/core/workers/workers_test.go index 23dd712..6224cc4 100644 --- a/core/workers/workers_test.go +++ b/core/workers/workers_test.go @@ -31,16 +31,18 @@ func TestNewPool(t *testing.T) { tests := []struct { name string workers int + config *WorkerConfig expectErr bool }{ - {"valid workers", 4, false}, - {"zero workers", 0, true}, - {"negative workers", -1, true}, + {"valid workers with no config", 4, nil, false}, + {"valid workers with config", 3, NewWorkerConfig(), false}, + {"zero workers", 0, nil, true}, + {"negative workers", -1, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - pool, err := NewPool(tt.workers) + pool, err := NewPool(tt.workers, tt.config) if tt.expectErr { if err == nil { @@ -63,7 +65,7 @@ func TestNewPool(t *testing.T) { // Here we're testing the basic job submission flow. We run a simple Lua script // that returns the number 42 and make sure we get that same value back from the worker pool. func TestPoolSubmit(t *testing.T) { - pool, err := NewPool(2) + pool, err := NewPool(2, nil) if err != nil { t.Fatalf("Failed to create pool: %v", err) } @@ -91,7 +93,7 @@ func TestPoolSubmit(t *testing.T) { // for successful completion, and another where we expect the operation to be canceled // due to a short timeout. func TestPoolSubmitWithContext(t *testing.T) { - pool, err := NewPool(2) + pool, err := NewPool(2, nil) if err != nil { t.Fatalf("Failed to create pool: %v", err) } @@ -131,7 +133,7 @@ func TestPoolSubmitWithContext(t *testing.T) { // get them back properly. This test sends numbers, strings, booleans, and arrays to // a Lua script and verifies they're all handled correctly in both directions. func TestContextValues(t *testing.T) { - pool, err := NewPool(2) + pool, err := NewPool(2, nil) if err != nil { t.Fatalf("Failed to create pool: %v", err) } @@ -189,7 +191,7 @@ func TestContextValues(t *testing.T) { // Test context with nested data structures func TestNestedContext(t *testing.T) { - pool, err := NewPool(2) + pool, err := NewPool(2, nil) if err != nil { t.Fatalf("Failed to create pool: %v", err) } @@ -249,7 +251,7 @@ func TestNestedContext(t *testing.T) { // This test confirms that by setting a global variable in one job and then checking // that it's been cleared before the next job runs on the same worker. func TestStateReset(t *testing.T) { - pool, err := NewPool(1) // Use 1 worker to ensure same state is reused + pool, err := NewPool(1, nil) // Use 1 worker to ensure same state is reused if err != nil { t.Fatalf("Failed to create pool: %v", err) } @@ -288,7 +290,7 @@ func TestStateReset(t *testing.T) { // before shutdown, that we get the right error when trying to submit after shutdown, // and that we properly handle attempts to shut down an already closed pool. func TestPoolShutdown(t *testing.T) { - pool, err := NewPool(2) + pool, err := NewPool(2, nil) if err != nil { t.Fatalf("Failed to create pool: %v", err) } @@ -321,7 +323,7 @@ func TestPoolShutdown(t *testing.T) { // error scenarios: invalid bytecode, Lua runtime errors, nil context (which // should work fine), and unsupported parameter types (which should properly error out). func TestErrorHandling(t *testing.T) { - pool, err := NewPool(2) + pool, err := NewPool(2, nil) if err != nil { t.Fatalf("Failed to create pool: %v", err) } @@ -372,7 +374,7 @@ func TestConcurrentExecution(t *testing.T) { const workers = 4 const jobs = 20 - pool, err := NewPool(workers) + pool, err := NewPool(workers, nil) if err != nil { t.Fatalf("Failed to create pool: %v", err) } diff --git a/luajit b/luajit index 13686b3..7ea0dbc 160000 --- a/luajit +++ b/luajit @@ -1 +1 @@ -Subproject commit 13686b3e66b388a31d459fe95d1aa3bfa05aeb27 +Subproject commit 7ea0dbcb7b2ddcd8758e66b034c300ee55178b29 diff --git a/moonshark.go b/moonshark.go index 30856f7..95a7578 100644 --- a/moonshark.go +++ b/moonshark.go @@ -47,7 +47,7 @@ func initRouters(routesDir, staticDir string, log *logger.Logger) (*routers.LuaR func main() { // Initialize logger - log := logger.New(logger.LevelDebug, true) + log := logger.New(logger.LevelInfo, true) log.Info("Starting Moonshark server") @@ -91,7 +91,7 @@ func main() { workerPoolSize := cfg.GetInt("worker_pool_size", 4) // Initialize worker pool - pool, err := workers.NewPool(workerPoolSize) + pool, err := workers.NewPool(workerPoolSize, nil) if err != nil { log.Fatal("Failed to initialize worker pool: %v", err) }