Moonshark/core/http/Csrf.go
2025-04-09 16:19:51 -05:00

132 lines
3.4 KiB
Go

package http
import (
"Moonshark/core/runner"
luaCtx "Moonshark/core/runner/context"
"Moonshark/core/utils"
"Moonshark/core/utils/logger"
"crypto/subtle"
"github.com/valyala/fasthttp"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// ValidateCSRFToken checks if the CSRF token is valid for a request
func ValidateCSRFToken(state *luajit.State, ctx *luaCtx.Context) bool {
// Only validate for form submissions
method, ok := ctx.Get("method").(string)
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
return true
}
// Get form data
formData, ok := ctx.Get("form").(map[string]any)
if !ok || formData == nil {
logger.Warning("CSRF validation failed: no form data")
return false
}
// Get token from form
formToken, ok := formData["csrf"].(string)
if !ok || formToken == "" {
logger.Warning("CSRF validation failed: no token in form")
return false
}
// Get session token
state.GetGlobal("session")
if state.IsNil(-1) {
state.Pop(1)
logger.Warning("CSRF validation failed: session module not available")
return false
}
state.GetField(-1, "get")
if !state.IsFunction(-1) {
state.Pop(2)
logger.Warning("CSRF validation failed: session.get not available")
return false
}
state.PushCopy(-1) // Duplicate function
state.PushString("_csrf_token")
if err := state.Call(1, 1); err != nil {
state.Pop(3) // Pop error, function and session table
logger.Warning("CSRF validation failed: %v", err)
return false
}
if state.IsNil(-1) {
state.Pop(3) // Pop nil, function and session table
logger.Warning("CSRF validation failed: no token in session")
return false
}
sessionToken := state.ToString(-1)
state.Pop(3) // Pop token, function and session table
// Constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
}
// WithCSRFProtection creates a runner option to add CSRF protection
func WithCSRFProtection() runner.RunnerOption {
return func(r *runner.Runner) {
r.AddInitHook(func(state *luajit.State, ctx *luaCtx.Context) error {
// Get request method
method, ok := ctx.Get("method").(string)
if !ok {
return nil
}
// Only validate for form submissions
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
return nil
}
// Check for form data
form, ok := ctx.Get("form").(map[string]any)
if !ok || form == nil {
return nil
}
// Validate CSRF token
if !ValidateCSRFToken(state, ctx) {
return ErrCSRFValidationFailed
}
return nil
})
}
}
// Error for CSRF validation failure
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
// CSRFError represents a CSRF validation error
type CSRFError struct {
message string
}
// Error implements the error interface
func (e *CSRFError) Error() string {
return e.message
}
// HandleCSRFError handles a CSRF validation error
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
method := string(ctx.Method())
path := string(ctx.Path())
logger.Warning("CSRF validation failed for %s %s", method, path)
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusForbidden)
errorMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg)
ctx.SetBody([]byte(errorHTML))
}