package http import ( "Moonshark/core/runner" luaCtx "Moonshark/core/runner/context" "Moonshark/core/utils" "Moonshark/core/utils/logger" "crypto/subtle" "github.com/valyala/fasthttp" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // ValidateCSRFToken checks if the CSRF token is valid for a request func ValidateCSRFToken(state *luajit.State, ctx *luaCtx.Context) bool { // Only validate for form submissions method, ok := ctx.Get("method").(string) if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") { return true } // Get form data formData, ok := ctx.Get("form").(map[string]any) if !ok || formData == nil { logger.Warning("CSRF validation failed: no form data") return false } // Get token from form formToken, ok := formData["csrf"].(string) if !ok || formToken == "" { logger.Warning("CSRF validation failed: no token in form") return false } // Get session token state.GetGlobal("session") if state.IsNil(-1) { state.Pop(1) logger.Warning("CSRF validation failed: session module not available") return false } state.GetField(-1, "get") if !state.IsFunction(-1) { state.Pop(2) logger.Warning("CSRF validation failed: session.get not available") return false } state.PushCopy(-1) // Duplicate function state.PushString("_csrf_token") if err := state.Call(1, 1); err != nil { state.Pop(3) // Pop error, function and session table logger.Warning("CSRF validation failed: %v", err) return false } if state.IsNil(-1) { state.Pop(3) // Pop nil, function and session table logger.Warning("CSRF validation failed: no token in session") return false } sessionToken := state.ToString(-1) state.Pop(3) // Pop token, function and session table // Constant-time comparison to prevent timing attacks return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1 } // WithCSRFProtection creates a runner option to add CSRF protection func WithCSRFProtection() runner.RunnerOption { return func(r *runner.Runner) { r.AddInitHook(func(state *luajit.State, ctx *luaCtx.Context) error { // Get request method method, ok := ctx.Get("method").(string) if !ok { return nil } // Only validate for form submissions if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" { return nil } // Check for form data form, ok := ctx.Get("form").(map[string]any) if !ok || form == nil { return nil } // Validate CSRF token if !ValidateCSRFToken(state, ctx) { return ErrCSRFValidationFailed } return nil }) } } // Error for CSRF validation failure var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"} // CSRFError represents a CSRF validation error type CSRFError struct { message string } // Error implements the error interface func (e *CSRFError) Error() string { return e.message } // HandleCSRFError handles a CSRF validation error func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) { method := string(ctx.Method()) path := string(ctx.Path()) logger.Warning("CSRF validation failed for %s %s", method, path) ctx.SetContentType("text/html; charset=utf-8") ctx.SetStatusCode(fasthttp.StatusForbidden) errorMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt." errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg) ctx.SetBody([]byte(errorHTML)) }