This commit is contained in:
Sky Johnson 2025-04-02 22:22:03 -05:00
parent 472d175093
commit eea5ba8c8a
7 changed files with 305 additions and 22 deletions

View File

@ -185,6 +185,7 @@ func (s *Moonshark) initRunner() error {
runner.WithPoolSize(s.Config.PoolSize), runner.WithPoolSize(s.Config.PoolSize),
runner.WithLibDirs(s.Config.LibDirs...), runner.WithLibDirs(s.Config.LibDirs...),
runner.WithSessionManager(sessionManager), runner.WithSessionManager(sessionManager),
runner.WithCSRFProtection(),
} }
// Add debug option conditionally // Add debug option conditionally

20
core/http/Csrf.go Normal file
View File

@ -0,0 +1,20 @@
package http
import (
"net/http"
"git.sharkk.net/Sky/Moonshark/core/logger"
"git.sharkk.net/Sky/Moonshark/core/utils"
)
// HandleCSRFError handles a CSRF validation error
func HandleCSRFError(w http.ResponseWriter, r *http.Request, errorConfig utils.ErrorPageConfig) {
logger.Warning("CSRF validation failed for %s %s", r.Method, r.URL.Path)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusForbidden)
errorMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
errorHTML := utils.ForbiddenPage(errorConfig, r.URL.Path, errorMsg)
w.Write([]byte(errorHTML))
}

View File

@ -162,24 +162,17 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
ctx := runner.NewContext() ctx := runner.NewContext()
defer ctx.Release() defer ctx.Release()
// Log bytecode size // Set up context exactly as the original
logger.Debug("Executing Lua route with %d bytes of bytecode", len(bytecode))
// Extract cookies instead of storing the raw request
cookieMap := make(map[string]any) cookieMap := make(map[string]any)
for _, cookie := range r.Cookies() { for _, cookie := range r.Cookies() {
cookieMap[cookie.Name] = cookie.Value cookieMap[cookie.Name] = cookie.Value
} }
// Store cookie map instead of raw request
ctx.Set("_request_cookies", cookieMap) ctx.Set("_request_cookies", cookieMap)
// Add request info directly to context
ctx.Set("method", r.Method) ctx.Set("method", r.Method)
ctx.Set("path", r.URL.Path) ctx.Set("path", r.URL.Path)
ctx.Set("host", r.Host) ctx.Set("host", r.Host)
// Add headers to context // Headers
headerMap := make(map[string]any, len(r.Header)) headerMap := make(map[string]any, len(r.Header))
for name, values := range r.Header { for name, values := range r.Header {
if len(values) == 1 { if len(values) == 1 {
@ -190,7 +183,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
} }
ctx.Set("headers", headerMap) ctx.Set("headers", headerMap)
// Add cookies to context // Cookies
if cookies := r.Cookies(); len(cookies) > 0 { if cookies := r.Cookies(); len(cookies) > 0 {
cookieMap := make(map[string]any, len(cookies)) cookieMap := make(map[string]any, len(cookies))
for _, cookie := range cookies { for _, cookie := range cookies {
@ -199,7 +192,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
ctx.Set("cookies", cookieMap) ctx.Set("cookies", cookieMap)
} }
// Add URL parameters // URL parameters
if params.Count > 0 { if params.Count > 0 {
paramMap := make(map[string]any, params.Count) paramMap := make(map[string]any, params.Count)
for i, key := range params.Keys { for i, key := range params.Keys {
@ -208,7 +201,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
ctx.Set("params", paramMap) ctx.Set("params", paramMap)
} }
// Parse query parameters only if present // Query parameters
queryMap := QueryToLua(r) queryMap := QueryToLua(r)
if queryMap == nil { if queryMap == nil {
ctx.Set("query", make(map[string]any)) ctx.Set("query", make(map[string]any))
@ -216,7 +209,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
ctx.Set("query", queryMap) ctx.Set("query", queryMap)
} }
// Add form data for POST/PUT/PATCH only when needed // Form data
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch { if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
if formData, err := ParseForm(r); err == nil && len(formData) > 0 { if formData, err := ParseForm(r); err == nil && len(formData) > 0 {
ctx.Set("form", formData) ctx.Set("form", formData)
@ -225,16 +218,20 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
// Execute Lua script // Execute Lua script
result, err := s.luaRunner.Run(bytecode, ctx, scriptPath) result, err := s.luaRunner.Run(bytecode, ctx, scriptPath)
if err != nil {
logger.Error("Error executing Lua route: %v", err)
// Set content type to HTML // Special handling for CSRF error
if err != nil {
if csrfErr, ok := err.(*runner.CSRFError); ok {
logger.Warning("CSRF error executing Lua route: %v", csrfErr)
HandleCSRFError(w, r, s.errorConfig)
return
}
// Normal error handling
logger.Error("Error executing Lua route: %v", err)
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
errorHTML := utils.InternalErrorPage(s.errorConfig, r.URL.Path, err.Error())
// Generate error page with error message
errorMsg := err.Error()
errorHTML := utils.InternalErrorPage(s.errorConfig, r.URL.Path, errorMsg)
w.Write([]byte(errorHTML)) w.Write([]byte(errorHTML))
return return
} }
@ -321,7 +318,7 @@ func setContentTypeIfMissing(w http.ResponseWriter, contentType string) {
} }
// handleDebugStats displays debug statistics // handleDebugStats displays debug statistics
func (s *Server) handleDebugStats(w http.ResponseWriter, r *http.Request) { func (s *Server) handleDebugStats(w http.ResponseWriter, _ *http.Request) {
// Collect system stats // Collect system stats
stats := utils.CollectSystemStats(s.config) stats := utils.CollectSystemStats(s.config)

View File

@ -148,7 +148,7 @@ local cookie = {
end end
-- Access values directly from current environment -- Access values directly from current environment
local env = getfenv(1) local env = getfenv(2)
-- Check if context exists and has cookies -- Check if context exists and has cookies
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then

View File

@ -132,6 +132,7 @@ func init() {
GlobalRegistry.EnableDebug() // Enable debugging by default GlobalRegistry.EnableDebug() // Enable debugging by default
GlobalRegistry.Register("http", HTTPModuleInitFunc()) GlobalRegistry.Register("http", HTTPModuleInitFunc())
GlobalRegistry.Register("cookie", CookieModuleInitFunc()) GlobalRegistry.Register("cookie", CookieModuleInitFunc())
GlobalRegistry.Register("csrf", CSRFModuleInitFunc())
logger.Debug("[CoreModuleRegistry] Core modules registered in init()") logger.Debug("[CoreModuleRegistry] Core modules registered in init()")
} }

230
core/runner/Csrf.go Normal file
View File

@ -0,0 +1,230 @@
package runner
import (
"crypto/subtle"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"git.sharkk.net/Sky/Moonshark/core/logger"
)
// LuaCSRFModule provides CSRF protection functionality to Lua scripts
const LuaCSRFModule = `
-- CSRF protection module
local csrf = {
-- Session key where the token is stored
TOKEN_KEY = "_csrf_token",
-- Default form field name
DEFAULT_FIELD = "csrf",
-- Generate a new CSRF token and store it in the session
generate = function(length)
-- Default length is 32 characters
length = length or 32
if length < 16 then
-- Enforce minimum security
length = 16
end
-- Check if we have a session module
if not session then
error("CSRF protection requires the session module", 2)
end
-- Generate a secure random token using os.time and math.random
local token = ""
local chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
-- Seed the random generator with current time
math.randomseed(os.time())
-- Generate random string
for i = 1, length do
local idx = math.random(1, #chars)
token = token .. chars:sub(idx, idx)
end
-- Store in session
session.set(csrf.TOKEN_KEY, token)
return token
end,
-- Get the current token or generate a new one
token = function()
-- Get from session if exists
local token = session.get(csrf.TOKEN_KEY)
-- Generate if needed
if not token then
token = csrf.generate()
end
return token
end,
-- Generate a hidden form field with the CSRF token
field = function(field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local token = csrf.token()
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
end,
-- Verify a given token against the session token
verify = function(token, field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local env = getfenv(2)
local form = nil
if env.ctx and env.ctx.form then
form = env.ctx.form
else
return false
end
token = token or form[field_name]
if not token then
return false
end
local session_token = session.get(csrf.TOKEN_KEY)
if not session_token then
return false
end
if #token ~= #session_token then
return false
end
local result = true
for i = 1, #token do
if token:sub(i, i) ~= session_token:sub(i, i) then
result = false
-- Don't break early - continue to prevent timing attacks
end
end
return result
end
}
-- Install CSRF module
_G.csrf = csrf
-- Make sure the CSRF module is accessible in sandbox
if __env_system and __env_system.base_env then
__env_system.base_env.csrf = csrf
end
`
// CSRFModuleInitFunc returns an initializer for the CSRF module
func CSRFModuleInitFunc() StateInitFunc {
return func(state *luajit.State) error {
return state.DoString(LuaCSRFModule)
}
}
// ValidateCSRFToken checks if the CSRF token is valid for a request
func ValidateCSRFToken(state *luajit.State, ctx *Context) bool {
// Only validate for form submissions
method, ok := ctx.Get("method").(string)
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
return true
}
// Get form data
formData, ok := ctx.Get("form").(map[string]any)
if !ok || formData == nil {
logger.Warning("CSRF validation failed: no form data")
return false
}
// Get token from form
formToken, ok := formData["csrf"].(string)
if !ok || formToken == "" {
logger.Warning("CSRF validation failed: no token in form")
return false
}
// Get session token
state.GetGlobal("session")
if state.IsNil(-1) {
state.Pop(1)
logger.Warning("CSRF validation failed: session module not available")
return false
}
state.GetField(-1, "get")
if !state.IsFunction(-1) {
state.Pop(2)
logger.Warning("CSRF validation failed: session.get not available")
return false
}
state.PushCopy(-1) // Duplicate function
state.PushString("_csrf_token")
if err := state.Call(1, 1); err != nil {
state.Pop(3) // Pop error, function and session table
logger.Warning("CSRF validation failed: %v", err)
return false
}
if state.IsNil(-1) {
state.Pop(3) // Pop nil, function and session table
logger.Warning("CSRF validation failed: no token in session")
return false
}
sessionToken := state.ToString(-1)
state.Pop(3) // Pop token, function and session table
// Constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
}
// WithCSRFProtection creates a runner option to add CSRF protection
func WithCSRFProtection() RunnerOption {
return func(r *LuaRunner) {
r.AddInitHook(func(state *luajit.State, ctx *Context) error {
// Get request method
method, ok := ctx.Get("method").(string)
if !ok {
return nil
}
// Only validate for form submissions
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
return nil
}
// Check for form data
form, ok := ctx.Get("form").(map[string]any)
if !ok || form == nil {
return nil
}
// Validate CSRF token
if !ValidateCSRFToken(state, ctx) {
return ErrCSRFValidationFailed
}
return nil
})
}
}
// Error for CSRF validation failure
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
// CSRFError represents a CSRF validation error
type CSRFError struct {
message string
}
// Error implements the error interface
func (e *CSRFError) Error() string {
return e.message
}

View File

@ -19,6 +19,7 @@ const (
ErrorTypeNotFound ErrorType = 404 ErrorTypeNotFound ErrorType = 404
ErrorTypeMethodNotAllowed ErrorType = 405 ErrorTypeMethodNotAllowed ErrorType = 405
ErrorTypeInternalError ErrorType = 500 ErrorTypeInternalError ErrorType = 500
ErrorTypeForbidden ErrorType = 403 // Added CSRF/Forbidden error type
) )
// ErrorPage generates an HTML error page based on the error type // ErrorPage generates an HTML error page based on the error type
@ -34,6 +35,8 @@ func ErrorPage(config ErrorPageConfig, errorType ErrorType, url string, errMsg s
filename = "405.html" filename = "405.html"
case ErrorTypeInternalError: case ErrorTypeInternalError:
filename = "500.html" filename = "500.html"
case ErrorTypeForbidden:
filename = "403.html"
} }
if filename != "" { if filename != "" {
@ -52,6 +55,8 @@ func ErrorPage(config ErrorPageConfig, errorType ErrorType, url string, errMsg s
return generateMethodNotAllowedHTML(url) return generateMethodNotAllowedHTML(url)
case ErrorTypeInternalError: case ErrorTypeInternalError:
return generateInternalErrorHTML(config.DebugMode, url, errMsg) return generateInternalErrorHTML(config.DebugMode, url, errMsg)
case ErrorTypeForbidden:
return generateForbiddenHTML(config.DebugMode, url, errMsg)
default: default:
// Fallback to internal error // Fallback to internal error
return generateInternalErrorHTML(config.DebugMode, url, errMsg) return generateInternalErrorHTML(config.DebugMode, url, errMsg)
@ -73,6 +78,11 @@ func InternalErrorPage(config ErrorPageConfig, url string, errMsg string) string
return ErrorPage(config, ErrorTypeInternalError, url, errMsg) return ErrorPage(config, ErrorTypeInternalError, url, errMsg)
} }
// ForbiddenPage generates a 403 Forbidden error page
func ForbiddenPage(config ErrorPageConfig, url string, errMsg string) string {
return ErrorPage(config, ErrorTypeForbidden, url, errMsg)
}
// generateInternalErrorHTML creates a 500 Internal Server Error page // generateInternalErrorHTML creates a 500 Internal Server Error page
func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string { func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string {
errorMessages := []string{ errorMessages := []string{
@ -92,6 +102,30 @@ func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string
return generateErrorHTML("500", randomMessage, "Internal Server Error", debugMode, errMsg) return generateErrorHTML("500", randomMessage, "Internal Server Error", debugMode, errMsg)
} }
// generateForbiddenHTML creates a 403 Forbidden error page
func generateForbiddenHTML(debugMode bool, url string, errMsg string) string {
errorMessages := []string{
"Access denied",
"You shall not pass",
"This area is off-limits",
"Security check failed",
"Invalid security token",
"Request blocked for security reasons",
"Permission denied",
"Security violation detected",
"This request was rejected",
"Security first, access second",
}
defaultMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
if errMsg == "" {
errMsg = defaultMsg
}
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
return generateErrorHTML("403", randomMessage, "Forbidden", debugMode, errMsg)
}
// generateNotFoundHTML creates a 404 Not Found error page // generateNotFoundHTML creates a 404 Not Found error page
func generateNotFoundHTML(url string) string { func generateNotFoundHTML(url string) string {
errorMessages := []string{ errorMessages := []string{