fasthttp 1
This commit is contained in:
parent
5bba4ffcf8
commit
95eae40357
|
@ -1,20 +1,22 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
"git.sharkk.net/Sky/Moonshark/core/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// HandleCSRFError handles a CSRF validation error
|
||||
func HandleCSRFError(w http.ResponseWriter, r *http.Request, errorConfig utils.ErrorPageConfig) {
|
||||
logger.Warning("CSRF validation failed for %s %s", r.Method, r.URL.Path)
|
||||
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
|
||||
method := string(ctx.Method())
|
||||
path := string(ctx.Path())
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
logger.Warning("CSRF validation failed for %s %s", method, path)
|
||||
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
|
||||
errorMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
|
||||
errorHTML := utils.ForbiddenPage(errorConfig, r.URL.Path, errorMsg)
|
||||
w.Write([]byte(errorHTML))
|
||||
errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg)
|
||||
ctx.SetBody([]byte(errorHTML))
|
||||
}
|
||||
|
|
|
@ -2,11 +2,10 @@ package http
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Maximum form parse size (16MB)
|
||||
|
@ -20,113 +19,94 @@ var (
|
|||
|
||||
// ParseForm parses a POST request body into a map of values
|
||||
// Supports both application/x-www-form-urlencoded and multipart/form-data content types
|
||||
func ParseForm(r *http.Request) (map[string]any, error) {
|
||||
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
// Only handle POST, PUT, PATCH
|
||||
if r.Method != http.MethodPost &&
|
||||
r.Method != http.MethodPut &&
|
||||
r.Method != http.MethodPatch {
|
||||
method := string(ctx.Method())
|
||||
if method != "POST" && method != "PUT" && method != "PATCH" {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
// Check content type
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
contentType := string(ctx.Request.Header.ContentType())
|
||||
if contentType == "" {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
// Parse the media type
|
||||
mediaType, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidFormType
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
|
||||
switch {
|
||||
case mediaType == "application/x-www-form-urlencoded":
|
||||
// Handle URL-encoded form
|
||||
if err := parseURLEncodedForm(r, result); err != nil {
|
||||
return nil, err
|
||||
// Check for content length to prevent DOS
|
||||
if len(ctx.Request.Body()) > maxFormSize {
|
||||
return nil, ErrFormSizeTooLarge
|
||||
}
|
||||
|
||||
case strings.HasPrefix(mediaType, "multipart/form-data"):
|
||||
// Handle multipart form
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return nil, ErrInvalidFormType
|
||||
// Handle by content type
|
||||
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") {
|
||||
return parseURLEncodedForm(ctx)
|
||||
} else if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return parseMultipartForm(ctx)
|
||||
}
|
||||
|
||||
if err := parseMultipartForm(r, boundary, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
// Unrecognized content type
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseURLEncodedForm handles application/x-www-form-urlencoded forms
|
||||
func parseURLEncodedForm(r *http.Request, result map[string]any) error {
|
||||
// Enforce size limit
|
||||
r.Body = http.MaxBytesReader(nil, r.Body, maxFormSize)
|
||||
func parseURLEncodedForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
|
||||
// Read the entire body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "http: request body too large") {
|
||||
return ErrFormSizeTooLarge
|
||||
}
|
||||
return err
|
||||
}
|
||||
// Process form values directly from PostArgs()
|
||||
ctx.PostArgs().VisitAll(func(key, value []byte) {
|
||||
keyStr := string(key)
|
||||
valStr := string(value)
|
||||
|
||||
// Parse form values
|
||||
form, err := url.ParseQuery(string(body))
|
||||
if err != nil {
|
||||
return err
|
||||
// Check if we already have this key
|
||||
if existing, ok := result[keyStr]; ok {
|
||||
// If it's already a slice, append
|
||||
if existingSlice, ok := existing.([]string); ok {
|
||||
result[keyStr] = append(existingSlice, valStr)
|
||||
} else {
|
||||
// Convert to slice and append
|
||||
result[keyStr] = []string{existing.(string), valStr}
|
||||
}
|
||||
} else {
|
||||
// New key
|
||||
result[keyStr] = valStr
|
||||
}
|
||||
})
|
||||
|
||||
// Convert to map[string]any
|
||||
for key, values := range form {
|
||||
if len(values) == 1 {
|
||||
// Single value
|
||||
result[key] = values[0]
|
||||
} else if len(values) > 1 {
|
||||
// Multiple values
|
||||
result[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseMultipartForm handles multipart/form-data forms
|
||||
func parseMultipartForm(r *http.Request, boundary string, result map[string]any) error {
|
||||
// Limit the form size
|
||||
if err := r.ParseMultipartForm(maxFormSize); err != nil {
|
||||
if strings.Contains(err.Error(), "http: request body too large") {
|
||||
return ErrFormSizeTooLarge
|
||||
func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
|
||||
// Parse multipart form
|
||||
form, err := ctx.MultipartForm()
|
||||
if err != nil {
|
||||
if err == multipart.ErrMessageTooLarge || strings.Contains(err.Error(), "too large") {
|
||||
return nil, ErrFormSizeTooLarge
|
||||
}
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process form values
|
||||
for key, values := range r.MultipartForm.Value {
|
||||
for key, values := range form.Value {
|
||||
if len(values) == 1 {
|
||||
// Single value
|
||||
result[key] = values[0]
|
||||
} else if len(values) > 1 {
|
||||
// Multiple values
|
||||
result[key] = values
|
||||
// Multiple values - store as string slice
|
||||
strValues := make([]string, len(values))
|
||||
copy(strValues, values)
|
||||
result[key] = strValues
|
||||
}
|
||||
}
|
||||
|
||||
// We don't handle file uploads here - could be extended in the future
|
||||
// if needed to support file uploads to Lua
|
||||
|
||||
return nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Usage:
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
|
@ -18,13 +17,13 @@ const (
|
|||
)
|
||||
|
||||
// LogRequest logs an HTTP request with custom formatting
|
||||
func LogRequest(statusCode int, r *http.Request, duration time.Duration) {
|
||||
func LogRequest(statusCode int, method, path string, duration time.Duration) {
|
||||
statusColor := getStatusColor(statusCode)
|
||||
|
||||
// Use the logger's raw message writer to bypass the standard format
|
||||
logger.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s",
|
||||
colorGray, time.Now().Format(logger.TimeFormat()), colorReset,
|
||||
statusColor, statusCode, r.Method, colorReset, r.URL.Path, colorGray, duration, colorReset)
|
||||
statusColor, statusCode, method, colorReset, path, colorGray, duration, colorReset)
|
||||
}
|
||||
|
||||
// getStatusColor returns the ANSI color code for a status code
|
||||
|
|
|
@ -1,36 +1,42 @@
|
|||
package http
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// QueryToLua converts HTTP query parameters to a map that can be used with LuaJIT.
|
||||
// Single value parameters are stored as strings.
|
||||
// Multi-value parameters are converted to []any arrays.
|
||||
func QueryToLua(r *http.Request) map[string]any {
|
||||
if r == nil || r.URL == nil {
|
||||
return nil
|
||||
}
|
||||
func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
|
||||
result := make(map[string]any)
|
||||
|
||||
query := r.URL.Query()
|
||||
if len(query) == 0 {
|
||||
return nil // Avoid allocation for empty queries
|
||||
}
|
||||
// Use a map to track keys that have multiple values
|
||||
multiValueKeys := make(map[string]bool)
|
||||
|
||||
result := make(map[string]any, len(query))
|
||||
for key, values := range query {
|
||||
switch len(values) {
|
||||
case 0:
|
||||
// Skip empty values
|
||||
case 1:
|
||||
// Single value
|
||||
result[key] = values[0]
|
||||
default:
|
||||
// Multiple values - convert to []any
|
||||
arr := make([]any, len(values))
|
||||
for i, v := range values {
|
||||
arr[i] = v
|
||||
// Process all query args
|
||||
ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
keyStr := string(key)
|
||||
valStr := string(value)
|
||||
|
||||
if _, exists := result[keyStr]; exists {
|
||||
// This key already exists, convert to array if not already
|
||||
if !multiValueKeys[keyStr] {
|
||||
// First duplicate, convert existing value to array
|
||||
multiValueKeys[keyStr] = true
|
||||
result[keyStr] = []any{result[keyStr], valStr}
|
||||
} else {
|
||||
// Already an array, append
|
||||
result[keyStr] = append(result[keyStr].([]any), valStr)
|
||||
}
|
||||
result[key] = arr
|
||||
} else {
|
||||
// New key
|
||||
result[keyStr] = valStr
|
||||
}
|
||||
})
|
||||
|
||||
// If we don't have any query parameters, return empty map
|
||||
if len(result) == 0 {
|
||||
return make(map[string]any)
|
||||
}
|
||||
|
||||
return result
|
||||
|
|
|
@ -3,9 +3,7 @@ package http
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt" // Added for fmt.Fprintf
|
||||
"net"
|
||||
"net/http"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.sharkk.net/Sky/Moonshark/core/config"
|
||||
|
@ -13,6 +11,8 @@ import (
|
|||
"git.sharkk.net/Sky/Moonshark/core/routers"
|
||||
"git.sharkk.net/Sky/Moonshark/core/runner"
|
||||
"git.sharkk.net/Sky/Moonshark/core/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Server handles HTTP requests using Lua and static file routers
|
||||
|
@ -20,9 +20,9 @@ type Server struct {
|
|||
luaRouter *routers.LuaRouter
|
||||
staticRouter *routers.StaticRouter
|
||||
luaRunner *runner.Runner
|
||||
httpServer *http.Server
|
||||
fasthttpServer *fasthttp.Server
|
||||
loggingEnabled bool
|
||||
debugMode bool // Controls whether to show error details
|
||||
debugMode bool
|
||||
config *config.Config
|
||||
errorConfig utils.ErrorPageConfig
|
||||
}
|
||||
|
@ -35,7 +35,6 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
|
|||
luaRouter: luaRouter,
|
||||
staticRouter: staticRouter,
|
||||
luaRunner: runner,
|
||||
httpServer: &http.Server{},
|
||||
loggingEnabled: loggingEnabled,
|
||||
debugMode: debugMode,
|
||||
config: config,
|
||||
|
@ -44,15 +43,20 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
|
|||
DebugMode: debugMode,
|
||||
},
|
||||
}
|
||||
server.httpServer.Handler = server
|
||||
|
||||
// Set TCP keep-alive for connections
|
||||
server.httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
|
||||
if state == http.StateNew {
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
tcpConn.SetKeepAlive(true)
|
||||
}
|
||||
}
|
||||
// Configure fasthttp server
|
||||
server.fasthttpServer = &fasthttp.Server{
|
||||
Handler: server.handleRequest,
|
||||
Name: "Moonshark",
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
MaxRequestBodySize: 16 << 20, // 16MB - consistent with Forms.go
|
||||
DisableKeepalive: false,
|
||||
TCPKeepalive: true,
|
||||
TCPKeepalivePeriod: 60 * time.Second,
|
||||
ReduceMemoryUsage: true,
|
||||
GetOnly: false,
|
||||
DisablePreParseMultipartForm: true, // We'll handle parsing manually
|
||||
}
|
||||
|
||||
return server
|
||||
|
@ -60,135 +64,130 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
|
|||
|
||||
// ListenAndServe starts the server on the given address
|
||||
func (s *Server) ListenAndServe(addr string) error {
|
||||
s.httpServer.Addr = addr
|
||||
logger.ServerCont("Catch the swell at http://localhost%s", addr)
|
||||
return s.httpServer.ListenAndServe()
|
||||
return s.fasthttpServer.ListenAndServe(addr)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
return s.httpServer.Shutdown(ctx)
|
||||
return s.fasthttpServer.ShutdownWithContext(ctx)
|
||||
}
|
||||
|
||||
// ServeHTTP handles HTTP requests
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// handleRequest processes the HTTP request
|
||||
func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
|
||||
start := time.Now()
|
||||
method := string(ctx.Method())
|
||||
path := string(ctx.Path())
|
||||
|
||||
// Special case for debug stats when debug mode is enabled
|
||||
if s.debugMode && r.URL.Path == "/debug/stats" {
|
||||
s.handleDebugStats(w, r)
|
||||
if s.debugMode && path == "/debug/stats" {
|
||||
s.handleDebugStats(ctx)
|
||||
|
||||
// Calculate and log request duration
|
||||
duration := time.Since(start)
|
||||
// Log request
|
||||
if s.loggingEnabled {
|
||||
LogRequest(http.StatusOK, r, duration)
|
||||
duration := time.Since(start)
|
||||
LogRequest(ctx.Response.StatusCode(), method, path, duration)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap the ResponseWriter to capture status code
|
||||
wrappedWriter := newStatusCaptureWriter(w)
|
||||
|
||||
// Process the request
|
||||
s.handleRequest(wrappedWriter, r)
|
||||
|
||||
// Calculate request duration
|
||||
duration := time.Since(start)
|
||||
|
||||
// Get the status code
|
||||
statusCode := wrappedWriter.StatusCode()
|
||||
s.processRequest(ctx)
|
||||
|
||||
// Log the request with our custom format
|
||||
if s.loggingEnabled {
|
||||
LogRequest(statusCode, r, duration)
|
||||
duration := time.Since(start)
|
||||
LogRequest(ctx.Response.StatusCode(), method, path, duration)
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest processes the actual request
|
||||
func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Debug("Processing request %s %s", r.Method, r.URL.Path)
|
||||
// processRequest processes the actual request
|
||||
func (s *Server) processRequest(ctx *fasthttp.RequestCtx) {
|
||||
method := string(ctx.Method())
|
||||
path := string(ctx.Path())
|
||||
|
||||
logger.Debug("Processing request %s %s", method, path)
|
||||
|
||||
// Try Lua routes first
|
||||
params := &routers.Params{}
|
||||
bytecode, scriptPath, found := s.luaRouter.GetBytecode(r.Method, r.URL.Path, params)
|
||||
bytecode, scriptPath, found := s.luaRouter.GetBytecode(method, path, params)
|
||||
|
||||
// Check if we found a route but it has no valid bytecode (compile error)
|
||||
if found && len(bytecode) == 0 {
|
||||
// Get the actual error from the router - this requires exposing the actual error
|
||||
// from the node in the GetBytecode method
|
||||
// Get the actual error from the router
|
||||
errorMsg := "Route exists but failed to compile. Check server logs for details."
|
||||
|
||||
// Get the actual node to access its error
|
||||
if node, _ := s.luaRouter.GetNodeWithError(r.Method, r.URL.Path, params); node != nil && node.Error != nil {
|
||||
if node, _ := s.luaRouter.GetNodeWithError(method, path, params); node != nil && node.Error != nil {
|
||||
errorMsg = node.Error.Error()
|
||||
}
|
||||
|
||||
logger.Error("%s %s - %s", r.Method, r.URL.Path, errorMsg)
|
||||
logger.Error("%s %s - %s", method, path, errorMsg)
|
||||
|
||||
// Show error page with the actual error message
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
errorHTML := utils.InternalErrorPage(s.errorConfig, r.URL.Path, errorMsg)
|
||||
w.Write([]byte(errorHTML))
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
errorHTML := utils.InternalErrorPage(s.errorConfig, path, errorMsg)
|
||||
ctx.SetBody([]byte(errorHTML))
|
||||
return
|
||||
} else if found {
|
||||
logger.Debug("Found Lua route match for %s %s with %d params", r.Method, r.URL.Path, params.Count)
|
||||
s.handleLuaRoute(w, r, bytecode, scriptPath, params)
|
||||
logger.Debug("Found Lua route match for %s %s with %d params", method, path, params.Count)
|
||||
s.handleLuaRoute(ctx, bytecode, scriptPath, params)
|
||||
return
|
||||
}
|
||||
|
||||
// Then try static files
|
||||
if filePath, found := s.staticRouter.Match(r.URL.Path); found {
|
||||
http.ServeFile(w, r, filePath)
|
||||
if filePath, found := s.staticRouter.Match(path); found {
|
||||
ctx.SendFile(filePath)
|
||||
return
|
||||
}
|
||||
|
||||
// No route found - 404 Not Found
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte(utils.NotFoundPage(s.errorConfig, r.URL.Path)))
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
|
||||
}
|
||||
|
||||
// HandleMethodNotAllowed responds with a 405 Method Not Allowed error
|
||||
func (s *Server) HandleMethodNotAllowed(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
w.Write([]byte(utils.MethodNotAllowedPage(s.errorConfig, r.URL.Path)))
|
||||
func HandleMethodNotAllowed(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
|
||||
path := string(ctx.Path())
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
|
||||
ctx.SetBody([]byte(utils.MethodNotAllowedPage(errorConfig, path)))
|
||||
}
|
||||
|
||||
// handleLuaRoute executes a Lua route
|
||||
func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode []byte, scriptPath string, params *routers.Params) {
|
||||
ctx := runner.NewContext()
|
||||
defer ctx.Release()
|
||||
func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params) {
|
||||
luaCtx := runner.NewContext()
|
||||
defer luaCtx.Release()
|
||||
|
||||
// Set up context exactly as the original
|
||||
cookieMap := make(map[string]any)
|
||||
for _, cookie := range r.Cookies() {
|
||||
cookieMap[cookie.Name] = cookie.Value
|
||||
}
|
||||
ctx.Set("_request_cookies", cookieMap)
|
||||
ctx.Set("method", r.Method)
|
||||
ctx.Set("path", r.URL.Path)
|
||||
ctx.Set("host", r.Host)
|
||||
method := string(ctx.Method())
|
||||
path := string(ctx.Path())
|
||||
host := string(ctx.Host())
|
||||
|
||||
// Set up context
|
||||
luaCtx.Set("method", method)
|
||||
luaCtx.Set("path", path)
|
||||
luaCtx.Set("host", host)
|
||||
|
||||
// Headers
|
||||
headerMap := make(map[string]any, len(r.Header))
|
||||
for name, values := range r.Header {
|
||||
if len(values) == 1 {
|
||||
headerMap[name] = values[0]
|
||||
} else {
|
||||
headerMap[name] = values
|
||||
}
|
||||
}
|
||||
ctx.Set("headers", headerMap)
|
||||
headerMap := make(map[string]any)
|
||||
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
headerMap[string(key)] = string(value)
|
||||
})
|
||||
luaCtx.Set("headers", headerMap)
|
||||
|
||||
// Cookies
|
||||
if cookies := r.Cookies(); len(cookies) > 0 {
|
||||
cookieMap := make(map[string]any, len(cookies))
|
||||
for _, cookie := range cookies {
|
||||
cookieMap[cookie.Name] = cookie.Value
|
||||
}
|
||||
ctx.Set("cookies", cookieMap)
|
||||
cookieMap := make(map[string]any)
|
||||
ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||
cookieMap[string(key)] = string(value)
|
||||
})
|
||||
if len(cookieMap) > 0 {
|
||||
luaCtx.Set("cookies", cookieMap)
|
||||
luaCtx.Set("_request_cookies", cookieMap) // For backward compatibility
|
||||
} else {
|
||||
luaCtx.Set("cookies", make(map[string]any))
|
||||
luaCtx.Set("_request_cookies", make(map[string]any))
|
||||
}
|
||||
|
||||
// URL parameters
|
||||
|
@ -197,45 +196,51 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
|||
for i, key := range params.Keys {
|
||||
paramMap[key] = params.Values[i]
|
||||
}
|
||||
ctx.Set("params", paramMap)
|
||||
luaCtx.Set("params", paramMap)
|
||||
} else {
|
||||
luaCtx.Set("params", make(map[string]any))
|
||||
}
|
||||
|
||||
// Query parameters
|
||||
queryMap := QueryToLua(r)
|
||||
if queryMap == nil {
|
||||
ctx.Set("query", make(map[string]any))
|
||||
} else {
|
||||
ctx.Set("query", queryMap)
|
||||
}
|
||||
queryMap := QueryToLua(ctx)
|
||||
luaCtx.Set("query", queryMap)
|
||||
|
||||
// Form data
|
||||
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
|
||||
if formData, err := ParseForm(r); err == nil && len(formData) > 0 {
|
||||
ctx.Set("form", formData)
|
||||
if method == "POST" || method == "PUT" || method == "PATCH" {
|
||||
formData, err := ParseForm(ctx)
|
||||
if err == nil && len(formData) > 0 {
|
||||
luaCtx.Set("form", formData)
|
||||
} else if err != nil {
|
||||
logger.Warning("Error parsing form: %v", err)
|
||||
luaCtx.Set("form", make(map[string]any))
|
||||
} else {
|
||||
luaCtx.Set("form", make(map[string]any))
|
||||
}
|
||||
} else {
|
||||
luaCtx.Set("form", make(map[string]any))
|
||||
}
|
||||
|
||||
// Execute Lua script
|
||||
result, err := s.luaRunner.Run(bytecode, ctx, scriptPath)
|
||||
result, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
|
||||
|
||||
// Special handling for CSRF error
|
||||
if err != nil {
|
||||
if csrfErr, ok := err.(*runner.CSRFError); ok {
|
||||
logger.Warning("CSRF error executing Lua route: %v", csrfErr)
|
||||
HandleCSRFError(w, r, s.errorConfig)
|
||||
HandleCSRFError(ctx, s.errorConfig)
|
||||
return
|
||||
}
|
||||
|
||||
// Normal error handling
|
||||
logger.Error("Error executing Lua route: %v", err)
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
errorHTML := utils.InternalErrorPage(s.errorConfig, r.URL.Path, err.Error())
|
||||
w.Write([]byte(errorHTML))
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
errorHTML := utils.InternalErrorPage(s.errorConfig, path, err.Error())
|
||||
ctx.SetBody([]byte(errorHTML))
|
||||
return
|
||||
}
|
||||
|
||||
writeResponse(w, result)
|
||||
writeResponse(ctx, result)
|
||||
}
|
||||
|
||||
// Content types for responses
|
||||
|
@ -245,9 +250,9 @@ const (
|
|||
)
|
||||
|
||||
// writeResponse writes the Lua result to the HTTP response
|
||||
func writeResponse(w http.ResponseWriter, result any) {
|
||||
func writeResponse(ctx *fasthttp.RequestCtx, result any) {
|
||||
if result == nil {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
ctx.SetStatusCode(fasthttp.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -257,16 +262,34 @@ func writeResponse(w http.ResponseWriter, result any) {
|
|||
|
||||
// Set response headers
|
||||
for name, value := range httpResp.Headers {
|
||||
w.Header().Set(name, value)
|
||||
ctx.Response.Header.Set(name, value)
|
||||
}
|
||||
|
||||
// Set cookies
|
||||
for _, cookie := range httpResp.Cookies {
|
||||
http.SetCookie(w, cookie)
|
||||
// Convert net/http cookie to fasthttp cookie
|
||||
var c fasthttp.Cookie
|
||||
c.SetKey(cookie.Name)
|
||||
c.SetValue(cookie.Value)
|
||||
if cookie.Path != "" {
|
||||
c.SetPath(cookie.Path)
|
||||
}
|
||||
if cookie.Domain != "" {
|
||||
c.SetDomain(cookie.Domain)
|
||||
}
|
||||
if cookie.MaxAge > 0 {
|
||||
c.SetMaxAge(cookie.MaxAge)
|
||||
}
|
||||
if cookie.Expires.After(time.Time{}) {
|
||||
c.SetExpire(cookie.Expires)
|
||||
}
|
||||
c.SetSecure(cookie.Secure)
|
||||
c.SetHTTPOnly(cookie.HttpOnly)
|
||||
ctx.Response.Header.SetCookie(&c)
|
||||
}
|
||||
|
||||
// Set status code
|
||||
w.WriteHeader(httpResp.Status)
|
||||
ctx.SetStatusCode(httpResp.Status)
|
||||
|
||||
// Process the body based on its type
|
||||
if httpResp.Body == nil {
|
||||
|
@ -285,39 +308,39 @@ func writeResponse(w http.ResponseWriter, result any) {
|
|||
}
|
||||
|
||||
if isJSON {
|
||||
setContentTypeIfMissing(w, contentTypeJSON)
|
||||
setContentTypeIfMissing(ctx, contentTypeJSON)
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
logger.Error("Failed to marshal response: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(data)
|
||||
ctx.SetBody(data)
|
||||
return
|
||||
}
|
||||
|
||||
// All other types - convert to plain text
|
||||
setContentTypeIfMissing(w, contentTypePlain)
|
||||
setContentTypeIfMissing(ctx, contentTypePlain)
|
||||
|
||||
switch r := result.(type) {
|
||||
case string:
|
||||
w.Write([]byte(r))
|
||||
ctx.SetBodyString(r)
|
||||
case []byte:
|
||||
w.Write(r)
|
||||
ctx.SetBody(r)
|
||||
default:
|
||||
// Convert any other type to string
|
||||
fmt.Fprintf(w, "%v", r)
|
||||
ctx.SetBodyString(fmt.Sprintf("%v", r))
|
||||
}
|
||||
}
|
||||
|
||||
func setContentTypeIfMissing(w http.ResponseWriter, contentType string) {
|
||||
if w.Header().Get("Content-Type") == "" {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
func setContentTypeIfMissing(ctx *fasthttp.RequestCtx, contentType string) {
|
||||
if len(ctx.Response.Header.ContentType()) == 0 {
|
||||
ctx.SetContentType(contentType)
|
||||
}
|
||||
}
|
||||
|
||||
// handleDebugStats displays debug statistics
|
||||
func (s *Server) handleDebugStats(w http.ResponseWriter, _ *http.Request) {
|
||||
func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) {
|
||||
// Collect system stats
|
||||
stats := utils.CollectSystemStats(s.config)
|
||||
|
||||
|
@ -335,7 +358,7 @@ func (s *Server) handleDebugStats(w http.ResponseWriter, _ *http.Request) {
|
|||
html := utils.DebugStatsPage(stats)
|
||||
|
||||
// Send the response
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(html))
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetBody([]byte(html))
|
||||
}
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// statusCaptureWriter is a ResponseWriter that captures the status code
|
||||
type statusCaptureWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code and passes it to the wrapped ResponseWriter
|
||||
func (w *statusCaptureWriter) WriteHeader(code int) {
|
||||
w.statusCode = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// StatusCode returns the captured status code
|
||||
func (w *statusCaptureWriter) StatusCode() int {
|
||||
if w.statusCode == 0 {
|
||||
return http.StatusOK // Default to 200 if not explicitly set
|
||||
}
|
||||
return w.statusCode
|
||||
}
|
||||
|
||||
// newStatusCaptureWriter creates a new statusCaptureWriter
|
||||
func newStatusCaptureWriter(w http.ResponseWriter) *statusCaptureWriter {
|
||||
return &statusCaptureWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: 0,
|
||||
}
|
||||
}
|
7
go.mod
7
go.mod
|
@ -4,4 +4,11 @@ go 1.24.1
|
|||
|
||||
require git.sharkk.net/Sky/LuaJIT-to-Go v0.0.0
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.60.0 // indirect
|
||||
)
|
||||
|
||||
replace git.sharkk.net/Sky/LuaJIT-to-Go => ./luajit
|
||||
|
|
9
go.sum
Normal file
9
go.sum
Normal file
|
@ -0,0 +1,9 @@
|
|||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw=
|
||||
github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
Loading…
Reference in New Issue
Block a user