diff --git a/http/server.go b/http/server.go index 443ddd7..575fe8f 100644 --- a/http/server.go +++ b/http/server.go @@ -17,7 +17,6 @@ import ( "github.com/valyala/fasthttp" ) -// Server handles HTTP requests using Lua and static file routers type Server struct { luaRouter *routers.LuaRouter staticRouter *routers.StaticRouter @@ -31,12 +30,11 @@ type Server struct { ctxPool sync.Pool } -// New creates a new HTTP server func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runner *runner.Runner, loggingEnabled bool, debugMode bool, overrideDir string, config *config.Config) *Server { - server := &Server{ + s := &Server{ luaRouter: luaRouter, staticRouter: staticRouter, luaRunner: runner, @@ -55,42 +53,34 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, }, } - server.fasthttpServer = &fasthttp.Server{ - Handler: server.handleRequest, + s.fasthttpServer = &fasthttp.Server{ + Handler: s.handleRequest, Name: "Moonshark/" + metadata.Version, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, - MaxRequestBodySize: 16 << 20, // 16MB - DisableKeepalive: false, + MaxRequestBodySize: 16 << 20, TCPKeepalive: true, TCPKeepalivePeriod: 60 * time.Second, ReduceMemoryUsage: true, - GetOnly: false, DisablePreParseMultipartForm: true, } - return server + return s } -// ListenAndServe starts the server on the given address func (s *Server) ListenAndServe(addr string) error { logger.Info("Catch the swell at %s", color.Apply("http://localhost"+addr, color.Cyan)) return s.fasthttpServer.ListenAndServe(addr) } -// Shutdown gracefully shuts down the server func (s *Server) Shutdown(ctx context.Context) error { return s.fasthttpServer.ShutdownWithContext(ctx) } -// handleRequest processes the HTTP request func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) { start := time.Now() - methodBytes := ctx.Method() - pathBytes := ctx.Path() - - method := string(methodBytes) - path := string(pathBytes) + method := string(ctx.Method()) + path := string(ctx.Path()) if s.debugMode && path == "/debug/stats" { s.handleDebugStats(ctx) @@ -100,62 +90,50 @@ func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) { return } - s.processRequest(ctx, method, path) + logger.Debug("Processing request %s %s", method, path) + + params := &routers.Params{} + bytecode, scriptPath, routeErr, found := s.luaRouter.GetRouteInfo(method, path, params) + + if found { + if len(bytecode) == 0 || routeErr != nil { + errorMsg := "Route exists but failed to compile. Check server logs for details." + if routeErr != nil { + errorMsg = routeErr.Error() + } + logger.Error("%s %s - %s", method, path, errorMsg) + ctx.SetContentType("text/html; charset=utf-8") + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBody([]byte(utils.InternalErrorPage(s.errorConfig, path, errorMsg))) + } else { + logger.Debug("Found Lua route match for %s %s with %d params", method, path, params.Count) + s.handleLuaRoute(ctx, bytecode, scriptPath, params, method, path) + } + } else if s.staticRouter != nil { + if _, found := s.staticRouter.Match(path); found { + s.staticRouter.ServeHTTP(ctx) + } else { + ctx.SetContentType("text/html; charset=utf-8") + ctx.SetStatusCode(fasthttp.StatusNotFound) + ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path))) + } + } else { + ctx.SetContentType("text/html; charset=utf-8") + ctx.SetStatusCode(fasthttp.StatusNotFound) + ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path))) + } if s.loggingEnabled { logger.LogRequest(ctx.Response.StatusCode(), method, path, time.Since(start)) } } -// processRequest handles the main request processing -func (s *Server) processRequest(ctx *fasthttp.RequestCtx, method, path string) { - logger.Debug("Processing request %s %s", method, path) - - params := &routers.Params{} - bytecode, scriptPath, routeErr, found := s.luaRouter.GetRouteInfo(method, path, params) - - if found && (len(bytecode) == 0 || routeErr != nil) { - errorMsg := "Route exists but failed to compile. Check server logs for details." - if routeErr != nil { - errorMsg = routeErr.Error() - } - - logger.Error("%s %s - %s", method, path, errorMsg) - ctx.SetContentType("text/html; charset=utf-8") - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - ctx.SetBody([]byte(utils.InternalErrorPage(s.errorConfig, path, errorMsg))) - return - } - - if found { - logger.Debug("Found Lua route match for %s %s with %d params", method, path, params.Count) - s.handleLuaRoute(ctx, bytecode, scriptPath, params, method, path) - return - } - - // Try static router - if s.staticRouter != nil { - if _, found := s.staticRouter.Match(path); found { - s.staticRouter.ServeHTTP(ctx) - return - } - } - - // 404 - ctx.SetContentType("text/html; charset=utf-8") - ctx.SetStatusCode(fasthttp.StatusNotFound) - ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path))) -} - -// handleLuaRoute executes the combined middleware + handler script func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params, method, path string) { luaCtx := runner.NewHTTPContext(ctx) defer luaCtx.Release() - // Get pooled map for session data sessionMap := s.ctxPool.Get().(map[string]any) defer func() { - // Clear and return to pool for k := range sessionMap { delete(sessionMap, k) } @@ -164,18 +142,16 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip session := s.sessionManager.GetSessionFromRequest(ctx) sessionMap["id"] = session.ID - - sessionMap["data"] = session.GetAll() // This now returns a deep copy + sessionMap["data"] = session.GetAll() luaCtx.Set("method", method) luaCtx.Set("path", path) luaCtx.Set("host", string(ctx.Host())) luaCtx.Set("session", sessionMap) - // Optimize params handling if params.Count > 0 { paramMap := make(map[string]any, params.Count) - for i := range params.Count { + for i := 0; i < params.Count; i++ { paramMap[params.Keys[i]] = params.Values[i] } luaCtx.Set("params", paramMap) @@ -183,7 +159,6 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip luaCtx.Set("params", emptyMap) } - // Optimize form handling for POST methods if method == "POST" || method == "PUT" || method == "PATCH" { if formData, err := ParseForm(ctx); err == nil { luaCtx.Set("form", formData) @@ -204,18 +179,16 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip return } - // Session handling optimization if _, clearAll := response.SessionData["__clear_all"]; clearAll { session.Clear() delete(response.SessionData, "__clear_all") } - // Apply session changes - now supports nested tables for k, v := range response.SessionData { if v == "__SESSION_DELETE_MARKER__" { session.Delete(k) } else { - session.Set(k, v) // This will handle tables through marshalling + session.Set(k, v) } } @@ -224,18 +197,14 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip runner.ReleaseResponse(response) } -// handleDebugStats displays debug statistics func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) { stats := utils.CollectSystemStats(s.config) - routeCount, bytecodeBytes := s.luaRouter.GetRouteStats() - stats.Components = utils.ComponentStats{ RouteCount: routeCount, BytecodeBytes: bytecodeBytes, SessionStats: sessions.GlobalSessionManager.GetCacheStats(), } - ctx.SetContentType("text/html; charset=utf-8") ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetBody([]byte(utils.DebugStatsPage(stats))) diff --git a/http/utils.go b/http/utils.go index 3128e78..94e6bde 100644 --- a/http/utils.go +++ b/http/utils.go @@ -10,14 +10,8 @@ import ( "github.com/valyala/fasthttp" ) -var emptyMap = make(map[string]any) - var ( - stringPool = sync.Pool{ - New: func() any { - return make([]string, 0, 4) - }, - } + emptyMap = make(map[string]any) formDataPool = sync.Pool{ New: func() any { return make(map[string]any, 16) @@ -25,44 +19,23 @@ var ( } ) -// QueryToLua converts HTTP query args to a Lua-friendly map func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any { args := ctx.QueryArgs() if args.Len() == 0 { return emptyMap } - queryMap := make(map[string]any, args.Len()) // Pre-size - + queryMap := make(map[string]any, args.Len()) args.VisitAll(func(key, value []byte) { k := string(key) v := string(value) - - if existing, exists := queryMap[k]; exists { - // Handle multiple values more efficiently - switch typed := existing.(type) { - case []string: - queryMap[k] = append(typed, v) - case string: - // Get slice from pool - slice := stringPool.Get().([]string) - slice = slice[:0] // Reset length - slice = append(slice, typed, v) - queryMap[k] = slice - } - } else { - queryMap[k] = v - } + appendValue(queryMap, k, v) }) - return queryMap } -// ParseForm extracts form data from a request func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { - contentType := string(ctx.Request.Header.ContentType()) - - if strings.Contains(contentType, "multipart/form-data") { + if strings.Contains(string(ctx.Request.Header.ContentType()), "multipart/form-data") { return parseMultipartForm(ctx) } @@ -72,7 +45,6 @@ func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { } formData := formDataPool.Get().(map[string]any) - // Clear the map (should already be clean from pool) for k := range formData { delete(formData, k) } @@ -80,26 +52,11 @@ func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { args.VisitAll(func(key, value []byte) { k := string(key) v := string(value) - - if existing, exists := formData[k]; exists { - switch typed := existing.(type) { - case []string: - formData[k] = append(typed, v) - case string: - slice := stringPool.Get().([]string) - slice = slice[:0] - slice = append(slice, typed, v) - formData[k] = slice - } - } else { - formData[k] = v - } + appendValue(formData, k, v) }) - return formData, nil } -// parseMultipartForm handles multipart/form-data requests func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { form, err := ctx.MultipartForm() if err != nil { @@ -111,26 +68,20 @@ func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { delete(formData, k) } - // Process form values for key, values := range form.Value { - switch len(values) { - case 0: - // Skip empty - case 1: + if len(values) == 1 { formData[key] = values[0] - default: + } else if len(values) > 1 { formData[key] = values } } - // Process files if present if len(form.File) > 0 { files := make(map[string]any, len(form.File)) for fieldName, fileHeaders := range form.File { - switch len(fileHeaders) { - case 1: + if len(fileHeaders) == 1 { files[fieldName] = fileInfoToMap(fileHeaders[0]) - default: + } else { fileInfos := make([]map[string]any, len(fileHeaders)) for i, fh := range fileHeaders { fileInfos[i] = fileInfoToMap(fh) @@ -144,41 +95,49 @@ func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { return formData, nil } -// fileInfoToMap converts a FileHeader to a map for Lua func fileInfoToMap(fh *multipart.FileHeader) map[string]any { + ct := fh.Header.Get("Content-Type") + if ct == "" { + ct = getMimeType(fh.Filename) + } return map[string]any{ "filename": fh.Filename, "size": fh.Size, - "mimetype": getMimeType(fh), + "mimetype": ct, } } -// getMimeType gets the mime type from a file header -func getMimeType(fh *multipart.FileHeader) string { - if fh.Header != nil { - contentType := fh.Header.Get("Content-Type") - if contentType != "" { - return contentType +func getMimeType(filename string) string { + if i := strings.LastIndex(filename, "."); i >= 0 { + switch filename[i:] { + case ".pdf": + return "application/pdf" + case ".png": + return "image/png" + case ".jpg", ".jpeg": + return "image/jpeg" + case ".gif": + return "image/gif" + case ".svg": + return "image/svg+xml" } } - - // Fallback to basic type detection from filename - if strings.HasSuffix(fh.Filename, ".pdf") { - return "application/pdf" - } else if strings.HasSuffix(fh.Filename, ".png") { - return "image/png" - } else if strings.HasSuffix(fh.Filename, ".jpg") || strings.HasSuffix(fh.Filename, ".jpeg") { - return "image/jpeg" - } else if strings.HasSuffix(fh.Filename, ".gif") { - return "image/gif" - } else if strings.HasSuffix(fh.Filename, ".svg") { - return "image/svg+xml" - } - return "application/octet-stream" } -// GenerateSecureToken creates a cryptographically secure random token +func appendValue(m map[string]any, k, v string) { + if existing, exists := m[k]; exists { + switch typed := existing.(type) { + case []string: + m[k] = append(typed, v) + case string: + m[k] = []string{typed, v} + } + } else { + m[k] = v + } +} + func GenerateSecureToken(length int) (string, error) { b := make([]byte, length) if _, err := rand.Read(b); err != nil {