fasthttp 1

This commit is contained in:
Sky Johnson 2025-04-03 13:14:45 -05:00
parent 5bba4ffcf8
commit 95eae40357
8 changed files with 255 additions and 262 deletions

View File

@ -1,20 +1,22 @@
package http package http
import ( import (
"net/http"
"git.sharkk.net/Sky/Moonshark/core/logger" "git.sharkk.net/Sky/Moonshark/core/logger"
"git.sharkk.net/Sky/Moonshark/core/utils" "git.sharkk.net/Sky/Moonshark/core/utils"
"github.com/valyala/fasthttp"
) )
// HandleCSRFError handles a CSRF validation error // HandleCSRFError handles a CSRF validation error
func HandleCSRFError(w http.ResponseWriter, r *http.Request, errorConfig utils.ErrorPageConfig) { func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
logger.Warning("CSRF validation failed for %s %s", r.Method, r.URL.Path) method := string(ctx.Method())
path := string(ctx.Path())
w.Header().Set("Content-Type", "text/html; charset=utf-8") logger.Warning("CSRF validation failed for %s %s", method, path)
w.WriteHeader(http.StatusForbidden)
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusForbidden)
errorMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt." errorMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
errorHTML := utils.ForbiddenPage(errorConfig, r.URL.Path, errorMsg) errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg)
w.Write([]byte(errorHTML)) ctx.SetBody([]byte(errorHTML))
} }

View File

@ -2,11 +2,10 @@ package http
import ( import (
"errors" "errors"
"io" "mime/multipart"
"mime"
"net/http"
"net/url"
"strings" "strings"
"github.com/valyala/fasthttp"
) )
// Maximum form parse size (16MB) // Maximum form parse size (16MB)
@ -20,113 +19,94 @@ var (
// ParseForm parses a POST request body into a map of values // ParseForm parses a POST request body into a map of values
// Supports both application/x-www-form-urlencoded and multipart/form-data content types // Supports both application/x-www-form-urlencoded and multipart/form-data content types
func ParseForm(r *http.Request) (map[string]any, error) { func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
// Only handle POST, PUT, PATCH // Only handle POST, PUT, PATCH
if r.Method != http.MethodPost && method := string(ctx.Method())
r.Method != http.MethodPut && if method != "POST" && method != "PUT" && method != "PATCH" {
r.Method != http.MethodPatch {
return make(map[string]any), nil return make(map[string]any), nil
} }
// Check content type // Check content type
contentType := r.Header.Get("Content-Type") contentType := string(ctx.Request.Header.ContentType())
if contentType == "" { if contentType == "" {
return make(map[string]any), nil return make(map[string]any), nil
} }
// Parse the media type
mediaType, params, err := mime.ParseMediaType(contentType)
if err != nil {
return nil, ErrInvalidFormType
}
result := make(map[string]any) result := make(map[string]any)
switch { // Check for content length to prevent DOS
case mediaType == "application/x-www-form-urlencoded": if len(ctx.Request.Body()) > maxFormSize {
// Handle URL-encoded form return nil, ErrFormSizeTooLarge
if err := parseURLEncodedForm(r, result); err != nil {
return nil, err
} }
case strings.HasPrefix(mediaType, "multipart/form-data"): // Handle by content type
// Handle multipart form if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") {
boundary := params["boundary"] return parseURLEncodedForm(ctx)
if boundary == "" { } else if strings.HasPrefix(contentType, "multipart/form-data") {
return nil, ErrInvalidFormType return parseMultipartForm(ctx)
} }
if err := parseMultipartForm(r, boundary, result); err != nil {
return nil, err
}
default:
// Unrecognized content type // Unrecognized content type
return make(map[string]any), nil
}
return result, nil return result, nil
} }
// parseURLEncodedForm handles application/x-www-form-urlencoded forms // parseURLEncodedForm handles application/x-www-form-urlencoded forms
func parseURLEncodedForm(r *http.Request, result map[string]any) error { func parseURLEncodedForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
// Enforce size limit result := make(map[string]any)
r.Body = http.MaxBytesReader(nil, r.Body, maxFormSize)
// Read the entire body // Process form values directly from PostArgs()
body, err := io.ReadAll(r.Body) ctx.PostArgs().VisitAll(func(key, value []byte) {
if err != nil { keyStr := string(key)
if strings.Contains(err.Error(), "http: request body too large") { valStr := string(value)
return ErrFormSizeTooLarge
}
return err
}
// Parse form values // Check if we already have this key
form, err := url.ParseQuery(string(body)) if existing, ok := result[keyStr]; ok {
if err != nil { // If it's already a slice, append
return err if existingSlice, ok := existing.([]string); ok {
result[keyStr] = append(existingSlice, valStr)
} else {
// Convert to slice and append
result[keyStr] = []string{existing.(string), valStr}
} }
} else {
// New key
result[keyStr] = valStr
}
})
// Convert to map[string]any return result, nil
for key, values := range form {
if len(values) == 1 {
// Single value
result[key] = values[0]
} else if len(values) > 1 {
// Multiple values
result[key] = values
}
}
return nil
} }
// parseMultipartForm handles multipart/form-data forms // parseMultipartForm handles multipart/form-data forms
func parseMultipartForm(r *http.Request, boundary string, result map[string]any) error { func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
// Limit the form size result := make(map[string]any)
if err := r.ParseMultipartForm(maxFormSize); err != nil {
if strings.Contains(err.Error(), "http: request body too large") { // Parse multipart form
return ErrFormSizeTooLarge form, err := ctx.MultipartForm()
if err != nil {
if err == multipart.ErrMessageTooLarge || strings.Contains(err.Error(), "too large") {
return nil, ErrFormSizeTooLarge
} }
return err return nil, err
} }
// Process form values // Process form values
for key, values := range r.MultipartForm.Value { for key, values := range form.Value {
if len(values) == 1 { if len(values) == 1 {
// Single value // Single value
result[key] = values[0] result[key] = values[0]
} else if len(values) > 1 { } else if len(values) > 1 {
// Multiple values // Multiple values - store as string slice
result[key] = values strValues := make([]string, len(values))
copy(strValues, values)
result[key] = strValues
} }
} }
// We don't handle file uploads here - could be extended in the future // We don't handle file uploads here - could be extended in the future
// if needed to support file uploads to Lua // if needed to support file uploads to Lua
return nil return result, nil
} }
// Usage: // Usage:

View File

@ -1,7 +1,6 @@
package http package http
import ( import (
"net/http"
"time" "time"
"git.sharkk.net/Sky/Moonshark/core/logger" "git.sharkk.net/Sky/Moonshark/core/logger"
@ -18,13 +17,13 @@ const (
) )
// LogRequest logs an HTTP request with custom formatting // LogRequest logs an HTTP request with custom formatting
func LogRequest(statusCode int, r *http.Request, duration time.Duration) { func LogRequest(statusCode int, method, path string, duration time.Duration) {
statusColor := getStatusColor(statusCode) statusColor := getStatusColor(statusCode)
// Use the logger's raw message writer to bypass the standard format // Use the logger's raw message writer to bypass the standard format
logger.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s", logger.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s",
colorGray, time.Now().Format(logger.TimeFormat()), colorReset, colorGray, time.Now().Format(logger.TimeFormat()), colorReset,
statusColor, statusCode, r.Method, colorReset, r.URL.Path, colorGray, duration, colorReset) statusColor, statusCode, method, colorReset, path, colorGray, duration, colorReset)
} }
// getStatusColor returns the ANSI color code for a status code // getStatusColor returns the ANSI color code for a status code

View File

@ -1,36 +1,42 @@
package http package http
import "net/http" import (
"github.com/valyala/fasthttp"
)
// QueryToLua converts HTTP query parameters to a map that can be used with LuaJIT. // QueryToLua converts HTTP query parameters to a map that can be used with LuaJIT.
// Single value parameters are stored as strings. // Single value parameters are stored as strings.
// Multi-value parameters are converted to []any arrays. // Multi-value parameters are converted to []any arrays.
func QueryToLua(r *http.Request) map[string]any { func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
if r == nil || r.URL == nil { result := make(map[string]any)
return nil
}
query := r.URL.Query() // Use a map to track keys that have multiple values
if len(query) == 0 { multiValueKeys := make(map[string]bool)
return nil // Avoid allocation for empty queries
}
result := make(map[string]any, len(query)) // Process all query args
for key, values := range query { ctx.QueryArgs().VisitAll(func(key, value []byte) {
switch len(values) { keyStr := string(key)
case 0: valStr := string(value)
// Skip empty values
case 1: if _, exists := result[keyStr]; exists {
// Single value // This key already exists, convert to array if not already
result[key] = values[0] if !multiValueKeys[keyStr] {
default: // First duplicate, convert existing value to array
// Multiple values - convert to []any multiValueKeys[keyStr] = true
arr := make([]any, len(values)) result[keyStr] = []any{result[keyStr], valStr}
for i, v := range values { } else {
arr[i] = v // Already an array, append
result[keyStr] = append(result[keyStr].([]any), valStr)
} }
result[key] = arr } else {
// New key
result[keyStr] = valStr
} }
})
// If we don't have any query parameters, return empty map
if len(result) == 0 {
return make(map[string]any)
} }
return result return result

View File

@ -3,9 +3,7 @@ package http
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" // Added for fmt.Fprintf "fmt"
"net"
"net/http"
"time" "time"
"git.sharkk.net/Sky/Moonshark/core/config" "git.sharkk.net/Sky/Moonshark/core/config"
@ -13,6 +11,8 @@ import (
"git.sharkk.net/Sky/Moonshark/core/routers" "git.sharkk.net/Sky/Moonshark/core/routers"
"git.sharkk.net/Sky/Moonshark/core/runner" "git.sharkk.net/Sky/Moonshark/core/runner"
"git.sharkk.net/Sky/Moonshark/core/utils" "git.sharkk.net/Sky/Moonshark/core/utils"
"github.com/valyala/fasthttp"
) )
// Server handles HTTP requests using Lua and static file routers // Server handles HTTP requests using Lua and static file routers
@ -20,9 +20,9 @@ type Server struct {
luaRouter *routers.LuaRouter luaRouter *routers.LuaRouter
staticRouter *routers.StaticRouter staticRouter *routers.StaticRouter
luaRunner *runner.Runner luaRunner *runner.Runner
httpServer *http.Server fasthttpServer *fasthttp.Server
loggingEnabled bool loggingEnabled bool
debugMode bool // Controls whether to show error details debugMode bool
config *config.Config config *config.Config
errorConfig utils.ErrorPageConfig errorConfig utils.ErrorPageConfig
} }
@ -35,7 +35,6 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
luaRouter: luaRouter, luaRouter: luaRouter,
staticRouter: staticRouter, staticRouter: staticRouter,
luaRunner: runner, luaRunner: runner,
httpServer: &http.Server{},
loggingEnabled: loggingEnabled, loggingEnabled: loggingEnabled,
debugMode: debugMode, debugMode: debugMode,
config: config, config: config,
@ -44,15 +43,20 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
DebugMode: debugMode, DebugMode: debugMode,
}, },
} }
server.httpServer.Handler = server
// Set TCP keep-alive for connections // Configure fasthttp server
server.httpServer.ConnState = func(conn net.Conn, state http.ConnState) { server.fasthttpServer = &fasthttp.Server{
if state == http.StateNew { Handler: server.handleRequest,
if tcpConn, ok := conn.(*net.TCPConn); ok { Name: "Moonshark",
tcpConn.SetKeepAlive(true) ReadTimeout: 30 * time.Second,
} WriteTimeout: 30 * time.Second,
} MaxRequestBodySize: 16 << 20, // 16MB - consistent with Forms.go
DisableKeepalive: false,
TCPKeepalive: true,
TCPKeepalivePeriod: 60 * time.Second,
ReduceMemoryUsage: true,
GetOnly: false,
DisablePreParseMultipartForm: true, // We'll handle parsing manually
} }
return server return server
@ -60,135 +64,130 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
// ListenAndServe starts the server on the given address // ListenAndServe starts the server on the given address
func (s *Server) ListenAndServe(addr string) error { func (s *Server) ListenAndServe(addr string) error {
s.httpServer.Addr = addr
logger.ServerCont("Catch the swell at http://localhost%s", addr) logger.ServerCont("Catch the swell at http://localhost%s", addr)
return s.httpServer.ListenAndServe() return s.fasthttpServer.ListenAndServe(addr)
} }
// Shutdown gracefully shuts down the server // Shutdown gracefully shuts down the server
func (s *Server) Shutdown(ctx context.Context) error { func (s *Server) Shutdown(ctx context.Context) error {
return s.httpServer.Shutdown(ctx) return s.fasthttpServer.ShutdownWithContext(ctx)
} }
// ServeHTTP handles HTTP requests // handleRequest processes the HTTP request
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
start := time.Now() start := time.Now()
method := string(ctx.Method())
path := string(ctx.Path())
// Special case for debug stats when debug mode is enabled // Special case for debug stats when debug mode is enabled
if s.debugMode && r.URL.Path == "/debug/stats" { if s.debugMode && path == "/debug/stats" {
s.handleDebugStats(w, r) s.handleDebugStats(ctx)
// Calculate and log request duration // Log request
duration := time.Since(start)
if s.loggingEnabled { if s.loggingEnabled {
LogRequest(http.StatusOK, r, duration) duration := time.Since(start)
LogRequest(ctx.Response.StatusCode(), method, path, duration)
} }
return return
} }
// Wrap the ResponseWriter to capture status code
wrappedWriter := newStatusCaptureWriter(w)
// Process the request // Process the request
s.handleRequest(wrappedWriter, r) s.processRequest(ctx)
// Calculate request duration
duration := time.Since(start)
// Get the status code
statusCode := wrappedWriter.StatusCode()
// Log the request with our custom format // Log the request with our custom format
if s.loggingEnabled { if s.loggingEnabled {
LogRequest(statusCode, r, duration) duration := time.Since(start)
LogRequest(ctx.Response.StatusCode(), method, path, duration)
} }
} }
// handleRequest processes the actual request // processRequest processes the actual request
func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { func (s *Server) processRequest(ctx *fasthttp.RequestCtx) {
logger.Debug("Processing request %s %s", r.Method, r.URL.Path) method := string(ctx.Method())
path := string(ctx.Path())
logger.Debug("Processing request %s %s", method, path)
// Try Lua routes first // Try Lua routes first
params := &routers.Params{} params := &routers.Params{}
bytecode, scriptPath, found := s.luaRouter.GetBytecode(r.Method, r.URL.Path, params) bytecode, scriptPath, found := s.luaRouter.GetBytecode(method, path, params)
// Check if we found a route but it has no valid bytecode (compile error) // Check if we found a route but it has no valid bytecode (compile error)
if found && len(bytecode) == 0 { if found && len(bytecode) == 0 {
// Get the actual error from the router - this requires exposing the actual error // Get the actual error from the router
// from the node in the GetBytecode method
errorMsg := "Route exists but failed to compile. Check server logs for details." errorMsg := "Route exists but failed to compile. Check server logs for details."
// Get the actual node to access its error // Get the actual node to access its error
if node, _ := s.luaRouter.GetNodeWithError(r.Method, r.URL.Path, params); node != nil && node.Error != nil { if node, _ := s.luaRouter.GetNodeWithError(method, path, params); node != nil && node.Error != nil {
errorMsg = node.Error.Error() errorMsg = node.Error.Error()
} }
logger.Error("%s %s - %s", r.Method, r.URL.Path, errorMsg) logger.Error("%s %s - %s", method, path, errorMsg)
// Show error page with the actual error message // Show error page with the actual error message
w.Header().Set("Content-Type", "text/html; charset=utf-8") ctx.SetContentType("text/html; charset=utf-8")
w.WriteHeader(http.StatusInternalServerError) ctx.SetStatusCode(fasthttp.StatusInternalServerError)
errorHTML := utils.InternalErrorPage(s.errorConfig, r.URL.Path, errorMsg) errorHTML := utils.InternalErrorPage(s.errorConfig, path, errorMsg)
w.Write([]byte(errorHTML)) ctx.SetBody([]byte(errorHTML))
return return
} else if found { } else if found {
logger.Debug("Found Lua route match for %s %s with %d params", r.Method, r.URL.Path, params.Count) logger.Debug("Found Lua route match for %s %s with %d params", method, path, params.Count)
s.handleLuaRoute(w, r, bytecode, scriptPath, params) s.handleLuaRoute(ctx, bytecode, scriptPath, params)
return return
} }
// Then try static files // Then try static files
if filePath, found := s.staticRouter.Match(r.URL.Path); found { if filePath, found := s.staticRouter.Match(path); found {
http.ServeFile(w, r, filePath) ctx.SendFile(filePath)
return return
} }
// No route found - 404 Not Found // No route found - 404 Not Found
w.Header().Set("Content-Type", "text/html; charset=utf-8") ctx.SetContentType("text/html; charset=utf-8")
w.WriteHeader(http.StatusNotFound) ctx.SetStatusCode(fasthttp.StatusNotFound)
w.Write([]byte(utils.NotFoundPage(s.errorConfig, r.URL.Path))) ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
} }
// HandleMethodNotAllowed responds with a 405 Method Not Allowed error // HandleMethodNotAllowed responds with a 405 Method Not Allowed error
func (s *Server) HandleMethodNotAllowed(w http.ResponseWriter, r *http.Request) { func HandleMethodNotAllowed(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
w.Header().Set("Content-Type", "text/html; charset=utf-8") path := string(ctx.Path())
w.WriteHeader(http.StatusMethodNotAllowed) ctx.SetContentType("text/html; charset=utf-8")
w.Write([]byte(utils.MethodNotAllowedPage(s.errorConfig, r.URL.Path))) ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
ctx.SetBody([]byte(utils.MethodNotAllowedPage(errorConfig, path)))
} }
// handleLuaRoute executes a Lua route // handleLuaRoute executes a Lua route
func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode []byte, scriptPath string, params *routers.Params) { func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params) {
ctx := runner.NewContext() luaCtx := runner.NewContext()
defer ctx.Release() defer luaCtx.Release()
// Set up context exactly as the original method := string(ctx.Method())
cookieMap := make(map[string]any) path := string(ctx.Path())
for _, cookie := range r.Cookies() { host := string(ctx.Host())
cookieMap[cookie.Name] = cookie.Value
} // Set up context
ctx.Set("_request_cookies", cookieMap) luaCtx.Set("method", method)
ctx.Set("method", r.Method) luaCtx.Set("path", path)
ctx.Set("path", r.URL.Path) luaCtx.Set("host", host)
ctx.Set("host", r.Host)
// Headers // Headers
headerMap := make(map[string]any, len(r.Header)) headerMap := make(map[string]any)
for name, values := range r.Header { ctx.Request.Header.VisitAll(func(key, value []byte) {
if len(values) == 1 { headerMap[string(key)] = string(value)
headerMap[name] = values[0] })
} else { luaCtx.Set("headers", headerMap)
headerMap[name] = values
}
}
ctx.Set("headers", headerMap)
// Cookies // Cookies
if cookies := r.Cookies(); len(cookies) > 0 { cookieMap := make(map[string]any)
cookieMap := make(map[string]any, len(cookies)) ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
for _, cookie := range cookies { cookieMap[string(key)] = string(value)
cookieMap[cookie.Name] = cookie.Value })
} if len(cookieMap) > 0 {
ctx.Set("cookies", cookieMap) luaCtx.Set("cookies", cookieMap)
luaCtx.Set("_request_cookies", cookieMap) // For backward compatibility
} else {
luaCtx.Set("cookies", make(map[string]any))
luaCtx.Set("_request_cookies", make(map[string]any))
} }
// URL parameters // URL parameters
@ -197,45 +196,51 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
for i, key := range params.Keys { for i, key := range params.Keys {
paramMap[key] = params.Values[i] paramMap[key] = params.Values[i]
} }
ctx.Set("params", paramMap) luaCtx.Set("params", paramMap)
} else {
luaCtx.Set("params", make(map[string]any))
} }
// Query parameters // Query parameters
queryMap := QueryToLua(r) queryMap := QueryToLua(ctx)
if queryMap == nil { luaCtx.Set("query", queryMap)
ctx.Set("query", make(map[string]any))
} else {
ctx.Set("query", queryMap)
}
// Form data // Form data
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch { if method == "POST" || method == "PUT" || method == "PATCH" {
if formData, err := ParseForm(r); err == nil && len(formData) > 0 { formData, err := ParseForm(ctx)
ctx.Set("form", formData) if err == nil && len(formData) > 0 {
luaCtx.Set("form", formData)
} else if err != nil {
logger.Warning("Error parsing form: %v", err)
luaCtx.Set("form", make(map[string]any))
} else {
luaCtx.Set("form", make(map[string]any))
} }
} else {
luaCtx.Set("form", make(map[string]any))
} }
// Execute Lua script // Execute Lua script
result, err := s.luaRunner.Run(bytecode, ctx, scriptPath) result, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
// Special handling for CSRF error // Special handling for CSRF error
if err != nil { if err != nil {
if csrfErr, ok := err.(*runner.CSRFError); ok { if csrfErr, ok := err.(*runner.CSRFError); ok {
logger.Warning("CSRF error executing Lua route: %v", csrfErr) logger.Warning("CSRF error executing Lua route: %v", csrfErr)
HandleCSRFError(w, r, s.errorConfig) HandleCSRFError(ctx, s.errorConfig)
return return
} }
// Normal error handling // Normal error handling
logger.Error("Error executing Lua route: %v", err) logger.Error("Error executing Lua route: %v", err)
w.Header().Set("Content-Type", "text/html; charset=utf-8") ctx.SetContentType("text/html; charset=utf-8")
w.WriteHeader(http.StatusInternalServerError) ctx.SetStatusCode(fasthttp.StatusInternalServerError)
errorHTML := utils.InternalErrorPage(s.errorConfig, r.URL.Path, err.Error()) errorHTML := utils.InternalErrorPage(s.errorConfig, path, err.Error())
w.Write([]byte(errorHTML)) ctx.SetBody([]byte(errorHTML))
return return
} }
writeResponse(w, result) writeResponse(ctx, result)
} }
// Content types for responses // Content types for responses
@ -245,9 +250,9 @@ const (
) )
// writeResponse writes the Lua result to the HTTP response // writeResponse writes the Lua result to the HTTP response
func writeResponse(w http.ResponseWriter, result any) { func writeResponse(ctx *fasthttp.RequestCtx, result any) {
if result == nil { if result == nil {
w.WriteHeader(http.StatusNoContent) ctx.SetStatusCode(fasthttp.StatusNoContent)
return return
} }
@ -257,16 +262,34 @@ func writeResponse(w http.ResponseWriter, result any) {
// Set response headers // Set response headers
for name, value := range httpResp.Headers { for name, value := range httpResp.Headers {
w.Header().Set(name, value) ctx.Response.Header.Set(name, value)
} }
// Set cookies // Set cookies
for _, cookie := range httpResp.Cookies { for _, cookie := range httpResp.Cookies {
http.SetCookie(w, cookie) // Convert net/http cookie to fasthttp cookie
var c fasthttp.Cookie
c.SetKey(cookie.Name)
c.SetValue(cookie.Value)
if cookie.Path != "" {
c.SetPath(cookie.Path)
}
if cookie.Domain != "" {
c.SetDomain(cookie.Domain)
}
if cookie.MaxAge > 0 {
c.SetMaxAge(cookie.MaxAge)
}
if cookie.Expires.After(time.Time{}) {
c.SetExpire(cookie.Expires)
}
c.SetSecure(cookie.Secure)
c.SetHTTPOnly(cookie.HttpOnly)
ctx.Response.Header.SetCookie(&c)
} }
// Set status code // Set status code
w.WriteHeader(httpResp.Status) ctx.SetStatusCode(httpResp.Status)
// Process the body based on its type // Process the body based on its type
if httpResp.Body == nil { if httpResp.Body == nil {
@ -285,39 +308,39 @@ func writeResponse(w http.ResponseWriter, result any) {
} }
if isJSON { if isJSON {
setContentTypeIfMissing(w, contentTypeJSON) setContentTypeIfMissing(ctx, contentTypeJSON)
data, err := json.Marshal(result) data, err := json.Marshal(result)
if err != nil { if err != nil {
logger.Error("Failed to marshal response: %v", err) logger.Error("Failed to marshal response: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return return
} }
w.Write(data) ctx.SetBody(data)
return return
} }
// All other types - convert to plain text // All other types - convert to plain text
setContentTypeIfMissing(w, contentTypePlain) setContentTypeIfMissing(ctx, contentTypePlain)
switch r := result.(type) { switch r := result.(type) {
case string: case string:
w.Write([]byte(r)) ctx.SetBodyString(r)
case []byte: case []byte:
w.Write(r) ctx.SetBody(r)
default: default:
// Convert any other type to string // Convert any other type to string
fmt.Fprintf(w, "%v", r) ctx.SetBodyString(fmt.Sprintf("%v", r))
} }
} }
func setContentTypeIfMissing(w http.ResponseWriter, contentType string) { func setContentTypeIfMissing(ctx *fasthttp.RequestCtx, contentType string) {
if w.Header().Get("Content-Type") == "" { if len(ctx.Response.Header.ContentType()) == 0 {
w.Header().Set("Content-Type", contentType) ctx.SetContentType(contentType)
} }
} }
// handleDebugStats displays debug statistics // handleDebugStats displays debug statistics
func (s *Server) handleDebugStats(w http.ResponseWriter, _ *http.Request) { func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) {
// Collect system stats // Collect system stats
stats := utils.CollectSystemStats(s.config) stats := utils.CollectSystemStats(s.config)
@ -335,7 +358,7 @@ func (s *Server) handleDebugStats(w http.ResponseWriter, _ *http.Request) {
html := utils.DebugStatsPage(stats) html := utils.DebugStatsPage(stats)
// Send the response // Send the response
w.Header().Set("Content-Type", "text/html; charset=utf-8") ctx.SetContentType("text/html; charset=utf-8")
w.WriteHeader(http.StatusOK) ctx.SetStatusCode(fasthttp.StatusOK)
w.Write([]byte(html)) ctx.SetBody([]byte(html))
} }

View File

@ -1,33 +0,0 @@
package http
import (
"net/http"
)
// statusCaptureWriter is a ResponseWriter that captures the status code
type statusCaptureWriter struct {
http.ResponseWriter
statusCode int
}
// WriteHeader captures the status code and passes it to the wrapped ResponseWriter
func (w *statusCaptureWriter) WriteHeader(code int) {
w.statusCode = code
w.ResponseWriter.WriteHeader(code)
}
// StatusCode returns the captured status code
func (w *statusCaptureWriter) StatusCode() int {
if w.statusCode == 0 {
return http.StatusOK // Default to 200 if not explicitly set
}
return w.statusCode
}
// newStatusCaptureWriter creates a new statusCaptureWriter
func newStatusCaptureWriter(w http.ResponseWriter) *statusCaptureWriter {
return &statusCaptureWriter{
ResponseWriter: w,
statusCode: 0,
}
}

7
go.mod
View File

@ -4,4 +4,11 @@ go 1.24.1
require git.sharkk.net/Sky/LuaJIT-to-Go v0.0.0 require git.sharkk.net/Sky/LuaJIT-to-Go v0.0.0
require (
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.60.0 // indirect
)
replace git.sharkk.net/Sky/LuaJIT-to-Go => ./luajit replace git.sharkk.net/Sky/LuaJIT-to-Go => ./luajit

9
go.sum Normal file
View File

@ -0,0 +1,9 @@
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw=
github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=