Compare commits

..

2 Commits

Author SHA1 Message Date
78337988bd round-robin 1 2025-03-28 15:14:36 -05:00
d07fc638e6 cap naming convention 2025-03-28 14:44:49 -05:00
19 changed files with 155 additions and 89 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"path/filepath" "path/filepath"
"runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -30,17 +31,19 @@ var resultChanPool = sync.Pool{
}, },
} }
// LuaRunner runs Lua scripts using a single Lua state // LuaRunner runs Lua scripts using multiple Lua states in a round-robin fashion
type LuaRunner struct { type LuaRunner struct {
state *luajit.State // The Lua state states []*luajit.State // Multiple Lua states for parallel execution
jobQueue chan job // Channel for incoming jobs jobQueues []chan job // Each state has its own job queue
workerCount int // Number of worker states (default 4)
nextWorker int32 // Atomic counter for round-robin distribution
isRunning atomic.Bool // Flag indicating if the runner is active isRunning atomic.Bool // Flag indicating if the runner is active
mu sync.RWMutex // Mutex for thread safety mu sync.RWMutex // Mutex for thread safety
wg sync.WaitGroup // WaitGroup for clean shutdown wg sync.WaitGroup // WaitGroup for clean shutdown
initFunc StateInitFunc // Optional function to initialize Lua state initFunc StateInitFunc // Optional function to initialize Lua states
bufferSize int // Size of the job queue buffer bufferSize int // Size of each job queue buffer
moduleLoader *NativeModuleLoader // Native module loader for require moduleLoader *NativeModuleLoader // Native module loader for require
sandbox *Sandbox // The sandbox environment sandboxes []*Sandbox // Sandbox for each state
debug bool // Enable debug logging debug bool // Enable debug logging
} }
@ -80,12 +83,21 @@ func WithDebugEnabled() RunnerOption {
} }
} }
// NewRunner creates a new LuaRunner // WithWorkerCount sets the number of worker states (min 1)
func WithWorkerCount(count int) RunnerOption {
return func(r *LuaRunner) {
if count > 0 {
r.workerCount = count
}
}
}
// NewRunner creates a new LuaRunner with multiple worker states
func NewRunner(options ...RunnerOption) (*LuaRunner, error) { func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
// Default configuration // Default configuration
runner := &LuaRunner{ runner := &LuaRunner{
bufferSize: 10, // Default buffer size bufferSize: 10,
sandbox: NewSandbox(), workerCount: runtime.GOMAXPROCS(0),
debug: false, debug: false,
} }
@ -94,9 +106,10 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
opt(runner) opt(runner)
} }
// Create job queue // Initialize states and job queues
runner.jobQueue = make(chan job, runner.bufferSize) runner.states = make([]*luajit.State, runner.workerCount)
runner.isRunning.Store(true) runner.jobQueues = make([]chan job, runner.workerCount)
runner.sandboxes = make([]*Sandbox, runner.workerCount)
// Set up module loader if not already initialized // Set up module loader if not already initialized
if runner.moduleLoader == nil { if runner.moduleLoader == nil {
@ -107,14 +120,28 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
runner.moduleLoader = NewNativeModuleLoader(requireConfig) runner.moduleLoader = NewNativeModuleLoader(requireConfig)
} }
// Initialize Lua state // Create job queues and initialize states
if err := runner.initState(true); err != nil { for i := 0; i < runner.workerCount; i++ {
// Create job queue
runner.jobQueues[i] = make(chan job, runner.bufferSize)
// Create sandbox
runner.sandboxes[i] = NewSandbox()
// Initialize state
if err := runner.initState(i, true); err != nil {
// Clean up if initialization fails
runner.Close()
return nil, err return nil, err
} }
// Start the event loop // Start worker goroutine
runner.wg.Add(1) runner.wg.Add(1)
go runner.processJobs() go runner.processJobs(i)
}
runner.isRunning.Store(true)
runner.nextWorker = 0
return runner, nil return runner, nil
} }
@ -126,17 +153,17 @@ func (r *LuaRunner) debugLog(format string, args ...interface{}) {
} }
} }
// initState initializes or reinitializes the Lua state // initState initializes or reinitializes a specific Lua state
func (r *LuaRunner) initState(initial bool) error { func (r *LuaRunner) initState(workerIndex int, initial bool) error {
r.debugLog("Initializing Lua state (initial=%v)", initial) r.debugLog("Initializing Lua state %d (initial=%v)", workerIndex, initial)
// Clean up existing state if there is one // Clean up existing state if there is one
if r.state != nil { if r.states[workerIndex] != nil {
r.debugLog("Cleaning up existing state") r.debugLog("Cleaning up existing state %d", workerIndex)
// Always call Cleanup before Close to properly free function pointers // Always call Cleanup before Close to properly free function pointers
r.state.Cleanup() r.states[workerIndex].Cleanup()
r.state.Close() r.states[workerIndex].Close()
r.state = nil r.states[workerIndex] = nil
} }
// Create fresh state // Create fresh state
@ -144,25 +171,25 @@ func (r *LuaRunner) initState(initial bool) error {
if state == nil { if state == nil {
return errors.New("failed to create Lua state") return errors.New("failed to create Lua state")
} }
r.debugLog("Created new Lua state") r.debugLog("Created new Lua state %d", workerIndex)
// Set up require paths and mechanism // Set up require paths and mechanism
if err := r.moduleLoader.SetupRequire(state); err != nil { if err := r.moduleLoader.SetupRequire(state); err != nil {
r.debugLog("Failed to set up require: %v", err) r.debugLog("Failed to set up require for state %d: %v", workerIndex, err)
state.Cleanup() state.Cleanup()
state.Close() state.Close()
return ErrInitFailed return ErrInitFailed
} }
r.debugLog("Require system initialized") r.debugLog("Require system initialized for state %d", workerIndex)
// Initialize all core modules from the registry // Initialize all core modules from the registry
if err := GlobalRegistry.Initialize(state); err != nil { if err := GlobalRegistry.Initialize(state); err != nil {
r.debugLog("Failed to initialize core modules: %v", err) r.debugLog("Failed to initialize core modules for state %d: %v", workerIndex, err)
state.Cleanup() state.Cleanup()
state.Close() state.Close()
return ErrInitFailed return ErrInitFailed
} }
r.debugLog("Core modules initialized") r.debugLog("Core modules initialized for state %d", workerIndex)
// Check if http module is properly registered // Check if http module is properly registered
testResult, err := state.ExecuteWithResult(` testResult, err := state.ExecuteWithResult(`
@ -174,42 +201,42 @@ func (r *LuaRunner) initState(initial bool) error {
end end
`) `)
if err != nil || testResult != true { if err != nil || testResult != true {
r.debugLog("HTTP module verification failed: %v, result: %v", err, testResult) r.debugLog("HTTP module verification failed for state %d: %v, result: %v", workerIndex, err, testResult)
} else { } else {
r.debugLog("HTTP module verified OK") r.debugLog("HTTP module verified OK for state %d", workerIndex)
} }
// Verify __http_request function // Verify __http_request function
testResult, _ = state.ExecuteWithResult(`return type(__http_request)`) testResult, _ = state.ExecuteWithResult(`return type(__http_request)`)
r.debugLog("__http_request function is of type: %v", testResult) r.debugLog("__http_request function for state %d is of type: %v", workerIndex, testResult)
// Set up sandbox after core modules are initialized // Set up sandbox after core modules are initialized
if err := r.sandbox.Setup(state); err != nil { if err := r.sandboxes[workerIndex].Setup(state); err != nil {
r.debugLog("Failed to set up sandbox: %v", err) r.debugLog("Failed to set up sandbox for state %d: %v", workerIndex, err)
state.Cleanup() state.Cleanup()
state.Close() state.Close()
return ErrInitFailed return ErrInitFailed
} }
r.debugLog("Sandbox environment set up") r.debugLog("Sandbox environment set up for state %d", workerIndex)
// Preload all modules into package.loaded // Preload all modules into package.loaded
if err := r.moduleLoader.PreloadAllModules(state); err != nil { if err := r.moduleLoader.PreloadAllModules(state); err != nil {
r.debugLog("Failed to preload modules: %v", err) r.debugLog("Failed to preload modules for state %d: %v", workerIndex, err)
state.Cleanup() state.Cleanup()
state.Close() state.Close()
return errors.New("failed to preload modules") return errors.New("failed to preload modules")
} }
r.debugLog("All modules preloaded") r.debugLog("All modules preloaded for state %d", workerIndex)
// Run init function if provided // Run init function if provided
if r.initFunc != nil { if r.initFunc != nil {
if err := r.initFunc(state); err != nil { if err := r.initFunc(state); err != nil {
r.debugLog("Custom init function failed: %v", err) r.debugLog("Custom init function failed for state %d: %v", workerIndex, err)
state.Cleanup() state.Cleanup()
state.Close() state.Close()
return ErrInitFailed return ErrInitFailed
} }
r.debugLog("Custom init function completed") r.debugLog("Custom init function completed for state %d", workerIndex)
} }
// Test for HTTP module again after full initialization // Test for HTTP module again after full initialization
@ -222,31 +249,31 @@ func (r *LuaRunner) initState(initial bool) error {
end end
`) `)
if err != nil || testResult != true { if err != nil || testResult != true {
r.debugLog("Final HTTP module verification failed: %v, result: %v", err, testResult) r.debugLog("Final HTTP module verification failed for state %d: %v, result: %v", workerIndex, err, testResult)
} else { } else {
r.debugLog("Final HTTP module verification OK") r.debugLog("Final HTTP module verification OK for state %d", workerIndex)
} }
r.state = state r.states[workerIndex] = state
r.debugLog("State initialization complete") r.debugLog("State %d initialization complete", workerIndex)
return nil return nil
} }
// processJobs handles the job queue // processJobs handles the job queue for a specific worker
func (r *LuaRunner) processJobs() { func (r *LuaRunner) processJobs(workerIndex int) {
defer r.wg.Done() defer r.wg.Done()
defer func() { defer func() {
if r.state != nil { if r.states[workerIndex] != nil {
r.debugLog("Cleaning up Lua state in processJobs") r.debugLog("Cleaning up Lua state %d in processJobs", workerIndex)
r.state.Cleanup() r.states[workerIndex].Cleanup()
r.state.Close() r.states[workerIndex].Close()
r.state = nil r.states[workerIndex] = nil
} }
}() }()
for job := range r.jobQueue { for job := range r.jobQueues[workerIndex] {
// Execute the job and send result // Execute the job and send result
result := r.executeJob(job) result := r.executeJob(workerIndex, job)
select { select {
case job.Result <- result: case job.Result <- result:
// Result sent successfully // Result sent successfully
@ -257,7 +284,7 @@ func (r *LuaRunner) processJobs() {
} }
// executeJob runs a script in the sandbox environment // executeJob runs a script in the sandbox environment
func (r *LuaRunner) executeJob(j job) JobResult { func (r *LuaRunner) executeJob(workerIndex int, j job) JobResult {
// If the job has a script path, update script dir for module resolution // If the job has a script path, update script dir for module resolution
if j.ScriptPath != "" { if j.ScriptPath != "" {
r.mu.Lock() r.mu.Lock()
@ -272,14 +299,16 @@ func (r *LuaRunner) executeJob(j job) JobResult {
} }
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() state := r.states[workerIndex]
sandbox := r.sandboxes[workerIndex]
r.mu.RUnlock()
if r.state == nil { if state == nil {
return JobResult{nil, errors.New("lua state is not initialized")} return JobResult{nil, errors.New("lua state is not initialized")}
} }
// Execute in sandbox // Execute in sandbox
value, err := r.sandbox.Execute(r.state, j.Bytecode, ctx) value, err := sandbox.Execute(state, j.Bytecode, ctx)
return JobResult{value, err} return JobResult{value, err}
} }
@ -311,10 +340,14 @@ func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx
Result: resultChan, Result: resultChan,
} }
// Choose worker in round-robin fashion
workerIndex := int(atomic.AddInt32(&r.nextWorker, 1) % int32(r.workerCount))
// Submit job with context // Submit job with context
select { select {
case r.jobQueue <- j: case r.jobQueues[workerIndex] <- j:
// Job submitted // Job submitted
r.debugLog("Job submitted to worker %d", workerIndex)
case <-ctx.Done(): case <-ctx.Done():
// Return the channel to the pool before exiting // Return the channel to the pool before exiting
resultChanPool.Put(resultChan) resultChanPool.Put(resultChan)
@ -353,9 +386,15 @@ func (r *LuaRunner) Close() error {
} }
r.isRunning.Store(false) r.isRunning.Store(false)
close(r.jobQueue)
// Wait for event loop to finish // Close all job queues
for i := 0; i < r.workerCount; i++ {
if r.jobQueues[i] != nil {
close(r.jobQueues[i])
}
}
// Wait for all workers to finish
r.wg.Wait() r.wg.Wait()
return nil return nil
@ -368,63 +407,90 @@ func (r *LuaRunner) NotifyFileChanged(filePath string) bool {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
// Reset the entire state on file changes // Reset all states on file changes
err := r.initState(false) success := true
for i := 0; i < r.workerCount; i++ {
err := r.initState(i, false)
if err != nil { if err != nil {
r.debugLog("Failed to reinitialize state: %v", err) r.debugLog("Failed to reinitialize state %d: %v", i, err)
return false success = false
} else {
r.debugLog("State %d successfully reinitialized", i)
}
} }
r.debugLog("State successfully reinitialized") return success
return true
} }
// ResetModuleCache clears non-core modules from package.loaded // ResetModuleCache clears non-core modules from package.loaded in all states
func (r *LuaRunner) ResetModuleCache() { func (r *LuaRunner) ResetModuleCache() {
if r.moduleLoader != nil { if r.moduleLoader != nil {
r.debugLog("Resetting module cache") r.debugLog("Resetting module cache in all states")
r.moduleLoader.ResetModules(r.state) r.mu.RLock()
defer r.mu.RUnlock()
for i := 0; i < r.workerCount; i++ {
if r.states[i] != nil {
r.moduleLoader.ResetModules(r.states[i])
}
}
} }
} }
// ReloadAllModules reloads all modules into package.loaded // ReloadAllModules reloads all modules into package.loaded in all states
func (r *LuaRunner) ReloadAllModules() error { func (r *LuaRunner) ReloadAllModules() error {
if r.moduleLoader != nil { if r.moduleLoader != nil {
r.debugLog("Reloading all modules") r.debugLog("Reloading all modules in all states")
return r.moduleLoader.PreloadAllModules(r.state) r.mu.RLock()
defer r.mu.RUnlock()
for i := 0; i < r.workerCount; i++ {
if r.states[i] != nil {
if err := r.moduleLoader.PreloadAllModules(r.states[i]); err != nil {
return err
}
}
}
} }
return nil return nil
} }
// RefreshModuleByName invalidates a specific module in package.loaded // RefreshModuleByName invalidates a specific module in package.loaded in all states
func (r *LuaRunner) RefreshModuleByName(modName string) bool { func (r *LuaRunner) RefreshModuleByName(modName string) bool {
if r.state != nil { r.mu.RLock()
r.debugLog("Refreshing module: %s", modName) defer r.mu.RUnlock()
if err := r.state.DoString(`package.loaded["` + modName + `"] = nil`); err != nil {
return false success := true
for i := 0; i < r.workerCount; i++ {
if r.states[i] != nil {
r.debugLog("Refreshing module %s in state %d", modName, i)
if err := r.states[i].DoString(`package.loaded["` + modName + `"] = nil`); err != nil {
success = false
} }
return true
} }
return false }
return success
} }
// AddModule adds a module to the sandbox environment // AddModule adds a module to all sandbox environments
func (r *LuaRunner) AddModule(name string, module any) { func (r *LuaRunner) AddModule(name string, module any) {
r.debugLog("Adding module: %s", name) r.debugLog("Adding module %s to all sandboxes", name)
r.sandbox.AddModule(name, module) for i := 0; i < r.workerCount; i++ {
r.sandboxes[i].AddModule(name, module)
}
} }
// GetModuleCount returns the number of loaded modules // GetModuleCount returns the number of loaded modules in the first state
func (r *LuaRunner) GetModuleCount() int { func (r *LuaRunner) GetModuleCount() int {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
count := 0 count := 0
// Get module count from Lua // Get module count from the first Lua state
if r.state != nil { if r.states[0] != nil {
// Execute a Lua snippet to count modules // Execute a Lua snippet to count modules
if res, err := r.state.ExecuteWithResult(` if res, err := r.states[0].ExecuteWithResult(`
local count = 0 local count = 0
for _ in pairs(package.loaded) do for _ in pairs(package.loaded) do
count = count + 1 count = count + 1