Compare commits

..

No commits in common. "d6feb408ce8e52b1cca103d499f68c2d8ba0c7e6" and "08a532f11a7dded00453926cb292617772270d7c" have entirely different histories.

11 changed files with 578 additions and 711 deletions

View File

@ -18,12 +18,12 @@ const (
) )
// LogRequest logs an HTTP request with custom formatting // LogRequest logs an HTTP request with custom formatting
func LogRequest(statusCode int, r *http.Request, duration time.Duration) { func LogRequest(log *logger.Logger, statusCode int, r *http.Request, duration time.Duration) {
statusColor := getStatusColor(statusCode) statusColor := getStatusColor(statusCode)
// Use the logger's raw message writer to bypass the standard format // Use the logger's raw message writer to bypass the standard format
logger.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s", log.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s",
colorGray, time.Now().Format(logger.TimeFormat()), colorReset, colorGray, time.Now().Format(log.TimeFormat()), colorReset,
statusColor, statusCode, r.Method, colorReset, r.URL.Path, colorGray, duration, colorReset) statusColor, statusCode, r.Method, colorReset, r.URL.Path, colorGray, duration, colorReset)
} }

View File

@ -19,6 +19,7 @@ type Server struct {
luaRouter *routers.LuaRouter luaRouter *routers.LuaRouter
staticRouter *routers.StaticRouter staticRouter *routers.StaticRouter
luaRunner *runner.LuaRunner luaRunner *runner.LuaRunner
logger *logger.Logger
httpServer *http.Server httpServer *http.Server
loggingEnabled bool loggingEnabled bool
debugMode bool // Controls whether to show error details debugMode bool // Controls whether to show error details
@ -27,13 +28,15 @@ type Server struct {
} }
// New creates a new HTTP server with optimized connection settings // New creates a new HTTP server with optimized connection settings
func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runner *runner.LuaRunner, func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter,
runner *runner.LuaRunner, log *logger.Logger,
loggingEnabled bool, debugMode bool, overrideDir string, config *config.Config) *Server { loggingEnabled bool, debugMode bool, overrideDir string, config *config.Config) *Server {
server := &Server{ server := &Server{
luaRouter: luaRouter, luaRouter: luaRouter,
staticRouter: staticRouter, staticRouter: staticRouter,
luaRunner: runner, luaRunner: runner,
logger: log,
httpServer: &http.Server{}, httpServer: &http.Server{},
loggingEnabled: loggingEnabled, loggingEnabled: loggingEnabled,
debugMode: debugMode, debugMode: debugMode,
@ -60,13 +63,13 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
// ListenAndServe starts the server on the given address // ListenAndServe starts the server on the given address
func (s *Server) ListenAndServe(addr string) error { func (s *Server) ListenAndServe(addr string) error {
s.httpServer.Addr = addr s.httpServer.Addr = addr
logger.Info("Server listening at http://localhost%s", addr) s.logger.Info("Server listening at http://localhost%s", addr)
return s.httpServer.ListenAndServe() return s.httpServer.ListenAndServe()
} }
// Shutdown gracefully shuts down the server // Shutdown gracefully shuts down the server
func (s *Server) Shutdown(ctx context.Context) error { func (s *Server) Shutdown(ctx context.Context) error {
logger.Info("Server shutting down...") s.logger.Info("Server shutting down...")
return s.httpServer.Shutdown(ctx) return s.httpServer.Shutdown(ctx)
} }
@ -81,7 +84,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Calculate and log request duration // Calculate and log request duration
duration := time.Since(start) duration := time.Since(start)
if s.loggingEnabled { if s.loggingEnabled {
LogRequest(http.StatusOK, r, duration) LogRequest(s.logger, http.StatusOK, r, duration)
} }
return return
} }
@ -100,13 +103,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Log the request with our custom format // Log the request with our custom format
if s.loggingEnabled { if s.loggingEnabled {
LogRequest(statusCode, r, duration) LogRequest(s.logger, statusCode, r, duration)
} }
} }
// handleRequest processes the actual request // handleRequest processes the actual request
func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
logger.Debug("Processing request %s %s", r.Method, r.URL.Path) s.logger.Debug("Processing request %s %s", r.Method, r.URL.Path)
// Try Lua routes first // Try Lua routes first
params := &routers.Params{} params := &routers.Params{}
@ -123,7 +126,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
errorMsg = node.Error.Error() errorMsg = node.Error.Error()
} }
logger.Error("%s %s - %s", r.Method, r.URL.Path, errorMsg) s.logger.Error("%s %s - %s", r.Method, r.URL.Path, errorMsg)
// Show error page with the actual error message // Show error page with the actual error message
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
@ -132,7 +135,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(errorHTML)) w.Write([]byte(errorHTML))
return return
} else if found { } else if found {
logger.Debug("Found Lua route match for %s %s with %d params", r.Method, r.URL.Path, params.Count) s.logger.Debug("Found Lua route match for %s %s with %d params", r.Method, r.URL.Path, params.Count)
s.handleLuaRoute(w, r, bytecode, scriptPath, params) s.handleLuaRoute(w, r, bytecode, scriptPath, params)
return return
} }
@ -162,7 +165,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
defer ctx.Release() defer ctx.Release()
// Log bytecode size // Log bytecode size
logger.Debug("Executing Lua route with %d bytes of bytecode", len(bytecode)) s.logger.Debug("Executing Lua route with %d bytes of bytecode", len(bytecode))
// Add request info directly to context // Add request info directly to context
ctx.Set("method", r.Method) ctx.Set("method", r.Method)
@ -216,7 +219,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
// Execute Lua script // Execute Lua script
result, err := s.luaRunner.Run(bytecode, ctx, scriptPath) result, err := s.luaRunner.Run(bytecode, ctx, scriptPath)
if err != nil { if err != nil {
logger.Error("Error executing Lua route: %v", err) s.logger.Error("Error executing Lua route: %v", err)
// Set content type to HTML // Set content type to HTML
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
@ -229,7 +232,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
return return
} }
writeResponse(w, result) writeResponse(w, result, s.logger)
} }
// Content types for responses // Content types for responses
@ -239,7 +242,7 @@ const (
) )
// writeResponse writes the Lua result to the HTTP response // writeResponse writes the Lua result to the HTTP response
func writeResponse(w http.ResponseWriter, result any) { func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
if result == nil { if result == nil {
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
return return
@ -280,7 +283,7 @@ func writeResponse(w http.ResponseWriter, result any) {
setContentTypeIfMissing(w, contentTypeJSON) setContentTypeIfMissing(w, contentTypeJSON)
data, err := json.Marshal(res) data, err := json.Marshal(res)
if err != nil { if err != nil {
logger.Error("Failed to marshal response: %v", err) log.Error("Failed to marshal response: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return return
} }

View File

@ -54,12 +54,6 @@ const (
defaultRateLimitTime = 10 * time.Second // How long to pause during rate limiting defaultRateLimitTime = 10 * time.Second // How long to pause during rate limiting
) )
// Single global logger instance with mutex for safe initialization
var (
globalLogger *Logger
globalLoggerOnce sync.Once
)
// Logger handles logging operations // Logger handles logging operations
type Logger struct { type Logger struct {
writer io.Writer writer io.Writer
@ -67,7 +61,6 @@ type Logger struct {
useColors bool useColors bool
timeFormat string timeFormat string
mu sync.Mutex // Mutex for thread-safe writing mu sync.Mutex // Mutex for thread-safe writing
debugMode atomic.Bool // Force debug logging regardless of level
// Simple rate limiting // Simple rate limiting
logCount atomic.Int64 // Number of logs in current window logCount atomic.Int64 // Number of logs in current window
@ -78,23 +71,8 @@ type Logger struct {
limitDuration time.Duration // How long to pause logging when rate limited limitDuration time.Duration // How long to pause logging when rate limited
} }
// GetLogger returns the global logger instance, creating it if needed // New creates a new logger
func GetLogger() *Logger { func New(minLevel int, useColors bool) *Logger {
globalLoggerOnce.Do(func() {
globalLogger = newLogger(LevelInfo, true)
})
return globalLogger
}
// InitGlobalLogger initializes the global logger with custom settings
func InitGlobalLogger(minLevel int, useColors bool) {
globalLoggerOnce.Do(func() {
globalLogger = newLogger(minLevel, useColors)
})
}
// newLogger creates a new logger instance (internal use)
func newLogger(minLevel int, useColors bool) *Logger {
logger := &Logger{ logger := &Logger{
writer: os.Stdout, writer: os.Stdout,
level: minLevel, level: minLevel,
@ -110,11 +88,6 @@ func newLogger(minLevel int, useColors bool) *Logger {
return logger return logger
} }
// New creates a new logger (deprecated - use GetLogger() instead)
func New(minLevel int, useColors bool) *Logger {
return newLogger(minLevel, useColors)
}
// resetCounters resets the rate limiting counters // resetCounters resets the rate limiting counters
func (l *Logger) resetCounters() { func (l *Logger) resetCounters() {
l.logCount.Store(0) l.logCount.Store(0)
@ -155,21 +128,6 @@ func (l *Logger) DisableColors() {
l.useColors = false l.useColors = false
} }
// EnableDebug forces debug logs to be shown regardless of level
func (l *Logger) EnableDebug() {
l.debugMode.Store(true)
}
// DisableDebug stops forcing debug logs
func (l *Logger) DisableDebug() {
l.debugMode.Store(false)
}
// IsDebugEnabled returns true if debug mode is enabled
func (l *Logger) IsDebugEnabled() bool {
return l.debugMode.Load()
}
// writeMessage writes a formatted log message directly to the writer // writeMessage writes a formatted log message directly to the writer
func (l *Logger) writeMessage(level int, message string, rawMode bool) { func (l *Logger) writeMessage(level int, message string, rawMode bool) {
var logLine string var logLine string
@ -259,9 +217,8 @@ func (l *Logger) checkRateLimit(level int) bool {
// log handles the core logging logic with level filtering // log handles the core logging logic with level filtering
func (l *Logger) log(level int, format string, args ...any) { func (l *Logger) log(level int, format string, args ...any) {
// Check if we should log this message // First check normal level filtering
// Either level is high enough OR (it's a debug message AND debug mode is enabled) if level < l.level {
if level < l.level && !(level == LevelDebug && l.debugMode.Load()) {
return return
} }
@ -369,69 +326,50 @@ func (l *Logger) Server(format string, args ...any) {
l.log(LevelServer, format, args...) l.log(LevelServer, format, args...)
} }
// Global helper functions that use the global logger // Default global logger
var defaultLogger = New(LevelInfo, true)
// Debug logs a debug message to the global logger // Debug logs a debug message to the default logger
func Debug(format string, args ...any) { func Debug(format string, args ...any) {
GetLogger().Debug(format, args...) defaultLogger.Debug(format, args...)
} }
// Info logs an informational message to the global logger // Info logs an informational message to the default logger
func Info(format string, args ...any) { func Info(format string, args ...any) {
GetLogger().Info(format, args...) defaultLogger.Info(format, args...)
} }
// Warning logs a warning message to the global logger // Warning logs a warning message to the default logger
func Warning(format string, args ...any) { func Warning(format string, args ...any) {
GetLogger().Warning(format, args...) defaultLogger.Warning(format, args...)
} }
// Error logs an error message to the global logger // Error logs an error message to the default logger
func Error(format string, args ...any) { func Error(format string, args ...any) {
GetLogger().Error(format, args...) defaultLogger.Error(format, args...)
} }
// Fatal logs a fatal error message to the global logger and exits // Fatal logs a fatal error message to the default logger and exits
func Fatal(format string, args ...any) { func Fatal(format string, args ...any) {
GetLogger().Fatal(format, args...) defaultLogger.Fatal(format, args...)
} }
// Server logs a server message to the global logger // Server logs a server message to the default logger
func Server(format string, args ...any) { func Server(format string, args ...any) {
GetLogger().Server(format, args...) defaultLogger.Server(format, args...)
} }
// LogRaw logs a raw message to the global logger // LogRaw logs a raw message to the default logger
func LogRaw(format string, args ...any) { func LogRaw(format string, args ...any) {
GetLogger().LogRaw(format, args...) defaultLogger.LogRaw(format, args...)
} }
// SetLevel changes the minimum log level of the global logger // SetLevel changes the minimum log level of the default logger
func SetLevel(level int) { func SetLevel(level int) {
GetLogger().SetLevel(level) defaultLogger.SetLevel(level)
} }
// SetOutput changes the output destination of the global logger // SetOutput changes the output destination of the default logger
func SetOutput(w io.Writer) { func SetOutput(w io.Writer) {
GetLogger().SetOutput(w) defaultLogger.SetOutput(w)
}
// TimeFormat returns the current time format of the global logger
func TimeFormat() string {
return GetLogger().TimeFormat()
}
// EnableDebug enables debug messages regardless of log level
func EnableDebug() {
GetLogger().EnableDebug()
}
// DisableDebug disables forced debug messages
func DisableDebug() {
GetLogger().DisableDebug()
}
// IsDebugEnabled returns true if debug mode is enabled
func IsDebugEnabled() bool {
return GetLogger().IsDebugEnabled()
} }

View File

@ -1,38 +1,22 @@
package runner package runner
import ( import (
"fmt"
"strings" "strings"
"sync" "sync"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"git.sharkk.net/Sky/Moonshark/core/logger"
) )
// CoreModuleRegistry manages the initialization and reloading of core modules // CoreModuleRegistry manages the initialization and reloading of core modules
type CoreModuleRegistry struct { type CoreModuleRegistry struct {
modules map[string]StateInitFunc modules map[string]StateInitFunc
mu sync.RWMutex mu sync.RWMutex
debug bool
} }
// NewCoreModuleRegistry creates a new core module registry // NewCoreModuleRegistry creates a new core module registry
func NewCoreModuleRegistry() *CoreModuleRegistry { func NewCoreModuleRegistry() *CoreModuleRegistry {
return &CoreModuleRegistry{ return &CoreModuleRegistry{
modules: make(map[string]StateInitFunc), modules: make(map[string]StateInitFunc),
debug: false,
}
}
// EnableDebug turns on debug logging
func (r *CoreModuleRegistry) EnableDebug() {
r.debug = true
}
// debugLog prints debug messages if enabled
func (r *CoreModuleRegistry) debugLog(format string, args ...interface{}) {
if r.debug {
logger.Debug("[CoreModuleRegistry] "+format, args...)
} }
} }
@ -41,7 +25,6 @@ func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
r.modules[name] = initFunc r.modules[name] = initFunc
r.debugLog("Registered module: %s", name)
} }
// Initialize initializes all registered modules // Initialize initializes all registered modules
@ -49,30 +32,16 @@ func (r *CoreModuleRegistry) Initialize(state *luajit.State) error {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
r.debugLog("Initializing all modules...") // Convert to StateInitFunc
initFunc := CombineInitFuncs(r.getInitFuncs()...)
// Get all module init functions return initFunc(state)
initFuncs := r.getInitFuncs()
// Initialize modules one by one to better track issues
for name, initFunc := range initFuncs {
r.debugLog("Initializing module: %s", name)
if err := initFunc(state); err != nil {
r.debugLog("Failed to initialize module %s: %v", name, err)
return fmt.Errorf("failed to initialize module %s: %w", name, err)
}
r.debugLog("Module %s initialized successfully", name)
}
r.debugLog("All modules initialized successfully")
return nil
} }
// getInitFuncs returns all module init functions // getInitFuncs returns all module init functions
func (r *CoreModuleRegistry) getInitFuncs() map[string]StateInitFunc { func (r *CoreModuleRegistry) getInitFuncs() []StateInitFunc {
funcs := make(map[string]StateInitFunc, len(r.modules)) funcs := make([]StateInitFunc, 0, len(r.modules))
for name, initFunc := range r.modules { for _, initFunc := range r.modules {
funcs[name] = initFunc funcs = append(funcs, initFunc)
} }
return funcs return funcs
} }
@ -84,11 +53,9 @@ func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string)
initFunc, ok := r.modules[name] initFunc, ok := r.modules[name]
if !ok { if !ok {
r.debugLog("Module not found: %s", name)
return nil // Module not found, no error return nil // Module not found, no error
} }
r.debugLog("Reinitializing module: %s", name)
return initFunc(state) return initFunc(state)
} }
@ -129,10 +96,8 @@ var GlobalRegistry = NewCoreModuleRegistry()
// Initialize global registry with core modules // Initialize global registry with core modules
func init() { func init() {
GlobalRegistry.EnableDebug() // Enable debugging by default
GlobalRegistry.Register("http", HTTPModuleInitFunc()) GlobalRegistry.Register("http", HTTPModuleInitFunc())
GlobalRegistry.Register("cookie", CookieModuleInitFunc()) GlobalRegistry.Register("cookie", CookieModuleInitFunc())
logger.Debug("[CoreModuleRegistry] Core modules registered in init()")
} }
// RegisterCoreModule is a helper to register a core module // RegisterCoreModule is a helper to register a core module
@ -140,3 +105,6 @@ func init() {
func RegisterCoreModule(name string, initFunc StateInitFunc) { func RegisterCoreModule(name string, initFunc StateInitFunc) {
GlobalRegistry.Register(name, initFunc) GlobalRegistry.Register(name, initFunc)
} }
// To add a new module, simply call:
// RegisterCoreModule("new_module_name", NewModuleInitFunc())

View File

@ -12,7 +12,6 @@ import (
"time" "time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"git.sharkk.net/Sky/Moonshark/core/logger"
) )
// HTTPResponse represents an HTTP response from Lua // HTTPResponse represents an HTTP response from Lua
@ -365,34 +364,190 @@ func httpRequest(state *luajit.State) int {
return 1 return 1
} }
// HTTPModuleInitFunc returns an initializer function for the HTTP module // LuaHTTPModule is the pure Lua implementation of the HTTP module
func HTTPModuleInitFunc() StateInitFunc { const LuaHTTPModule = `
return func(state *luajit.State) error { -- Table to store response data
// CRITICAL: Register the native Go function first __http_responses = {}
// This must be done BEFORE any Lua code that references it
if err := state.RegisterGoFunction(httpRequestFuncName, httpRequest); err != nil {
logger.Error("[HTTP Module] Failed to register __http_request function: %v\n", err)
return err
}
// Set up default HTTP client configuration -- HTTP module implementation
setupHTTPClientConfig(state) local http = {
-- Set HTTP status code
set_status = function(code)
if type(code) ~= "number" then
error("http.set_status: status code must be a number", 2)
end
// Initialize Lua HTTP module local resp = __http_responses[1] or {}
if err := state.DoString(LuaHTTPModule); err != nil { resp.status = code
logger.Error("[HTTP Module] Failed to initialize HTTP module Lua code: %v\n", err) __http_responses[1] = resp
return err end,
}
// Verify HTTP client functions are available -- Set HTTP header
verifyHTTPClient(state) set_header = function(name, value)
if type(name) ~= "string" or type(value) ~= "string" then
error("http.set_header: name and value must be strings", 2)
end
return nil local resp = __http_responses[1] or {}
resp.headers = resp.headers or {}
resp.headers[name] = value
__http_responses[1] = resp
end,
-- Set content type; set_header helper
set_content_type = function(content_type)
http.set_header("Content-Type", content_type)
end,
-- HTTP client submodule
client = {
-- Generic request function
request = function(method, url, body, options)
if type(method) ~= "string" then
error("http.client.request: method must be a string", 2)
end
if type(url) ~= "string" then
error("http.client.request: url must be a string", 2)
end
-- Call native implementation
return __http_request(method, url, body, options)
end,
-- Simple GET request
get = function(url, options)
return http.client.request("GET", url, nil, options)
end,
-- Simple POST request with automatic content-type
post = function(url, body, options)
options = options or {}
return http.client.request("POST", url, body, options)
end,
-- Simple PUT request with automatic content-type
put = function(url, body, options)
options = options or {}
return http.client.request("PUT", url, body, options)
end,
-- Simple DELETE request
delete = function(url, options)
return http.client.request("DELETE", url, nil, options)
end,
-- Simple PATCH request
patch = function(url, body, options)
options = options or {}
return http.client.request("PATCH", url, body, options)
end,
-- Simple HEAD request
head = function(url, options)
options = options or {}
local old_options = options
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
local response = http.client.request("HEAD", url, nil, options)
return response
end,
-- Simple OPTIONS request
options = function(url, options)
return http.client.request("OPTIONS", url, nil, options)
end,
-- Shorthand function to directly get JSON
get_json = function(url, options)
options = options or {}
local response = http.client.get(url, options)
if response.ok and response.json then
return response.json
end
return nil, response
end,
-- Utility to build a URL with query parameters
build_url = function(base_url, params)
if not params or type(params) ~= "table" then
return base_url
end
local query = {}
for k, v in pairs(params) do
if type(v) == "table" then
for _, item in ipairs(v) do
table.insert(query, k .. "=" .. tostring(item))
end
else
table.insert(query, k .. "=" .. tostring(v))
end
end
if #query > 0 then
if base_url:find("?") then
return base_url .. "&" .. table.concat(query, "&")
else
return base_url .. "?" .. table.concat(query, "&")
end
end
return base_url
end
} }
} }
// Helper to set up HTTP client config -- Install HTTP module
func setupHTTPClientConfig(state *luajit.State) { _G.http = http
-- Override sandbox executor to clear HTTP responses
local old_execute_sandbox = __execute_sandbox
__execute_sandbox = function(bytecode, ctx)
-- Clear previous response for this thread
__http_responses[1] = nil
-- Execute the original function
local result = old_execute_sandbox(bytecode, ctx)
-- Return the result unchanged
return result
end
-- Make sure the HTTP module is accessible in sandbox
if __env_system and __env_system.base_env then
__env_system.base_env.http = http
end
`
// HTTPModuleInitFunc returns an initializer function for the HTTP module
func HTTPModuleInitFunc() StateInitFunc {
return func(state *luajit.State) error {
// The important fix: register the Go function directly to the global environment
if err := state.RegisterGoFunction(httpRequestFuncName, httpRequest); err != nil {
return err
}
// Initialize pure Lua HTTP module
if err := state.DoString(LuaHTTPModule); err != nil {
return err
}
// Check for existing config (in sandbox modules)
state.GetGlobal("__sandbox_modules")
if !state.IsNil(-1) && state.IsTable(-1) {
state.PushString("__http_client_config")
state.GetTable(-2)
if !state.IsNil(-1) && state.IsTable(-1) {
// Use the config from sandbox modules
state.SetGlobal("__http_client_config")
state.Pop(1) // Pop the sandbox modules table
return nil
}
state.Pop(1) // Pop the nil or non-table value
}
state.Pop(1) // Pop the nil or sandbox modules table
// Setup default configuration if no custom config exists
state.NewTable() state.NewTable()
state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second)) state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
@ -408,6 +563,16 @@ func setupHTTPClientConfig(state *luajit.State) {
state.SetField(-2, "allow_remote") state.SetField(-2, "allow_remote")
state.SetGlobal("__http_client_config") state.SetGlobal("__http_client_config")
// Ensure the Go function is registered with the base environment
// This is critical to make it persist across reloads
return state.DoString(`
-- Make the __http_request function available in the base environment
if __env_system and __env_system.base_env then
__env_system.base_env.__http_request = __http_request
end
`)
}
} }
// GetHTTPResponse extracts the HTTP response from Lua state // GetHTTPResponse extracts the HTTP response from Lua state
@ -504,187 +669,3 @@ func RestrictHTTPToLocalhost() RunnerOption {
AllowRemote: false, AllowRemote: false,
}) })
} }
// Verify that HTTP client is properly set up
func verifyHTTPClient(state *luajit.State) {
// Get the client table
state.GetGlobal("http")
if !state.IsTable(-1) {
logger.Warning("[HTTP Module] 'http' is not a table\n")
state.Pop(1)
return
}
state.GetField(-1, "client")
if !state.IsTable(-1) {
logger.Warning("[HTTP Module] 'http.client' is not a table\n")
state.Pop(2)
return
}
// Check for get function
state.GetField(-1, "get")
if !state.IsFunction(-1) {
logger.Warning("[HTTP Module] 'http.client.get' is not a function\n")
} else {
logger.Debug("[HTTP Module] 'http.client.get' is properly registered\n")
}
state.Pop(1)
// Check for the request function
state.GetField(-1, "request")
if !state.IsFunction(-1) {
logger.Warning("[HTTP Module] 'http.client.request' is not a function\n")
} else {
logger.Debug("[HTTP Module] 'http.client.request' is properly registered\n")
}
state.Pop(3) // Pop request, client, http
}
const LuaHTTPModule = `
-- Table to store response data
__http_responses = {}
-- HTTP module implementation
local http = {
-- Set HTTP status code
set_status = function(code)
if type(code) ~= "number" then
error("http.set_status: status code must be a number", 2)
end
local resp = __http_responses[1] or {}
resp.status = code
__http_responses[1] = resp
end,
-- Set HTTP header
set_header = function(name, value)
if type(name) ~= "string" or type(value) ~= "string" then
error("http.set_header: name and value must be strings", 2)
end
local resp = __http_responses[1] or {}
resp.headers = resp.headers or {}
resp.headers[name] = value
__http_responses[1] = resp
end,
-- Set content type; set_header helper
set_content_type = function(content_type)
http.set_header("Content-Type", content_type)
end,
-- HTTP client submodule
client = {
-- Generic request function
request = function(method, url, body, options)
if type(method) ~= "string" then
error("http.client.request: method must be a string", 2)
end
if type(url) ~= "string" then
error("http.client.request: url must be a string", 2)
end
-- Call native implementation (this is the critical part)
local result = __http_request(method, url, body, options)
return result
end,
-- Simple GET request
get = function(url, options)
return http.client.request("GET", url, nil, options)
end,
-- Simple POST request with automatic content-type
post = function(url, body, options)
options = options or {}
return http.client.request("POST", url, body, options)
end,
-- Simple PUT request with automatic content-type
put = function(url, body, options)
options = options or {}
return http.client.request("PUT", url, body, options)
end,
-- Simple DELETE request
delete = function(url, options)
return http.client.request("DELETE", url, nil, options)
end,
-- Simple PATCH request
patch = function(url, body, options)
options = options or {}
return http.client.request("PATCH", url, body, options)
end,
-- Simple HEAD request
head = function(url, options)
options = options or {}
local old_options = options
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
local response = http.client.request("HEAD", url, nil, options)
return response
end,
-- Simple OPTIONS request
options = function(url, options)
return http.client.request("OPTIONS", url, nil, options)
end,
-- Shorthand function to directly get JSON
get_json = function(url, options)
options = options or {}
local response = http.client.get(url, options)
if response.ok and response.json then
return response.json
end
return nil, response
end,
-- Utility to build a URL with query parameters
build_url = function(base_url, params)
if not params or type(params) ~= "table" then
return base_url
end
local query = {}
for k, v in pairs(params) do
if type(v) == "table" then
for _, item in ipairs(v) do
table.insert(query, k .. "=" .. tostring(item))
end
else
table.insert(query, k .. "=" .. tostring(v))
end
end
if #query > 0 then
if base_url:find("?") then
return base_url .. "&" .. table.concat(query, "&")
else
return base_url .. "?" .. table.concat(query, "&")
end
end
return base_url
end
}
}
-- Install HTTP module
_G.http = http
-- Clear previous responses when executing scripts
local old_execute_script = __execute_script
if old_execute_script then
__execute_script = function(fn, ctx)
-- Clear previous response
__http_responses[1] = nil
-- Execute original function
return old_execute_script(fn, ctx)
end
end
`

View File

@ -8,7 +8,6 @@ import (
"sync/atomic" "sync/atomic"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"git.sharkk.net/Sky/Moonshark/core/logger"
) )
// Common errors // Common errors
@ -41,7 +40,6 @@ type LuaRunner struct {
bufferSize int // Size of the job queue buffer bufferSize int // Size of the job queue buffer
moduleLoader *NativeModuleLoader // Native module loader for require moduleLoader *NativeModuleLoader // Native module loader for require
sandbox *Sandbox // The sandbox environment sandbox *Sandbox // The sandbox environment
debug bool // Enable debug logging
} }
// WithBufferSize sets the job queue buffer size // WithBufferSize sets the job queue buffer size
@ -73,20 +71,12 @@ func WithLibDirs(dirs ...string) RunnerOption {
} }
} }
// WithDebugEnabled enables debug output
func WithDebugEnabled() RunnerOption {
return func(r *LuaRunner) {
r.debug = true
}
}
// NewRunner creates a new LuaRunner // NewRunner creates a new LuaRunner
func NewRunner(options ...RunnerOption) (*LuaRunner, error) { func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
// Default configuration // Default configuration
runner := &LuaRunner{ runner := &LuaRunner{
bufferSize: 10, // Default buffer size bufferSize: 10, // Default buffer size
sandbox: NewSandbox(), sandbox: NewSandbox(),
debug: false,
} }
// Apply options // Apply options
@ -94,6 +84,13 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
opt(runner) opt(runner)
} }
// Initialize Lua state
state := luajit.New()
if state == nil {
return nil, errors.New("failed to create Lua state")
}
runner.state = state
// Create job queue // Create job queue
runner.jobQueue = make(chan job, runner.bufferSize) runner.jobQueue = make(chan job, runner.bufferSize)
runner.isRunning.Store(true) runner.isRunning.Store(true)
@ -107,9 +104,36 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
runner.moduleLoader = NewNativeModuleLoader(requireConfig) runner.moduleLoader = NewNativeModuleLoader(requireConfig)
} }
// Initialize Lua state // Set up require paths and mechanism
if err := runner.initState(true); err != nil { if err := runner.moduleLoader.SetupRequire(state); err != nil {
return nil, err state.Close()
return nil, ErrInitFailed
}
// Initialize all core modules from the registry
if err := GlobalRegistry.Initialize(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Set up sandbox after core modules are initialized
if err := runner.sandbox.Setup(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Preload all modules into package.loaded
if err := runner.moduleLoader.PreloadAllModules(state); err != nil {
state.Close()
return nil, errors.New("failed to preload modules")
}
// Run init function if provided
if runner.initFunc != nil {
if err := runner.initFunc(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
} }
// Start the event loop // Start the event loop
@ -119,130 +143,10 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
return runner, nil return runner, nil
} }
// debugLog logs a message if debug mode is enabled
func (r *LuaRunner) debugLog(format string, args ...interface{}) {
if r.debug {
logger.Debug("[LuaRunner] "+format, args...)
}
}
// initState initializes or reinitializes the Lua state
func (r *LuaRunner) initState(initial bool) error {
r.debugLog("Initializing Lua state (initial=%v)", initial)
// Clean up existing state if there is one
if r.state != nil {
r.debugLog("Cleaning up existing state")
// Always call Cleanup before Close to properly free function pointers
r.state.Cleanup()
r.state.Close()
r.state = nil
}
// Create fresh state
state := luajit.New()
if state == nil {
return errors.New("failed to create Lua state")
}
r.debugLog("Created new Lua state")
// Set up require paths and mechanism
if err := r.moduleLoader.SetupRequire(state); err != nil {
r.debugLog("Failed to set up require: %v", err)
state.Cleanup()
state.Close()
return ErrInitFailed
}
r.debugLog("Require system initialized")
// Initialize all core modules from the registry
if err := GlobalRegistry.Initialize(state); err != nil {
r.debugLog("Failed to initialize core modules: %v", err)
state.Cleanup()
state.Close()
return ErrInitFailed
}
r.debugLog("Core modules initialized")
// Check if http module is properly registered
testResult, err := state.ExecuteWithResult(`
if type(http) == "table" and type(http.client) == "table" and
type(http.client.get) == "function" then
return true
else
return false
end
`)
if err != nil || testResult != true {
r.debugLog("HTTP module verification failed: %v, result: %v", err, testResult)
} else {
r.debugLog("HTTP module verified OK")
}
// Verify __http_request function
testResult, _ = state.ExecuteWithResult(`return type(__http_request)`)
r.debugLog("__http_request function is of type: %v", testResult)
// Set up sandbox after core modules are initialized
if err := r.sandbox.Setup(state); err != nil {
r.debugLog("Failed to set up sandbox: %v", err)
state.Cleanup()
state.Close()
return ErrInitFailed
}
r.debugLog("Sandbox environment set up")
// Preload all modules into package.loaded
if err := r.moduleLoader.PreloadAllModules(state); err != nil {
r.debugLog("Failed to preload modules: %v", err)
state.Cleanup()
state.Close()
return errors.New("failed to preload modules")
}
r.debugLog("All modules preloaded")
// Run init function if provided
if r.initFunc != nil {
if err := r.initFunc(state); err != nil {
r.debugLog("Custom init function failed: %v", err)
state.Cleanup()
state.Close()
return ErrInitFailed
}
r.debugLog("Custom init function completed")
}
// Test for HTTP module again after full initialization
testResult, err = state.ExecuteWithResult(`
if type(http) == "table" and type(http.client) == "table" and
type(http.client.get) == "function" then
return true
else
return false
end
`)
if err != nil || testResult != true {
r.debugLog("Final HTTP module verification failed: %v, result: %v", err, testResult)
} else {
r.debugLog("Final HTTP module verification OK")
}
r.state = state
r.debugLog("State initialization complete")
return nil
}
// processJobs handles the job queue // processJobs handles the job queue
func (r *LuaRunner) processJobs() { func (r *LuaRunner) processJobs() {
defer r.wg.Done() defer r.wg.Done()
defer func() { defer r.state.Close()
if r.state != nil {
r.debugLog("Cleaning up Lua state in processJobs")
r.state.Cleanup()
r.state.Close()
r.state = nil
}
}()
for job := range r.jobQueue { for job := range r.jobQueue {
// Execute the job and send result // Execute the job and send result
@ -271,13 +175,6 @@ func (r *LuaRunner) executeJob(j job) JobResult {
ctx = j.Context.Values ctx = j.Context.Values
} }
r.mu.RLock()
defer r.mu.RUnlock()
if r.state == nil {
return JobResult{nil, errors.New("lua state is not initialized")}
}
// Execute in sandbox // Execute in sandbox
value, err := r.sandbox.Execute(r.state, j.Bytecode, ctx) value, err := r.sandbox.Execute(r.state, j.Bytecode, ctx)
return JobResult{value, err} return JobResult{value, err}
@ -363,26 +260,15 @@ func (r *LuaRunner) Close() error {
// NotifyFileChanged handles file change notifications from watchers // NotifyFileChanged handles file change notifications from watchers
func (r *LuaRunner) NotifyFileChanged(filePath string) bool { func (r *LuaRunner) NotifyFileChanged(filePath string) bool {
r.debugLog("File change detected: %s", filePath) if r.moduleLoader != nil {
return r.moduleLoader.NotifyFileChanged(r.state, filePath)
r.mu.Lock()
defer r.mu.Unlock()
// Reset the entire state on file changes
err := r.initState(false)
if err != nil {
r.debugLog("Failed to reinitialize state: %v", err)
return false
} }
return false
r.debugLog("State successfully reinitialized")
return true
} }
// ResetModuleCache clears non-core modules from package.loaded // ResetModuleCache clears non-core modules from package.loaded
func (r *LuaRunner) ResetModuleCache() { func (r *LuaRunner) ResetModuleCache() {
if r.moduleLoader != nil { if r.moduleLoader != nil {
r.debugLog("Resetting module cache")
r.moduleLoader.ResetModules(r.state) r.moduleLoader.ResetModules(r.state)
} }
} }
@ -390,7 +276,6 @@ func (r *LuaRunner) ResetModuleCache() {
// ReloadAllModules reloads all modules into package.loaded // ReloadAllModules reloads all modules into package.loaded
func (r *LuaRunner) ReloadAllModules() error { func (r *LuaRunner) ReloadAllModules() error {
if r.moduleLoader != nil { if r.moduleLoader != nil {
r.debugLog("Reloading all modules")
return r.moduleLoader.PreloadAllModules(r.state) return r.moduleLoader.PreloadAllModules(r.state)
} }
return nil return nil
@ -399,7 +284,6 @@ func (r *LuaRunner) ReloadAllModules() error {
// RefreshModuleByName invalidates a specific module in package.loaded // RefreshModuleByName invalidates a specific module in package.loaded
func (r *LuaRunner) RefreshModuleByName(modName string) bool { func (r *LuaRunner) RefreshModuleByName(modName string) bool {
if r.state != nil { if r.state != nil {
r.debugLog("Refreshing module: %s", modName)
if err := r.state.DoString(`package.loaded["` + modName + `"] = nil`); err != nil { if err := r.state.DoString(`package.loaded["` + modName + `"] = nil`); err != nil {
return false return false
} }
@ -410,7 +294,6 @@ func (r *LuaRunner) RefreshModuleByName(modName string) bool {
// AddModule adds a module to the sandbox environment // AddModule adds a module to the sandbox environment
func (r *LuaRunner) AddModule(name string, module any) { func (r *LuaRunner) AddModule(name string, module any) {
r.debugLog("Adding module: %s", name)
r.sandbox.AddModule(name, module) r.sandbox.AddModule(name, module)
} }

View File

@ -1,29 +1,23 @@
package runner package runner
import ( import (
"fmt"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"git.sharkk.net/Sky/Moonshark/core/logger"
) )
// Sandbox manages a simplified Lua environment // Sandbox manages a sandboxed Lua environment
type Sandbox struct { type Sandbox struct {
modules map[string]any // Custom modules for environment modules map[string]any // Custom modules for environment
debug bool // Enable debug output initialized bool // Whether base environment is initialized
} }
// NewSandbox creates a new sandbox // NewSandbox creates a new sandbox
func NewSandbox() *Sandbox { func NewSandbox() *Sandbox {
return &Sandbox{ s := &Sandbox{
modules: make(map[string]any), modules: make(map[string]any),
debug: false, initialized: false,
} }
}
// EnableDebug turns on debug output return s
func (s *Sandbox) EnableDebug() {
s.debug = true
} }
// AddModule adds a module to the sandbox environment // AddModule adds a module to the sandbox environment
@ -31,109 +25,228 @@ func (s *Sandbox) AddModule(name string, module any) {
s.modules[name] = module s.modules[name] = module
} }
// debugLog prints debug messages if debug is enabled
func (s *Sandbox) debugLog(format string, args ...interface{}) {
if s.debug {
logger.Debug("[Sandbox Debug] "+format, args...)
}
}
// Setup initializes the sandbox in a Lua state // Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State) error { func (s *Sandbox) Setup(state *luajit.State) error {
s.debugLog("Setting up sandbox environment") // Register modules
if err := s.registerModules(state); err != nil {
// Register modules in the global environment
for name, module := range s.modules {
s.debugLog("Registering module: %s", name)
if err := state.PushValue(module); err != nil {
s.debugLog("Failed to register module %s: %v", name, err)
return err return err
} }
state.SetGlobal(name)
// Create high-performance persistent environment
return state.DoString(`
-- Global shared environment (created once)
__env_system = __env_system or {
base_env = nil, -- Template environment
initialized = false, -- Initialization flag
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
} }
// Initialize simple environment setup -- Initialize base environment once
err := state.DoString(` if not __env_system.initialized then
-- Global tables for response handling -- Create base environment with all standard libraries
__http_responses = __http_responses or {} local base = {}
-- Simple environment creation -- Safe standard libraries
function __create_env(ctx) base.string = string
-- Create environment inheriting from _G base.table = table
local env = setmetatable({}, {__index = _G}) base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
-- Add context if provided -- Basic functions
base.print = print
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.unpack = unpack
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
-- Package system is shared for performance
base.package = {
loaded = package.loaded,
path = package.path,
preload = package.preload
}
base.http = http
base.cookie = cookie
-- http_client module is now part of http.client
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
-- Store base environment
__env_system.base_env = base
__env_system.initialized = true
end
-- Global variable for tracking current environment
__last_env = nil
-- Fast environment creation with pre-allocation
function __get_sandbox_env(ctx)
local env
-- Try to reuse from pool
if __env_system.pool_size > 0 then
env = table.remove(__env_system.env_pool)
__env_system.pool_size = __env_system.pool_size - 1
-- Clear any previous context
env.ctx = ctx or nil
-- Clear any previous response
env._response = nil
else
-- Create new environment with metatable inheritance
env = setmetatable({}, {
__index = __env_system.base_env
})
-- Set context if provided
if ctx then if ctx then
env.ctx = ctx env.ctx = ctx
end end
-- Install the fast require implementation
env.require = function(modname)
return __fast_require(env, modname)
end
-- Install cookie module methods directly into environment
env.cookie = {
get = function(name)
if type(name) ~= "string" then
error("cookie.get: name must be a string", 2)
end
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
return tostring(env.ctx.cookies[name])
end
return nil
end,
set = cookie.set,
remove = cookie.remove
}
end
-- Store reference to current environment
__last_env = env
return env return env
end end
-- Execute script with clean environment -- Return environment to pool for reuse
function __execute_script(fn, ctx) function __recycle_env(env)
-- Clear previous responses -- Only recycle if pool isn't full
__http_responses[1] = nil if __env_system.pool_size < __env_system.max_pool_size then
-- Clear context reference to avoid memory leaks
env.ctx = nil
-- Don't clear response data - we need it for extraction
-- Create environment -- Add to pool
local env = __create_env(ctx) table.insert(__env_system.env_pool, env)
__env_system.pool_size = __env_system.pool_size + 1
end
end
-- Set environment for function -- Hyper-optimized sandbox executor
setfenv(fn, env) function __execute_sandbox(bytecode, ctx)
-- Get environment (from pool if available)
local env = __get_sandbox_env(ctx)
-- Set environment for bytecode
setfenv(bytecode, env)
-- Execute with protected call -- Execute with protected call
local ok, result = pcall(fn) local success, result = pcall(bytecode)
if not ok then
-- Recycle environment for future use
__recycle_env(env)
-- Process result
if not success then
error(result, 0) error(result, 0)
end end
return result return result
end end
`)
if err != nil { -- Run minimal GC for overall health
s.debugLog("Failed to set up sandbox: %v", err) collectgarbage("step", 10)
`)
}
// registerModules registers custom modules in the Lua state
func (s *Sandbox) registerModules(state *luajit.State) error {
// Create or get module registry table
state.GetGlobal("__sandbox_modules")
if state.IsNil(-1) {
// Table doesn't exist, create it
state.Pop(1)
state.NewTable()
state.SetGlobal("__sandbox_modules")
state.GetGlobal("__sandbox_modules")
}
// Add modules to registry
for name, module := range s.modules {
state.PushString(name)
if err := state.PushValue(module); err != nil {
state.Pop(2)
return err return err
} }
state.SetTable(-3)
s.debugLog("Sandbox setup complete")
// Verify HTTP module is accessible
httpResult, _ := state.ExecuteWithResult(`
if type(http) == "table" and
type(http.client) == "table" and
type(http.client.get) == "function" then
return "HTTP module verified OK"
else
local status = {
http = type(http),
client = type(http) == "table" and type(http.client) or "N/A",
get = type(http) == "table" and type(http.client) == "table" and type(http.client.get) or "N/A"
} }
return status
end
`)
s.debugLog("HTTP verification result: %v", httpResult) // Pop module table
state.Pop(1)
return nil return nil
} }
// Execute runs bytecode in the sandbox // Execute runs bytecode in the sandbox
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) { func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
// Update custom modules if needed
if !s.initialized {
if err := s.registerModules(state); err != nil {
return nil, err
}
s.initialized = true
}
// Load bytecode // Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil { if err := state.LoadBytecode(bytecode, "script"); err != nil {
s.debugLog("Failed to load bytecode: %v", err)
return nil, err return nil, err
} }
// Prepare context // Create context table if provided
if ctx != nil { if len(ctx) > 0 {
// Preallocate table with appropriate size
state.CreateTable(0, len(ctx)) state.CreateTable(0, len(ctx))
// Add context entries
for k, v := range ctx { for k, v := range ctx {
state.PushString(k) state.PushString(k)
if err := state.PushValue(v); err != nil { if err := state.PushValue(v); err != nil {
state.Pop(2) state.Pop(2)
s.debugLog("Failed to push context value %s: %v", k, err)
return nil, err return nil, err
} }
state.SetTable(-3) state.SetTable(-3)
@ -142,37 +255,31 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
state.PushNil() // No context state.PushNil() // No context
} }
// Get execution function // Get optimized sandbox executor
state.GetGlobal("__execute_script") state.GetGlobal("__execute_sandbox")
if !state.IsFunction(-1) {
state.Pop(2) // Pop nil and non-function
s.debugLog("__execute_script is not a function")
return nil, fmt.Errorf("sandbox execution function not found")
}
// Push arguments // Setup call with correct argument order
state.PushCopy(-3) // bytecode function state.PushCopy(-3) // Copy bytecode function
state.PushCopy(-3) // context state.PushCopy(-3) // Copy context
// Clean up stack // Clean up stack
state.Remove(-5) // original bytecode state.Remove(-5) // Remove original bytecode
state.Remove(-4) // original context state.Remove(-4) // Remove original context
// Call with 2 args, 1 result // Call optimized sandbox executor
if err := state.Call(2, 1); err != nil { if err := state.Call(2, 1); err != nil {
s.debugLog("Execution failed: %v", err)
return nil, err return nil, err
} }
// Get result // Get result
result, err := state.ToValue(-1) result, err := state.ToValue(-1)
state.Pop(1) state.Pop(1) // Pop result
// Check for HTTP response // Check if HTTP response was set
httpResponse, hasResponse := GetHTTPResponse(state) httpResponse, hasHTTPResponse := GetHTTPResponse(state)
if hasResponse { if hasHTTPResponse {
httpResponse.Body = result httpResponse.Body = result
return httpResponse, nil return httpResponse, err
} }
return result, err return result, err

View File

@ -16,9 +16,9 @@ var (
) )
// GetWatcherManager returns the global watcher manager, creating it if needed // GetWatcherManager returns the global watcher manager, creating it if needed
func GetWatcherManager(adaptive bool) *WatcherManager { func GetWatcherManager(log *logger.Logger, adaptive bool) *WatcherManager {
globalManagerOnce.Do(func() { globalManagerOnce.Do(func() {
globalManager = NewWatcherManager(adaptive) globalManager = NewWatcherManager(log, adaptive)
}) })
return globalManager return globalManager
} }
@ -39,7 +39,7 @@ func WatchDirectory(config DirectoryWatcherConfig, manager *WatcherManager) (*Wa
manager: manager, manager: manager,
} }
logger.Debug("Started watching directory: %s", config.Dir) config.Log.Debug("Started watching directory: %s", config.Dir)
return w, nil return w, nil
} }
@ -56,22 +56,14 @@ func (w *Watcher) Close() error {
return nil return nil
} }
// WatchLuaRouter sets up a watcher for a LuaRouter's routes directory; also updates // WatchLuaRouter sets up a watcher for a LuaRouter's routes directory
// the LuaRunner so that the state can be rebuilt func WatchLuaRouter(router *routers.LuaRouter, routesDir string, log *logger.Logger) (*Watcher, error) {
func WatchLuaRouter(router *routers.LuaRouter, runner *runner.LuaRunner, routesDir string) (*Watcher, error) { manager := GetWatcherManager(log, true) // Use adaptive polling
manager := GetWatcherManager(true)
runnerRefresh := func() error {
logger.Debug("Refreshing LuaRunner state due to file change")
runner.NotifyFileChanged("")
return nil
}
combinedCallback := combineCallbacks(router.Refresh, runnerRefresh)
config := DirectoryWatcherConfig{ config := DirectoryWatcherConfig{
Dir: routesDir, Dir: routesDir,
Callback: combinedCallback, Callback: router.Refresh,
Log: log,
Recursive: true, Recursive: true,
} }
@ -80,17 +72,18 @@ func WatchLuaRouter(router *routers.LuaRouter, runner *runner.LuaRunner, routesD
return nil, err return nil, err
} }
logger.Info("Started watching Lua routes directory: %s", routesDir) log.Info("Started watching Lua routes directory: %s", routesDir)
return watcher, nil return watcher, nil
} }
// WatchStaticRouter sets up a watcher for a StaticRouter's root directory // WatchStaticRouter sets up a watcher for a StaticRouter's root directory
func WatchStaticRouter(router *routers.StaticRouter, staticDir string) (*Watcher, error) { func WatchStaticRouter(router *routers.StaticRouter, staticDir string, log *logger.Logger) (*Watcher, error) {
manager := GetWatcherManager(true) manager := GetWatcherManager(log, true) // Use adaptive polling
config := DirectoryWatcherConfig{ config := DirectoryWatcherConfig{
Dir: staticDir, Dir: staticDir,
Callback: router.Refresh, Callback: router.Refresh,
Log: log,
Recursive: true, Recursive: true,
} }
@ -99,13 +92,13 @@ func WatchStaticRouter(router *routers.StaticRouter, staticDir string) (*Watcher
return nil, err return nil, err
} }
logger.Info("Started watching static files directory: %s", staticDir) log.Info("Started watching static files directory: %s", staticDir)
return watcher, nil return watcher, nil
} }
// WatchLuaModules sets up watchers for Lua module directories // WatchLuaModules sets up watchers for Lua module directories
func WatchLuaModules(luaRunner *runner.LuaRunner, libDirs []string) ([]*Watcher, error) { func WatchLuaModules(luaRunner *runner.LuaRunner, libDirs []string, log *logger.Logger) ([]*Watcher, error) {
manager := GetWatcherManager(true) manager := GetWatcherManager(log, true) // Use adaptive polling
watchers := make([]*Watcher, 0, len(libDirs)) watchers := make([]*Watcher, 0, len(libDirs))
for _, dir := range libDirs { for _, dir := range libDirs {
@ -113,11 +106,11 @@ func WatchLuaModules(luaRunner *runner.LuaRunner, libDirs []string) ([]*Watcher,
dirCopy := dir // Capture for closure dirCopy := dir // Capture for closure
callback := func() error { callback := func() error {
logger.Debug("Detected changes in Lua module directory: %s", dirCopy) log.Debug("Detected changes in Lua module directory: %s", dirCopy)
// Reload modules from this directory // Reload modules from this directory
if err := luaRunner.ReloadAllModules(); err != nil { if err := luaRunner.ReloadAllModules(); err != nil {
logger.Warning("Error reloading modules: %v", err) log.Warning("Error reloading modules: %v", err)
} }
return nil return nil
@ -126,6 +119,7 @@ func WatchLuaModules(luaRunner *runner.LuaRunner, libDirs []string) ([]*Watcher,
config := DirectoryWatcherConfig{ config := DirectoryWatcherConfig{
Dir: dir, Dir: dir,
Callback: callback, Callback: callback,
Log: log,
Recursive: true, Recursive: true,
} }
@ -139,7 +133,7 @@ func WatchLuaModules(luaRunner *runner.LuaRunner, libDirs []string) ([]*Watcher,
} }
watchers = append(watchers, watcher) watchers = append(watchers, watcher)
logger.Info("Started watching Lua modules directory: %s", dir) log.Info("Started watching Lua modules directory: %s", dir)
} }
return watchers, nil return watchers, nil
@ -152,15 +146,3 @@ func ShutdownWatcherManager() {
globalManager = nil globalManager = nil
} }
} }
// combineCallbacks creates a single callback function from multiple callbacks
func combineCallbacks(callbacks ...func() error) func() error {
return func() error {
for _, callback := range callbacks {
if err := callback(); err != nil {
return err
}
}
return nil
}
}

View File

@ -31,6 +31,7 @@ type DirectoryWatcher struct {
// Configuration // Configuration
callback func() error callback func() error
log *logger.Logger
debounceTime time.Duration debounceTime time.Duration
recursive bool recursive bool
@ -48,6 +49,9 @@ type DirectoryWatcherConfig struct {
// Callback function to call when changes are detected // Callback function to call when changes are detected
Callback func() error Callback func() error
// Logger instance
Log *logger.Logger
// Debounce time (0 means use default) // Debounce time (0 means use default)
DebounceTime time.Duration DebounceTime time.Duration
@ -66,6 +70,7 @@ func NewDirectoryWatcher(config DirectoryWatcherConfig) (*DirectoryWatcher, erro
dir: config.Dir, dir: config.Dir,
files: make(map[string]FileInfo), files: make(map[string]FileInfo),
callback: config.Callback, callback: config.Callback,
log: config.Log,
debounceTime: debounceTime, debounceTime: debounceTime,
recursive: config.Recursive, recursive: config.Recursive,
} }
@ -212,15 +217,15 @@ func (w *DirectoryWatcher) notifyChange() {
// logDebug logs a debug message with the watcher's directory prefix // logDebug logs a debug message with the watcher's directory prefix
func (w *DirectoryWatcher) logDebug(format string, args ...any) { func (w *DirectoryWatcher) logDebug(format string, args ...any) {
logger.Debug("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...)) w.log.Debug("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...))
} }
// logWarning logs a warning message with the watcher's directory prefix // logWarning logs a warning message with the watcher's directory prefix
func (w *DirectoryWatcher) logWarning(format string, args ...any) { func (w *DirectoryWatcher) logWarning(format string, args ...any) {
logger.Warning("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...)) w.log.Warning("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...))
} }
// logError logs an error message with the watcher's directory prefix // logError logs an error message with the watcher's directory prefix
func (w *DirectoryWatcher) logError(format string, args ...any) { func (w *DirectoryWatcher) logError(format string, args ...any) {
logger.Error("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...)) w.log.Error("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...))
} }

View File

@ -29,18 +29,22 @@ type WatcherManager struct {
done chan struct{} done chan struct{}
ticker *time.Ticker ticker *time.Ticker
// Logger
log *logger.Logger
// Wait group for shutdown coordination // Wait group for shutdown coordination
wg sync.WaitGroup wg sync.WaitGroup
} }
// NewWatcherManager creates a new watcher manager // NewWatcherManager creates a new watcher manager
func NewWatcherManager(adaptive bool) *WatcherManager { func NewWatcherManager(log *logger.Logger, adaptive bool) *WatcherManager {
manager := &WatcherManager{ manager := &WatcherManager{
watchers: make(map[string]*DirectoryWatcher), watchers: make(map[string]*DirectoryWatcher),
pollInterval: defaultPollInterval, pollInterval: defaultPollInterval,
adaptive: adaptive, adaptive: adaptive,
lastActivity: time.Now(), lastActivity: time.Now(),
done: make(chan struct{}), done: make(chan struct{}),
log: log,
} }
// Start the polling loop // Start the polling loop
@ -67,7 +71,7 @@ func (m *WatcherManager) AddWatcher(watcher *DirectoryWatcher) {
defer m.mu.Unlock() defer m.mu.Unlock()
m.watchers[watcher.dir] = watcher m.watchers[watcher.dir] = watcher
logger.Debug("[WatcherManager] Added watcher for directory: %s", watcher.dir) m.log.Debug("[WatcherManager] Added watcher for directory: %s", watcher.dir)
} }
// RemoveWatcher unregisters a directory watcher // RemoveWatcher unregisters a directory watcher
@ -76,7 +80,7 @@ func (m *WatcherManager) RemoveWatcher(dir string) {
defer m.mu.Unlock() defer m.mu.Unlock()
delete(m.watchers, dir) delete(m.watchers, dir)
logger.Debug("[WatcherManager] Removed watcher for directory: %s", dir) m.log.Debug("[WatcherManager] Removed watcher for directory: %s", dir)
} }
// pollLoop is the main polling loop that checks all watched directories // pollLoop is the main polling loop that checks all watched directories
@ -96,7 +100,7 @@ func (m *WatcherManager) pollLoop() {
if m.pollInterval > defaultPollInterval { if m.pollInterval > defaultPollInterval {
m.pollInterval = defaultPollInterval m.pollInterval = defaultPollInterval
m.ticker.Reset(m.pollInterval) m.ticker.Reset(m.pollInterval)
logger.Debug("[WatcherManager] Reset to base polling interval: %v", m.pollInterval) m.log.Debug("[WatcherManager] Reset to base polling interval: %v", m.pollInterval)
} }
} else { } else {
// No activity, consider slowing down polling // No activity, consider slowing down polling
@ -104,7 +108,7 @@ func (m *WatcherManager) pollLoop() {
if m.pollInterval == defaultPollInterval && inactiveDuration > inactivityThreshold { if m.pollInterval == defaultPollInterval && inactiveDuration > inactivityThreshold {
m.pollInterval = extendedPollInterval m.pollInterval = extendedPollInterval
m.ticker.Reset(m.pollInterval) m.ticker.Reset(m.pollInterval)
logger.Debug("[WatcherManager] Extended polling interval to: %v after %v of inactivity", m.log.Debug("[WatcherManager] Extended polling interval to: %v after %v of inactivity",
m.pollInterval, inactiveDuration.Round(time.Minute)) m.pollInterval, inactiveDuration.Round(time.Minute))
} }
} }
@ -125,7 +129,7 @@ func (m *WatcherManager) checkAllDirectories() bool {
for _, watcher := range m.watchers { for _, watcher := range m.watchers {
changed, err := watcher.checkForChanges() changed, err := watcher.checkForChanges()
if err != nil { if err != nil {
logger.Error("[WatcherManager] Error checking directory %s: %v", watcher.dir, err) m.log.Error("[WatcherManager] Error checking directory %s: %v", watcher.dir, err)
continue continue
} }

View File

@ -26,7 +26,7 @@ type WatcherConfig struct {
} }
// initRouters sets up the Lua and static routers // initRouters sets up the Lua and static routers
func initRouters(routesDir, staticDir string) (*routers.LuaRouter, *routers.StaticRouter, error) { func initRouters(routesDir, staticDir string, log *logger.Logger) (*routers.LuaRouter, *routers.StaticRouter, error) {
// Ensure directories exist // Ensure directories exist
if err := utils.EnsureDir(routesDir); err != nil { if err := utils.EnsureDir(routesDir); err != nil {
return nil, nil, fmt.Errorf("routes directory doesn't exist, and could not create it: %v", err) return nil, nil, fmt.Errorf("routes directory doesn't exist, and could not create it: %v", err)
@ -41,12 +41,12 @@ func initRouters(routesDir, staticDir string) (*routers.LuaRouter, *routers.Stat
// Check if this is a compilation warning or a more serious error // Check if this is a compilation warning or a more serious error
if errors.Is(err, routers.ErrRoutesCompilationErrors) { if errors.Is(err, routers.ErrRoutesCompilationErrors) {
// Some routes failed to compile, but router is still usable // Some routes failed to compile, but router is still usable
logger.Warning("Some Lua routes failed to compile. Check logs for details.") log.Warning("Some Lua routes failed to compile. Check logs for details.")
// Log details about each failed route // Log details about each failed route
if failedRoutes := luaRouter.ReportFailedRoutes(); len(failedRoutes) > 0 { if failedRoutes := luaRouter.ReportFailedRoutes(); len(failedRoutes) > 0 {
for _, re := range failedRoutes { for _, re := range failedRoutes {
logger.Error("Route %s %s failed to compile: %v", re.Method, re.Path, re.Err) log.Error("Route %s %s failed to compile: %v", re.Method, re.Path, re.Err)
} }
} }
} else { } else {
@ -54,14 +54,14 @@ func initRouters(routesDir, staticDir string) (*routers.LuaRouter, *routers.Stat
return nil, nil, fmt.Errorf("failed to initialize Lua router: %v", err) return nil, nil, fmt.Errorf("failed to initialize Lua router: %v", err)
} }
} }
logger.Info("Lua router initialized with routes from %s", routesDir) log.Info("Lua router initialized with routes from %s", routesDir)
// Initialize static file router // Initialize static file router
staticRouter, err := routers.NewStaticRouter(staticDir) staticRouter, err := routers.NewStaticRouterWithLogger(staticDir, log)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to initialize static router: %v", err) return nil, nil, fmt.Errorf("failed to initialize static router: %v", err)
} }
logger.Info("Static router initialized with files from %s", staticDir) log.Info("Static router initialized with files from %s", staticDir)
staticRouter.EnableDebugLog() staticRouter.EnableDebugLog()
return luaRouter, staticRouter, nil return luaRouter, staticRouter, nil
@ -70,15 +70,15 @@ func initRouters(routesDir, staticDir string) (*routers.LuaRouter, *routers.Stat
// setupWatchers initializes and starts all file watchers // setupWatchers initializes and starts all file watchers
func setupWatchers(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, func setupWatchers(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter,
luaRunner *runner.LuaRunner, routesDir string, staticDir string, luaRunner *runner.LuaRunner, routesDir string, staticDir string,
libDirs []string, config WatcherConfig) ([]func() error, error) { libDirs []string, log *logger.Logger, config WatcherConfig) ([]func() error, error) {
var cleanupFuncs []func() error var cleanupFuncs []func() error
// Set up watcher for Lua routes // Set up watcher for Lua routes
if config.Routes { if config.Routes {
luaRouterWatcher, err := watchers.WatchLuaRouter(luaRouter, luaRunner, routesDir) luaRouterWatcher, err := watchers.WatchLuaRouter(luaRouter, routesDir, log)
if err != nil { if err != nil {
logger.Warning("Failed to watch routes directory: %v", err) log.Warning("Failed to watch routes directory: %v", err)
} else { } else {
cleanupFuncs = append(cleanupFuncs, luaRouterWatcher.Close) cleanupFuncs = append(cleanupFuncs, luaRouterWatcher.Close)
} }
@ -86,9 +86,9 @@ func setupWatchers(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRou
// Set up watcher for static files // Set up watcher for static files
if config.Static { if config.Static {
staticWatcher, err := watchers.WatchStaticRouter(staticRouter, staticDir) staticWatcher, err := watchers.WatchStaticRouter(staticRouter, staticDir, log)
if err != nil { if err != nil {
logger.Warning("Failed to watch static directory: %v", err) log.Warning("Failed to watch static directory: %v", err)
} else { } else {
cleanupFuncs = append(cleanupFuncs, staticWatcher.Close) cleanupFuncs = append(cleanupFuncs, staticWatcher.Close)
} }
@ -96,15 +96,15 @@ func setupWatchers(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRou
// Set up watchers for Lua modules libraries // Set up watchers for Lua modules libraries
if config.Modules && len(libDirs) > 0 { if config.Modules && len(libDirs) > 0 {
moduleWatchers, err := watchers.WatchLuaModules(luaRunner, libDirs) moduleWatchers, err := watchers.WatchLuaModules(luaRunner, libDirs, log)
if err != nil { if err != nil {
logger.Warning("Failed to watch Lua module directories: %v", err) log.Warning("Failed to watch Lua module directories: %v", err)
} else { } else {
for _, watcher := range moduleWatchers { for _, watcher := range moduleWatchers {
w := watcher // Capture variable for closure w := watcher // Capture variable for closure
cleanupFuncs = append(cleanupFuncs, w.Close) cleanupFuncs = append(cleanupFuncs, w.Close)
} }
logger.Info("File watchers active for %d Lua module directories", len(moduleWatchers)) log.Info("File watchers active for %d Lua module directories", len(moduleWatchers))
} }
} }
@ -112,38 +112,33 @@ func setupWatchers(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRou
} }
func main() { func main() {
// Initialize global logger with debug level // Initialize logger
logger.InitGlobalLogger(logger.LevelDebug, true) log := logger.New(logger.LevelDebug, true)
logger.Server("Starting Moonshark server") log.Server("Starting Moonshark server")
// Load configuration from config.lua // Load configuration from config.lua
cfg, err := config.Load("config.lua") cfg, err := config.Load("config.lua")
if err != nil { if err != nil {
logger.Warning("Failed to load config.lua: %v", err) log.Warning("Failed to load config.lua: %v", err)
logger.Info("Using default configuration") log.Info("Using default configuration")
cfg = config.New() cfg = config.New()
} }
// Set log level from config // Set log level from config
switch cfg.GetString("log_level", "info") { switch cfg.GetString("log_level", "info") {
case "debug":
log.SetLevel(logger.LevelDebug)
case "warn": case "warn":
logger.SetLevel(logger.LevelWarning) log.SetLevel(logger.LevelWarning)
case "error": case "error":
logger.SetLevel(logger.LevelError) log.SetLevel(logger.LevelError)
case "server": case "server":
logger.SetLevel(logger.LevelServer) log.SetLevel(logger.LevelServer)
case "fatal": case "fatal":
logger.SetLevel(logger.LevelFatal) log.SetLevel(logger.LevelFatal)
default: default:
logger.SetLevel(logger.LevelInfo) log.SetLevel(logger.LevelInfo)
}
// Get debug mode setting
debugMode := cfg.GetBool("debug", false)
if debugMode {
logger.EnableDebug() // Force debug logs regardless of level
logger.Debug("Debug mode enabled")
} }
// Get configuration values // Get configuration values
@ -154,7 +149,7 @@ func main() {
bufferSize := cfg.GetInt("buffer_size", 20) bufferSize := cfg.GetInt("buffer_size", 20)
if err := utils.EnsureDir(overrideDir); err != nil { if err := utils.EnsureDir(overrideDir); err != nil {
logger.Warning("Override directory doesn't exist, and could not create it: %v", err) log.Warning("Override directory doesn't exist, and could not create it: %v", err)
overrideDir = "" // Disable overrides if directory can't be created overrideDir = "" // Disable overrides if directory can't be created
} }
@ -171,26 +166,25 @@ func main() {
// Ensure lib directories exist // Ensure lib directories exist
for _, dir := range libDirs { for _, dir := range libDirs {
if err := utils.EnsureDir(dir); err != nil { if err := utils.EnsureDir(dir); err != nil {
logger.Warning("Lib directory doesn't exist, and could not create it: %v", err) log.Warning("Lib directory doesn't exist, and could not create it: %v", err)
} }
} }
// Initialize routers // Initialize routers
luaRouter, staticRouter, err := initRouters(routesDir, staticDir) luaRouter, staticRouter, err := initRouters(routesDir, staticDir, log)
if err != nil { if err != nil {
logger.Fatal("Router initialization failed: %v", err) log.Fatal("Router initialization failed: %v", err)
} }
// Initialize Lua runner // Initialize Lua runner
luaRunner, err := runner.NewRunner( luaRunner, err := runner.NewRunner(
runner.WithBufferSize(bufferSize), runner.WithBufferSize(bufferSize),
runner.WithLibDirs(libDirs...), runner.WithLibDirs(libDirs...),
runner.WithDebugEnabled(),
) )
if err != nil { if err != nil {
logger.Fatal("Failed to initialize Lua runner: %v", err) log.Fatal("Failed to initialize Lua runner: %v", err)
} }
logger.Server("Lua runner initialized with buffer size %d", bufferSize) log.Server("Lua runner initialized with buffer size %d", bufferSize)
defer luaRunner.Close() defer luaRunner.Close()
// Set up file watchers if enabled // Set up file watchers if enabled
@ -219,29 +213,31 @@ func main() {
// Setup enabled watchers // Setup enabled watchers
cleanupFuncs, err = setupWatchers(luaRouter, staticRouter, luaRunner, cleanupFuncs, err = setupWatchers(luaRouter, staticRouter, luaRunner,
routesDir, staticDir, libDirs, watcherConfig) routesDir, staticDir, libDirs, log, watcherConfig)
if err != nil { if err != nil {
logger.Warning("Error setting up watchers: %v", err) log.Warning("Error setting up watchers: %v", err)
} }
// Register cleanup functions // Register cleanup functions
defer func() { defer func() {
for _, cleanup := range cleanupFuncs { for _, cleanup := range cleanupFuncs {
if err := cleanup(); err != nil { if err := cleanup(); err != nil {
logger.Warning("Cleanup error: %v", err) log.Warning("Cleanup error: %v", err)
} }
} }
}() }()
httpLoggingEnabled := cfg.GetBool("http_logging_enabled", true) httpLoggingEnabled := cfg.GetBool("http_logging_enabled", true)
if httpLoggingEnabled { if httpLoggingEnabled {
logger.Info("HTTP logging is enabled") log.Info("HTTP logging is enabled")
} else { } else {
logger.Info("HTTP logging is disabled") log.Info("HTTP logging is disabled")
} }
debugMode := cfg.GetBool("debug", false)
// Create HTTP server // Create HTTP server
server := http.New(luaRouter, staticRouter, luaRunner, httpLoggingEnabled, debugMode, overrideDir, cfg) server := http.New(luaRouter, staticRouter, luaRunner, log, httpLoggingEnabled, debugMode, overrideDir, cfg)
// Handle graceful shutdown // Handle graceful shutdown
stop := make(chan os.Signal, 1) stop := make(chan os.Signal, 1)
@ -251,24 +247,24 @@ func main() {
go func() { go func() {
if err := server.ListenAndServe(fmt.Sprintf(":%d", port)); err != nil { if err := server.ListenAndServe(fmt.Sprintf(":%d", port)); err != nil {
if err.Error() != "http: Server closed" { if err.Error() != "http: Server closed" {
logger.Error("Server error: %v", err) log.Error("Server error: %v", err)
} }
} }
}() }()
logger.Server("Server started on port %d", port) log.Server("Server started on port %d", port)
// Wait for interrupt signal // Wait for interrupt signal
<-stop <-stop
logger.Server("Shutdown signal received") log.Server("Shutdown signal received")
// Gracefully shut down the server // Gracefully shut down the server
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := server.Shutdown(ctx); err != nil { if err := server.Shutdown(ctx); err != nil {
logger.Error("Server shutdown error: %v", err) log.Error("Server shutdown error: %v", err)
} }
logger.Server("Server stopped") log.Server("Server stopped")
} }