optimize http server

This commit is contained in:
Sky Johnson 2025-05-26 13:03:29 -05:00
parent e4cd490f0f
commit 6264407d02
2 changed files with 82 additions and 154 deletions

View File

@ -17,7 +17,6 @@ import (
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
// Server handles HTTP requests using Lua and static file routers
type Server struct { type Server struct {
luaRouter *routers.LuaRouter luaRouter *routers.LuaRouter
staticRouter *routers.StaticRouter staticRouter *routers.StaticRouter
@ -31,12 +30,11 @@ type Server struct {
ctxPool sync.Pool ctxPool sync.Pool
} }
// New creates a new HTTP server
func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter,
runner *runner.Runner, loggingEnabled bool, debugMode bool, runner *runner.Runner, loggingEnabled bool, debugMode bool,
overrideDir string, config *config.Config) *Server { overrideDir string, config *config.Config) *Server {
server := &Server{ s := &Server{
luaRouter: luaRouter, luaRouter: luaRouter,
staticRouter: staticRouter, staticRouter: staticRouter,
luaRunner: runner, luaRunner: runner,
@ -55,42 +53,34 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter,
}, },
} }
server.fasthttpServer = &fasthttp.Server{ s.fasthttpServer = &fasthttp.Server{
Handler: server.handleRequest, Handler: s.handleRequest,
Name: "Moonshark/" + metadata.Version, Name: "Moonshark/" + metadata.Version,
ReadTimeout: 30 * time.Second, ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
MaxRequestBodySize: 16 << 20, // 16MB MaxRequestBodySize: 16 << 20,
DisableKeepalive: false,
TCPKeepalive: true, TCPKeepalive: true,
TCPKeepalivePeriod: 60 * time.Second, TCPKeepalivePeriod: 60 * time.Second,
ReduceMemoryUsage: true, ReduceMemoryUsage: true,
GetOnly: false,
DisablePreParseMultipartForm: true, DisablePreParseMultipartForm: true,
} }
return server return s
} }
// ListenAndServe starts the server on the given address
func (s *Server) ListenAndServe(addr string) error { func (s *Server) ListenAndServe(addr string) error {
logger.Info("Catch the swell at %s", color.Apply("http://localhost"+addr, color.Cyan)) logger.Info("Catch the swell at %s", color.Apply("http://localhost"+addr, color.Cyan))
return s.fasthttpServer.ListenAndServe(addr) return s.fasthttpServer.ListenAndServe(addr)
} }
// Shutdown gracefully shuts down the server
func (s *Server) Shutdown(ctx context.Context) error { func (s *Server) Shutdown(ctx context.Context) error {
return s.fasthttpServer.ShutdownWithContext(ctx) return s.fasthttpServer.ShutdownWithContext(ctx)
} }
// handleRequest processes the HTTP request
func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) { func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
start := time.Now() start := time.Now()
methodBytes := ctx.Method() method := string(ctx.Method())
pathBytes := ctx.Path() path := string(ctx.Path())
method := string(methodBytes)
path := string(pathBytes)
if s.debugMode && path == "/debug/stats" { if s.debugMode && path == "/debug/stats" {
s.handleDebugStats(ctx) s.handleDebugStats(ctx)
@ -100,62 +90,50 @@ func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
return return
} }
s.processRequest(ctx, method, path) logger.Debug("Processing request %s %s", method, path)
params := &routers.Params{}
bytecode, scriptPath, routeErr, found := s.luaRouter.GetRouteInfo(method, path, params)
if found {
if len(bytecode) == 0 || routeErr != nil {
errorMsg := "Route exists but failed to compile. Check server logs for details."
if routeErr != nil {
errorMsg = routeErr.Error()
}
logger.Error("%s %s - %s", method, path, errorMsg)
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetBody([]byte(utils.InternalErrorPage(s.errorConfig, path, errorMsg)))
} else {
logger.Debug("Found Lua route match for %s %s with %d params", method, path, params.Count)
s.handleLuaRoute(ctx, bytecode, scriptPath, params, method, path)
}
} else if s.staticRouter != nil {
if _, found := s.staticRouter.Match(path); found {
s.staticRouter.ServeHTTP(ctx)
} else {
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
}
} else {
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
}
if s.loggingEnabled { if s.loggingEnabled {
logger.LogRequest(ctx.Response.StatusCode(), method, path, time.Since(start)) logger.LogRequest(ctx.Response.StatusCode(), method, path, time.Since(start))
} }
} }
// processRequest handles the main request processing
func (s *Server) processRequest(ctx *fasthttp.RequestCtx, method, path string) {
logger.Debug("Processing request %s %s", method, path)
params := &routers.Params{}
bytecode, scriptPath, routeErr, found := s.luaRouter.GetRouteInfo(method, path, params)
if found && (len(bytecode) == 0 || routeErr != nil) {
errorMsg := "Route exists but failed to compile. Check server logs for details."
if routeErr != nil {
errorMsg = routeErr.Error()
}
logger.Error("%s %s - %s", method, path, errorMsg)
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetBody([]byte(utils.InternalErrorPage(s.errorConfig, path, errorMsg)))
return
}
if found {
logger.Debug("Found Lua route match for %s %s with %d params", method, path, params.Count)
s.handleLuaRoute(ctx, bytecode, scriptPath, params, method, path)
return
}
// Try static router
if s.staticRouter != nil {
if _, found := s.staticRouter.Match(path); found {
s.staticRouter.ServeHTTP(ctx)
return
}
}
// 404
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
}
// handleLuaRoute executes the combined middleware + handler script
func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params, method, path string) { func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params, method, path string) {
luaCtx := runner.NewHTTPContext(ctx) luaCtx := runner.NewHTTPContext(ctx)
defer luaCtx.Release() defer luaCtx.Release()
// Get pooled map for session data
sessionMap := s.ctxPool.Get().(map[string]any) sessionMap := s.ctxPool.Get().(map[string]any)
defer func() { defer func() {
// Clear and return to pool
for k := range sessionMap { for k := range sessionMap {
delete(sessionMap, k) delete(sessionMap, k)
} }
@ -164,18 +142,16 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
session := s.sessionManager.GetSessionFromRequest(ctx) session := s.sessionManager.GetSessionFromRequest(ctx)
sessionMap["id"] = session.ID sessionMap["id"] = session.ID
sessionMap["data"] = session.GetAll()
sessionMap["data"] = session.GetAll() // This now returns a deep copy
luaCtx.Set("method", method) luaCtx.Set("method", method)
luaCtx.Set("path", path) luaCtx.Set("path", path)
luaCtx.Set("host", string(ctx.Host())) luaCtx.Set("host", string(ctx.Host()))
luaCtx.Set("session", sessionMap) luaCtx.Set("session", sessionMap)
// Optimize params handling
if params.Count > 0 { if params.Count > 0 {
paramMap := make(map[string]any, params.Count) paramMap := make(map[string]any, params.Count)
for i := range params.Count { for i := 0; i < params.Count; i++ {
paramMap[params.Keys[i]] = params.Values[i] paramMap[params.Keys[i]] = params.Values[i]
} }
luaCtx.Set("params", paramMap) luaCtx.Set("params", paramMap)
@ -183,7 +159,6 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
luaCtx.Set("params", emptyMap) luaCtx.Set("params", emptyMap)
} }
// Optimize form handling for POST methods
if method == "POST" || method == "PUT" || method == "PATCH" { if method == "POST" || method == "PUT" || method == "PATCH" {
if formData, err := ParseForm(ctx); err == nil { if formData, err := ParseForm(ctx); err == nil {
luaCtx.Set("form", formData) luaCtx.Set("form", formData)
@ -204,18 +179,16 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
return return
} }
// Session handling optimization
if _, clearAll := response.SessionData["__clear_all"]; clearAll { if _, clearAll := response.SessionData["__clear_all"]; clearAll {
session.Clear() session.Clear()
delete(response.SessionData, "__clear_all") delete(response.SessionData, "__clear_all")
} }
// Apply session changes - now supports nested tables
for k, v := range response.SessionData { for k, v := range response.SessionData {
if v == "__SESSION_DELETE_MARKER__" { if v == "__SESSION_DELETE_MARKER__" {
session.Delete(k) session.Delete(k)
} else { } else {
session.Set(k, v) // This will handle tables through marshalling session.Set(k, v)
} }
} }
@ -224,18 +197,14 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
runner.ReleaseResponse(response) runner.ReleaseResponse(response)
} }
// handleDebugStats displays debug statistics
func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) { func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) {
stats := utils.CollectSystemStats(s.config) stats := utils.CollectSystemStats(s.config)
routeCount, bytecodeBytes := s.luaRouter.GetRouteStats() routeCount, bytecodeBytes := s.luaRouter.GetRouteStats()
stats.Components = utils.ComponentStats{ stats.Components = utils.ComponentStats{
RouteCount: routeCount, RouteCount: routeCount,
BytecodeBytes: bytecodeBytes, BytecodeBytes: bytecodeBytes,
SessionStats: sessions.GlobalSessionManager.GetCacheStats(), SessionStats: sessions.GlobalSessionManager.GetCacheStats(),
} }
ctx.SetContentType("text/html; charset=utf-8") ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetBody([]byte(utils.DebugStatsPage(stats))) ctx.SetBody([]byte(utils.DebugStatsPage(stats)))

View File

@ -10,14 +10,8 @@ import (
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
var emptyMap = make(map[string]any)
var ( var (
stringPool = sync.Pool{ emptyMap = make(map[string]any)
New: func() any {
return make([]string, 0, 4)
},
}
formDataPool = sync.Pool{ formDataPool = sync.Pool{
New: func() any { New: func() any {
return make(map[string]any, 16) return make(map[string]any, 16)
@ -25,44 +19,23 @@ var (
} }
) )
// QueryToLua converts HTTP query args to a Lua-friendly map
func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any { func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
args := ctx.QueryArgs() args := ctx.QueryArgs()
if args.Len() == 0 { if args.Len() == 0 {
return emptyMap return emptyMap
} }
queryMap := make(map[string]any, args.Len()) // Pre-size queryMap := make(map[string]any, args.Len())
args.VisitAll(func(key, value []byte) { args.VisitAll(func(key, value []byte) {
k := string(key) k := string(key)
v := string(value) v := string(value)
appendValue(queryMap, k, v)
if existing, exists := queryMap[k]; exists {
// Handle multiple values more efficiently
switch typed := existing.(type) {
case []string:
queryMap[k] = append(typed, v)
case string:
// Get slice from pool
slice := stringPool.Get().([]string)
slice = slice[:0] // Reset length
slice = append(slice, typed, v)
queryMap[k] = slice
}
} else {
queryMap[k] = v
}
}) })
return queryMap return queryMap
} }
// ParseForm extracts form data from a request
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
contentType := string(ctx.Request.Header.ContentType()) if strings.Contains(string(ctx.Request.Header.ContentType()), "multipart/form-data") {
if strings.Contains(contentType, "multipart/form-data") {
return parseMultipartForm(ctx) return parseMultipartForm(ctx)
} }
@ -72,7 +45,6 @@ func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
} }
formData := formDataPool.Get().(map[string]any) formData := formDataPool.Get().(map[string]any)
// Clear the map (should already be clean from pool)
for k := range formData { for k := range formData {
delete(formData, k) delete(formData, k)
} }
@ -80,26 +52,11 @@ func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
args.VisitAll(func(key, value []byte) { args.VisitAll(func(key, value []byte) {
k := string(key) k := string(key)
v := string(value) v := string(value)
appendValue(formData, k, v)
if existing, exists := formData[k]; exists {
switch typed := existing.(type) {
case []string:
formData[k] = append(typed, v)
case string:
slice := stringPool.Get().([]string)
slice = slice[:0]
slice = append(slice, typed, v)
formData[k] = slice
}
} else {
formData[k] = v
}
}) })
return formData, nil return formData, nil
} }
// parseMultipartForm handles multipart/form-data requests
func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
form, err := ctx.MultipartForm() form, err := ctx.MultipartForm()
if err != nil { if err != nil {
@ -111,26 +68,20 @@ func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
delete(formData, k) delete(formData, k)
} }
// Process form values
for key, values := range form.Value { for key, values := range form.Value {
switch len(values) { if len(values) == 1 {
case 0:
// Skip empty
case 1:
formData[key] = values[0] formData[key] = values[0]
default: } else if len(values) > 1 {
formData[key] = values formData[key] = values
} }
} }
// Process files if present
if len(form.File) > 0 { if len(form.File) > 0 {
files := make(map[string]any, len(form.File)) files := make(map[string]any, len(form.File))
for fieldName, fileHeaders := range form.File { for fieldName, fileHeaders := range form.File {
switch len(fileHeaders) { if len(fileHeaders) == 1 {
case 1:
files[fieldName] = fileInfoToMap(fileHeaders[0]) files[fieldName] = fileInfoToMap(fileHeaders[0])
default: } else {
fileInfos := make([]map[string]any, len(fileHeaders)) fileInfos := make([]map[string]any, len(fileHeaders))
for i, fh := range fileHeaders { for i, fh := range fileHeaders {
fileInfos[i] = fileInfoToMap(fh) fileInfos[i] = fileInfoToMap(fh)
@ -144,41 +95,49 @@ func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
return formData, nil return formData, nil
} }
// fileInfoToMap converts a FileHeader to a map for Lua
func fileInfoToMap(fh *multipart.FileHeader) map[string]any { func fileInfoToMap(fh *multipart.FileHeader) map[string]any {
ct := fh.Header.Get("Content-Type")
if ct == "" {
ct = getMimeType(fh.Filename)
}
return map[string]any{ return map[string]any{
"filename": fh.Filename, "filename": fh.Filename,
"size": fh.Size, "size": fh.Size,
"mimetype": getMimeType(fh), "mimetype": ct,
} }
} }
// getMimeType gets the mime type from a file header func getMimeType(filename string) string {
func getMimeType(fh *multipart.FileHeader) string { if i := strings.LastIndex(filename, "."); i >= 0 {
if fh.Header != nil { switch filename[i:] {
contentType := fh.Header.Get("Content-Type") case ".pdf":
if contentType != "" {
return contentType
}
}
// Fallback to basic type detection from filename
if strings.HasSuffix(fh.Filename, ".pdf") {
return "application/pdf" return "application/pdf"
} else if strings.HasSuffix(fh.Filename, ".png") { case ".png":
return "image/png" return "image/png"
} else if strings.HasSuffix(fh.Filename, ".jpg") || strings.HasSuffix(fh.Filename, ".jpeg") { case ".jpg", ".jpeg":
return "image/jpeg" return "image/jpeg"
} else if strings.HasSuffix(fh.Filename, ".gif") { case ".gif":
return "image/gif" return "image/gif"
} else if strings.HasSuffix(fh.Filename, ".svg") { case ".svg":
return "image/svg+xml" return "image/svg+xml"
} }
}
return "application/octet-stream" return "application/octet-stream"
} }
// GenerateSecureToken creates a cryptographically secure random token func appendValue(m map[string]any, k, v string) {
if existing, exists := m[k]; exists {
switch typed := existing.(type) {
case []string:
m[k] = append(typed, v)
case string:
m[k] = []string{typed, v}
}
} else {
m[k] = v
}
}
func GenerateSecureToken(length int) (string, error) { func GenerateSecureToken(length int) (string, error) {
b := make([]byte, length) b := make([]byte, length)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(b); err != nil {