diff --git a/Moonshark.go b/Moonshark.go index a659cc7..4a89bdf 100644 --- a/Moonshark.go +++ b/Moonshark.go @@ -21,7 +21,7 @@ type Moonshark struct { Config *config.Config LuaRouter *routers.LuaRouter StaticRouter *routers.StaticRouter - LuaRunner *runner.LuaRunner + LuaRunner *runner.Runner HTTPServer *http.Server // Clean-up functions for watchers diff --git a/core/http/Server.go b/core/http/Server.go index 350fbdd..99b96d3 100644 --- a/core/http/Server.go +++ b/core/http/Server.go @@ -19,7 +19,7 @@ import ( type Server struct { luaRouter *routers.LuaRouter staticRouter *routers.StaticRouter - luaRunner *runner.LuaRunner + luaRunner *runner.Runner httpServer *http.Server loggingEnabled bool debugMode bool // Controls whether to show error details @@ -28,7 +28,7 @@ type Server struct { } // New creates a new HTTP server with optimized connection settings -func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runner *runner.LuaRunner, +func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runner *runner.Runner, loggingEnabled bool, debugMode bool, overrideDir string, config *config.Config) *Server { server := &Server{ diff --git a/core/runner/Context.go b/core/runner/Context.go index 3078093..1a829e4 100644 --- a/core/runner/Context.go +++ b/core/runner/Context.go @@ -4,8 +4,11 @@ import "sync" // Context represents execution context for a Lua script type Context struct { - // Generic map for any context values (route params, HTTP request info, etc.) + // Values stores any context values (route params, HTTP request info, etc.) Values map[string]any + + // internal mutex for concurrent access + mu sync.RWMutex } // Context pool to reduce allocations @@ -24,19 +27,59 @@ func NewContext() *Context { // Release returns the context to the pool after clearing its values func (c *Context) Release() { + c.mu.Lock() + defer c.mu.Unlock() + // Clear all values to prevent data leakage for k := range c.Values { delete(c.Values, k) } + contextPool.Put(c) } // Set adds a value to the context func (c *Context) Set(key string, value any) { + c.mu.Lock() + defer c.mu.Unlock() + c.Values[key] = value } // Get retrieves a value from the context func (c *Context) Get(key string) any { + c.mu.RLock() + defer c.mu.RUnlock() + return c.Values[key] } + +// Contains checks if a key exists in the context +func (c *Context) Contains(key string) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + _, exists := c.Values[key] + return exists +} + +// Delete removes a value from the context +func (c *Context) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.Values, key) +} + +// All returns a copy of all values in the context +func (c *Context) All() map[string]any { + c.mu.RLock() + defer c.mu.RUnlock() + + result := make(map[string]any, len(c.Values)) + for k, v := range c.Values { + result[k] = v + } + + return result +} diff --git a/core/runner/CoreModules.go b/core/runner/CoreModules.go index 52ebb29..dffcedb 100644 --- a/core/runner/CoreModules.go +++ b/core/runner/CoreModules.go @@ -9,12 +9,15 @@ import ( "git.sharkk.net/Sky/Moonshark/core/logger" ) +// StateInitFunc is a function that initializes a module in a Lua state +type StateInitFunc func(*luajit.State) error + // CoreModuleRegistry manages the initialization and reloading of core modules type CoreModuleRegistry struct { - modules map[string]StateInitFunc - initOrder []string // Explicit initialization order - dependencies map[string][]string // Module dependencies - initializedFlag map[string]bool // Track which modules are initialized + modules map[string]StateInitFunc // Module initializers + initOrder []string // Explicit initialization order + dependencies map[string][]string // Module dependencies + initializedFlag map[string]bool // Track which modules are initialized mu sync.RWMutex debug bool } @@ -46,21 +49,18 @@ func (r *CoreModuleRegistry) debugLog(format string, args ...interface{}) { func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) { r.mu.Lock() defer r.mu.Unlock() + r.modules[name] = initFunc // Add to initialization order if not already there - found := false for _, n := range r.initOrder { if n == name { - found = true - break + r.debugLog("Module already in init order: %s", name) + return } } - if !found { - r.initOrder = append(r.initOrder, name) - } - + r.initOrder = append(r.initOrder, name) r.debugLog("Registered module: %s", name) } @@ -73,18 +73,14 @@ func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc Stat r.dependencies[name] = dependencies // Add to initialization order if not already there - found := false for _, n := range r.initOrder { if n == name { - found = true - break + r.debugLog("Module already in init order: %s", name) + return } } - if !found { - r.initOrder = append(r.initOrder, name) - } - + r.initOrder = append(r.initOrder, name) r.debugLog("Registered module %s with dependencies: %v", name, dependencies) } @@ -93,17 +89,32 @@ func (r *CoreModuleRegistry) SetInitOrder(order []string) { r.mu.Lock() defer r.mu.Unlock() + // Create new init order + newOrder := make([]string, 0, len(order)) + // First add all known modules that are in the specified order for _, name := range order { if _, exists := r.modules[name]; exists { - r.initOrder = append(r.initOrder, name) + // Check for duplicates + isDuplicate := false + for _, existing := range newOrder { + if existing == name { + isDuplicate = true + break + } + } + + if !isDuplicate { + newOrder = append(newOrder, name) + } } } // Then add any modules not in the specified order for name := range r.modules { + // Check if module already in the new order found := false - for _, n := range r.initOrder { + for _, n := range newOrder { if n == name { found = true break @@ -111,10 +122,11 @@ func (r *CoreModuleRegistry) SetInitOrder(order []string) { } if !found { - r.initOrder = append(r.initOrder, name) + newOrder = append(newOrder, name) } } + r.initOrder = newOrder r.debugLog("Set initialization order: %v", r.initOrder) } @@ -162,10 +174,12 @@ func (r *CoreModuleRegistry) initializeModule(state *luajit.State, name string, // Initialize dependencies first deps := r.dependencies[name] - for _, dep := range deps { + if len(deps) > 0 { newStack := append(initStack, name) - if err := r.initializeModule(state, dep, newStack); err != nil { - return err + for _, dep := range deps { + if err := r.initializeModule(state, dep, newStack); err != nil { + return err + } } } @@ -233,7 +247,6 @@ var GlobalRegistry = NewCoreModuleRegistry() func init() { GlobalRegistry.EnableDebug() // Enable debugging by default - // Register modules GlobalRegistry.Register("go", GoModuleInitFunc()) // Register HTTP module (no dependencies) @@ -256,8 +269,7 @@ func init() { logger.Debug("[CoreModuleRegistry] Core modules registered in init()") } -// RegisterCoreModule is a helper to register a core module -// with the global registry +// RegisterCoreModule is a helper to register a core module with the global registry func RegisterCoreModule(name string, initFunc StateInitFunc) { GlobalRegistry.Register(name, initFunc) } diff --git a/core/runner/Csrf.go b/core/runner/Csrf.go index f0812e9..5d4f538 100644 --- a/core/runner/Csrf.go +++ b/core/runner/Csrf.go @@ -179,7 +179,7 @@ func ValidateCSRFToken(state *luajit.State, ctx *Context) bool { // WithCSRFProtection creates a runner option to add CSRF protection func WithCSRFProtection() RunnerOption { - return func(r *LuaRunner) { + return func(r *Runner) { r.AddInitHook(func(state *luajit.State, ctx *Context) error { // Get request method method, ok := ctx.Get("method").(string) diff --git a/core/runner/Go.go b/core/runner/Go.go index ccd2a64..b3163f3 100644 --- a/core/runner/Go.go +++ b/core/runner/Go.go @@ -55,8 +55,3 @@ func GoModuleInitFunc() StateInitFunc { return RegisterModule(state, "go", GoModuleFunctions()) } } - -// Initialize the core module during startup -func init() { - RegisterCoreModule("go", GoModuleInitFunc()) -} diff --git a/core/runner/Http.go b/core/runner/Http.go index f8e40fe..2fa6378 100644 --- a/core/runner/Http.go +++ b/core/runner/Http.go @@ -483,7 +483,7 @@ func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) { // WithHTTPClientConfig creates a runner option to configure the HTTP client func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption { - return func(r *LuaRunner) { + return func(r *Runner) { // Store the config to be applied during initialization r.AddModule("__http_client_config", map[string]any{ "max_timeout": float64(config.MaxTimeout / time.Second), diff --git a/core/runner/LuaRunner.go b/core/runner/LuaRunner.go deleted file mode 100644 index 58ba40e..0000000 --- a/core/runner/LuaRunner.go +++ /dev/null @@ -1,535 +0,0 @@ -package runner - -import ( - "context" - "errors" - "path/filepath" - "runtime" - "sync" - "sync/atomic" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" - "git.sharkk.net/Sky/Moonshark/core/logger" -) - -// Common errors -var ( - ErrRunnerClosed = errors.New("lua runner is closed") - ErrInitFailed = errors.New("initialization failed") -) - -// StateInitFunc is a function that initializes a Lua state -type StateInitFunc func(*luajit.State) error - -// RunnerOption defines a functional option for configuring the LuaRunner -type RunnerOption func(*LuaRunner) - -// JobResult represents the result of a Lua script execution -type JobResult struct { - Value any // Return value from Lua - Error error // Error if any -} - -// StateWrapper wraps a Lua state with its sandbox -type StateWrapper struct { - state *luajit.State // The Lua state - sandbox *Sandbox // Associated sandbox - index int // Index for debugging -} - -// InitHook is a function that runs before executing a script -type InitHook func(*luajit.State, *Context) error - -// FinalizeHook is a function that runs after executing a script -type FinalizeHook func(*luajit.State, *Context, any) error - -// LuaRunner runs Lua scripts using a pool of Lua states -type LuaRunner struct { - states []*StateWrapper // Pool of Lua states - stateSem chan int // Semaphore with state indexes - poolSize int // Size of the state pool - initFunc StateInitFunc // Optional function to initialize Lua states - moduleLoader *NativeModuleLoader // Native module loader for require - isRunning atomic.Bool // Flag indicating if the runner is active - mu sync.RWMutex // Mutex for thread safety - debug bool // Enable debug logging - initHooks []InitHook // Hooks to run before script execution - finalizeHooks []FinalizeHook // Hooks to run after script execution -} - -// WithPoolSize sets the state pool size -func WithPoolSize(size int) RunnerOption { - return func(r *LuaRunner) { - if size > 0 { - r.poolSize = size - } - } -} - -// WithInitFunc sets the init function for the Lua state -func WithInitFunc(initFunc StateInitFunc) RunnerOption { - return func(r *LuaRunner) { - r.initFunc = initFunc - } -} - -// WithLibDirs sets additional library directories -func WithLibDirs(dirs ...string) RunnerOption { - return func(r *LuaRunner) { - if r.moduleLoader == nil || r.moduleLoader.config == nil { - r.moduleLoader = NewNativeModuleLoader(&RequireConfig{ - LibDirs: dirs, - }) - } else { - r.moduleLoader.config.LibDirs = dirs - } - } -} - -// WithDebugEnabled enables debug output -func WithDebugEnabled() RunnerOption { - return func(r *LuaRunner) { - r.debug = true - } -} - -// NewRunner creates a new LuaRunner with a pool of states -func NewRunner(options ...RunnerOption) (*LuaRunner, error) { - // Default configuration - runner := &LuaRunner{ - poolSize: runtime.GOMAXPROCS(0), - debug: false, - initHooks: make([]InitHook, 0), - finalizeHooks: make([]FinalizeHook, 0), - } - - // Apply options - for _, opt := range options { - opt(runner) - } - - // Set up module loader if not already initialized - if runner.moduleLoader == nil { - requireConfig := &RequireConfig{ - ScriptDir: "", - LibDirs: []string{}, - } - runner.moduleLoader = NewNativeModuleLoader(requireConfig) - } - - // Initialize states and semaphore - runner.states = make([]*StateWrapper, runner.poolSize) - runner.stateSem = make(chan int, runner.poolSize) - - // Create and initialize all states - for i := 0; i < runner.poolSize; i++ { - wrapper, err := runner.initState(i) - if err != nil { - runner.Close() // Clean up already created states - return nil, err - } - - runner.states[i] = wrapper - runner.stateSem <- i // Add index to semaphore - } - - runner.isRunning.Store(true) - - return runner, nil -} - -// debugLog logs a message if debug mode is enabled -func (r *LuaRunner) debugLog(format string, args ...interface{}) { - if r.debug { - logger.Debug("[LuaRunner] "+format, args...) - } -} - -// initState creates and initializes a new state -func (r *LuaRunner) initState(index int) (*StateWrapper, error) { - r.debugLog("Initializing Lua state %d", index) - - // Create a new state - state := luajit.New() - if state == nil { - return nil, errors.New("failed to create Lua state") - } - r.debugLog("Created new Lua state %d", index) - - // Create sandbox - sandbox := NewSandbox() - if r.debug { - sandbox.EnableDebug() - } - - // Set up require paths and mechanism - if err := r.moduleLoader.SetupRequire(state); err != nil { - r.debugLog("Failed to set up require for state %d: %v", index, err) - state.Cleanup() - state.Close() - return nil, ErrInitFailed - } - r.debugLog("Require system initialized for state %d", index) - - // Initialize all core modules from the registry - if err := GlobalRegistry.Initialize(state); err != nil { - r.debugLog("Failed to initialize core modules for state %d: %v", index, err) - state.Cleanup() - state.Close() - return nil, ErrInitFailed - } - r.debugLog("Core modules initialized for state %d", index) - - // Set up sandbox after core modules are initialized - if err := sandbox.Setup(state); err != nil { - r.debugLog("Failed to set up sandbox for state %d: %v", index, err) - state.Cleanup() - state.Close() - return nil, ErrInitFailed - } - r.debugLog("Sandbox environment set up for state %d", index) - - // Preload all modules into package.loaded - if err := r.moduleLoader.PreloadAllModules(state); err != nil { - r.debugLog("Failed to preload modules for state %d: %v", index, err) - state.Cleanup() - state.Close() - return nil, errors.New("failed to preload modules") - } - r.debugLog("All modules preloaded for state %d", index) - - // Run init function if provided - if r.initFunc != nil { - if err := r.initFunc(state); err != nil { - r.debugLog("Custom init function failed for state %d: %v", index, err) - state.Cleanup() - state.Close() - return nil, ErrInitFailed - } - r.debugLog("Custom init function completed for state %d", index) - } - - r.debugLog("State %d initialization complete", index) - - return &StateWrapper{ - state: state, - sandbox: sandbox, - index: index, - }, nil -} - -// AddInitHook adds a hook to be called before script execution -func (r *LuaRunner) AddInitHook(hook InitHook) { - r.mu.Lock() - defer r.mu.Unlock() - r.initHooks = append(r.initHooks, hook) -} - -// AddFinalizeHook adds a hook to be called after script execution -func (r *LuaRunner) AddFinalizeHook(hook FinalizeHook) { - r.mu.Lock() - defer r.mu.Unlock() - r.finalizeHooks = append(r.finalizeHooks, hook) -} - -// RunWithContext executes a Lua script with context and timeout -func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) { - if !r.isRunning.Load() { - return nil, ErrRunnerClosed - } - - // Create a result channel - resultChan := make(chan JobResult, 1) - - // Get a state index with timeout - var stateIndex int - select { - case stateIndex = <-r.stateSem: - // Got a state - case <-ctx.Done(): - return nil, ctx.Err() - } - - // Launch a goroutine to execute the job - go func() { - // Make sure to return the state to the pool when done - defer func() { - // Only return if runner is still open - if r.isRunning.Load() { - select { - case r.stateSem <- stateIndex: - // State returned to pool - default: - // Pool is full or closed (shouldn't happen) - } - } - }() - - // Execute the job - var result JobResult - - r.mu.RLock() - state := r.states[stateIndex] - // Copy hooks to ensure we don't hold the lock during execution - initHooks := make([]InitHook, len(r.initHooks)) - copy(initHooks, r.initHooks) - finalizeHooks := make([]FinalizeHook, len(r.finalizeHooks)) - copy(finalizeHooks, r.finalizeHooks) - r.mu.RUnlock() - - if state == nil { - result = JobResult{nil, errors.New("state is not initialized")} - } else { - // Set script directory for module resolution - if scriptPath != "" { - r.mu.Lock() - r.moduleLoader.config.ScriptDir = filepath.Dir(scriptPath) - r.mu.Unlock() - } - - // Run init hooks - for _, hook := range initHooks { - if err := hook(state.state, execCtx); err != nil { - result = JobResult{nil, err} - // Send result and return early - select { - case resultChan <- result: - default: - } - return - } - } - - // Convert context - var ctxMap map[string]any - if execCtx != nil { - ctxMap = execCtx.Values - } - - // Execute in sandbox - value, err := state.sandbox.Execute(state.state, bytecode, ctxMap) - - // Run finalize hooks - for _, hook := range finalizeHooks { - hookErr := hook(state.state, execCtx, value) - if hookErr != nil && err == nil { - // Only override nil errors - err = hookErr - } - } - - result = JobResult{value, err} - } - - // Send result - select { - case resultChan <- result: - // Result sent - default: - // Result channel closed or full (shouldn't happen with buffered channel) - } - }() - - // Wait for result with context - select { - case result := <-resultChan: - return result.Value, result.Error - case <-ctx.Done(): - // Note: we can't cancel the Lua execution, but we can stop waiting for it - // The state will be returned to the pool when the goroutine completes - return nil, ctx.Err() - } -} - -// Run executes a Lua script -func (r *LuaRunner) Run(bytecode []byte, execCtx *Context, scriptPath string) (any, error) { - return r.RunWithContext(context.Background(), bytecode, execCtx, scriptPath) -} - -// Close gracefully shuts down the LuaRunner -func (r *LuaRunner) Close() error { - r.mu.Lock() - defer r.mu.Unlock() - - if !r.isRunning.Load() { - return ErrRunnerClosed - } - - r.isRunning.Store(false) - - // Drain the semaphore (non-blocking) - for { - select { - case <-r.stateSem: - // Drained one slot - default: - // Empty - goto drained - } - } -drained: - - // Clean up all states - for i := 0; i < len(r.states); i++ { - if r.states[i] != nil { - r.states[i].state.Cleanup() - r.states[i].state.Close() - r.states[i] = nil - } - } - - return nil -} - -// NotifyFileChanged handles file change notifications from watchers -func (r *LuaRunner) NotifyFileChanged(filePath string) bool { - r.debugLog("File change detected: %s", filePath) - - r.mu.Lock() - defer r.mu.Unlock() - - // Check if runner is closed - if !r.isRunning.Load() { - return false - } - - // Create a new semaphore - newSem := make(chan int, cap(r.stateSem)) - - // Drain the current semaphore (non-blocking) - for { - select { - case <-r.stateSem: - // Drained one slot - default: - // Empty - goto drained - } - } -drained: - - r.stateSem = newSem - - // Reinitialize all states - success := true - for i := 0; i < len(r.states); i++ { - // Clean up old state - if r.states[i] != nil { - r.states[i].state.Cleanup() - r.states[i].state.Close() - } - - // Initialize new state - wrapper, err := r.initState(i) - if err != nil { - r.debugLog("Failed to reinitialize state %d: %v", i, err) - success = false - r.states[i] = nil - continue - } - - r.states[i] = wrapper - - // Add to semaphore - select { - case newSem <- i: - // Added to semaphore - default: - // Semaphore full (shouldn't happen) - } - } - - return success -} - -// ResetModuleCache clears non-core modules from package.loaded in all states -func (r *LuaRunner) ResetModuleCache() { - if r.moduleLoader != nil { - r.debugLog("Resetting module cache in all states") - r.mu.RLock() - defer r.mu.RUnlock() - - for i := 0; i < len(r.states); i++ { - if r.states[i] != nil && r.states[i].state != nil { - r.moduleLoader.ResetModules(r.states[i].state) - } - } - } -} - -// ReloadAllModules reloads all modules into package.loaded in all states -func (r *LuaRunner) ReloadAllModules() error { - if r.moduleLoader != nil { - r.debugLog("Reloading all modules in all states") - r.mu.RLock() - defer r.mu.RUnlock() - - for i := 0; i < len(r.states); i++ { - if r.states[i] != nil && r.states[i].state != nil { - if err := r.moduleLoader.PreloadAllModules(r.states[i].state); err != nil { - return err - } - } - } - } - return nil -} - -// RefreshModuleByName invalidates a specific module in package.loaded in all states -func (r *LuaRunner) RefreshModuleByName(modName string) bool { - r.mu.RLock() - defer r.mu.RUnlock() - - success := true - for i := 0; i < len(r.states); i++ { - if r.states[i] != nil && r.states[i].state != nil { - r.debugLog("Refreshing module %s in state %d", modName, i) - if err := r.states[i].state.DoString(`package.loaded["` + modName + `"] = nil`); err != nil { - success = false - } - } - } - return success -} - -// AddModule adds a module to all sandbox environments -func (r *LuaRunner) AddModule(name string, module any) { - r.debugLog("Adding module %s to all sandboxes", name) - r.mu.RLock() - defer r.mu.RUnlock() - - for i := 0; i < len(r.states); i++ { - if r.states[i] != nil && r.states[i].sandbox != nil { - r.states[i].sandbox.AddModule(name, module) - } - } -} - -// GetModuleCount returns the number of loaded modules in the first state -func (r *LuaRunner) GetModuleCount() int { - r.mu.RLock() - defer r.mu.RUnlock() - - count := 0 - - // Get count from the first available state - for i := 0; i < len(r.states); i++ { - if r.states[i] != nil && r.states[i].state != nil { - // Execute a Lua snippet to count modules - if res, err := r.states[i].state.ExecuteWithResult(` - local count = 0 - for _ in pairs(package.loaded) do - count = count + 1 - end - return count - `); err == nil { - if num, ok := res.(float64); ok { - count = int(num) - } - } - break - } - } - - return count -} diff --git a/core/runner/ModuleLoader.go b/core/runner/ModuleLoader.go new file mode 100644 index 0000000..ba24147 --- /dev/null +++ b/core/runner/ModuleLoader.go @@ -0,0 +1,488 @@ +package runner + +import ( + "os" + "path/filepath" + "strings" + "sync" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "git.sharkk.net/Sky/Moonshark/core/logger" +) + +// ModuleConfig holds configuration for Lua's module loading system +type ModuleConfig struct { + ScriptDir string // Base directory for script being executed + LibDirs []string // Additional library directories +} + +// ModuleInfo stores information about a loaded module +type ModuleInfo struct { + Name string + Path string + IsCore bool + Bytecode []byte +} + +// ModuleLoader manages module loading and caching +type ModuleLoader struct { + config *ModuleConfig + registry *ModuleRegistry + pathCache map[string]string // Cache module paths for fast lookups + bytecodeCache map[string][]byte // Cache of compiled bytecode + debug bool + mu sync.RWMutex +} + +// ModuleRegistry keeps track of Lua modules for file watching +type ModuleRegistry struct { + // Maps file paths to module names + pathToModule sync.Map + // Maps module names to file paths + moduleToPath sync.Map +} + +// NewModuleRegistry creates a new module registry +func NewModuleRegistry() *ModuleRegistry { + return &ModuleRegistry{} +} + +// Register adds a module path to the registry +func (r *ModuleRegistry) Register(path string, name string) { + r.pathToModule.Store(path, name) + r.moduleToPath.Store(name, path) +} + +// GetModuleName retrieves a module name by path +func (r *ModuleRegistry) GetModuleName(path string) (string, bool) { + value, ok := r.pathToModule.Load(path) + if !ok { + return "", false + } + return value.(string), true +} + +// GetModulePath retrieves a path by module name +func (r *ModuleRegistry) GetModulePath(name string) (string, bool) { + value, ok := r.moduleToPath.Load(name) + if !ok { + return "", false + } + return value.(string), true +} + +// NewModuleLoader creates a new module loader +func NewModuleLoader(config *ModuleConfig) *ModuleLoader { + if config == nil { + config = &ModuleConfig{ + ScriptDir: "", + LibDirs: []string{}, + } + } + + return &ModuleLoader{ + config: config, + registry: NewModuleRegistry(), + pathCache: make(map[string]string), + bytecodeCache: make(map[string][]byte), + debug: false, + } +} + +// EnableDebug turns on debug logging +func (l *ModuleLoader) EnableDebug() { + l.debug = true +} + +// debugLog logs a message if debug is enabled +func (l *ModuleLoader) debugLog(format string, args ...interface{}) { + if l.debug { + logger.Debug("[ModuleLoader] "+format, args...) + } +} + +// SetScriptDir sets the script directory +func (l *ModuleLoader) SetScriptDir(dir string) { + l.mu.Lock() + defer l.mu.Unlock() + l.config.ScriptDir = dir +} + +// SetupRequire configures the require system in a Lua state +func (l *ModuleLoader) SetupRequire(state *luajit.State) error { + l.mu.RLock() + defer l.mu.RUnlock() + + // Initialize our module registry in Lua + err := state.DoString(` + -- Initialize global module registry + __module_paths = {} + + -- Setup fast module loading system + __module_bytecode = {} + + -- Create module preload table + package.preload = package.preload or {} + + -- Setup module state registry + __ready_modules = {} + `) + + if err != nil { + return err + } + + // Set up package.path based on search paths + paths := l.getSearchPaths() + pathStr := strings.Join(paths, ";") + escapedPathStr := escapeLuaString(pathStr) + + return state.DoString(`package.path = "` + escapedPathStr + `"`) +} + +// getSearchPaths returns a list of Lua search paths +func (l *ModuleLoader) getSearchPaths() []string { + absPaths := []string{} + seen := map[string]bool{} + + // Add script directory (highest priority) + if l.config.ScriptDir != "" { + absPath, err := filepath.Abs(l.config.ScriptDir) + if err == nil && !seen[absPath] { + absPaths = append(absPaths, filepath.Join(absPath, "?.lua")) + seen[absPath] = true + } + } + + // Add lib directories + for _, dir := range l.config.LibDirs { + if dir == "" { + continue + } + + absPath, err := filepath.Abs(dir) + if err == nil && !seen[absPath] { + absPaths = append(absPaths, filepath.Join(absPath, "?.lua")) + seen[absPath] = true + } + } + + return absPaths +} + +// PreloadModules preloads modules from library directories +func (l *ModuleLoader) PreloadModules(state *luajit.State) error { + l.mu.Lock() + defer l.mu.Unlock() + + // Reset caches + l.pathCache = make(map[string]string) + l.bytecodeCache = make(map[string][]byte) + + // Reset module registry in Lua + if err := state.DoString(` + -- Reset module registry + __module_paths = {} + __module_bytecode = {} + __ready_modules = {} + + -- Clear non-core modules from package.loaded + local core_modules = { + string = true, table = true, math = true, os = true, + package = true, io = true, coroutine = true, debug = true, _G = true + } + + for name in pairs(package.loaded) do + if not core_modules[name] then + package.loaded[name] = nil + end + end + + -- Reset preload table + package.preload = {} + `); err != nil { + return err + } + + // Scan and preload modules from all library directories + for _, dir := range l.config.LibDirs { + if dir == "" { + continue + } + + absDir, err := filepath.Abs(dir) + if err != nil { + continue + } + + // Find all Lua files + err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error { + if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") { + return nil + } + + // Get module name from path + relPath, err := filepath.Rel(absDir, path) + if err != nil || strings.HasPrefix(relPath, "..") { + return nil + } + + // Convert path to module name + modName := strings.TrimSuffix(relPath, ".lua") + modName = strings.ReplaceAll(modName, string(filepath.Separator), ".") + + // Register in our caches + l.pathCache[modName] = path + l.registry.Register(path, modName) + + // Load file content + content, err := os.ReadFile(path) + if err != nil { + return nil + } + + // Compile to bytecode + bytecode, err := state.CompileBytecode(string(content), path) + if err != nil { + return nil + } + + // Cache bytecode + l.bytecodeCache[modName] = bytecode + + // Register in Lua + escapedPath := escapeLuaString(path) + escapedName := escapeLuaString(modName) + + if err := state.DoString(`__module_paths["` + escapedName + `"] = "` + escapedPath + `"`); err != nil { + return nil + } + + // Load bytecode into Lua state + if err := state.LoadBytecode(bytecode, path); err != nil { + return nil + } + + // Add to package.preload + luaCode := ` + local modname = "` + escapedName + `" + local chunk = ... + package.preload[modname] = chunk + __ready_modules[modname] = true + ` + + if err := state.DoString(luaCode); err != nil { + state.Pop(1) // Remove chunk from stack + return nil + } + + state.Pop(1) // Remove chunk from stack + return nil + }) + + if err != nil { + return err + } + } + + // Install optimized require implementation + return state.DoString(` + -- Setup environment-aware require function + function __setup_require(env) + -- Create require function specific to this environment + env.require = function(modname) + -- Check if already loaded + if package.loaded[modname] then + return package.loaded[modname] + end + + -- Check preloaded modules + if __ready_modules[modname] then + local loader = package.preload[modname] + if loader then + -- Set environment for loader + setfenv(loader, env) + + -- Execute and store result + local result = loader() + if result == nil then + result = true + end + + package.loaded[modname] = result + return result + end + end + + -- Direct file load as fallback + if __module_paths[modname] then + local path = __module_paths[modname] + local chunk, err = loadfile(path) + if chunk then + setfenv(chunk, env) + local result = chunk() + if result == nil then + result = true + end + package.loaded[modname] = result + return result + end + end + + -- Full path search as last resort + local errors = {} + for path in package.path:gmatch("[^;]+") do + local file_path = path:gsub("?", modname:gsub("%.", "/")) + local chunk, err = loadfile(file_path) + if chunk then + setfenv(chunk, env) + local result = chunk() + if result == nil then + result = true + end + package.loaded[modname] = result + return result + end + table.insert(errors, "\tno file '" .. file_path .. "'") + end + + error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2) + end + + return env + end + `) +} + +// GetModuleByPath finds the module name for a file path +func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) { + l.mu.RLock() + defer l.mu.RUnlock() + + // Clean path for proper comparison + path = filepath.Clean(path) + + // Try direct lookup from registry + modName, found := l.registry.GetModuleName(path) + if found { + return modName, true + } + + // Try to find by relative path from lib dirs + for _, dir := range l.config.LibDirs { + absDir, err := filepath.Abs(dir) + if err != nil { + continue + } + + relPath, err := filepath.Rel(absDir, path) + if err != nil || strings.HasPrefix(relPath, "..") { + continue + } + + if strings.HasSuffix(relPath, ".lua") { + modName = strings.TrimSuffix(relPath, ".lua") + modName = strings.ReplaceAll(modName, string(filepath.Separator), ".") + return modName, true + } + } + + return "", false +} + +// ReloadModule reloads a module from disk +func (l *ModuleLoader) ReloadModule(state *luajit.State, name string) (bool, error) { + l.mu.Lock() + defer l.mu.Unlock() + + // Get module path + path, ok := l.registry.GetModulePath(name) + if !ok { + for modName, modPath := range l.pathCache { + if modName == name { + path = modPath + ok = true + break + } + } + } + + if !ok || path == "" { + return false, nil + } + + // Invalidate module in Lua + err := state.DoString(` + package.loaded["` + name + `"] = nil + __ready_modules["` + name + `"] = nil + if package.preload then + package.preload["` + name + `"] = nil + end + `) + + if err != nil { + return false, err + } + + // Check if file still exists + if _, err := os.Stat(path); os.IsNotExist(err) { + // File was deleted, just invalidate + delete(l.pathCache, name) + delete(l.bytecodeCache, name) + l.registry.moduleToPath.Delete(name) + l.registry.pathToModule.Delete(path) + return true, nil + } + + // Read updated file + content, err := os.ReadFile(path) + if err != nil { + return false, err + } + + // Compile to bytecode + bytecode, err := state.CompileBytecode(string(content), path) + if err != nil { + return false, err + } + + // Update cache + l.bytecodeCache[name] = bytecode + + // Load bytecode into state + if err := state.LoadBytecode(bytecode, path); err != nil { + return false, err + } + + // Update preload + luaCode := ` + local modname = "` + name + `" + package.loaded[modname] = nil + package.preload[modname] = ... + __ready_modules[modname] = true + ` + + if err := state.DoString(luaCode); err != nil { + state.Pop(1) // Remove chunk from stack + return false, err + } + + state.Pop(1) // Remove chunk from stack + return true, nil +} + +// ResetModules clears non-core modules from package.loaded +func (l *ModuleLoader) ResetModules(state *luajit.State) error { + return state.DoString(` + local core_modules = { + string = true, table = true, math = true, os = true, + package = true, io = true, coroutine = true, debug = true, _G = true + } + + for name in pairs(package.loaded) do + if not core_modules[name] then + package.loaded[name] = nil + end + end + `) +} diff --git a/core/runner/Modules.go b/core/runner/Modules.go index 3a66921..d84df43 100644 --- a/core/runner/Modules.go +++ b/core/runner/Modules.go @@ -2,6 +2,7 @@ package runner import ( luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "git.sharkk.net/Sky/Moonshark/core/logger" ) // ModuleFunc is a function that returns a map of module functions @@ -19,7 +20,7 @@ func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.Go // Push function if err := state.PushGoFunction(f); err != nil { - state.Pop(2) // Pop table and function name + state.Pop(1) // Pop table return err } @@ -37,6 +38,7 @@ func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc { return func(state *luajit.State) error { for name, moduleFunc := range modules { if err := RegisterModule(state, name, moduleFunc()); err != nil { + logger.Error("Failed to register module %s: %v", name, err) return err } } @@ -57,3 +59,27 @@ func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc { return nil } } + +// RegisterLuaCode registers a Lua code snippet as a module +func RegisterLuaCode(state *luajit.State, code string) error { + return state.DoString(code) +} + +// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code +func RegisterLuaCodeInitFunc(code string) StateInitFunc { + return func(state *luajit.State) error { + return RegisterLuaCode(state, code) + } +} + +// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module +func RegisterLuaModuleInitFunc(name string, code string) StateInitFunc { + return func(state *luajit.State) error { + // Create name = {} global + state.NewTable() + state.SetGlobal(name) + + // Then run the module code which will populate it + return state.DoString(code) + } +} diff --git a/core/runner/Require.go b/core/runner/Require.go index 89df7d2..dbcace1 100644 --- a/core/runner/Require.go +++ b/core/runner/Require.go @@ -22,46 +22,6 @@ type NativeModuleLoader struct { mu sync.RWMutex } -// ModuleRegistry keeps track of Lua modules for file watching -type ModuleRegistry struct { - // Maps file paths to module names - pathToModule sync.Map - // Maps module names to file paths (for direct access) - moduleToPath sync.Map -} - -// NewModuleRegistry creates a new module registry -func NewModuleRegistry() *ModuleRegistry { - return &ModuleRegistry{ - pathToModule: sync.Map{}, - moduleToPath: sync.Map{}, - } -} - -// Register adds a module path to the registry -func (r *ModuleRegistry) Register(path string, name string) { - r.pathToModule.Store(path, name) - r.moduleToPath.Store(name, path) -} - -// GetModuleName retrieves a module name by path -func (r *ModuleRegistry) GetModuleName(path string) (string, bool) { - value, ok := r.pathToModule.Load(path) - if !ok { - return "", false - } - return value.(string), true -} - -// GetModulePath retrieves a path by module name -func (r *ModuleRegistry) GetModulePath(name string) (string, bool) { - value, ok := r.moduleToPath.Load(name) - if !ok { - return "", false - } - return value.(string), true -} - // NewNativeModuleLoader creates a new native module loader func NewNativeModuleLoader(config *RequireConfig) *NativeModuleLoader { return &NativeModuleLoader{ diff --git a/core/runner/Runner.go b/core/runner/Runner.go new file mode 100644 index 0000000..ce9411c --- /dev/null +++ b/core/runner/Runner.go @@ -0,0 +1,574 @@ +package runner + +import ( + "context" + "errors" + "path/filepath" + "runtime" + "sync" + "sync/atomic" + "time" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "git.sharkk.net/Sky/Moonshark/core/logger" +) + +// Common errors +var ( + ErrRunnerClosed = errors.New("lua runner is closed") + ErrInitFailed = errors.New("initialization failed") + ErrStateNotReady = errors.New("lua state not ready") + ErrTimeout = errors.New("operation timed out") +) + +// RunnerOption defines a functional option for configuring the Runner +type RunnerOption func(*Runner) + +// State wraps a Lua state with its sandbox +type State struct { + L *luajit.State // The Lua state + sandbox *Sandbox // Associated sandbox + index int // Index for debugging + inUse bool // Whether the state is currently in use + initTime time.Time // When this state was initialized +} + +// InitHook runs before executing a script +type InitHook func(*luajit.State, *Context) error + +// FinalizeHook runs after executing a script +type FinalizeHook func(*luajit.State, *Context, any) error + +// Runner runs Lua scripts using a pool of Lua states +type Runner struct { + states []*State // All states managed by this runner + statePool chan int // Pool of available state indexes + poolSize int // Size of the state pool + moduleLoader *ModuleLoader // Module loader + isRunning atomic.Bool // Whether the runner is active + mu sync.RWMutex // Mutex for thread safety + debug bool // Enable debug logging + initHooks []InitHook // Hooks run before script execution + finalizeHooks []FinalizeHook // Hooks run after script execution + scriptDir string // Current script directory +} + +// WithPoolSize sets the state pool size +func WithPoolSize(size int) RunnerOption { + return func(r *Runner) { + if size > 0 { + r.poolSize = size + } + } +} + +// WithDebugEnabled enables debug output +func WithDebugEnabled() RunnerOption { + return func(r *Runner) { + r.debug = true + } +} + +// WithLibDirs sets additional library directories +func WithLibDirs(dirs ...string) RunnerOption { + return func(r *Runner) { + if r.moduleLoader == nil { + r.moduleLoader = NewModuleLoader(&ModuleConfig{ + LibDirs: dirs, + }) + } else { + r.moduleLoader.config.LibDirs = dirs + } + } +} + +// WithInitHook adds a hook to run before script execution +func WithInitHook(hook InitHook) RunnerOption { + return func(r *Runner) { + r.initHooks = append(r.initHooks, hook) + } +} + +// WithFinalizeHook adds a hook to run after script execution +func WithFinalizeHook(hook FinalizeHook) RunnerOption { + return func(r *Runner) { + r.finalizeHooks = append(r.finalizeHooks, hook) + } +} + +// NewRunner creates a new Runner with a pool of states +func NewRunner(options ...RunnerOption) (*Runner, error) { + // Default configuration + runner := &Runner{ + poolSize: runtime.GOMAXPROCS(0), + debug: false, + initHooks: make([]InitHook, 0, 4), + finalizeHooks: make([]FinalizeHook, 0, 4), + } + + // Apply options + for _, opt := range options { + opt(runner) + } + + // Set up module loader if not already initialized + if runner.moduleLoader == nil { + config := &ModuleConfig{ + ScriptDir: "", + LibDirs: []string{}, + } + runner.moduleLoader = NewModuleLoader(config) + } + + // Initialize states and pool + runner.states = make([]*State, runner.poolSize) + runner.statePool = make(chan int, runner.poolSize) + + // Create and initialize all states + if err := runner.initializeStates(); err != nil { + runner.Close() // Clean up already created states + return nil, err + } + + runner.isRunning.Store(true) + return runner, nil +} + +// debugLog logs a message if debug mode is enabled +func (r *Runner) debugLog(format string, args ...interface{}) { + if r.debug { + logger.Debug("[Runner] "+format, args...) + } +} + +// initializeStates creates and initializes all states in the pool +func (r *Runner) initializeStates() error { + r.debugLog("Initializing %d Lua states", r.poolSize) + + // Create main template state first + templateState, err := r.createState(0) + if err != nil { + return err + } + + r.states[0] = templateState + r.statePool <- 0 // Add index to the pool + + // Create remaining states + for i := 1; i < r.poolSize; i++ { + state, err := r.createState(i) + if err != nil { + return err + } + + r.states[i] = state + r.statePool <- i // Add index to the pool + } + + r.debugLog("All %d Lua states initialized successfully", r.poolSize) + return nil +} + +// createState initializes a new Lua state +func (r *Runner) createState(index int) (*State, error) { + r.debugLog("Creating Lua state %d", index) + + // Create a new state + L := luajit.New() + if L == nil { + return nil, errors.New("failed to create Lua state") + } + + // Create sandbox + sandbox := NewSandbox() + if r.debug { + sandbox.EnableDebug() + } + + // Set up require system + if err := r.moduleLoader.SetupRequire(L); err != nil { + r.debugLog("Failed to set up require for state %d: %v", index, err) + L.Cleanup() + L.Close() + return nil, ErrInitFailed + } + + // Initialize all core modules from the registry + if err := GlobalRegistry.Initialize(L); err != nil { + r.debugLog("Failed to initialize core modules for state %d: %v", index, err) + L.Cleanup() + L.Close() + return nil, ErrInitFailed + } + + // Set up sandbox after core modules are initialized + if err := sandbox.Setup(L); err != nil { + r.debugLog("Failed to set up sandbox for state %d: %v", index, err) + L.Cleanup() + L.Close() + return nil, ErrInitFailed + } + + // Preload all modules + if err := r.moduleLoader.PreloadModules(L); err != nil { + r.debugLog("Failed to preload modules for state %d: %v", index, err) + L.Cleanup() + L.Close() + return nil, errors.New("failed to preload modules") + } + + state := &State{ + L: L, + sandbox: sandbox, + index: index, + inUse: false, + initTime: time.Now(), + } + + r.debugLog("State %d created successfully", index) + return state, nil +} + +// Execute runs a script with context +func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) { + if !r.isRunning.Load() { + return nil, ErrRunnerClosed + } + + // Set script directory if provided + if scriptPath != "" { + r.mu.Lock() + r.scriptDir = filepath.Dir(scriptPath) + r.moduleLoader.SetScriptDir(r.scriptDir) + r.mu.Unlock() + } + + // Get a state index from the pool with timeout + var stateIndex int + select { + case stateIndex = <-r.statePool: + // Got a state + case <-ctx.Done(): + return nil, ctx.Err() + } + + // Get the actual state + r.mu.RLock() + state := r.states[stateIndex] + r.mu.RUnlock() + + if state == nil { + // This shouldn't happen, but recover gracefully + r.statePool <- stateIndex + return nil, ErrStateNotReady + } + + // Mark state as in use + state.inUse = true + + // Ensure state is returned to pool when done + defer func() { + state.inUse = false + if r.isRunning.Load() { + select { + case r.statePool <- stateIndex: + // State returned to pool + default: + // Pool is full or closed (shouldn't happen) + } + } + }() + + // Copy hooks to avoid holding lock during execution + r.mu.RLock() + initHooks := make([]InitHook, len(r.initHooks)) + copy(initHooks, r.initHooks) + finalizeHooks := make([]FinalizeHook, len(r.finalizeHooks)) + copy(finalizeHooks, r.finalizeHooks) + r.mu.RUnlock() + + // Run init hooks + for _, hook := range initHooks { + if err := hook(state.L, execCtx); err != nil { + return nil, err + } + } + + // Prepare context values + var ctxValues map[string]any + if execCtx != nil { + ctxValues = execCtx.Values + } + + // Execute in sandbox + result, err := state.sandbox.Execute(state.L, bytecode, ctxValues) + + // Run finalize hooks + for _, hook := range finalizeHooks { + if hookErr := hook(state.L, execCtx, result); hookErr != nil && err == nil { + err = hookErr + } + } + + return result, err +} + +// Run executes a Lua script (convenience wrapper) +func (r *Runner) Run(bytecode []byte, execCtx *Context, scriptPath string) (any, error) { + return r.Execute(context.Background(), bytecode, execCtx, scriptPath) +} + +// Close gracefully shuts down the Runner +func (r *Runner) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.isRunning.Load() { + return ErrRunnerClosed + } + + r.isRunning.Store(false) + r.debugLog("Closing Runner and destroying all states") + + // Drain the state pool + r.drainStatePool() + + // Clean up all states + for i, state := range r.states { + if state != nil { + state.L.Cleanup() + state.L.Close() + r.states[i] = nil + } + } + + return nil +} + +// drainStatePool removes all states from the pool +func (r *Runner) drainStatePool() { + for { + select { + case <-r.statePool: + // Drain one state + default: + // Pool is empty + return + } + } +} + +// RefreshStates rebuilds all states in the pool +func (r *Runner) RefreshStates() error { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.isRunning.Load() { + return ErrRunnerClosed + } + + r.debugLog("Refreshing all Lua states") + + // Drain all states from the pool + r.drainStatePool() + + // Destroy all existing states + for i, state := range r.states { + if state != nil { + if state.inUse { + r.debugLog("Warning: attempting to refresh state %d that is in use", i) + } + state.L.Cleanup() + state.L.Close() + r.states[i] = nil + } + } + + // Reinitialize all states + if err := r.initializeStates(); err != nil { + return err + } + + r.debugLog("All states refreshed successfully") + return nil +} + +// NotifyFileChanged handles file change notifications +func (r *Runner) NotifyFileChanged(filePath string) bool { + r.debugLog("File change detected: %s", filePath) + + // Check if it's a module file + module, isModule := r.moduleLoader.GetModuleByPath(filePath) + if isModule { + r.debugLog("File is a module: %s", module) + return r.RefreshModule(module) + } + + // For non-module files, refresh all states + if err := r.RefreshStates(); err != nil { + r.debugLog("Failed to refresh states: %v", err) + return false + } + + return true +} + +// RefreshModule refreshes a specific module across all states +func (r *Runner) RefreshModule(moduleName string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + if !r.isRunning.Load() { + return false + } + + r.debugLog("Refreshing module: %s", moduleName) + + // Check if it's a core module + coreName, isCore := GlobalRegistry.MatchModuleName(moduleName) + + success := true + for _, state := range r.states { + if state == nil { + continue + } + + // Skip states that are in use + if state.inUse { + r.debugLog("Skipping refresh for state %d (in use)", state.index) + success = false + continue + } + + // Invalidate module in Lua + if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil { + r.debugLog("Failed to invalidate module %s in state %d: %v", + moduleName, state.index, err) + success = false + continue + } + + // For core modules, reinitialize them + if isCore { + if err := GlobalRegistry.InitializeModule(state.L, coreName); err != nil { + r.debugLog("Failed to reinitialize core module %s in state %d: %v", + coreName, state.index, err) + success = false + } + } + } + + if success { + r.debugLog("Module %s refreshed successfully in all states", moduleName) + } else { + r.debugLog("Module %s refresh had some failures", moduleName) + } + + return success +} + +// AddModule adds a module to all sandbox environments +func (r *Runner) AddModule(name string, module any) { + r.debugLog("Adding module %s to all sandboxes", name) + r.mu.RLock() + defer r.mu.RUnlock() + + for _, state := range r.states { + if state != nil && state.sandbox != nil && !state.inUse { + state.sandbox.AddModule(name, module) + } + } +} + +// AddInitHook adds a hook to be called before script execution +func (r *Runner) AddInitHook(hook InitHook) { + r.mu.Lock() + defer r.mu.Unlock() + r.initHooks = append(r.initHooks, hook) +} + +// AddFinalizeHook adds a hook to be called after script execution +func (r *Runner) AddFinalizeHook(hook FinalizeHook) { + r.mu.Lock() + defer r.mu.Unlock() + r.finalizeHooks = append(r.finalizeHooks, hook) +} + +// ResetModuleCache clears the module cache in all states +func (r *Runner) ResetModuleCache() { + r.mu.RLock() + defer r.mu.RUnlock() + + if !r.isRunning.Load() { + return + } + + r.debugLog("Resetting module cache in all states") + + for _, state := range r.states { + if state != nil && !state.inUse { + r.moduleLoader.ResetModules(state.L) + } + } +} + +// GetStateCount returns the number of initialized states +func (r *Runner) GetStateCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + + count := 0 + for _, state := range r.states { + if state != nil { + count++ + } + } + + return count +} + +// GetActiveStateCount returns the number of states currently in use +func (r *Runner) GetActiveStateCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + + count := 0 + for _, state := range r.states { + if state != nil && state.inUse { + count++ + } + } + + return count +} + +// GetModuleCount returns the number of loaded modules in the first available state +func (r *Runner) GetModuleCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + + if !r.isRunning.Load() { + return 0 + } + + // Find first available state + for _, state := range r.states { + if state != nil && !state.inUse { + // Execute a Lua snippet to count modules + if res, err := state.L.ExecuteWithResult(` + local count = 0 + for _ in pairs(package.loaded) do + count = count + 1 + end + return count + `); err == nil { + if num, ok := res.(float64); ok { + return int(num) + } + } + break + } + } + + return 0 +} diff --git a/core/runner/Sandbox.go b/core/runner/Sandbox.go index a6f253c..b2b9c8a 100644 --- a/core/runner/Sandbox.go +++ b/core/runner/Sandbox.go @@ -2,64 +2,73 @@ package runner import ( "fmt" + "sync" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" "git.sharkk.net/Sky/Moonshark/core/logger" ) -// Sandbox manages a simplified Lua environment +// Sandbox provides a secure execution environment for Lua scripts type Sandbox struct { modules map[string]any // Custom modules for environment debug bool // Enable debug output + mu sync.RWMutex // Protects modules } -// NewSandbox creates a new sandbox +// NewSandbox creates a new sandbox environment func NewSandbox() *Sandbox { return &Sandbox{ - modules: make(map[string]any), + modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity debug: false, } } -// EnableDebug turns on debug output +// EnableDebug turns on debug logging func (s *Sandbox) EnableDebug() { s.debug = true } -// AddModule adds a module to the sandbox environment -func (s *Sandbox) AddModule(name string, module any) { - s.modules[name] = module -} - -// debugLog prints debug messages if debug is enabled +// debugLog logs a message if debug mode is enabled func (s *Sandbox) debugLog(format string, args ...interface{}) { if s.debug { - logger.Debug("[Sandbox Debug] "+format, args...) + logger.Debug("[Sandbox] "+format, args...) } } +// AddModule adds a module to the sandbox environment +func (s *Sandbox) AddModule(name string, module any) { + s.mu.Lock() + defer s.mu.Unlock() + + s.modules[name] = module + s.debugLog("Added module: %s", name) +} + // Setup initializes the sandbox in a Lua state func (s *Sandbox) Setup(state *luajit.State) error { s.debugLog("Setting up sandbox environment") // Register modules in the global environment + s.mu.RLock() for name, module := range s.modules { s.debugLog("Registering module: %s", name) if err := state.PushValue(module); err != nil { + s.mu.RUnlock() s.debugLog("Failed to register module %s: %v", name, err) return err } state.SetGlobal(name) } + s.mu.RUnlock() - // Initialize simple environment setup + // Initialize environment setup err := state.DoString(` -- Global tables for response handling __http_responses = __http_responses or {} - -- Simple environment creation + -- Create environment inheriting from _G function __create_env(ctx) - -- Create environment inheriting from _G + -- Create environment with metatable inheriting from _G local env = setmetatable({}, {__index = _G}) -- Add context if provided @@ -67,6 +76,11 @@ func (s *Sandbox) Setup(state *luajit.State) error { env.ctx = ctx end + -- Add proper require function to this environment + if __setup_require then + __setup_require(env) + end + return env end @@ -97,24 +111,6 @@ func (s *Sandbox) Setup(state *luajit.State) error { } s.debugLog("Sandbox setup complete") - - // Verify HTTP module is accessible - httpResult, _ := state.ExecuteWithResult(` - if type(http) == "table" and - type(http.client) == "table" and - type(http.client.get) == "function" then - return "HTTP module verified OK" - else - local status = { - http = type(http), - client = type(http) == "table" and type(http.client) or "N/A", - get = type(http) == "table" and type(http.client) == "table" and type(http.client.get) or "N/A" - } - return status - end - `) - - s.debugLog("HTTP verification result: %v", httpResult) return nil } @@ -123,7 +119,7 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a // Load bytecode if err := state.LoadBytecode(bytecode, "script"); err != nil { s.debugLog("Failed to load bytecode: %v", err) - return nil, err + return nil, fmt.Errorf("failed to load script: %w", err) } // Prepare context @@ -132,9 +128,9 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a for k, v := range ctx { state.PushString(k) if err := state.PushValue(v); err != nil { - state.Pop(2) + state.Pop(2) // Pop key and table s.debugLog("Failed to push context value %s: %v", k, err) - return nil, err + return nil, fmt.Errorf("failed to prepare context: %w", err) } state.SetTable(-3) } @@ -145,32 +141,33 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a // Get execution function state.GetGlobal("__execute_script") if !state.IsFunction(-1) { - state.Pop(2) // Pop nil and non-function + state.Pop(2) // Pop context and non-function s.debugLog("__execute_script is not a function") return nil, fmt.Errorf("sandbox execution function not found") } - // Push arguments - state.PushCopy(-3) // bytecode function - state.PushCopy(-3) // context + // Stack setup for call: __execute_script, bytecode function, context + state.PushCopy(-3) // bytecode function (copy from -3) + state.PushCopy(-3) // context (copy from -3) - // Clean up stack - state.Remove(-5) // original bytecode - state.Remove(-4) // original context + // Clean up duplicate references + state.Remove(-5) // Remove original bytecode function + state.Remove(-4) // Remove original context - // Call with 2 args, 1 result + // Call with 2 args (function, context), 1 result if err := state.Call(2, 1); err != nil { s.debugLog("Execution failed: %v", err) - return nil, err + return nil, fmt.Errorf("script execution failed: %w", err) } // Get result result, err := state.ToValue(-1) - state.Pop(1) + state.Pop(1) // Pop result // Check for HTTP response httpResponse, hasResponse := GetHTTPResponse(state) if hasResponse { + // Add the script result as the response body httpResponse.Body = result return httpResponse, nil } diff --git a/core/runner/SessionHandler.go b/core/runner/SessionHandler.go index d8468c1..dd035b1 100644 --- a/core/runner/SessionHandler.go +++ b/core/runner/SessionHandler.go @@ -36,7 +36,7 @@ func (h *SessionHandler) debug(format string, args ...interface{}) { // WithSessionManager creates a RunnerOption to add session support func WithSessionManager(manager *sessions.SessionManager) RunnerOption { - return func(r *LuaRunner) { + return func(r *Runner) { handler := NewSessionHandler(manager) // Register the session module diff --git a/core/watchers/Api.go b/core/watchers/Api.go index 178a6b0..39bf109 100644 --- a/core/watchers/Api.go +++ b/core/watchers/Api.go @@ -58,7 +58,7 @@ func (w *Watcher) Close() error { // WatchLuaRouter sets up a watcher for a LuaRouter's routes directory; also updates // the LuaRunner so that the state can be rebuilt -func WatchLuaRouter(router *routers.LuaRouter, runner *runner.LuaRunner, routesDir string) (*Watcher, error) { +func WatchLuaRouter(router *routers.LuaRouter, runner *runner.Runner, routesDir string) (*Watcher, error) { manager := GetWatcherManager(true) runnerRefresh := func() error { @@ -104,7 +104,7 @@ func WatchStaticRouter(router *routers.StaticRouter, staticDir string) (*Watcher } // WatchLuaModules sets up watchers for Lua module directories -func WatchLuaModules(luaRunner *runner.LuaRunner, libDirs []string) ([]*Watcher, error) { +func WatchLuaModules(luaRunner *runner.Runner, libDirs []string) ([]*Watcher, error) { manager := GetWatcherManager(true) watchers := make([]*Watcher, 0, len(libDirs)) @@ -115,8 +115,7 @@ func WatchLuaModules(luaRunner *runner.LuaRunner, libDirs []string) ([]*Watcher, callback := func() error { logger.Debug("Detected changes in Lua module directory: %s", dirCopy) - // Reload modules from this directory - if err := luaRunner.ReloadAllModules(); err != nil { + if err := luaRunner.RefreshStates(); err != nil { logger.Warning("Error reloading modules: %v", err) }