package runner import ( "crypto/subtle" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" "git.sharkk.net/Sky/Moonshark/core/logger" ) // LuaCSRFModule provides CSRF protection functionality to Lua scripts const LuaCSRFModule = ` -- CSRF protection module local csrf = { -- Session key where the token is stored TOKEN_KEY = "_csrf_token", -- Default form field name DEFAULT_FIELD = "csrf", -- Generate a new CSRF token and store it in the session generate = function(length) -- Default length is 32 characters length = length or 32 if length < 16 then -- Enforce minimum security length = 16 end -- Check if we have a session module if not session then error("CSRF protection requires the session module", 2) end -- Use Go's secure token generation local token = go.generate_token(length) -- Store in session session.set(csrf.TOKEN_KEY, token) return token end, -- Get the current token or generate a new one token = function() -- Get from session if exists local token = session.get(csrf.TOKEN_KEY) -- Generate if needed if not token then token = csrf.generate() end return token end, -- Generate a hidden form field with the CSRF token field = function(field_name) field_name = field_name or csrf.DEFAULT_FIELD local token = csrf.token() return string.format('', field_name, token) end, -- Verify a given token against the session token verify = function(token, field_name) field_name = field_name or csrf.DEFAULT_FIELD local env = getfenv(2) local form = nil if env.ctx and env.ctx.form then form = env.ctx.form else return false end token = token or form[field_name] if not token then return false end local session_token = session.get(csrf.TOKEN_KEY) if not session_token then return false end -- Constant-time comparison to prevent timing attacks -- This is safe since Lua strings are immutable if #token ~= #session_token then return false end local result = true for i = 1, #token do if token:sub(i, i) ~= session_token:sub(i, i) then result = false -- Don't break early - continue to prevent timing attacks end end return result end } -- Install CSRF module _G.csrf = csrf -- Make sure the CSRF module is accessible in sandbox if __env_system and __env_system.base_env then __env_system.base_env.csrf = csrf end ` // CSRFModuleInitFunc returns an initializer for the CSRF module func CSRFModuleInitFunc() StateInitFunc { return func(state *luajit.State) error { return state.DoString(LuaCSRFModule) } } // ValidateCSRFToken checks if the CSRF token is valid for a request func ValidateCSRFToken(state *luajit.State, ctx *Context) bool { // Only validate for form submissions method, ok := ctx.Get("method").(string) if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") { return true } // Get form data formData, ok := ctx.Get("form").(map[string]any) if !ok || formData == nil { logger.Warning("CSRF validation failed: no form data") return false } // Get token from form formToken, ok := formData["csrf"].(string) if !ok || formToken == "" { logger.Warning("CSRF validation failed: no token in form") return false } // Get session token state.GetGlobal("session") if state.IsNil(-1) { state.Pop(1) logger.Warning("CSRF validation failed: session module not available") return false } state.GetField(-1, "get") if !state.IsFunction(-1) { state.Pop(2) logger.Warning("CSRF validation failed: session.get not available") return false } state.PushCopy(-1) // Duplicate function state.PushString("_csrf_token") if err := state.Call(1, 1); err != nil { state.Pop(3) // Pop error, function and session table logger.Warning("CSRF validation failed: %v", err) return false } if state.IsNil(-1) { state.Pop(3) // Pop nil, function and session table logger.Warning("CSRF validation failed: no token in session") return false } sessionToken := state.ToString(-1) state.Pop(3) // Pop token, function and session table // Constant-time comparison to prevent timing attacks return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1 } // WithCSRFProtection creates a runner option to add CSRF protection func WithCSRFProtection() RunnerOption { return func(r *LuaRunner) { r.AddInitHook(func(state *luajit.State, ctx *Context) error { // Get request method method, ok := ctx.Get("method").(string) if !ok { return nil } // Only validate for form submissions if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" { return nil } // Check for form data form, ok := ctx.Get("form").(map[string]any) if !ok || form == nil { return nil } // Validate CSRF token if !ValidateCSRFToken(state, ctx) { return ErrCSRFValidationFailed } return nil }) } } // Error for CSRF validation failure var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"} // CSRFError represents a CSRF validation error type CSRFError struct { message string } // Error implements the error interface func (e *CSRFError) Error() string { return e.message }