From 8e511c5dc98ecf0c13e227260cd16d0b1e866c1b Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Tue, 8 Apr 2025 22:10:50 -0500 Subject: [PATCH] runner cleanup 2 --- core/runner/Context.go | 41 +-- core/runner/Cookies.go | 77 ------ core/runner/CoreModules.go | 97 +++---- core/runner/GoModules.go | 86 ------ core/runner/Runner.go | 480 ++++++++++++--------------------- core/runner/Sessions.go | 47 +--- core/runner/sandbox/Http.go | 87 +++--- core/runner/sandbox/Modules.go | 44 ++- 8 files changed, 272 insertions(+), 687 deletions(-) delete mode 100644 core/runner/Cookies.go delete mode 100644 core/runner/GoModules.go diff --git a/core/runner/Context.go b/core/runner/Context.go index 9586030..e916761 100644 --- a/core/runner/Context.go +++ b/core/runner/Context.go @@ -3,8 +3,6 @@ package runner import ( "sync" - "maps" - "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) @@ -14,9 +12,6 @@ type Context struct { // Values stores any context values (route params, HTTP request info, etc.) Values map[string]any - // internal mutex for concurrent access - mu sync.RWMutex - // FastHTTP context if this was created from an HTTP request RequestCtx *fasthttp.RequestCtx @@ -28,7 +23,7 @@ type Context struct { var contextPool = sync.Pool{ New: func() any { return &Context{ - Values: make(map[string]any, 16), // Pre-allocate with reasonable capacity + Values: make(map[string]any, 16), } }, } @@ -47,9 +42,6 @@ func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context { // Release returns the context to the pool after clearing its values func (c *Context) Release() { - c.mu.Lock() - defer c.mu.Unlock() - // Clear all values to prevent data leakage for k := range c.Values { delete(c.Values, k) @@ -69,9 +61,6 @@ func (c *Context) Release() { // GetBuffer returns a byte buffer for efficient string operations func (c *Context) GetBuffer() *bytebufferpool.ByteBuffer { - c.mu.Lock() - defer c.mu.Unlock() - if c.buffer == nil { c.buffer = bytebufferpool.Get() } @@ -80,49 +69,21 @@ func (c *Context) GetBuffer() *bytebufferpool.ByteBuffer { // Set adds a value to the context func (c *Context) Set(key string, value any) { - c.mu.Lock() - defer c.mu.Unlock() - c.Values[key] = value } // Get retrieves a value from the context func (c *Context) Get(key string) any { - c.mu.RLock() - defer c.mu.RUnlock() - return c.Values[key] } // Contains checks if a key exists in the context func (c *Context) Contains(key string) bool { - c.mu.RLock() - defer c.mu.RUnlock() - _, exists := c.Values[key] return exists } // Delete removes a value from the context func (c *Context) Delete(key string) { - c.mu.Lock() - defer c.mu.Unlock() - delete(c.Values, key) } - -// All returns a copy of all values in the context -func (c *Context) All() map[string]any { - c.mu.RLock() - defer c.mu.RUnlock() - - result := make(map[string]any, len(c.Values)) - maps.Copy(result, c.Values) - - return result -} - -// IsHTTPRequest returns true if this context contains a fasthttp RequestCtx -func (c *Context) IsHTTPRequest() bool { - return c.RequestCtx != nil -} diff --git a/core/runner/Cookies.go b/core/runner/Cookies.go deleted file mode 100644 index b1258e4..0000000 --- a/core/runner/Cookies.go +++ /dev/null @@ -1,77 +0,0 @@ -package runner - -import ( - "time" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" - "github.com/valyala/fasthttp" -) - -// extractCookie grabs cookies from the Lua state -func extractCookie(state *luajit.State) *fasthttp.Cookie { - cookie := fasthttp.AcquireCookie() - - // Get name - state.GetField(-1, "name") - if !state.IsString(-1) { - state.Pop(1) - fasthttp.ReleaseCookie(cookie) - return nil // Name is required - } - cookie.SetKey(state.ToString(-1)) - state.Pop(1) - - // Get value - state.GetField(-1, "value") - if state.IsString(-1) { - cookie.SetValue(state.ToString(-1)) - } - state.Pop(1) - - // Get path - state.GetField(-1, "path") - if state.IsString(-1) { - cookie.SetPath(state.ToString(-1)) - } else { - cookie.SetPath("/") // Default path - } - state.Pop(1) - - // Get domain - state.GetField(-1, "domain") - if state.IsString(-1) { - cookie.SetDomain(state.ToString(-1)) - } - state.Pop(1) - - // Get expires - state.GetField(-1, "expires") - if state.IsNumber(-1) { - expiry := int64(state.ToNumber(-1)) - cookie.SetExpire(time.Unix(expiry, 0)) - } - state.Pop(1) - - // Get max age - state.GetField(-1, "max_age") - if state.IsNumber(-1) { - cookie.SetMaxAge(int(state.ToNumber(-1))) - } - state.Pop(1) - - // Get secure - state.GetField(-1, "secure") - if state.IsBoolean(-1) { - cookie.SetSecure(state.ToBoolean(-1)) - } - state.Pop(1) - - // Get http only - state.GetField(-1, "http_only") - if state.IsBoolean(-1) { - cookie.SetHTTPOnly(state.ToBoolean(-1)) - } - state.Pop(1) - - return cookie -} diff --git a/core/runner/CoreModules.go b/core/runner/CoreModules.go index 98f18dc..21d9ed0 100644 --- a/core/runner/CoreModules.go +++ b/core/runner/CoreModules.go @@ -10,15 +10,12 @@ import ( luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) -// StateInitFunc is a function that initializes a module in a Lua state -type StateInitFunc func(*luajit.State) error - // CoreModuleRegistry manages the initialization and reloading of core modules type CoreModuleRegistry struct { - modules map[string]StateInitFunc // Module initializers - initOrder []string // Explicit initialization order - dependencies map[string][]string // Module dependencies - initializedFlag map[string]bool // Track which modules are initialized + modules map[string]sandbox.StateInitFunc // Module initializers + initOrder []string // Explicit initialization order + dependencies map[string][]string // Module dependencies + initializedFlag map[string]bool // Track which modules are initialized mu sync.RWMutex debug bool } @@ -26,7 +23,7 @@ type CoreModuleRegistry struct { // NewCoreModuleRegistry creates a new core module registry func NewCoreModuleRegistry() *CoreModuleRegistry { return &CoreModuleRegistry{ - modules: make(map[string]StateInitFunc), + modules: make(map[string]sandbox.StateInitFunc), initOrder: []string{}, dependencies: make(map[string][]string), initializedFlag: make(map[string]bool), @@ -46,15 +43,8 @@ func (r *CoreModuleRegistry) debugLog(format string, args ...interface{}) { } } -// debugLogCont prints continuation debug messages if enabled -func (r *CoreModuleRegistry) debugLogCont(format string, args ...interface{}) { - if r.debug { - logger.DebugCont(format, args...) - } -} - // Register adds a module to the registry -func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) { +func (r *CoreModuleRegistry) Register(name string, initFunc sandbox.StateInitFunc) { r.mu.Lock() defer r.mu.Unlock() @@ -63,7 +53,7 @@ func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) { // Add to initialization order if not already there for _, n := range r.initOrder { if n == name { - return // Already registered, silently continue + return // Already registered } } @@ -72,7 +62,7 @@ func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) { } // RegisterWithDependencies registers a module with explicit dependencies -func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc StateInitFunc, dependencies []string) { +func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) { r.mu.Lock() defer r.mu.Unlock() @@ -82,15 +72,12 @@ func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc Stat // Add to initialization order if not already there for _, n := range r.initOrder { if n == name { - return // Already registered, silently continue + return // Already registered } } r.initOrder = append(r.initOrder, name) - r.debugLog("registered module %s", name) - if len(dependencies) > 0 { - r.debugLogCont("Dependencies: %v", dependencies) - } + r.debugLog("registered module %s with dependencies: %v", name, dependencies) } // SetInitOrder sets explicit initialization order @@ -103,34 +90,14 @@ func (r *CoreModuleRegistry) SetInitOrder(order []string) { // First add all known modules that are in the specified order for _, name := range order { - if _, exists := r.modules[name]; exists { - // Check for duplicates - isDuplicate := false - for _, existing := range newOrder { - if existing == name { - isDuplicate = true - break - } - } - - if !isDuplicate { - newOrder = append(newOrder, name) - } + if _, exists := r.modules[name]; exists && !contains(newOrder, name) { + newOrder = append(newOrder, name) } } // Then add any modules not in the specified order for name := range r.modules { - // Check if module already in the new order - found := false - for _, n := range newOrder { - if n == name { - found = true - break - } - } - - if !found { + if !contains(newOrder, name) { newOrder = append(newOrder, name) } } @@ -160,7 +127,7 @@ func (r *CoreModuleRegistry) Initialize(state *luajit.State, stateIndex int) err } if verbose { - r.debugLogCont("All modules initialized successfully") + r.debugLog("All modules initialized successfully") } return nil } @@ -201,14 +168,14 @@ func (r *CoreModuleRegistry) initializeModule(state *luajit.State, name string, err := initFunc(state) if err != nil { // Always log failures regardless of verbose setting - r.debugLogCont("Initializing module %s... failure: %v", name, err) + r.debugLog("Initializing module %s... failure: %v", name, err) return fmt.Errorf("failed to initialize module %s: %w", name, err) } r.initializedFlag[name] = true if verbose { - r.debugLogCont("Initializing module %s... success", name) + r.debugLog("Initializing module %s... success", name) } return nil @@ -226,18 +193,6 @@ func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string) return r.initializeModule(state, name, []string{}, true) } -// ModuleNames returns a list of all registered module names -func (r *CoreModuleRegistry) ModuleNames() []string { - r.mu.RLock() - defer r.mu.RUnlock() - - names := make([]string, 0, len(r.modules)) - for name := range r.modules { - names = append(names, name) - } - return names -} - // MatchModuleName checks if a file path corresponds to a registered module func (r *CoreModuleRegistry) MatchModuleName(modName string) (string, bool) { r.mu.RLock() @@ -266,7 +221,7 @@ func init() { GlobalRegistry.EnableDebug() // Enable debugging by default logger.Debug("[ModuleRegistry] Registering core modules...") - // Register core modules - these now point to the sandbox implementations + // Register core modules GlobalRegistry.Register("util", func(state *luajit.State) error { return sandbox.UtilModuleInitFunc()(state) }) @@ -283,15 +238,25 @@ func init() { "csrf", // Fourth: CSRF protection }) - logger.DebugCont("Core modules registered successfully") + logger.Debug("Core modules registered successfully") } -// RegisterCoreModule is a helper to register a core module with the global registry -func RegisterCoreModule(name string, initFunc StateInitFunc) { +// RegisterCoreModule registers a core module with the global registry +func RegisterCoreModule(name string, initFunc sandbox.StateInitFunc) { GlobalRegistry.Register(name, initFunc) } // RegisterCoreModuleWithDependencies registers a module with dependencies -func RegisterCoreModuleWithDependencies(name string, initFunc StateInitFunc, dependencies []string) { +func RegisterCoreModuleWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) { GlobalRegistry.RegisterWithDependencies(name, initFunc, dependencies) } + +// Helper functions +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} diff --git a/core/runner/GoModules.go b/core/runner/GoModules.go deleted file mode 100644 index daa1df7..0000000 --- a/core/runner/GoModules.go +++ /dev/null @@ -1,86 +0,0 @@ -package runner - -import ( - "Moonshark/core/utils/logger" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// ModuleFunc is a function that returns a map of module functions -type ModuleFunc func() map[string]luajit.GoFunction - -// RegisterModule registers a map of functions as a Lua module -func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error { - // Create a new table for the module - state.NewTable() - - // Add each function to the module table - for fname, f := range funcs { - // Push function name - state.PushString(fname) - - // Push function - if err := state.PushGoFunction(f); err != nil { - state.Pop(1) // Pop table - return err - } - - // Set table[fname] = f - state.SetTable(-3) - } - - // Register the module globally - state.SetGlobal(name) - return nil -} - -// ModuleInitFunc creates a state initializer that registers multiple modules -func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc { - return func(state *luajit.State) error { - for name, moduleFunc := range modules { - if err := RegisterModule(state, name, moduleFunc()); err != nil { - logger.Error("Failed to register module %s: %v", name, err) - return err - } - } - return nil - } -} - -// CombineInitFuncs combines multiple state initializer functions into one -func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc { - return func(state *luajit.State) error { - for _, f := range funcs { - if f != nil { - if err := f(state); err != nil { - return err - } - } - } - return nil - } -} - -// RegisterLuaCode registers a Lua code snippet as a module -func RegisterLuaCode(state *luajit.State, code string) error { - return state.DoString(code) -} - -// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code -func RegisterLuaCodeInitFunc(code string) StateInitFunc { - return func(state *luajit.State) error { - return RegisterLuaCode(state, code) - } -} - -// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module -func RegisterLuaModuleInitFunc(name string, code string) StateInitFunc { - return func(state *luajit.State) error { - // Create name = {} global - state.NewTable() - state.SetGlobal(name) - - // Then run the module code which will populate it - return state.DoString(code) - } -} diff --git a/core/runner/Runner.go b/core/runner/Runner.go index de6a647..0c5008d 100644 --- a/core/runner/Runner.go +++ b/core/runner/Runner.go @@ -3,15 +3,13 @@ package runner import ( "context" "errors" + "fmt" "path/filepath" "runtime" "sync" "sync/atomic" "time" - "github.com/panjf2000/ants/v2" - "github.com/valyala/bytebufferpool" - "Moonshark/core/runner/sandbox" "Moonshark/core/utils/logger" @@ -31,11 +29,10 @@ type RunnerOption func(*Runner) // State wraps a Lua state with its sandbox type State struct { - L *luajit.State // The Lua state - sandbox *sandbox.Sandbox // Associated sandbox - index int // Index for debugging - inUse bool // Whether the state is currently in use - initTime time.Time // When this state was initialized + L *luajit.State // The Lua state + sandbox *sandbox.Sandbox // Associated sandbox + index int // Index for debugging + inUse bool // Whether the state is currently in use } // InitHook runs before executing a script @@ -44,20 +41,6 @@ type InitHook func(*luajit.State, *Context) error // FinalizeHook runs after executing a script type FinalizeHook func(*luajit.State, *Context, any) error -// ExecuteTask represents a task in the execution goroutine pool -type ExecuteTask struct { - bytecode []byte - context *Context - scriptPath string - result chan<- taskResult -} - -// taskResult holds the result of an execution task -type taskResult struct { - value any - err error -} - // Runner runs Lua scripts using a pool of Lua states type Runner struct { states []*State // All states managed by this runner @@ -70,7 +53,6 @@ type Runner struct { initHooks []InitHook // Hooks run before script execution finalizeHooks []FinalizeHook // Hooks run after script execution scriptDir string // Current script directory - pool *ants.Pool // Goroutine pool for task execution } // WithPoolSize sets the state pool size @@ -144,13 +126,6 @@ func NewRunner(options ...RunnerOption) (*Runner, error) { runner.states = make([]*State, runner.poolSize) runner.statePool = make(chan int, runner.poolSize) - // Create ants goroutine pool - var err error - runner.pool, err = ants.NewPool(runner.poolSize * 2) - if err != nil { - return nil, err - } - // Create and initialize all states if err := runner.initializeStates(); err != nil { runner.Close() // Clean up already created states @@ -168,28 +143,12 @@ func (r *Runner) debugLog(format string, args ...interface{}) { } } -func (r *Runner) debugLogCont(format string, args ...interface{}) { - if r.debug { - logger.DebugCont(format, args...) - } -} - // initializeStates creates and initializes all states in the pool func (r *Runner) initializeStates() error { r.debugLog("is initializing %d states", r.poolSize) - // Create main template state first with full logging - templateState, err := r.createState(0) - if err != nil { - return err - } - - r.states[0] = templateState - r.statePool <- 0 // Add index to the pool - - // Create remaining states with minimal logging - successCount := 1 - for i := 1; i < r.poolSize; i++ { + // Create all states + for i := 0; i < r.poolSize; i++ { state, err := r.createState(i) if err != nil { return err @@ -197,10 +156,8 @@ func (r *Runner) initializeStates() error { r.states[i] = state r.statePool <- i // Add index to the pool - successCount++ } - r.debugLog("has built %d/%d states successfully", successCount, r.poolSize) return nil } @@ -218,16 +175,13 @@ func (r *Runner) createState(index int) (*State, error) { } // Create sandbox - sandbox := sandbox.NewSandbox() + sb := sandbox.NewSandbox() if r.debug && verbose { - sandbox.EnableDebug() + sb.EnableDebug() } // Set up require system if err := r.moduleLoader.SetupRequire(L); err != nil { - if verbose { - r.debugLogCont("Failed to set up require for state %d: %v", index, err) - } L.Cleanup() L.Close() return nil, ErrInitFailed @@ -235,19 +189,13 @@ func (r *Runner) createState(index int) (*State, error) { // Initialize all core modules from the registry if err := GlobalRegistry.Initialize(L, index); err != nil { - if verbose { - r.debugLogCont("Failed to initialize core modules for state %d: %v", index, err) - } L.Cleanup() L.Close() return nil, ErrInitFailed } // Set up sandbox after core modules are initialized - if err := sandbox.Setup(L, index); err != nil { - if verbose { - r.debugLogCont("Failed to set up sandbox for state %d: %v", index, err) - } + if err := sb.Setup(L, index); err != nil { L.Cleanup() L.Close() return nil, ErrInitFailed @@ -255,63 +203,49 @@ func (r *Runner) createState(index int) (*State, error) { // Preload all modules if err := r.moduleLoader.PreloadModules(L); err != nil { - if verbose { - r.debugLogCont("Failed to preload modules for state %d: %v", index, err) - } L.Cleanup() L.Close() return nil, errors.New("failed to preload modules") } - state := &State{ - L: L, - sandbox: sandbox, - index: index, - inUse: false, - initTime: time.Now(), - } - - if verbose { - r.debugLog("State %d created successfully", index) - } - return state, nil + return &State{ + L: L, + sandbox: sb, + index: index, + inUse: false, + }, nil } -// executeTask is the worker function for the ants pool -func (r *Runner) executeTask(i interface{}) { - task, ok := i.(*ExecuteTask) - if !ok { - return +// Execute runs a script with context +func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) { + if !r.isRunning.Load() { + return nil, ErrRunnerClosed } // Set script directory if provided - if task.scriptPath != "" { + if scriptPath != "" { r.mu.Lock() - r.scriptDir = filepath.Dir(task.scriptPath) + r.scriptDir = filepath.Dir(scriptPath) r.moduleLoader.SetScriptDir(r.scriptDir) r.mu.Unlock() } - // Get a state index from the pool + // Get a state from the pool var stateIndex int select { case stateIndex = <-r.statePool: // Got a state - case <-time.After(5 * time.Second): // 5-second timeout - // Timed out waiting for a state - task.result <- taskResult{nil, errors.New("server busy - timed out waiting for a Lua state")} - return + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(5 * time.Second): + return nil, ErrTimeout } // Get the actual state - r.mu.RLock() state := r.states[stateIndex] - r.mu.RUnlock() - if state == nil { r.statePool <- stateIndex - task.result <- taskResult{nil, ErrStateNotReady} - return + return nil, ErrStateNotReady } // Mark state as in use @@ -325,75 +259,69 @@ func (r *Runner) executeTask(i interface{}) { case r.statePool <- stateIndex: // State returned to pool default: - // Pool is full or closed (shouldn't happen) + // Pool is full or closed } } }() - // Copy hooks to avoid holding lock during execution - r.mu.RLock() - initHooks := make([]InitHook, len(r.initHooks)) - copy(initHooks, r.initHooks) - finalizeHooks := make([]FinalizeHook, len(r.finalizeHooks)) - copy(finalizeHooks, r.finalizeHooks) - r.mu.RUnlock() - // Run init hooks - for _, hook := range initHooks { - if err := hook(state.L, task.context); err != nil { - task.result <- taskResult{nil, err} - return + for _, hook := range r.initHooks { + if err := hook(state.L, execCtx); err != nil { + return nil, err } } - // Prepare context values + // Get context values var ctxValues map[string]any - if task.context != nil { - ctxValues = task.context.Values + if execCtx != nil { + ctxValues = execCtx.Values } // Execute in sandbox - result, err := state.sandbox.Execute(state.L, task.bytecode, ctxValues) - - // Run finalize hooks - for _, hook := range finalizeHooks { - if hookErr := hook(state.L, task.context, result); hookErr != nil && err == nil { - err = hookErr - } - } - - task.result <- taskResult{result, err} -} - -// Execute runs a script with context -func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) { - if !r.isRunning.Load() { - return nil, ErrRunnerClosed - } - - // Create result channel - resultChan := make(chan taskResult, 1) - - // Create task - task := &ExecuteTask{ - bytecode: bytecode, - context: execCtx, - scriptPath: scriptPath, - result: resultChan, - } - - // Submit task to pool - if err := r.pool.Submit(func() { r.executeTask(task) }); err != nil { + result, err := state.sandbox.Execute(state.L, bytecode, ctxValues) + if err != nil { return nil, err } - // Wait for result with context timeout - select { - case result := <-resultChan: - return result.value, result.err - case <-ctx.Done(): - return nil, ctx.Err() + // Run finalize hooks + for _, hook := range r.finalizeHooks { + if hookErr := hook(state.L, execCtx, result); hookErr != nil { + return nil, hookErr + } } + + // Check for HTTP response + httpResp, hasResponse := sandbox.GetHTTPResponse(state.L) + if hasResponse { + // Set result as body if not already set + if httpResp.Body == nil { + httpResp.Body = result + } + + // Apply directly to request context if available + if execCtx != nil && execCtx.RequestCtx != nil { + sandbox.ApplyHTTPResponse(httpResp, execCtx.RequestCtx) + sandbox.ReleaseResponse(httpResp) + return nil, nil + } + + return httpResp, nil + } + + // Handle direct result if FastHTTP context is available + if execCtx != nil && execCtx.RequestCtx != nil && result != nil { + switch r := result.(type) { + case string: + execCtx.RequestCtx.SetBodyString(r) + case []byte: + execCtx.RequestCtx.SetBody(r) + default: + execCtx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r)) + } + return nil, nil + } + + return result, err } // Run executes a Lua script (convenience wrapper) @@ -411,14 +339,19 @@ func (r *Runner) Close() error { } r.isRunning.Store(false) - r.debugLog("Closing Runner and destroying all states") - - // Shut down goroutine pool - r.pool.Release() // Drain the state pool - r.drainStatePool() + for { + select { + case <-r.statePool: + // Drain one state + default: + // Pool is empty + goto cleanup + } + } +cleanup: // Clean up all states for i, state := range r.states { if state != nil { @@ -431,19 +364,6 @@ func (r *Runner) Close() error { return nil } -// drainStatePool removes all states from the pool -func (r *Runner) drainStatePool() { - for { - select { - case <-r.statePool: - // Drain one state - default: - // Pool is empty - return - } - } -} - // RefreshStates rebuilds all states in the pool func (r *Runner) RefreshStates() error { r.mu.Lock() @@ -453,11 +373,18 @@ func (r *Runner) RefreshStates() error { return ErrRunnerClosed } - r.debugLog("Refreshing all Lua states") - // Drain all states from the pool - r.drainStatePool() + for { + select { + case <-r.statePool: + // Drain one state + default: + // Pool is empty + goto cleanup + } + } +cleanup: // Destroy all existing states for i, state := range r.states { if state != nil { @@ -479,7 +406,82 @@ func (r *Runner) RefreshStates() error { return nil } -// NotifyFileChanged handles file change notifications +// AddInitHook adds a hook to be called before script execution +func (r *Runner) AddInitHook(hook InitHook) { + r.mu.Lock() + defer r.mu.Unlock() + r.initHooks = append(r.initHooks, hook) +} + +// AddFinalizeHook adds a hook to be called after script execution +func (r *Runner) AddFinalizeHook(hook FinalizeHook) { + r.mu.Lock() + defer r.mu.Unlock() + r.finalizeHooks = append(r.finalizeHooks, hook) +} + +// GetStateCount returns the number of initialized states +func (r *Runner) GetStateCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + + count := 0 + for _, state := range r.states { + if state != nil { + count++ + } + } + + return count +} + +// GetActiveStateCount returns the number of states currently in use +func (r *Runner) GetActiveStateCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + + count := 0 + for _, state := range r.states { + if state != nil && state.inUse { + count++ + } + } + + return count +} + +// GetModuleCount returns the number of loaded modules in the first available state +func (r *Runner) GetModuleCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + + if !r.isRunning.Load() { + return 0 + } + + // Find first available state + for _, state := range r.states { + if state != nil && !state.inUse { + // Execute a Lua snippet to count modules + if res, err := state.L.ExecuteWithResult(` + local count = 0 + for _ in pairs(package.loaded) do + count = count + 1 + end + return count + `); err == nil { + if num, ok := res.(float64); ok { + return int(num) + } + } + break + } + } + + return 0 +} + +// NotifyFileChanged alerts the runner about file changes func (r *Runner) NotifyFileChanged(filePath string) bool { r.debugLog("File change detected: %s", filePath) @@ -515,21 +517,12 @@ func (r *Runner) RefreshModule(moduleName string) bool { success := true for _, state := range r.states { - if state == nil { - continue - } - - // Skip states that are in use - if state.inUse { - r.debugLog("Skipping refresh for state %d (in use)", state.index) - success = false + if state == nil || state.inUse { continue } // Invalidate module in Lua if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil { - r.debugLog("Failed to invalidate module %s in state %d: %v", - moduleName, state.index, err) success = false continue } @@ -537,139 +530,10 @@ func (r *Runner) RefreshModule(moduleName string) bool { // For core modules, reinitialize them if isCore { if err := GlobalRegistry.InitializeModule(state.L, coreName); err != nil { - r.debugLog("Failed to reinitialize core module %s in state %d: %v", - coreName, state.index, err) success = false } } } - if success { - r.debugLog("Module %s refreshed successfully in all states", moduleName) - } else { - r.debugLog("Module %s refresh had some failures", moduleName) - } - return success } - -// AddModule adds a module to all sandbox environments -func (r *Runner) AddModule(name string, module any) { - r.debugLog("Adding module %s to all sandboxes", name) - r.mu.RLock() - defer r.mu.RUnlock() - - for _, state := range r.states { - if state != nil && state.sandbox != nil && !state.inUse { - state.sandbox.AddModule(name, module) - } - } -} - -// AddInitHook adds a hook to be called before script execution -func (r *Runner) AddInitHook(hook InitHook) { - r.mu.Lock() - defer r.mu.Unlock() - r.initHooks = append(r.initHooks, hook) -} - -// AddFinalizeHook adds a hook to be called after script execution -func (r *Runner) AddFinalizeHook(hook FinalizeHook) { - r.mu.Lock() - defer r.mu.Unlock() - r.finalizeHooks = append(r.finalizeHooks, hook) -} - -// ResetModuleCache clears the module cache in all states -func (r *Runner) ResetModuleCache() { - r.mu.RLock() - defer r.mu.RUnlock() - - if !r.isRunning.Load() { - return - } - - r.debugLog("Resetting module cache in all states") - - for _, state := range r.states { - if state != nil && !state.inUse { - r.moduleLoader.ResetModules(state.L) - } - } -} - -// GetStateCount returns the number of initialized states -func (r *Runner) GetStateCount() int { - r.mu.RLock() - defer r.mu.RUnlock() - - count := 0 - for _, state := range r.states { - if state != nil { - count++ - } - } - - return count -} - -// GetActiveStateCount returns the number of states currently in use -func (r *Runner) GetActiveStateCount() int { - r.mu.RLock() - defer r.mu.RUnlock() - - count := 0 - for _, state := range r.states { - if state != nil && state.inUse { - count++ - } - } - - return count -} - -// GetWorkerPoolStats returns statistics about the worker pool -func (r *Runner) GetWorkerPoolStats() (running, capacity int) { - return r.pool.Running(), r.pool.Cap() -} - -// GetModuleCount returns the number of loaded modules in the first available state -func (r *Runner) GetModuleCount() int { - r.mu.RLock() - defer r.mu.RUnlock() - - if !r.isRunning.Load() { - return 0 - } - - // Find first available state - for _, state := range r.states { - if state != nil && !state.inUse { - // Execute a Lua snippet to count modules - if res, err := state.L.ExecuteWithResult(` - local count = 0 - for _ in pairs(package.loaded) do - count = count + 1 - end - return count - `); err == nil { - if num, ok := res.(float64); ok { - return int(num) - } - } - break - } - } - - return 0 -} - -// GetBufferPool returns a buffer from the bytebufferpool -func GetBufferPool() *bytebufferpool.ByteBuffer { - return bytebufferpool.Get() -} - -// ReleaseBufferPool returns a buffer to the bytebufferpool -func ReleaseBufferPool(buf *bytebufferpool.ByteBuffer) { - bytebufferpool.Put(buf) -} diff --git a/core/runner/Sessions.go b/core/runner/Sessions.go index e73de0e..00bd55f 100644 --- a/core/runner/Sessions.go +++ b/core/runner/Sessions.go @@ -1,13 +1,12 @@ package runner import ( - "github.com/valyala/fasthttp" - "Moonshark/core/runner/sandbox" "Moonshark/core/sessions" "Moonshark/core/utils/logger" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "github.com/valyala/fasthttp" ) // SessionHandler handles session management for Lua scripts @@ -29,30 +28,17 @@ func (h *SessionHandler) EnableDebug() { h.debugLog = true } -// debug logs a message if debug is enabled -func (h *SessionHandler) debug(format string, args ...interface{}) { - if h.debugLog { - logger.Debug("[SessionHandler] "+format, args...) - } -} - // WithSessionManager creates a RunnerOption to add session support func WithSessionManager(manager *sessions.SessionManager) RunnerOption { return func(r *Runner) { handler := NewSessionHandler(manager) - - // Add hooks to the runner r.AddInitHook(handler.preRequestHook) r.AddFinalizeHook(handler.postRequestHook) } } -// preRequestHook is called before executing a request +// preRequestHook initializes session before script execution func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error { - h.debug("Running pre-request session hook") - - // Check if we have cookie information in context - // Instead of raw request, we now look for the cookie map if ctx == nil || ctx.Values["_request_cookies"] == nil { return nil } @@ -71,43 +57,31 @@ func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error if cookieValue, exists := cookies[cookieName]; exists { if strValue, ok := cookieValue.(string); ok && strValue != "" { sessionID = strValue - h.debug("Found existing session ID: %s", sessionID) } } - // If no session ID found, create new session + // Create new session if needed if sessionID == "" { - // Create a new session session := h.manager.CreateSession() sessionID = session.ID - h.debug("Created new session with ID: %s", sessionID) } - // Store the session ID in the context for later use + // Store the session ID in the context ctx.Set("_session_id", sessionID) - // Get the session data + // Get session data session := h.manager.GetSession(sessionID) sessionData := session.GetAll() // Set session data in Lua state - if err := SetSessionData(state, sessionID, sessionData); err != nil { - h.debug("Failed to set session data: %v", err) - return err - } - - h.debug("Session data initialized successfully") - return nil + return SetSessionData(state, sessionID, sessionData) } -// postRequestHook is called after executing a request +// postRequestHook handles session after script execution func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, result any) error { - h.debug("Running post-request session hook") - // Check if session was modified modifiedID, modifiedData, modified := GetSessionData(state) if !modified { - h.debug("Session not modified, skipping") return nil } @@ -125,12 +99,9 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu } if modifiedID == "" { - h.debug("No session ID found, cannot persist session data") return nil } - h.debug("Persisting modified session data for ID: %s", modifiedID) - // Update session in manager session := h.manager.GetSession(modifiedID) session.Clear() // clear to sync deleted values @@ -145,7 +116,6 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu h.addSessionCookie(httpResp, modifiedID) } - h.debug("Session data persisted successfully") return nil } @@ -158,13 +128,10 @@ func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID cookieName := opts["name"].(string) for _, cookie := range resp.Cookies { if string(cookie.Key()) == cookieName { - h.debug("Session cookie already set in response") return } } - h.debug("Adding session cookie to response") - // Create and add cookie cookie := fasthttp.AcquireCookie() cookie.SetKey(cookieName) diff --git a/core/runner/sandbox/Http.go b/core/runner/sandbox/Http.go index c207f70..18bc767 100644 --- a/core/runner/sandbox/Http.go +++ b/core/runner/sandbox/Http.go @@ -31,8 +31,8 @@ var responsePool = sync.Pool{ New: func() any { return &HTTPResponse{ Status: 200, - Headers: make(map[string]string, 8), // Pre-allocate with reasonable capacity - Cookies: make([]*fasthttp.Cookie, 0, 4), // Pre-allocate with reasonable capacity + Headers: make(map[string]string, 8), + Cookies: make([]*fasthttp.Cookie, 0, 4), } }, } @@ -48,14 +48,10 @@ var defaultFastClient fasthttp.Client = fasthttp.Client{ // HTTPClientConfig contains client settings type HTTPClientConfig struct { - // Maximum timeout for requests (0 = no limit) - MaxTimeout time.Duration - // Default request timeout - DefaultTimeout time.Duration - // Maximum response size in bytes (0 = no limit) - MaxResponseSize int64 - // Whether to allow remote connections - AllowRemote bool + MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit) + DefaultTimeout time.Duration // Default request timeout + MaxResponseSize int64 // Maximum response size in bytes (0 = no limit) + AllowRemote bool // Whether to allow remote connections } // DefaultHTTPClientConfig provides sensible defaults @@ -66,12 +62,12 @@ var DefaultHTTPClientConfig = HTTPClientConfig{ AllowRemote: true, } -// NewHTTPResponse creates a default HTTP response, potentially reusing one from the pool +// NewHTTPResponse creates a default HTTP response from pool func NewHTTPResponse() *HTTPResponse { return responsePool.Get().(*HTTPResponse) } -// ReleaseResponse returns the response to the pool after clearing its values +// ReleaseResponse returns the response to the pool func ReleaseResponse(resp *HTTPResponse) { if resp == nil { return @@ -99,8 +95,7 @@ func HTTPModuleInitFunc() func(*luajit.State) error { return func(state *luajit.State) error { // Register the native Go function first if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil { - logger.Error("[HTTP Module] Failed to register __http_request function") - logger.ErrorCont("%v", err) + logger.Error("[HTTP Module] Failed to register __http_request function: %v", err) return err } @@ -111,7 +106,7 @@ func HTTPModuleInitFunc() func(*luajit.State) error { } } -// Helper to set up HTTP client config +// setupHTTPClientConfig configures HTTP client in Lua func setupHTTPClientConfig(state *luajit.State) { state.NewTable() @@ -138,7 +133,7 @@ func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) { state.GetGlobal("__http_responses") if state.IsNil(-1) { state.Pop(1) - ReleaseResponse(response) // Return unused response to pool + ReleaseResponse(response) return nil, false } @@ -147,7 +142,7 @@ func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) { state.GetTable(-2) if state.IsNil(-1) { state.Pop(2) - ReleaseResponse(response) // Return unused response to pool + ReleaseResponse(response) return nil, false } @@ -340,39 +335,37 @@ func httpRequest(state *luajit.State) int { return -1 } - // Get client configuration from registry (if available) + // Get client configuration var config HTTPClientConfig = DefaultHTTPClientConfig state.GetGlobal("__http_client_config") - if !state.IsNil(-1) { - if state.IsTable(-1) { - // Extract max timeout - state.GetField(-1, "max_timeout") - if state.IsNumber(-1) { - config.MaxTimeout = time.Duration(state.ToNumber(-1)) * time.Second - } - state.Pop(1) - - // Extract default timeout - state.GetField(-1, "default_timeout") - if state.IsNumber(-1) { - config.DefaultTimeout = time.Duration(state.ToNumber(-1)) * time.Second - } - state.Pop(1) - - // Extract max response size - state.GetField(-1, "max_response_size") - if state.IsNumber(-1) { - config.MaxResponseSize = int64(state.ToNumber(-1)) - } - state.Pop(1) - - // Extract allow remote - state.GetField(-1, "allow_remote") - if state.IsBoolean(-1) { - config.AllowRemote = state.ToBoolean(-1) - } - state.Pop(1) + if !state.IsNil(-1) && state.IsTable(-1) { + // Extract max timeout + state.GetField(-1, "max_timeout") + if state.IsNumber(-1) { + config.MaxTimeout = time.Duration(state.ToNumber(-1)) * time.Second } + state.Pop(1) + + // Extract default timeout + state.GetField(-1, "default_timeout") + if state.IsNumber(-1) { + config.DefaultTimeout = time.Duration(state.ToNumber(-1)) * time.Second + } + state.Pop(1) + + // Extract max response size + state.GetField(-1, "max_response_size") + if state.IsNumber(-1) { + config.MaxResponseSize = int64(state.ToNumber(-1)) + } + state.Pop(1) + + // Extract allow remote + state.GetField(-1, "allow_remote") + if state.IsBoolean(-1) { + config.AllowRemote = state.ToBoolean(-1) + } + state.Pop(1) } state.Pop(1) diff --git a/core/runner/sandbox/Modules.go b/core/runner/sandbox/Modules.go index c9a9764..1641402 100644 --- a/core/runner/sandbox/Modules.go +++ b/core/runner/sandbox/Modules.go @@ -6,9 +6,12 @@ import ( luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) -// ModuleFunc is a function that returns a map of module functions +// ModuleFunc returns a map of module functions type ModuleFunc func() map[string]luajit.GoFunction +// StateInitFunc initializes a module in a Lua state +type StateInitFunc func(*luajit.State) error + // RegisterModule registers a map of functions as a Lua module func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error { // Create a new table for the module @@ -16,16 +19,11 @@ func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.Go // Add each function to the module table for fname, f := range funcs { - // Push function name state.PushString(fname) - - // Push function if err := state.PushGoFunction(f); err != nil { state.Pop(1) // Pop table return err } - - // Set table[fname] = f state.SetTable(-3) } @@ -34,8 +32,21 @@ func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.Go return nil } +// ModuleInitFunc creates a state initializer that registers multiple modules +func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc { + return func(state *luajit.State) error { + for name, moduleFunc := range modules { + if err := RegisterModule(state, name, moduleFunc()); err != nil { + logger.Error("Failed to register module %s: %v", name, err) + return err + } + } + return nil + } +} + // CombineInitFuncs combines multiple state initializer functions into one -func CombineInitFuncs(funcs ...func(*luajit.State) error) func(*luajit.State) error { +func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc { return func(state *luajit.State) error { for _, f := range funcs { if f != nil { @@ -48,33 +59,20 @@ func CombineInitFuncs(funcs ...func(*luajit.State) error) func(*luajit.State) er } } -// ModuleInitFunc creates a state initializer that registers multiple modules -func ModuleInitFunc(modules map[string]ModuleFunc) func(*luajit.State) error { - return func(state *luajit.State) error { - for name, moduleFunc := range modules { - if err := RegisterModule(state, name, moduleFunc()); err != nil { - logger.Error("Failed to register module %s: %v", name, err) - return err - } - } - return nil - } -} - -// RegisterLuaCode registers a Lua code snippet as a module +// RegisterLuaCode registers a Lua code snippet in a state func RegisterLuaCode(state *luajit.State, code string) error { return state.DoString(code) } // RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code -func RegisterLuaCodeInitFunc(code string) func(*luajit.State) error { +func RegisterLuaCodeInitFunc(code string) StateInitFunc { return func(state *luajit.State) error { return RegisterLuaCode(state, code) } } // RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module -func RegisterLuaModuleInitFunc(name string, code string) func(*luajit.State) error { +func RegisterLuaModuleInitFunc(name string, code string) StateInitFunc { return func(state *luajit.State) error { // Create name = {} global state.NewTable()