Moonshark/core/runner/Csrf.go
2025-04-04 11:22:13 -05:00

224 lines
5.3 KiB
Go

package runner
import (
"crypto/subtle"
"Moonshark/core/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// LuaCSRFModule provides CSRF protection functionality to Lua scripts
const LuaCSRFModule = `
-- CSRF protection module
local csrf = {
-- Session key where the token is stored
TOKEN_KEY = "_csrf_token",
-- Default form field name
DEFAULT_FIELD = "csrf",
-- Generate a new CSRF token and store it in the session
generate = function(length)
-- Default length is 32 characters
length = length or 32
if length < 16 then
-- Enforce minimum security
length = 16
end
-- Check if we have a session module
if not session then
error("CSRF protection requires the session module", 2)
end
-- Use Go's secure token generation
local token = go.generate_token(length)
-- Store in session
session.set(csrf.TOKEN_KEY, token)
return token
end,
-- Get the current token or generate a new one
token = function()
-- Get from session if exists
local token = session.get(csrf.TOKEN_KEY)
-- Generate if needed
if not token then
token = csrf.generate()
end
return token
end,
-- Generate a hidden form field with the CSRF token
field = function(field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local token = csrf.token()
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
end,
-- Verify a given token against the session token
verify = function(token, field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local env = getfenv(2)
local form = nil
if env.ctx and env.ctx.form then
form = env.ctx.form
else
return false
end
token = token or form[field_name]
if not token then
return false
end
local session_token = session.get(csrf.TOKEN_KEY)
if not session_token then
return false
end
-- Constant-time comparison to prevent timing attacks
-- This is safe since Lua strings are immutable
if #token ~= #session_token then
return false
end
local result = true
for i = 1, #token do
if token:sub(i, i) ~= session_token:sub(i, i) then
result = false
-- Don't break early - continue to prevent timing attacks
end
end
return result
end
}
-- Install CSRF module
_G.csrf = csrf
-- Make sure the CSRF module is accessible in sandbox
if __env_system and __env_system.base_env then
__env_system.base_env.csrf = csrf
end
`
// CSRFModuleInitFunc returns an initializer for the CSRF module
func CSRFModuleInitFunc() StateInitFunc {
return func(state *luajit.State) error {
return state.DoString(LuaCSRFModule)
}
}
// ValidateCSRFToken checks if the CSRF token is valid for a request
func ValidateCSRFToken(state *luajit.State, ctx *Context) bool {
// Only validate for form submissions
method, ok := ctx.Get("method").(string)
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
return true
}
// Get form data
formData, ok := ctx.Get("form").(map[string]any)
if !ok || formData == nil {
logger.Warning("CSRF validation failed: no form data")
return false
}
// Get token from form
formToken, ok := formData["csrf"].(string)
if !ok || formToken == "" {
logger.Warning("CSRF validation failed: no token in form")
return false
}
// Get session token
state.GetGlobal("session")
if state.IsNil(-1) {
state.Pop(1)
logger.Warning("CSRF validation failed: session module not available")
return false
}
state.GetField(-1, "get")
if !state.IsFunction(-1) {
state.Pop(2)
logger.Warning("CSRF validation failed: session.get not available")
return false
}
state.PushCopy(-1) // Duplicate function
state.PushString("_csrf_token")
if err := state.Call(1, 1); err != nil {
state.Pop(3) // Pop error, function and session table
logger.Warning("CSRF validation failed: %v", err)
return false
}
if state.IsNil(-1) {
state.Pop(3) // Pop nil, function and session table
logger.Warning("CSRF validation failed: no token in session")
return false
}
sessionToken := state.ToString(-1)
state.Pop(3) // Pop token, function and session table
// Constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
}
// WithCSRFProtection creates a runner option to add CSRF protection
func WithCSRFProtection() RunnerOption {
return func(r *Runner) {
r.AddInitHook(func(state *luajit.State, ctx *Context) error {
// Get request method
method, ok := ctx.Get("method").(string)
if !ok {
return nil
}
// Only validate for form submissions
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
return nil
}
// Check for form data
form, ok := ctx.Get("form").(map[string]any)
if !ok || form == nil {
return nil
}
// Validate CSRF token
if !ValidateCSRFToken(state, ctx) {
return ErrCSRFValidationFailed
}
return nil
})
}
}
// Error for CSRF validation failure
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
// CSRFError represents a CSRF validation error
type CSRFError struct {
message string
}
// Error implements the error interface
func (e *CSRFError) Error() string {
return e.message
}