From eea5ba8c8acbdb7489fc7644f59a31a19f2adcbf Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Wed, 2 Apr 2025 22:22:03 -0500 Subject: [PATCH] csrf 1 --- Moonshark.go | 1 + core/http/Csrf.go | 20 ++++ core/http/Server.go | 39 +++---- core/runner/Cookies.go | 2 +- core/runner/CoreModules.go | 1 + core/runner/Csrf.go | 230 +++++++++++++++++++++++++++++++++++++ core/utils/ErrorPages.go | 34 ++++++ 7 files changed, 305 insertions(+), 22 deletions(-) create mode 100644 core/http/Csrf.go create mode 100644 core/runner/Csrf.go diff --git a/Moonshark.go b/Moonshark.go index bb20f14..a659cc7 100644 --- a/Moonshark.go +++ b/Moonshark.go @@ -185,6 +185,7 @@ func (s *Moonshark) initRunner() error { runner.WithPoolSize(s.Config.PoolSize), runner.WithLibDirs(s.Config.LibDirs...), runner.WithSessionManager(sessionManager), + runner.WithCSRFProtection(), } // Add debug option conditionally diff --git a/core/http/Csrf.go b/core/http/Csrf.go new file mode 100644 index 0000000..6a986a6 --- /dev/null +++ b/core/http/Csrf.go @@ -0,0 +1,20 @@ +package http + +import ( + "net/http" + + "git.sharkk.net/Sky/Moonshark/core/logger" + "git.sharkk.net/Sky/Moonshark/core/utils" +) + +// HandleCSRFError handles a CSRF validation error +func HandleCSRFError(w http.ResponseWriter, r *http.Request, errorConfig utils.ErrorPageConfig) { + logger.Warning("CSRF validation failed for %s %s", r.Method, r.URL.Path) + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusForbidden) + + errorMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt." + errorHTML := utils.ForbiddenPage(errorConfig, r.URL.Path, errorMsg) + w.Write([]byte(errorHTML)) +} diff --git a/core/http/Server.go b/core/http/Server.go index d510a83..350fbdd 100644 --- a/core/http/Server.go +++ b/core/http/Server.go @@ -162,24 +162,17 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode ctx := runner.NewContext() defer ctx.Release() - // Log bytecode size - logger.Debug("Executing Lua route with %d bytes of bytecode", len(bytecode)) - - // Extract cookies instead of storing the raw request + // Set up context exactly as the original cookieMap := make(map[string]any) for _, cookie := range r.Cookies() { cookieMap[cookie.Name] = cookie.Value } - - // Store cookie map instead of raw request ctx.Set("_request_cookies", cookieMap) - - // Add request info directly to context ctx.Set("method", r.Method) ctx.Set("path", r.URL.Path) ctx.Set("host", r.Host) - // Add headers to context + // Headers headerMap := make(map[string]any, len(r.Header)) for name, values := range r.Header { if len(values) == 1 { @@ -190,7 +183,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode } ctx.Set("headers", headerMap) - // Add cookies to context + // Cookies if cookies := r.Cookies(); len(cookies) > 0 { cookieMap := make(map[string]any, len(cookies)) for _, cookie := range cookies { @@ -199,7 +192,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode ctx.Set("cookies", cookieMap) } - // Add URL parameters + // URL parameters if params.Count > 0 { paramMap := make(map[string]any, params.Count) for i, key := range params.Keys { @@ -208,7 +201,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode ctx.Set("params", paramMap) } - // Parse query parameters only if present + // Query parameters queryMap := QueryToLua(r) if queryMap == nil { ctx.Set("query", make(map[string]any)) @@ -216,7 +209,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode ctx.Set("query", queryMap) } - // Add form data for POST/PUT/PATCH only when needed + // Form data if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch { if formData, err := ParseForm(r); err == nil && len(formData) > 0 { ctx.Set("form", formData) @@ -225,16 +218,20 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode // Execute Lua script result, err := s.luaRunner.Run(bytecode, ctx, scriptPath) - if err != nil { - logger.Error("Error executing Lua route: %v", err) - // Set content type to HTML + // Special handling for CSRF error + if err != nil { + if csrfErr, ok := err.(*runner.CSRFError); ok { + logger.Warning("CSRF error executing Lua route: %v", csrfErr) + HandleCSRFError(w, r, s.errorConfig) + return + } + + // Normal error handling + logger.Error("Error executing Lua route: %v", err) w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusInternalServerError) - - // Generate error page with error message - errorMsg := err.Error() - errorHTML := utils.InternalErrorPage(s.errorConfig, r.URL.Path, errorMsg) + errorHTML := utils.InternalErrorPage(s.errorConfig, r.URL.Path, err.Error()) w.Write([]byte(errorHTML)) return } @@ -321,7 +318,7 @@ func setContentTypeIfMissing(w http.ResponseWriter, contentType string) { } // handleDebugStats displays debug statistics -func (s *Server) handleDebugStats(w http.ResponseWriter, r *http.Request) { +func (s *Server) handleDebugStats(w http.ResponseWriter, _ *http.Request) { // Collect system stats stats := utils.CollectSystemStats(s.config) diff --git a/core/runner/Cookies.go b/core/runner/Cookies.go index b75f6d4..7047952 100644 --- a/core/runner/Cookies.go +++ b/core/runner/Cookies.go @@ -148,7 +148,7 @@ local cookie = { end -- Access values directly from current environment - local env = getfenv(1) + local env = getfenv(2) -- Check if context exists and has cookies if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then diff --git a/core/runner/CoreModules.go b/core/runner/CoreModules.go index 25ab135..fc98e33 100644 --- a/core/runner/CoreModules.go +++ b/core/runner/CoreModules.go @@ -132,6 +132,7 @@ func init() { GlobalRegistry.EnableDebug() // Enable debugging by default GlobalRegistry.Register("http", HTTPModuleInitFunc()) GlobalRegistry.Register("cookie", CookieModuleInitFunc()) + GlobalRegistry.Register("csrf", CSRFModuleInitFunc()) logger.Debug("[CoreModuleRegistry] Core modules registered in init()") } diff --git a/core/runner/Csrf.go b/core/runner/Csrf.go new file mode 100644 index 0000000..8c0bd04 --- /dev/null +++ b/core/runner/Csrf.go @@ -0,0 +1,230 @@ +package runner + +import ( + "crypto/subtle" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "git.sharkk.net/Sky/Moonshark/core/logger" +) + +// LuaCSRFModule provides CSRF protection functionality to Lua scripts +const LuaCSRFModule = ` +-- CSRF protection module +local csrf = { + -- Session key where the token is stored + TOKEN_KEY = "_csrf_token", + + -- Default form field name + DEFAULT_FIELD = "csrf", + + -- Generate a new CSRF token and store it in the session + generate = function(length) + -- Default length is 32 characters + length = length or 32 + + if length < 16 then + -- Enforce minimum security + length = 16 + end + + -- Check if we have a session module + if not session then + error("CSRF protection requires the session module", 2) + end + + -- Generate a secure random token using os.time and math.random + local token = "" + local chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + -- Seed the random generator with current time + math.randomseed(os.time()) + + -- Generate random string + for i = 1, length do + local idx = math.random(1, #chars) + token = token .. chars:sub(idx, idx) + end + + -- Store in session + session.set(csrf.TOKEN_KEY, token) + + return token + end, + + -- Get the current token or generate a new one + token = function() + -- Get from session if exists + local token = session.get(csrf.TOKEN_KEY) + + -- Generate if needed + if not token then + token = csrf.generate() + end + + return token + end, + + -- Generate a hidden form field with the CSRF token + field = function(field_name) + field_name = field_name or csrf.DEFAULT_FIELD + local token = csrf.token() + return string.format('', field_name, token) + end, + + -- Verify a given token against the session token + verify = function(token, field_name) + field_name = field_name or csrf.DEFAULT_FIELD + + local env = getfenv(2) + + local form = nil + if env.ctx and env.ctx.form then + form = env.ctx.form + else + return false + end + + token = token or form[field_name] + if not token then + return false + end + + local session_token = session.get(csrf.TOKEN_KEY) + if not session_token then + return false + end + + if #token ~= #session_token then + return false + end + + local result = true + for i = 1, #token do + if token:sub(i, i) ~= session_token:sub(i, i) then + result = false + -- Don't break early - continue to prevent timing attacks + end + end + + return result + end +} + +-- Install CSRF module +_G.csrf = csrf + +-- Make sure the CSRF module is accessible in sandbox +if __env_system and __env_system.base_env then + __env_system.base_env.csrf = csrf +end +` + +// CSRFModuleInitFunc returns an initializer for the CSRF module +func CSRFModuleInitFunc() StateInitFunc { + return func(state *luajit.State) error { + return state.DoString(LuaCSRFModule) + } +} + +// ValidateCSRFToken checks if the CSRF token is valid for a request +func ValidateCSRFToken(state *luajit.State, ctx *Context) bool { + // Only validate for form submissions + method, ok := ctx.Get("method").(string) + if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") { + return true + } + + // Get form data + formData, ok := ctx.Get("form").(map[string]any) + if !ok || formData == nil { + logger.Warning("CSRF validation failed: no form data") + return false + } + + // Get token from form + formToken, ok := formData["csrf"].(string) + if !ok || formToken == "" { + logger.Warning("CSRF validation failed: no token in form") + return false + } + + // Get session token + state.GetGlobal("session") + if state.IsNil(-1) { + state.Pop(1) + logger.Warning("CSRF validation failed: session module not available") + return false + } + + state.GetField(-1, "get") + if !state.IsFunction(-1) { + state.Pop(2) + logger.Warning("CSRF validation failed: session.get not available") + return false + } + + state.PushCopy(-1) // Duplicate function + state.PushString("_csrf_token") + + if err := state.Call(1, 1); err != nil { + state.Pop(3) // Pop error, function and session table + logger.Warning("CSRF validation failed: %v", err) + return false + } + + if state.IsNil(-1) { + state.Pop(3) // Pop nil, function and session table + logger.Warning("CSRF validation failed: no token in session") + return false + } + + sessionToken := state.ToString(-1) + state.Pop(3) // Pop token, function and session table + + // Constant-time comparison to prevent timing attacks + return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1 +} + +// WithCSRFProtection creates a runner option to add CSRF protection +func WithCSRFProtection() RunnerOption { + return func(r *LuaRunner) { + r.AddInitHook(func(state *luajit.State, ctx *Context) error { + // Get request method + method, ok := ctx.Get("method").(string) + if !ok { + return nil + } + + // Only validate for form submissions + if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" { + return nil + } + + // Check for form data + form, ok := ctx.Get("form").(map[string]any) + if !ok || form == nil { + return nil + } + + // Validate CSRF token + if !ValidateCSRFToken(state, ctx) { + return ErrCSRFValidationFailed + } + + return nil + }) + } +} + +// Error for CSRF validation failure +var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"} + +// CSRFError represents a CSRF validation error +type CSRFError struct { + message string +} + +// Error implements the error interface +func (e *CSRFError) Error() string { + return e.message +} diff --git a/core/utils/ErrorPages.go b/core/utils/ErrorPages.go index dafe216..956cd1c 100644 --- a/core/utils/ErrorPages.go +++ b/core/utils/ErrorPages.go @@ -19,6 +19,7 @@ const ( ErrorTypeNotFound ErrorType = 404 ErrorTypeMethodNotAllowed ErrorType = 405 ErrorTypeInternalError ErrorType = 500 + ErrorTypeForbidden ErrorType = 403 // Added CSRF/Forbidden error type ) // ErrorPage generates an HTML error page based on the error type @@ -34,6 +35,8 @@ func ErrorPage(config ErrorPageConfig, errorType ErrorType, url string, errMsg s filename = "405.html" case ErrorTypeInternalError: filename = "500.html" + case ErrorTypeForbidden: + filename = "403.html" } if filename != "" { @@ -52,6 +55,8 @@ func ErrorPage(config ErrorPageConfig, errorType ErrorType, url string, errMsg s return generateMethodNotAllowedHTML(url) case ErrorTypeInternalError: return generateInternalErrorHTML(config.DebugMode, url, errMsg) + case ErrorTypeForbidden: + return generateForbiddenHTML(config.DebugMode, url, errMsg) default: // Fallback to internal error return generateInternalErrorHTML(config.DebugMode, url, errMsg) @@ -73,6 +78,11 @@ func InternalErrorPage(config ErrorPageConfig, url string, errMsg string) string return ErrorPage(config, ErrorTypeInternalError, url, errMsg) } +// ForbiddenPage generates a 403 Forbidden error page +func ForbiddenPage(config ErrorPageConfig, url string, errMsg string) string { + return ErrorPage(config, ErrorTypeForbidden, url, errMsg) +} + // generateInternalErrorHTML creates a 500 Internal Server Error page func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string { errorMessages := []string{ @@ -92,6 +102,30 @@ func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string return generateErrorHTML("500", randomMessage, "Internal Server Error", debugMode, errMsg) } +// generateForbiddenHTML creates a 403 Forbidden error page +func generateForbiddenHTML(debugMode bool, url string, errMsg string) string { + errorMessages := []string{ + "Access denied", + "You shall not pass", + "This area is off-limits", + "Security check failed", + "Invalid security token", + "Request blocked for security reasons", + "Permission denied", + "Security violation detected", + "This request was rejected", + "Security first, access second", + } + + defaultMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt." + if errMsg == "" { + errMsg = defaultMsg + } + + randomMessage := errorMessages[rand.Intn(len(errorMessages))] + return generateErrorHTML("403", randomMessage, "Forbidden", debugMode, errMsg) +} + // generateNotFoundHTML creates a 404 Not Found error page func generateNotFoundHTML(url string) string { errorMessages := []string{