diff --git a/go.mod b/go.mod index 6ba8fc3..157815b 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( git.sharkk.net/Go/Color v1.1.0 git.sharkk.net/Go/LRU v1.0.0 git.sharkk.net/Sharkk/Fin v1.3.0 - git.sharkk.net/Sky/LuaJIT-to-Go v0.5.2 + git.sharkk.net/Sky/LuaJIT-to-Go v0.5.3 github.com/VictoriaMetrics/fastcache v1.12.4 github.com/alexedwards/argon2id v1.0.0 github.com/deneonet/benc v1.1.8 diff --git a/go.sum b/go.sum index de6139b..ae66c7f 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ git.sharkk.net/Go/LRU v1.0.0 h1:/KqdRVhHldi23aVfQZ4ss6vhCWZqA3vFiQyf1MJPpQc= git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50= git.sharkk.net/Sharkk/Fin v1.3.0 h1:6/f7+h382jJOeo09cgdzH+PGb5XdvajZZRiES52sBkI= git.sharkk.net/Sharkk/Fin v1.3.0/go.mod h1:ca0Ej9yCM/vHh1o3YMvBZspme3EtbwoEL2UXN5UPXMo= -git.sharkk.net/Sky/LuaJIT-to-Go v0.5.2 h1:BgsZPkoqJjQ7Rb+bSs7QQ24+wwLzyc2AALbnpB/n9Kw= -git.sharkk.net/Sky/LuaJIT-to-Go v0.5.2/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8= +git.sharkk.net/Sky/LuaJIT-to-Go v0.5.3 h1:SuLz4X/k+sMy+Uj1lhEy6brJtvtzHLdivUcu5K91y+o= +git.sharkk.net/Sky/LuaJIT-to-Go v0.5.3/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8= github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0= github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI= github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w= diff --git a/http/server.go b/http/server.go index cb9f824..a423fbe 100644 --- a/http/server.go +++ b/http/server.go @@ -2,24 +2,20 @@ package http import ( - "context" + "fmt" "strings" "time" "Moonshark/config" - "Moonshark/logger" "Moonshark/metadata" "Moonshark/runner" - "Moonshark/sessions" "Moonshark/utils" - "git.sharkk.net/Go/Color" - + "github.com/goccy/go-json" + "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) -var emptyMap = make(map[string]any) - func NewHttpServer(cfg *config.Config, handler fasthttp.RequestHandler, dbg bool) *fasthttp.Server { return &fasthttp.Server{ Handler: handler, @@ -55,113 +51,109 @@ func NewPublicHandler(pubDir, prefix string) fasthttp.RequestHandler { return fs.NewRequestHandler() } -func (s *Server) ListenAndServe(addr string) error { - logger.Infof("Catch the swell at %s", color.Cyan("http://localhost"+addr)) - return s.fasthttpServer.ListenAndServe(addr) -} - -func (s *Server) Shutdown(ctx context.Context) error { - return s.fasthttpServer.ShutdownWithContext(ctx) -} - -func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) { - start := time.Now() - method := string(ctx.Method()) - path := string(ctx.Path()) - - // Route lookup - bytecode, params, found := s.luaRouter.Lookup(method, path) - if !found { - s.send404(ctx) - s.logRequest(ctx, method, path, time.Since(start)) - return - } - - if len(bytecode) == 0 { - s.send500(ctx, nil) - s.logRequest(ctx, method, path, time.Since(start)) - return - } - - // Get session - session := s.sessionManager.GetSessionFromRequest(ctx) - - // Execute Lua script - response, err := s.luaRunner.ExecuteHTTP(bytecode, ctx, params, session) - if err != nil { - logger.Errorf("Lua execution error: %v", err) - s.send500(ctx, err) - s.logRequest(ctx, method, path, time.Since(start)) - return - } - - // Apply response - s.applyResponse(ctx, response, session) - runner.ReleaseResponse(response) - - s.logRequest(ctx, method, path, time.Since(start)) -} - -func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, resp *runner.Response, session *sessions.Session) { - // Handle session updates - if len(resp.SessionData) > 0 { - if _, clearAll := resp.SessionData["__clear_all"]; clearAll { - session.Clear() - session.ClearFlash() - delete(resp.SessionData, "__clear_all") - } - - for k, v := range resp.SessionData { - if v == "__DELETE__" { - session.Delete(k) - } else { - session.Set(k, v) - } - } - } - - // Handle flash data - if flashData, ok := resp.Metadata["flash"].(map[string]any); ok { - for k, v := range flashData { - if err := session.FlashSafe(k, v); err != nil && s.debugMode { - logger.Warnf("Error setting flash data %s: %v", k, err) - } - } - } - - // Apply session cookie - s.sessionManager.ApplySessionCookie(ctx, session) - - // Apply HTTP response - runner.ApplyResponse(resp, ctx) -} - -func (s *Server) send404(ctx *fasthttp.RequestCtx) { +func Send404(ctx *fasthttp.RequestCtx) { ctx.SetContentType("text/html; charset=utf-8") ctx.SetStatusCode(fasthttp.StatusNotFound) - cacheMu.RLock() - ctx.SetBody(cached404) - cacheMu.RUnlock() + ctx.SetBody([]byte(utils.NotFoundPage(ctx.URI().String()))) } -func (s *Server) send500(ctx *fasthttp.RequestCtx, err error) { +func Send500(ctx *fasthttp.RequestCtx, err error) { ctx.SetContentType("text/html; charset=utf-8") ctx.SetStatusCode(fasthttp.StatusInternalServerError) if err == nil { - cacheMu.RLock() - ctx.SetBody(cached500) - cacheMu.RUnlock() + ctx.SetBody([]byte(utils.InternalErrorPage(string(ctx.Path()), ""))) } else { - errorConfig := utils.ErrorPageConfig{ - OverrideDir: s.cfg.Dirs.Override, - DebugMode: s.debugMode, - } - ctx.SetBody([]byte(utils.InternalErrorPage(errorConfig, string(ctx.Path()), err.Error()))) + ctx.SetBody([]byte(utils.InternalErrorPage(string(ctx.Path()), err.Error()))) } } -func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) { +// ApplyResponse applies a Response to a fasthttp.RequestCtx +func ApplyResponse(resp *runner.Response, ctx *fasthttp.RequestCtx) { + // Set status code + ctx.SetStatusCode(resp.Status) + + // Set headers + for name, value := range resp.Headers { + ctx.Response.Header.Set(name, value) + } + + // Set cookies + for _, cookie := range resp.Cookies { + ctx.Response.Header.SetCookie(cookie) + } + + // Process the body based on its type + if resp.Body == nil { + return + } + + // Check if Content-Type was manually set + contentTypeSet := false + for name := range resp.Headers { + if strings.ToLower(name) == "content-type" { + contentTypeSet = true + break + } + } + + // Get a buffer from the pool + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + // Set body based on type + switch body := resp.Body.(type) { + case string: + if !contentTypeSet { + ctx.Response.Header.SetContentType("text/plain; charset=utf-8") + } + ctx.SetBodyString(body) + case []byte: + if !contentTypeSet { + ctx.Response.Header.SetContentType("text/plain; charset=utf-8") + } + ctx.SetBody(body) + case map[string]any, map[any]any, []any, []float64, []string, []int, []map[string]any: + // Marshal JSON + if err := json.NewEncoder(buf).Encode(body); err == nil { + if !contentTypeSet { + ctx.Response.Header.SetContentType("application/json") + } + ctx.SetBody(buf.Bytes()) + } else { + // Fallback to string representation + if !contentTypeSet { + ctx.Response.Header.SetContentType("text/plain; charset=utf-8") + } + ctx.SetBodyString(fmt.Sprintf("%v", body)) + } + default: + // Check if it's any other map or slice type + typeStr := fmt.Sprintf("%T", body) + if typeStr[0] == '[' || (len(typeStr) > 3 && typeStr[:3] == "map") { + if err := json.NewEncoder(buf).Encode(body); err == nil { + if !contentTypeSet { + ctx.Response.Header.SetContentType("application/json") + } + ctx.SetBody(buf.Bytes()) + } else { + if !contentTypeSet { + ctx.Response.Header.SetContentType("text/plain; charset=utf-8") + } + ctx.SetBodyString(fmt.Sprintf("%v", body)) + } + } else { + // Default to string representation + if !contentTypeSet { + ctx.Response.Header.SetContentType("text/plain; charset=utf-8") + } + ctx.SetBodyString(fmt.Sprintf("%v", body)) + } + } +} + +/* +func HandleDebugStats(ctx *fasthttp.RequestCtx) { stats := utils.CollectSystemStats(s.cfg) stats.Components = utils.ComponentStats{ RouteCount: 0, // TODO: Get from router @@ -172,9 +164,4 @@ func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) { ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetBody([]byte(utils.DebugStatsPage(stats))) } - -func (s *Server) logRequest(ctx *fasthttp.RequestCtx, method, path string, duration time.Duration) { - if s.cfg.Server.HTTPLogging { - logger.Request(ctx.Response.StatusCode(), method, path, duration) - } -} +*/ diff --git a/logger/logger.go b/logger/logger.go index fc3fbab..2cc5f38 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -43,6 +43,7 @@ type Logger struct { enabled atomic.Bool timestamp atomic.Bool debug atomic.Bool + http atomic.Bool colors atomic.Bool mu sync.Mutex } @@ -134,11 +135,17 @@ func Raw(format string, args ...any) { global.mu.Unlock() } -func Request(status int, method, path string, duration time.Duration) { +// Attempts to write the HTTP request result to the log. Will not print if +// http is disabled on the logger. Method and path take byte slices for convenience. +func Request(status int, method, path []byte, duration time.Duration) { if !global.enabled.Load() { return } + if !global.http.Load() { + return + } + var statusColor func(string) string switch { case status < 300: @@ -170,9 +177,9 @@ func Request(status int, method, path string, duration time.Duration) { } parts = append(parts, - applyColor("["+method+"]", color.Gray), + applyColor("["+string(method)+"]", color.Gray), applyColor(fmt.Sprintf("%d", status), statusColor), - applyColor(path, color.Gray), + applyColor(string(path), color.Gray), applyColor(dur, color.Gray), ) @@ -198,3 +205,4 @@ func ColorsEnabled() bool { return global.colors.Load() } func Timestamp(enabled bool) { global.timestamp.Store(enabled) } func Debug(enabled bool) { global.debug.Store(enabled) } func IsDebug() bool { return global.debug.Load() } +func Http(enabled bool) { global.debug.Store(enabled) } diff --git a/main.go b/main.go deleted file mode 100644 index defefb1..0000000 --- a/main.go +++ /dev/null @@ -1,88 +0,0 @@ -package main - -import ( - "Moonshark/config" - "Moonshark/http" - "Moonshark/logger" - "Moonshark/metadata" - "Moonshark/router" - "Moonshark/runner" - "Moonshark/sessions" - "bytes" - "flag" - "fmt" - "time" - - "git.sharkk.net/Go/Color" - - fin "git.sharkk.net/Sharkk/Fin" - "github.com/valyala/fasthttp" -) - -var ( - cfg *config.Config // Server config from Fin file - rtr *router.Router // Lua file router - rnr *runner.Runner // Lua runner - svr *fasthttp.Server // FastHTTP server - pub fasthttp.RequestHandler // Public asset handler - snm *sessions.SessionManager // Session data manager - dbg bool // Debug mode flag - pubPfx []byte // Cached public asset prefix -) - -func main() { - cfgPath := flag.String("config", "config", "Path to Fin config file") - dbgFlag := flag.Bool("debug", false, "Force debug mode") - sptPath := flag.String("script", "", "Path to Lua script to execute once") - flag.Parse() - - sptMode := *sptPath != "" - color.SetColors(color.DetectShellColors()) - banner(sptMode) - - cfg = config.New(fin.LoadFromFile(*cfgPath)) - dbg = *dbgFlag || cfg.Server.Debug - logger.Debug(dbg) - if dbg { - logger.Debugf("Debug logging enabled") - } - - svr = http.NewHttpServer(cfg, requestMux, dbg) - pub = http.NewPublicHandler(cfg.Dirs.Public, cfg.Server.PublicPrefix) - pubPfx = []byte(cfg.Server.PublicPrefix) -} - -// This is the primary request handler mux - determines whether we need to handle a Lua -// route or if we're serving a static file. -func requestMux(ctx *fasthttp.RequestCtx) { - start := time.Now() - method := ctx.Method() - path := ctx.Path() - - // Handle static file request - if bytes.HasPrefix(path, pubPfx) { - pub(ctx) - if cfg.Server.HTTPLogging { - logger.Request(ctx.Response.StatusCode(), string(method), string(path), time.Since(start)) - } - return - } -} - -// Print our super-awesome banner with the current version! -func banner(scriptMode bool) { - if scriptMode { - fmt.Println(color.Blue(fmt.Sprintf("Moonshark %s << Script Mode >>", metadata.Version))) - return - } - - banner := ` - _____ _________.__ __ - / \ ____ ____ ____ / _____/| |__ _____ _______| | __ - / \ / \ / _ \ / _ \ / \ \_____ \ | | \\__ \\_ __ \ |/ / -/ Y ( <_> | <_> ) | \/ \| Y \/ __ \| | \/ < -\____|__ /\____/ \____/|___| /_______ /|___| (____ /__| |__|_ \ %s - \/ \/ \/ \/ \/ \/ - ` - fmt.Println(color.Blue(fmt.Sprintf(banner, metadata.Version))) -} diff --git a/moonshark-old.go b/moonshark-old.go new file mode 100644 index 0000000..c958bc8 --- /dev/null +++ b/moonshark-old.go @@ -0,0 +1,35 @@ +package main + +/* +func (s *Moonshark) setupWatchers() { + manager := watchers.GetWatcherManager() + + // Watch routes + if routeWatcher, err := watchers.WatchLuaRouter(s.LuaRouter, s.LuaRunner, s.Config.Dirs.Routes); err != nil { + logger.Warnf("Routes directory watch failed: %v", err) + } else { + routesDir := routeWatcher.GetDir() + s.cleanupFuncs = append(s.cleanupFuncs, func() error { + return manager.UnwatchDirectory(routesDir) + }) + } + + // Watch modules + if moduleWatchers, err := watchers.WatchLuaModules(s.LuaRunner, s.Config.Dirs.Libs); err != nil { + logger.Warnf("Module directories watch failed: %v", err) + } else { + for _, watcher := range moduleWatchers { + dirPath := watcher.GetDir() + s.cleanupFuncs = append(s.cleanupFuncs, func() error { + return manager.UnwatchDirectory(dirPath) + }) + } + + plural := "directory" + if len(moduleWatchers) != 1 { + plural = "directories" + } + logger.Infof("Watching %s module %s.", color.Yellow(strconv.Itoa(len(moduleWatchers))), plural) + } +} +*/ diff --git a/moonshark.go b/moonshark.go index 39d66fd..b739613 100644 --- a/moonshark.go +++ b/moonshark.go @@ -1,220 +1,252 @@ package main import ( + "Moonshark/config" + "Moonshark/http" + "Moonshark/logger" + "Moonshark/metadata" + "Moonshark/router" + "Moonshark/runner" + "Moonshark/sessions" + "Moonshark/utils" + "bytes" "context" - "errors" "flag" "fmt" "os" "os/signal" "path/filepath" - "runtime" "strconv" "syscall" "time" - "Moonshark/config" - "Moonshark/http" - "Moonshark/logger" - "Moonshark/router" - "Moonshark/runner" - "Moonshark/runner/lualibs" - "Moonshark/sessions" - "Moonshark/watchers" + color "git.sharkk.net/Go/Color" - "git.sharkk.net/Go/Color" + fin "git.sharkk.net/Sharkk/Fin" + "github.com/valyala/fasthttp" ) -type Moonshark struct { - Config *config.Config - LuaRouter *router.Router - LuaRunner *runner.Runner - HTTPServer *http.Server - cleanupFuncs []func() error - scriptMode bool -} +var ( + cfg *config.Config // Server config from Fin file + rtr *router.Router // Lua file router + rnr *runner.Runner // Lua runner + svr *fasthttp.Server // FastHTTP server + pub fasthttp.RequestHandler // Public asset handler + snm *sessions.SessionManager // Session data manager + dbg bool // Debug mode flag + pubPfx []byte // Cached public asset prefix +) func main() { - configPath := flag.String("config", "config", "Path to configuration file") - debugFlag := flag.Bool("debug", false, "Enable debug mode") - scriptPath := flag.String("script", "", "Path to Lua script to execute once") + cfgPath := flag.String("config", "config", "Path to Fin config file") + dbgFlag := flag.Bool("debug", false, "Force debug mode") + sptPath := flag.String("script", "", "Path to Lua script to execute once") flag.Parse() - scriptMode := *scriptPath != "" + // Init sequence + sptMode := *sptPath != "" color.SetColors(color.DetectShellColors()) - banner(scriptMode) + banner(sptMode) - cfg := config.New(readConfig(*configPath)) - debug := *debugFlag || cfg.Server.Debug - logger.Debug(debug) + // Load Fin-based config + cfg = config.New(fin.LoadFromFile(*cfgPath)) - if debug { - logger.Debugf("Debug logging enabled") + // Setup debug mode + dbg = *dbgFlag || cfg.Server.Debug + logger.Debug(dbg) + logger.Debugf("Debug logging enabled") // Only prints if dbg is true + utils.Debug(dbg) // @TODO find a better way to do this + + // Determine Lua runner pool size + poolSize := cfg.Runner.PoolSize + if sptMode { + poolSize = 1 } - moonshark, err := newMoonshark(cfg, debug, scriptMode) - if err != nil { - logger.Fatalf("Initialization failed: %v", err) + // Set up the Lua runner + if err := initRunner(poolSize); err != nil { + logger.Fatalf("Runner failed to init: %v", err) } - defer func() { - if err := moonshark.Shutdown(); err != nil { - logger.Errorf("Error during shutdown: %v", err) - os.Exit(1) - } - }() - - if scriptMode { - if err := moonshark.RunScript(*scriptPath); err != nil { + // If in script mode, attempt to run the Lua script at the given path + if sptMode { + if err := handleScriptMode(*sptPath); err != nil { logger.Fatalf("Script execution failed: %v", err) } + + shutdown() return } - if err := moonshark.Start(); err != nil { - logger.Fatalf("Failed to start server: %v", err) + // Set up the Lua router + if err := initRouter(); err != nil { + logger.Fatalf("Router failed to init: %s", color.Red(err.Error())) } + // Set up the HTTP portion of the server + logger.Http(cfg.Server.HTTPLogging) // Whether we'll log HTTP request results + svr = http.NewHttpServer(cfg, requestMux, dbg) + pub = http.NewPublicHandler(cfg.Dirs.Public, cfg.Server.PublicPrefix) + pubPfx = []byte(cfg.Server.PublicPrefix) // Avoids casting to []byte when check prefixes + snm = sessions.NewSessionManager(sessions.DefaultMaxSessions) + + // Start the HTTP server + logger.Infof("Surf's up on port %s!", color.Cyan(strconv.Itoa(cfg.Server.Port))) + go func() { + if err := svr.ListenAndServe(":" + strconv.Itoa(cfg.Server.Port)); err != nil { + if err.Error() != "http: Server closed" { + logger.Errorf("Server error: %v", err) + } + } + }() + + // Handle a shutdown signal stop := make(chan os.Signal, 1) signal.Notify(stop, os.Interrupt, syscall.SIGTERM) <-stop fmt.Print("\n") logger.Infof("Shutdown signal received") + shutdown() } -func newMoonshark(cfg *config.Config, debug, scriptMode bool) (*Moonshark, error) { - s := &Moonshark{Config: cfg, scriptMode: scriptMode} +// This is the primary request handler mux - determines whether we need to handle a Lua +// route or if we're serving a static file. +func requestMux(ctx *fasthttp.RequestCtx) { + start := time.Now() + method := ctx.Method() + path := ctx.Path() - if debug { - cfg.Server.Debug = true + // Handle static file request + if bytes.HasPrefix(path, pubPfx) { + pub(ctx) + logRequest(ctx, method, path, start) + return } - poolSize := cfg.Runner.PoolSize - if scriptMode { - poolSize = 1 - } - if poolSize == 0 { - poolSize = runtime.GOMAXPROCS(0) + // See if the requested route even exists + bytecode, params, found := rtr.Lookup(method, path) + if !found { + http.Send404(ctx) + logRequest(ctx, method, path, start) + return } - // Initialize runner first (needed for both modes) - if err := s.initRunner(poolSize); err != nil { - return nil, err + // If there's no bytecode then it's an internal server error + if len(bytecode) == 0 { + http.Send500(ctx, nil) + logRequest(ctx, method, path, start) } - if scriptMode { - logger.Debugf("Script mode initialized") - return s, nil + // We've made it this far so the endpoint will likely load. Let's get any session data + // for this request + session := snm.GetSessionFromRequest(ctx) + + // Let's build an HTTP context for the Lua runner to consume + luaCtx := runner.NewHTTPContext(ctx, params, session) + defer luaCtx.Release() + + // Ask the runner to execute our endpoint with our context + res, err := rnr.Execute(bytecode, luaCtx) + if err != nil { + logger.Errorf("Lua execution error: %v", err) + http.Send500(ctx, err) + logRequest(ctx, method, path, start) + return } - // Server mode: initialize router, watchers, and HTTP server - if err := s.initRouter(); err != nil { - return nil, err - } - - s.setupWatchers() - s.HTTPServer = http.New(s.LuaRouter, s.LuaRunner, cfg, debug) - - // Log static directory status - if dirExists(cfg.Dirs.Static) { - logger.Infof("Static files enabled: %s", color.Yellow(cfg.Dirs.Static)) - } else { - logger.Warnf("Static directory not found: %s", color.Yellow(cfg.Dirs.Static)) - } - - return s, nil + // Sweet, our execution went through! Let's now use the Response we got and build the HTTP response, then return + // the response object to be cleaned. After, we'll log our request cus we are *done* + applyResponse(ctx, res, session) + runner.ReleaseResponse(res) + logRequest(ctx, method, path, start) } -func (s *Moonshark) initRunner(poolSize int) error { - // Warn about missing directories but continue - if !dirExists(s.Config.Dirs.Override) { - logger.Warnf("Override directory not found... %s", color.Yellow(s.Config.Dirs.Override)) - s.Config.Dirs.Override = "" +func applyResponse(ctx *fasthttp.RequestCtx, resp *runner.Response, session *sessions.Session) { + // Handle session updates + if len(resp.SessionData) > 0 { + if _, clearAll := resp.SessionData["__clear_all"]; clearAll { + session.Clear() + session.ClearFlash() + delete(resp.SessionData, "__clear_all") + } + + for k, v := range resp.SessionData { + if v == "__DELETE__" { + session.Delete(k) + } else { + session.Set(k, v) + } + } } - for _, dir := range s.Config.Dirs.Libs { + // Handle flash data + if flashData, ok := resp.Metadata["flash"].(map[string]any); ok { + for k, v := range flashData { + if err := session.FlashSafe(k, v); err != nil && dbg { + logger.Warnf("Error setting flash data %s: %v", k, err) + } + } + } + + // Apply session cookie + snm.ApplySessionCookie(ctx, session) + + // Apply HTTP response + http.ApplyResponse(resp, ctx) +} + +// Attempts to start the Lua runner. poolSize allows overriding the config, like for script mode. A poolSize of +// 0 will default to the config, and if the config is 0 then it will default to GOMAXPROCS. +func initRunner(poolSize int) error { + for _, dir := range cfg.Dirs.Libs { if !dirExists(dir) { logger.Warnf("Lib directory not found... %s", color.Yellow(dir)) } } - if err := lualibs.InitEnv(s.Config.Dirs.Data); err != nil { - logger.Warnf("Environment initialization failed: %v", err) - } - - sessions.GlobalSessionManager.SetCookieOptions("MoonsharkSID", "/", "", false, true, 86400) - - var err error - s.LuaRunner, err = runner.NewRunner(poolSize, s.Config.Dirs.Data, s.Config.Dirs.FS, s.Config.Dirs.Libs) + runner, err := runner.NewRunner(cfg, poolSize) if err != nil { return fmt.Errorf("lua runner init failed: %v", err) } + rnr = runner logger.Infof("LuaRunner is g2g with %s states!", color.Yellow(strconv.Itoa(poolSize))) return nil } -func (s *Moonshark) initRouter() error { - if err := os.MkdirAll(s.Config.Dirs.Routes, 0755); err != nil { +// Attempt to spin up the Lua router. Attempts to create the routes directory if it doesn't exist, +// since it's required for Moonshark to work. +func initRouter() error { + if err := os.MkdirAll(cfg.Dirs.Routes, 0755); err != nil { return fmt.Errorf("failed to create routes directory: %w", err) } - var err error - s.LuaRouter, err = router.New(s.Config.Dirs.Routes) + router, err := router.New(cfg.Dirs.Routes) if err != nil { return fmt.Errorf("lua router init failed: %v", err) } + rtr = router - logger.Infof("LuaRouter is g2g! %s", color.Yellow(s.Config.Dirs.Routes)) + logger.Infof("LuaRouter is g2g! %s", color.Yellow(cfg.Dirs.Routes)) return nil } -func (s *Moonshark) setupWatchers() { - manager := watchers.GetWatcherManager() - - // Watch routes - if routeWatcher, err := watchers.WatchLuaRouter(s.LuaRouter, s.LuaRunner, s.Config.Dirs.Routes); err != nil { - logger.Warnf("Routes directory watch failed: %v", err) - } else { - routesDir := routeWatcher.GetDir() - s.cleanupFuncs = append(s.cleanupFuncs, func() error { - return manager.UnwatchDirectory(routesDir) - }) - } - - // Watch modules - if moduleWatchers, err := watchers.WatchLuaModules(s.LuaRunner, s.Config.Dirs.Libs); err != nil { - logger.Warnf("Module directories watch failed: %v", err) - } else { - for _, watcher := range moduleWatchers { - dirPath := watcher.GetDir() - s.cleanupFuncs = append(s.cleanupFuncs, func() error { - return manager.UnwatchDirectory(dirPath) - }) - } - - plural := "directory" - if len(moduleWatchers) != 1 { - plural = "directories" - } - logger.Infof("Watching %s module %s.", color.Yellow(strconv.Itoa(len(moduleWatchers))), plural) - } -} - -func (s *Moonshark) RunScript(scriptPath string) error { - scriptPath, err := filepath.Abs(scriptPath) +// Attempts to execute the Lua script at the given path inside a fully initialized sandbox environment. Handy +// for pre-launch tasks and the like. +func handleScriptMode(path string) error { + path, err := filepath.Abs(path) if err != nil { return fmt.Errorf("failed to resolve script path: %v", err) } - if _, err := os.Stat(scriptPath); os.IsNotExist(err) { - return fmt.Errorf("script file not found: %s", scriptPath) + if _, err := os.Stat(path); os.IsNotExist(err) { + return fmt.Errorf("script file not found: %s", path) } - logger.Infof("Executing: %s", scriptPath) + logger.Infof("Executing: %s", path) - resp, err := s.LuaRunner.RunScriptFile(scriptPath) + resp, err := rnr.RunScriptFile(path) if err != nil { return fmt.Errorf("execution failed: %v", err) } @@ -228,55 +260,48 @@ func (s *Moonshark) RunScript(scriptPath string) error { return nil } -func (s *Moonshark) Start() error { - if s.scriptMode { - return errors.New("cannot start server in script mode") - } - - logger.Infof("Surf's up on port %s!", color.Cyan(strconv.Itoa(s.Config.Server.Port))) - - go func() { - if err := s.HTTPServer.ListenAndServe(fmt.Sprintf(":%d", s.Config.Server.Port)); err != nil { - if err.Error() != "http: Server closed" { - logger.Errorf("Server error: %v", err) - } - } - }() - - return nil -} - -func (s *Moonshark) Shutdown() error { +func shutdown() { logger.Infof("Shutting down...") - if !s.scriptMode && s.HTTPServer != nil { + if svr != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := s.HTTPServer.Shutdown(ctx); err != nil { + if err := svr.ShutdownWithContext(ctx); err != nil { logger.Errorf("HTTP server shutdown error: %v", err) } } - for _, cleanup := range s.cleanupFuncs { - if err := cleanup(); err != nil { - logger.Warnf("Cleanup error: %v", err) - } - } - - if s.LuaRunner != nil { - s.LuaRunner.Close() - } - - if err := lualibs.CleanupEnv(); err != nil { - logger.Warnf("Environment cleanup failed: %v", err) + if rnr != nil { + rnr.Close() } logger.Infof("Shutdown complete") - return nil +} + +// Print our super-awesome banner with the current version! +func banner(scriptMode bool) { + if scriptMode { + fmt.Println(color.Blue(fmt.Sprintf("Moonshark %s << Script Mode >>", metadata.Version))) + return + } + + banner := ` + _____ _________.__ __ + / \ ____ ____ ____ / _____/| |__ _____ _______| | __ + / \ / \ / _ \ / _ \ / \ \_____ \ | | \\__ \\_ __ \ |/ / +/ Y ( <_> | <_> ) | \/ \| Y \/ __ \| | \/ < +\____|__ /\____/ \____/|___| /_______ /|___| (____ /__| |__|_ \ %s + \/ \/ \/ \/ \/ \/ + ` + fmt.Println(color.Blue(fmt.Sprintf(banner, metadata.Version))) } func dirExists(path string) bool { info, err := os.Stat(path) return err == nil && info.IsDir() } + +func logRequest(ctx *fasthttp.RequestCtx, method, path []byte, start time.Time) { + logger.Request(ctx.Response.StatusCode(), method, path, time.Since(start)) +} diff --git a/router/router.go b/router/router.go index 590c9d1..2a923ff 100644 --- a/router/router.go +++ b/router/router.go @@ -1,6 +1,7 @@ package router import ( + "bytes" "errors" "os" "path/filepath" @@ -11,9 +12,18 @@ import ( "github.com/VictoriaMetrics/fastcache" ) +var ( + slash = []byte("/") + get = []byte("GET") + post = []byte("POST") + put = []byte("PUT") + patch = []byte("PATCH") + delete = []byte("DELETE") +) + // node represents a node in the radix trie type node struct { - segment string + segment []byte bytecode []byte scriptPath string children []*node @@ -30,7 +40,7 @@ type Router struct { compileState *luajit.State compileMu sync.Mutex paramsBuffer []string - middlewareFiles map[string][]string // filesystem path -> middleware file paths + middlewareFiles map[string][]string } // Params holds URL parameters @@ -71,7 +81,7 @@ func New(routesDir string) (*Router, error) { put: &node{}, patch: &node{}, delete: &node{}, - bytecodeCache: fastcache.New(32 * 1024 * 1024), // 32MB + bytecodeCache: fastcache.New(32 * 1024 * 1024), compileState: compileState, paramsBuffer: make([]string, 64), middlewareFiles: make(map[string][]string), @@ -81,17 +91,17 @@ func New(routesDir string) (*Router, error) { } // methodNode returns the root node for a method -func (r *Router) methodNode(method string) *node { - switch method { - case "GET": +func (r *Router) methodNode(method []byte) *node { + switch { + case bytes.Equal(method, get): return r.get - case "POST": + case bytes.Equal(method, post): return r.post - case "PUT": + case bytes.Equal(method, put): return r.put - case "PATCH": + case bytes.Equal(method, patch): return r.patch - case "DELETE": + case bytes.Equal(method, delete): return r.delete default: return nil @@ -108,7 +118,8 @@ func (r *Router) buildRoutes() error { return err } - if strings.TrimSuffix(info.Name(), ".lua") == "middleware" { + fileName := strings.TrimSuffix(info.Name(), ".lua") + if fileName == "middleware" { relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path)) if err != nil { return err @@ -121,7 +132,6 @@ func (r *Router) buildRoutes() error { r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path) } - return nil }) @@ -136,39 +146,36 @@ func (r *Router) buildRoutes() error { } fileName := strings.TrimSuffix(info.Name(), ".lua") - - // Skip middleware files if fileName == "middleware" { return nil } - // Get relative path from routes directory relPath, err := filepath.Rel(r.routesDir, path) if err != nil { return err } - // Get filesystem path (includes groups) fsPath := "/" + strings.ReplaceAll(filepath.Dir(relPath), "\\", "/") if fsPath == "/." { fsPath = "/" } - // Get URL path (excludes groups) urlPath := r.parseURLPath(fsPath) + urlPathBytes := []byte(urlPath) - // Handle method files (get.lua, post.lua, etc.) - method := strings.ToUpper(fileName) - root := r.methodNode(method) + // Handle method files + methodBytes := []byte(strings.ToUpper(fileName)) + root := r.methodNode(methodBytes) if root != nil { - return r.addRoute(root, urlPath, fsPath, path) + return r.addRoute(root, urlPathBytes, fsPath, path) } - // Handle index files - register for all methods + // Handle index files if fileName == "index" { - for _, method := range []string{"GET", "POST", "PUT", "PATCH", "DELETE"} { + methods := [][]byte{get, post, put, patch, delete} + for _, method := range methods { if root := r.methodNode(method); root != nil { - if err := r.addRoute(root, urlPath, fsPath, path); err != nil { + if err := r.addRoute(root, urlPathBytes, fsPath, path); err != nil { return err } } @@ -176,12 +183,13 @@ func (r *Router) buildRoutes() error { return nil } - // Handle named route files - register as GET by default - namedPath := urlPath + // Handle named route files + var namedPath []byte if urlPath == "/" { - namedPath = "/" + fileName + namedPath = append(slash, fileName...) } else { - namedPath = urlPath + "/" + fileName + namedPath = append(urlPathBytes, '/') + namedPath = append(namedPath, fileName...) } return r.addRoute(r.get, namedPath, fsPath, path) }) @@ -196,7 +204,6 @@ func (r *Router) parseURLPath(fsPath string) string { if segment == "" { continue } - // Skip group segments (enclosed in parentheses) if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") { continue } @@ -218,12 +225,10 @@ func (r *Router) getMiddlewareChain(fsPath string) []string { pathParts = []string{} } - // Add root middleware if mw, exists := r.middlewareFiles["/"]; exists { chain = append(chain, mw...) } - // Add middleware from each path level (including groups) currentPath := "" for _, part := range pathParts { currentPath += "/" + part @@ -239,7 +244,6 @@ func (r *Router) getMiddlewareChain(fsPath string) []string { func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) { var combined strings.Builder - // Add middleware in order middlewareChain := r.getMiddlewareChain(fsPath) for _, mwPath := range middlewareChain { content, err := os.ReadFile(mwPath) @@ -253,7 +257,6 @@ func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) combined.WriteString("\n") } - // Add main handler content, err := os.ReadFile(scriptPath) if err != nil { return "", err @@ -267,14 +270,12 @@ func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) } // addRoute adds a new route to the trie with bytecode compilation -func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error { - // Build combined source with middleware +func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string) error { combinedSource, err := r.buildCombinedSource(fsPath, scriptPath) if err != nil { return err } - // Compile bytecode r.compileMu.Lock() bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath) r.compileMu.Unlock() @@ -283,11 +284,10 @@ func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error return err } - // Cache bytecode cacheKey := hashString(scriptPath) r.bytecodeCache.Set(uint64ToBytes(cacheKey), bytecode) - if urlPath == "/" { + if len(urlPath) == 1 && urlPath[0] == '/' { root.bytecode = bytecode root.scriptPath = scriptPath return nil @@ -298,8 +298,8 @@ func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error paramCount := uint8(0) for { - seg, newPos, more := readSegment(urlPath, pos) - if seg == "" { + seg, newPos, more := readSegmentBytes(urlPath, pos) + if len(seg) == 0 { break } @@ -314,10 +314,9 @@ func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error paramCount++ } - // Find or create child var child *node for _, c := range current.children { - if c.segment == seg { + if bytes.Equal(c.segment, seg) { child = c break } @@ -325,7 +324,7 @@ func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error if child == nil { child = &node{ - segment: seg, + segment: append([]byte(nil), seg...), isDynamic: isDyn, isWildcard: isWC, } @@ -345,16 +344,16 @@ func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error return nil } -// readSegment extracts the next path segment -func readSegment(path string, start int) (segment string, end int, hasMore bool) { +// readSegmentBytes extracts the next path segment from byte slice +func readSegmentBytes(path []byte, start int) (segment []byte, end int, hasMore bool) { if start >= len(path) { - return "", start, false + return nil, start, false } if path[start] == '/' { start++ } if start >= len(path) { - return "", start, false + return nil, start, false } end = start for end < len(path) && path[end] != '/' { @@ -364,20 +363,19 @@ func readSegment(path string, start int) (segment string, end int, hasMore bool) } // Lookup finds bytecode and parameters for a method and path -func (r *Router) Lookup(method, path string) ([]byte, *Params, bool) { +func (r *Router) Lookup(method, path []byte) ([]byte, *Params, bool) { root := r.methodNode(method) if root == nil { return nil, nil, false } - if path == "/" { + if len(path) == 1 && path[0] == '/' { if root.bytecode != nil { return root.bytecode, &Params{}, true } return nil, nil, false } - // Prepare params buffer buffer := r.paramsBuffer if cap(buffer) < int(root.maxParams) { buffer = make([]string, root.maxParams) @@ -400,7 +398,7 @@ func (r *Router) Lookup(method, path string) ([]byte, *Params, bool) { } // match traverses the trie to find bytecode -func (r *Router) match(current *node, path string, start int, params *[]string, keys *[]string) ([]byte, int, bool) { +func (r *Router) match(current *node, path []byte, start int, params *[]string, keys *[]string) ([]byte, int, bool) { paramCount := 0 // Check wildcard first @@ -410,22 +408,23 @@ func (r *Router) match(current *node, path string, start int, params *[]string, if len(rem) > 0 && rem[0] == '/' { rem = rem[1:] } - *params = append(*params, rem) - *keys = append(*keys, strings.TrimPrefix(c.segment, "*")) + *params = append(*params, string(rem)) + paramName := string(c.segment[1:]) // Remove * + *keys = append(*keys, paramName) return c.bytecode, 1, c.bytecode != nil } } - seg, pos, more := readSegment(path, start) - if seg == "" { + seg, pos, more := readSegmentBytes(path, start) + if len(seg) == 0 { return current.bytecode, 0, current.bytecode != nil } for _, c := range current.children { - if c.segment == seg || c.isDynamic { + if bytes.Equal(c.segment, seg) || c.isDynamic { if c.isDynamic { - *params = append(*params, seg) - paramName := c.segment[1 : len(c.segment)-1] // Remove [ ] + *params = append(*params, string(seg)) + paramName := string(c.segment[1 : len(c.segment)-1]) // Remove [ ] *keys = append(*keys, paramName) paramCount++ } @@ -478,7 +477,6 @@ func (r *Router) Close() { r.compileMu.Unlock() } -// Helper functions from cache.go func hashString(s string) uint64 { h := uint64(5381) for i := 0; i < len(s); i++ { diff --git a/runner/context.go b/runner/context.go index 382e682..d7388ad 100644 --- a/runner/context.go +++ b/runner/context.go @@ -2,21 +2,19 @@ package runner import ( "sync" - - "github.com/valyala/bytebufferpool" - "github.com/valyala/fasthttp" ) -// Context represents execution context for a Lua script +// Generic interface to support different types of execution contexts for the runner +type ExecutionContext interface { + Get(key string) any + Set(key string, value any) + ToMap() map[string]any + Release() +} + +// This is a generic context that satisfies the runner's ExecutionContext interface type Context struct { - // Values stores any context values (route params, HTTP request info, etc.) - Values map[string]any - - // FastHTTP context if this was created from an HTTP request - RequestCtx *fasthttp.RequestCtx - - // Buffer for efficient string operations - buffer *bytebufferpool.ByteBuffer + Values map[string]any // Any data we want to pass to the state's global ctx table. } // Context pool to reduce allocations @@ -28,57 +26,12 @@ var contextPool = sync.Pool{ }, } -// NewContext creates a new context, potentially reusing one from the pool +// Gets a new context from the pool func NewContext() *Context { ctx := contextPool.Get().(*Context) return ctx } -// NewHTTPContext creates a new context from a fasthttp RequestCtx -func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context { - ctx := NewContext() - ctx.RequestCtx = requestCtx - - // Extract common HTTP values that Lua might need - if requestCtx != nil { - ctx.Values["_request_method"] = string(requestCtx.Method()) - ctx.Values["_request_path"] = string(requestCtx.Path()) - ctx.Values["_request_url"] = string(requestCtx.RequestURI()) - - // Extract cookies - cookies := make(map[string]any) - requestCtx.Request.Header.VisitAllCookie(func(key, value []byte) { - cookies[string(key)] = string(value) - }) - ctx.Values["_request_cookies"] = cookies - - // Extract query params - query := make(map[string]any) - requestCtx.QueryArgs().VisitAll(func(key, value []byte) { - query[string(key)] = string(value) - }) - ctx.Values["_request_query"] = query - - // Extract form data if present - if requestCtx.IsPost() || requestCtx.IsPut() { - form := make(map[string]any) - requestCtx.PostArgs().VisitAll(func(key, value []byte) { - form[string(key)] = string(value) - }) - ctx.Values["_request_form"] = form - } - - // Extract headers - headers := make(map[string]any) - requestCtx.Request.Header.VisitAll(func(key, value []byte) { - headers[string(key)] = string(value) - }) - ctx.Values["_request_headers"] = headers - } - - return ctx -} - // Release returns the context to the pool after clearing its values func (c *Context) Release() { // Clear all values to prevent data leakage @@ -86,26 +39,9 @@ func (c *Context) Release() { delete(c.Values, k) } - // Reset request context - c.RequestCtx = nil - - // Return buffer to pool if we have one - if c.buffer != nil { - bytebufferpool.Put(c.buffer) - c.buffer = nil - } - contextPool.Put(c) } -// GetBuffer returns a byte buffer for efficient string operations -func (c *Context) GetBuffer() *bytebufferpool.ByteBuffer { - if c.buffer == nil { - c.buffer = bytebufferpool.Get() - } - return c.buffer -} - // Set adds a value to the context func (c *Context) Set(key string, value any) { c.Values[key] = value @@ -115,3 +51,8 @@ func (c *Context) Set(key string, value any) { func (c *Context) Get(key string) any { return c.Values[key] } + +// We can just return the Values map as it's already g2g for Lua +func (c *Context) ToMap() map[string]any { + return c.Values +} diff --git a/runner/httpContext.go b/runner/httpContext.go new file mode 100644 index 0000000..254a51a --- /dev/null +++ b/runner/httpContext.go @@ -0,0 +1,165 @@ +package runner + +import ( + "Moonshark/router" + "Moonshark/runner/lualibs" + "Moonshark/sessions" + "Moonshark/utils" + "sync" + + "github.com/valyala/fasthttp" +) + +// A prebuilt, ready-to-go context for HTTP requests to the runner. +type HTTPContext struct { + Method []byte + Path []byte + Host []byte + Headers map[string]any + Cookies map[string]string + Query map[string]string + Params map[string]any + Form map[string]any + Session map[string]any + Env map[string]any + Values map[string]any // Extra context vars just in case +} + +// HTTP context pool to reduce allocations +var httpContextPool = sync.Pool{ + New: func() any { + return &HTTPContext{ + Headers: make(map[string]any, 16), + Cookies: make(map[string]string, 8), + Query: make(map[string]string, 8), + Params: make(map[string]any, 4), + Form: make(map[string]any, 8), + Session: make(map[string]any, 4), + Env: make(map[string]any, 16), + Values: make(map[string]any, 32), + } + }, +} + +// Get a clean HTTP context from the pool and build it up with an HTTP request, router params and session data +func NewHTTPContext(httpCtx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) *HTTPContext { + ctx := httpContextPool.Get().(*HTTPContext) + + // Extract basic HTTP info + ctx.Method = httpCtx.Method() + ctx.Path = httpCtx.Path() + ctx.Host = httpCtx.Host() + + // Extract headers + httpCtx.Request.Header.VisitAll(func(key, value []byte) { + ctx.Headers[string(key)] = string(value) + }) + + // Extract cookies + httpCtx.Request.Header.VisitAllCookie(func(key, value []byte) { + ctx.Cookies[string(key)] = string(value) + }) + + // Extract query params + httpCtx.QueryArgs().VisitAll(func(key, value []byte) { + ctx.Query[string(key)] = string(value) + }) + + // Extract route parameters + if params != nil { + for i := 0; i < min(len(params.Keys), len(params.Values)); i++ { + ctx.Params[params.Keys[i]] = params.Values[i] + } + } + + // Extract form data if present + if httpCtx.IsPost() || httpCtx.IsPut() || httpCtx.IsPatch() { + if form, err := utils.ParseForm(httpCtx); err == nil { + for k, v := range form { + ctx.Form[k] = v + } + } + } + + // Extract session data + session.AdvanceFlash() + ctx.Session["id"] = session.ID + if session.IsEmpty() { + ctx.Session["data"] = emptyMap + ctx.Session["flash"] = emptyMap + } else { + ctx.Session["data"] = session.GetAll() + ctx.Session["flash"] = session.GetAllFlash() + } + + // Add environment vars + if envMgr := lualibs.GetGlobalEnvManager(); envMgr != nil { + for k, v := range envMgr.GetAll() { + ctx.Env[k] = v + } + } + + return ctx +} + +// Clear out all the request data from the context and give it back to the pool. Keeps the contexts and inner maps +// allocated to prevent GC churn. +func (c *HTTPContext) Release() { + for k := range c.Headers { + delete(c.Headers, k) + } + for k := range c.Cookies { + delete(c.Cookies, k) + } + for k := range c.Query { + delete(c.Query, k) + } + for k := range c.Params { + delete(c.Params, k) + } + for k := range c.Form { + delete(c.Form, k) + } + for k := range c.Session { + delete(c.Session, k) + } + for k := range c.Env { + delete(c.Env, k) + } + for k := range c.Values { + delete(c.Values, k) + } + + c.Method = nil + c.Path = nil + c.Host = nil + + httpContextPool.Put(c) +} + +// Add a value to the extras map +func (c *HTTPContext) Set(key string, value any) { + c.Values[key] = value +} + +// Get a value from the extras map +func (c *HTTPContext) Get(key string) any { + return c.Values[key] +} + +// Returns a representation of the context ready for Lua +func (c *HTTPContext) ToMap() map[string]any { + return map[string]any{ + "method": string(c.Method), + "path": string(c.Path), + "host": string(c.Host), + "headers": c.Headers, + "cookies": c.Cookies, + "query": c.Query, + "params": c.Params, + "form": c.Form, + "session": c.Session, + "env": c.Env, + "values": c.Values, + } +} diff --git a/runner/lualibs/env.go b/runner/lualibs/env.go index 93bb26d..d68a5d8 100644 --- a/runner/lualibs/env.go +++ b/runner/lualibs/env.go @@ -10,9 +10,10 @@ import ( "strings" "sync" - "Moonshark/color" "Moonshark/logger" + "git.sharkk.net/Go/Color" + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) diff --git a/runner/lualibs/util.go b/runner/lualibs/util.go index 5cf45fa..83ec874 100644 --- a/runner/lualibs/util.go +++ b/runner/lualibs/util.go @@ -2,10 +2,11 @@ package lualibs import ( "encoding/base64" - "encoding/json" "html" "strings" + "github.com/goccy/go-json" + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) diff --git a/runner/runner.go b/runner/runner.go index 5d6a83a..0aafa2d 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -6,21 +6,16 @@ import ( "os" "path/filepath" "runtime" - "strings" "sync" "sync/atomic" "time" + "Moonshark/config" "Moonshark/logger" - "Moonshark/router" "Moonshark/runner/lualibs" "Moonshark/runner/sqlite" - "Moonshark/sessions" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" - "github.com/goccy/go-json" - "github.com/valyala/bytebufferpool" - "github.com/valyala/fasthttp" ) var emptyMap = make(map[string]any) @@ -52,14 +47,13 @@ type Runner struct { paramsPool sync.Pool } -func NewRunner(poolSize int, dataDir, fsDir string, libDirs []string) (*Runner, error) { - if poolSize <= 0 { +func NewRunner(cfg *config.Config, poolSize int) (*Runner, error) { + if poolSize <= 0 && cfg.Runner.PoolSize <= 0 { poolSize = runtime.GOMAXPROCS(0) } - // Configure module loader with lib directories moduleConfig := &ModuleConfig{ - LibDirs: libDirs, + LibDirs: cfg.Dirs.Libs, } r := &Runner{ @@ -73,9 +67,10 @@ func NewRunner(poolSize int, dataDir, fsDir string, libDirs []string) (*Runner, }, } - sqlite.InitSQLite(dataDir) - lualibs.InitFS(fsDir) + sqlite.InitSQLite(cfg.Dirs.Data) sqlite.SetSQLitePoolSize(poolSize) + lualibs.InitFS(cfg.Dirs.FS) + lualibs.InitEnv(cfg.Dirs.Data) r.states = make([]*State, poolSize) r.statePool = make(chan int, poolSize) @@ -89,15 +84,11 @@ func NewRunner(poolSize int, dataDir, fsDir string, libDirs []string) (*Runner, return r, nil } -// Single entry point for HTTP execution -func (r *Runner) ExecuteHTTP(bytecode []byte, httpCtx *fasthttp.RequestCtx, - params *router.Params, session *sessions.Session) (*Response, error) { - +func (r *Runner) Execute(bytecode []byte, ctx ExecutionContext) (*Response, error) { if !r.isRunning.Load() { return nil, ErrRunnerClosed } - // Get state with timeout var stateIndex int select { case stateIndex = <-r.statePool: @@ -106,8 +97,12 @@ func (r *Runner) ExecuteHTTP(bytecode []byte, httpCtx *fasthttp.RequestCtx, } state := r.states[stateIndex] - state.inUse.Store(true) + if state == nil { + r.statePool <- stateIndex + return nil, ErrStateNotReady + } + state.inUse.Store(true) defer func() { state.inUse.Store(false) if r.isRunning.Load() { @@ -118,114 +113,7 @@ func (r *Runner) ExecuteHTTP(bytecode []byte, httpCtx *fasthttp.RequestCtx, } }() - // Build Lua context directly from HTTP request - luaCtx := r.buildHTTPContext(httpCtx, params, session) - defer r.releaseHTTPContext(luaCtx) - - return state.sandbox.Execute(state.L, bytecode, luaCtx, state.index) -} - -// Build Lua context from HTTP request -func (r *Runner) buildHTTPContext(ctx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) *Context { - luaCtx := NewContext() - - // Basic request info - luaCtx.Set("method", string(ctx.Method())) - luaCtx.Set("path", string(ctx.Path())) - luaCtx.Set("host", string(ctx.Host())) - - // Headers - headers := r.ctxPool.Get().(map[string]any) - ctx.Request.Header.VisitAll(func(key, value []byte) { - headers[string(key)] = string(value) - }) - luaCtx.Set("headers", headers) - - // Cookies - cookies := r.ctxPool.Get().(map[string]any) - ctx.Request.Header.VisitAllCookie(func(key, value []byte) { - cookies[string(key)] = string(value) - }) - luaCtx.Set("cookies", cookies) - - // Route parameters - if params != nil && len(params.Keys) > 0 { - paramMap := r.paramsPool.Get().(map[string]any) - for i, key := range params.Keys { - if i < len(params.Values) { - paramMap[key] = params.Values[i] - } - } - luaCtx.Set("params", paramMap) - } else { - luaCtx.Set("params", emptyMap) - } - - // Form data for POST/PUT/PATCH - method := ctx.Method() - if string(method) == "POST" || string(method) == "PUT" || string(method) == "PATCH" { - if formData := parseForm(ctx); formData != nil { - luaCtx.Set("form", formData) - } else { - luaCtx.Set("form", emptyMap) - } - } else { - luaCtx.Set("form", emptyMap) - } - - // Session data - sessionMap := r.ctxPool.Get().(map[string]any) - session.AdvanceFlash() - sessionMap["id"] = session.ID - - if !session.IsEmpty() { - sessionMap["data"] = session.GetAll() - sessionMap["flash"] = session.GetAllFlash() - } else { - sessionMap["data"] = emptyMap - sessionMap["flash"] = emptyMap - } - luaCtx.Set("session", sessionMap) - - // Environment variables - if envMgr := lualibs.GetGlobalEnvManager(); envMgr != nil { - luaCtx.Set("env", envMgr.GetAll()) - } - - return luaCtx -} - -// Releases the HTTP context's maps back to their pool -func (r *Runner) releaseHTTPContext(luaCtx *Context) { - if headers, ok := luaCtx.Get("headers").(map[string]any); ok { - for k := range headers { - delete(headers, k) - } - r.ctxPool.Put(headers) - } - - if cookies, ok := luaCtx.Get("cookies").(map[string]any); ok { - for k := range cookies { - delete(cookies, k) - } - r.ctxPool.Put(cookies) - } - - if params, ok := luaCtx.Get("params").(map[string]any); ok && len(params) > 0 { - for k := range params { - delete(params, k) - } - r.paramsPool.Put(params) - } - - if sessionMap, ok := luaCtx.Get("session").(map[string]any); ok { - for k := range sessionMap { - delete(sessionMap, k) - } - r.ctxPool.Put(sessionMap) - } - - luaCtx.Release() + return state.sandbox.Execute(state.L, bytecode, ctx, state.index) } func (r *Runner) initStates() error { @@ -316,35 +204,10 @@ cleanup: lualibs.CleanupFS() sqlite.CleanupSQLite() + lualibs.CleanupEnv() return nil } -// parseForm extracts form data from HTTP request -func parseForm(ctx *fasthttp.RequestCtx) map[string]any { - form := make(map[string]any) - - // Parse POST form data - ctx.PostArgs().VisitAll(func(key, value []byte) { - form[string(key)] = string(value) - }) - - // Parse multipart form if present - if multipartForm, err := ctx.MultipartForm(); err == nil { - for key, values := range multipartForm.Value { - if len(values) == 1 { - form[key] = values[0] - } else { - form[key] = values - } - } - } - - if len(form) == 0 { - return nil - } - return form -} - // NotifyFileChanged alerts the runner about file changes func (r *Runner) NotifyFileChanged(filePath string) bool { logger.Debugf("Runner notified of file change: %s", filePath) @@ -470,87 +333,3 @@ func (r *Runner) RunScriptFile(filePath string) (*Response, error) { return response, nil } - -// ApplyResponse applies a Response to a fasthttp.RequestCtx -func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) { - // Set status code - ctx.SetStatusCode(resp.Status) - - // Set headers - for name, value := range resp.Headers { - ctx.Response.Header.Set(name, value) - } - - // Set cookies - for _, cookie := range resp.Cookies { - ctx.Response.Header.SetCookie(cookie) - } - - // Process the body based on its type - if resp.Body == nil { - return - } - - // Check if Content-Type was manually set - contentTypeSet := false - for name := range resp.Headers { - if strings.ToLower(name) == "content-type" { - contentTypeSet = true - break - } - } - - // Get a buffer from the pool - buf := bytebufferpool.Get() - defer bytebufferpool.Put(buf) - - // Set body based on type - switch body := resp.Body.(type) { - case string: - if !contentTypeSet { - ctx.Response.Header.SetContentType("text/plain; charset=utf-8") - } - ctx.SetBodyString(body) - case []byte: - if !contentTypeSet { - ctx.Response.Header.SetContentType("text/plain; charset=utf-8") - } - ctx.SetBody(body) - case map[string]any, map[any]any, []any, []float64, []string, []int, []map[string]any: - // Marshal JSON - if err := json.NewEncoder(buf).Encode(body); err == nil { - if !contentTypeSet { - ctx.Response.Header.SetContentType("application/json") - } - ctx.SetBody(buf.Bytes()) - } else { - // Fallback to string representation - if !contentTypeSet { - ctx.Response.Header.SetContentType("text/plain; charset=utf-8") - } - ctx.SetBodyString(fmt.Sprintf("%v", body)) - } - default: - // Check if it's any other map or slice type - typeStr := fmt.Sprintf("%T", body) - if typeStr[0] == '[' || (len(typeStr) > 3 && typeStr[:3] == "map") { - if err := json.NewEncoder(buf).Encode(body); err == nil { - if !contentTypeSet { - ctx.Response.Header.SetContentType("application/json") - } - ctx.SetBody(buf.Bytes()) - } else { - if !contentTypeSet { - ctx.Response.Header.SetContentType("text/plain; charset=utf-8") - } - ctx.SetBodyString(fmt.Sprintf("%v", body)) - } - } else { - // Default to string representation - if !contentTypeSet { - ctx.Response.Header.SetContentType("text/plain; charset=utf-8") - } - ctx.SetBodyString(fmt.Sprintf("%v", body)) - } - } -} diff --git a/runner/sandbox.go b/runner/sandbox.go index 61f91db..5393390 100644 --- a/runner/sandbox.go +++ b/runner/sandbox.go @@ -49,7 +49,7 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int, verbose bool) error } // Execute runs a Lua script in the sandbox with the given context -func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context, stateIndex int) (*Response, error) { +func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx ExecutionContext, stateIndex int) (*Response, error) { // Load script and executor if err := state.LoadBytecode(bytecode, "script"); err != nil { return nil, fmt.Errorf("failed to load bytecode: %w", err) @@ -78,7 +78,7 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context, st // Call __execute(script_func, ctx, response) state.PushCopy(-2) // script function - state.PushValue(ctx.Values) + state.PushValue(ctx.ToMap()) state.PushValue(response) if err := state.Call(3, 1); err != nil { diff --git a/sessions/manager.go b/sessions/manager.go index 0e4cdf5..5d8dd2d 100644 --- a/sessions/manager.go +++ b/sessions/manager.go @@ -41,13 +41,15 @@ func NewSessionManager(maxSessions int) *SessionManager { cache: fastcache.New(maxSessions * 4096), cookieName: DefaultCookieName, cookiePath: DefaultCookiePath, + cookieDomain: "", + cookieSecure: false, cookieHTTPOnly: true, cookieMaxAge: DefaultMaxAge, cleanupDone: make(chan struct{}), } // Pre-populate session pool - for i := 0; i < 100; i++ { + for range 100 { s := NewSession("", 0) s.Release() } @@ -207,6 +209,3 @@ func (sm *SessionManager) GetCacheStats() map[string]uint64 { "misses": stats.Misses, } } - -// GlobalSessionManager is the default session manager instance -var GlobalSessionManager = NewSessionManager(DefaultMaxSessions) diff --git a/utils/errorPages.go b/utils/errorPages.go index 956cd1c..ba8437d 100644 --- a/utils/errorPages.go +++ b/utils/errorPages.go @@ -2,15 +2,9 @@ package utils import ( "math/rand" - "os" - "path/filepath" ) -// ErrorPageConfig holds configuration for generating error pages -type ErrorPageConfig struct { - OverrideDir string // Directory where override templates are stored - DebugMode bool // Whether to show debug information -} +var dbg bool // ErrorType represents HTTP error types type ErrorType int @@ -22,31 +16,11 @@ const ( ErrorTypeForbidden ErrorType = 403 // Added CSRF/Forbidden error type ) +func Debug(enabled bool) { dbg = enabled } + // ErrorPage generates an HTML error page based on the error type // It first checks for an override file, and if not found, generates a default page -func ErrorPage(config ErrorPageConfig, errorType ErrorType, url string, errMsg string) string { - // Check for override file - if config.OverrideDir != "" { - var filename string - switch errorType { - case ErrorTypeNotFound: - filename = "404.html" - case ErrorTypeMethodNotAllowed: - filename = "405.html" - case ErrorTypeInternalError: - filename = "500.html" - case ErrorTypeForbidden: - filename = "403.html" - } - - if filename != "" { - overridePath := filepath.Join(config.OverrideDir, filename) - if content, err := os.ReadFile(overridePath); err == nil { - return string(content) - } - } - } - +func ErrorPage(errorType ErrorType, url string, errMsg string) string { // No override found, generate default page switch errorType { case ErrorTypeNotFound: @@ -54,33 +28,33 @@ func ErrorPage(config ErrorPageConfig, errorType ErrorType, url string, errMsg s case ErrorTypeMethodNotAllowed: return generateMethodNotAllowedHTML(url) case ErrorTypeInternalError: - return generateInternalErrorHTML(config.DebugMode, url, errMsg) + return generateInternalErrorHTML(dbg, url, errMsg) case ErrorTypeForbidden: - return generateForbiddenHTML(config.DebugMode, url, errMsg) + return generateForbiddenHTML(dbg, url, errMsg) default: // Fallback to internal error - return generateInternalErrorHTML(config.DebugMode, url, errMsg) + return generateInternalErrorHTML(dbg, url, errMsg) } } // NotFoundPage generates a 404 Not Found error page -func NotFoundPage(config ErrorPageConfig, url string) string { - return ErrorPage(config, ErrorTypeNotFound, url, "") +func NotFoundPage(url string) string { + return ErrorPage(ErrorTypeNotFound, url, "") } // MethodNotAllowedPage generates a 405 Method Not Allowed error page -func MethodNotAllowedPage(config ErrorPageConfig, url string) string { - return ErrorPage(config, ErrorTypeMethodNotAllowed, url, "") +func MethodNotAllowedPage(url string) string { + return ErrorPage(ErrorTypeMethodNotAllowed, url, "") } // InternalErrorPage generates a 500 Internal Server Error page -func InternalErrorPage(config ErrorPageConfig, url string, errMsg string) string { - return ErrorPage(config, ErrorTypeInternalError, url, errMsg) +func InternalErrorPage(url string, errMsg string) string { + return ErrorPage(ErrorTypeInternalError, url, errMsg) } // ForbiddenPage generates a 403 Forbidden error page -func ForbiddenPage(config ErrorPageConfig, url string, errMsg string) string { - return ErrorPage(config, ErrorTypeForbidden, url, errMsg) +func ForbiddenPage(url string, errMsg string) string { + return ErrorPage(ErrorTypeForbidden, url, errMsg) } // generateInternalErrorHTML creates a 500 Internal Server Error page diff --git a/http/utils.go b/utils/formData.go similarity index 50% rename from http/utils.go rename to utils/formData.go index 07d3a51..e29e1d4 100644 --- a/http/utils.go +++ b/utils/formData.go @@ -1,68 +1,46 @@ -package http +package utils import ( - "crypto/rand" - "encoding/base64" "mime/multipart" "strings" - "sync" + + "github.com/goccy/go-json" "github.com/valyala/fasthttp" ) -var formDataPool = sync.Pool{ - New: func() any { - return make(map[string]any, 16) - }, -} - -func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any { - args := ctx.QueryArgs() - if args.Len() == 0 { - return emptyMap - } - - queryMap := make(map[string]any, args.Len()) - args.VisitAll(func(key, value []byte) { - k := string(key) - v := string(value) - appendValue(queryMap, k, v) - }) - return queryMap -} - func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { - if strings.Contains(string(ctx.Request.Header.ContentType()), "multipart/form-data") { - return parseMultipartForm(ctx) + contentType := string(ctx.Request.Header.ContentType()) + formData := make(map[string]any) + + switch { + case strings.Contains(contentType, "multipart/form-data"): + if err := parseMultipartInto(ctx, formData); err != nil { + return nil, err + } + + case strings.Contains(contentType, "application/x-www-form-urlencoded"): + args := ctx.PostArgs() + args.VisitAll(func(key, value []byte) { + appendValue(formData, string(key), string(value)) + }) + + case strings.Contains(contentType, "application/json"): + if err := json.Unmarshal(ctx.PostBody(), &formData); err != nil { + return nil, err + } + + default: + // Leave formData empty if content-type is unrecognized } - args := ctx.PostArgs() - if args.Len() == 0 { - return emptyMap, nil - } - - formData := formDataPool.Get().(map[string]any) - for k := range formData { - delete(formData, k) - } - - args.VisitAll(func(key, value []byte) { - k := string(key) - v := string(value) - appendValue(formData, k, v) - }) return formData, nil } -func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { +func parseMultipartInto(ctx *fasthttp.RequestCtx, formData map[string]any) error { form, err := ctx.MultipartForm() if err != nil { - return nil, err - } - - formData := formDataPool.Get().(map[string]any) - for k := range formData { - delete(formData, k) + return err } for key, values := range form.Value { @@ -89,7 +67,20 @@ func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { formData["_files"] = files } - return formData, nil + return nil +} + +func appendValue(formData map[string]any, key, value string) { + if existing, exists := formData[key]; exists { + switch v := existing.(type) { + case string: + formData[key] = []string{v, value} + case []string: + formData[key] = append(v, value) + } + } else { + formData[key] = value + } } func fileInfoToMap(fh *multipart.FileHeader) map[string]any { @@ -97,6 +88,7 @@ func fileInfoToMap(fh *multipart.FileHeader) map[string]any { if ct == "" { ct = getMimeType(fh.Filename) } + return map[string]any{ "filename": fh.Filename, "size": fh.Size, @@ -121,24 +113,3 @@ func getMimeType(filename string) string { } return "application/octet-stream" } - -func appendValue(m map[string]any, k, v string) { - if existing, exists := m[k]; exists { - switch typed := existing.(type) { - case []string: - m[k] = append(typed, v) - case string: - m[k] = []string{typed, v} - } - } else { - m[k] = v - } -} - -func GenerateSecureToken(length int) (string, error) { - b := make([]byte, length) - if _, err := rand.Read(b); err != nil { - return "", err - } - return base64.URLEncoding.EncodeToString(b)[:length], nil -}