csrf 1
This commit is contained in:
parent
472d175093
commit
eea5ba8c8a
|
@ -185,6 +185,7 @@ func (s *Moonshark) initRunner() error {
|
|||
runner.WithPoolSize(s.Config.PoolSize),
|
||||
runner.WithLibDirs(s.Config.LibDirs...),
|
||||
runner.WithSessionManager(sessionManager),
|
||||
runner.WithCSRFProtection(),
|
||||
}
|
||||
|
||||
// Add debug option conditionally
|
||||
|
|
20
core/http/Csrf.go
Normal file
20
core/http/Csrf.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
"git.sharkk.net/Sky/Moonshark/core/utils"
|
||||
)
|
||||
|
||||
// HandleCSRFError handles a CSRF validation error
|
||||
func HandleCSRFError(w http.ResponseWriter, r *http.Request, errorConfig utils.ErrorPageConfig) {
|
||||
logger.Warning("CSRF validation failed for %s %s", r.Method, r.URL.Path)
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
|
||||
errorMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
|
||||
errorHTML := utils.ForbiddenPage(errorConfig, r.URL.Path, errorMsg)
|
||||
w.Write([]byte(errorHTML))
|
||||
}
|
|
@ -162,24 +162,17 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
|||
ctx := runner.NewContext()
|
||||
defer ctx.Release()
|
||||
|
||||
// Log bytecode size
|
||||
logger.Debug("Executing Lua route with %d bytes of bytecode", len(bytecode))
|
||||
|
||||
// Extract cookies instead of storing the raw request
|
||||
// Set up context exactly as the original
|
||||
cookieMap := make(map[string]any)
|
||||
for _, cookie := range r.Cookies() {
|
||||
cookieMap[cookie.Name] = cookie.Value
|
||||
}
|
||||
|
||||
// Store cookie map instead of raw request
|
||||
ctx.Set("_request_cookies", cookieMap)
|
||||
|
||||
// Add request info directly to context
|
||||
ctx.Set("method", r.Method)
|
||||
ctx.Set("path", r.URL.Path)
|
||||
ctx.Set("host", r.Host)
|
||||
|
||||
// Add headers to context
|
||||
// Headers
|
||||
headerMap := make(map[string]any, len(r.Header))
|
||||
for name, values := range r.Header {
|
||||
if len(values) == 1 {
|
||||
|
@ -190,7 +183,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
|||
}
|
||||
ctx.Set("headers", headerMap)
|
||||
|
||||
// Add cookies to context
|
||||
// Cookies
|
||||
if cookies := r.Cookies(); len(cookies) > 0 {
|
||||
cookieMap := make(map[string]any, len(cookies))
|
||||
for _, cookie := range cookies {
|
||||
|
@ -199,7 +192,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
|||
ctx.Set("cookies", cookieMap)
|
||||
}
|
||||
|
||||
// Add URL parameters
|
||||
// URL parameters
|
||||
if params.Count > 0 {
|
||||
paramMap := make(map[string]any, params.Count)
|
||||
for i, key := range params.Keys {
|
||||
|
@ -208,7 +201,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
|||
ctx.Set("params", paramMap)
|
||||
}
|
||||
|
||||
// Parse query parameters only if present
|
||||
// Query parameters
|
||||
queryMap := QueryToLua(r)
|
||||
if queryMap == nil {
|
||||
ctx.Set("query", make(map[string]any))
|
||||
|
@ -216,7 +209,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
|||
ctx.Set("query", queryMap)
|
||||
}
|
||||
|
||||
// Add form data for POST/PUT/PATCH only when needed
|
||||
// Form data
|
||||
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
|
||||
if formData, err := ParseForm(r); err == nil && len(formData) > 0 {
|
||||
ctx.Set("form", formData)
|
||||
|
@ -225,16 +218,20 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
|||
|
||||
// Execute Lua script
|
||||
result, err := s.luaRunner.Run(bytecode, ctx, scriptPath)
|
||||
if err != nil {
|
||||
logger.Error("Error executing Lua route: %v", err)
|
||||
|
||||
// Set content type to HTML
|
||||
// Special handling for CSRF error
|
||||
if err != nil {
|
||||
if csrfErr, ok := err.(*runner.CSRFError); ok {
|
||||
logger.Warning("CSRF error executing Lua route: %v", csrfErr)
|
||||
HandleCSRFError(w, r, s.errorConfig)
|
||||
return
|
||||
}
|
||||
|
||||
// Normal error handling
|
||||
logger.Error("Error executing Lua route: %v", err)
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
// Generate error page with error message
|
||||
errorMsg := err.Error()
|
||||
errorHTML := utils.InternalErrorPage(s.errorConfig, r.URL.Path, errorMsg)
|
||||
errorHTML := utils.InternalErrorPage(s.errorConfig, r.URL.Path, err.Error())
|
||||
w.Write([]byte(errorHTML))
|
||||
return
|
||||
}
|
||||
|
@ -321,7 +318,7 @@ func setContentTypeIfMissing(w http.ResponseWriter, contentType string) {
|
|||
}
|
||||
|
||||
// handleDebugStats displays debug statistics
|
||||
func (s *Server) handleDebugStats(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleDebugStats(w http.ResponseWriter, _ *http.Request) {
|
||||
// Collect system stats
|
||||
stats := utils.CollectSystemStats(s.config)
|
||||
|
||||
|
|
|
@ -148,7 +148,7 @@ local cookie = {
|
|||
end
|
||||
|
||||
-- Access values directly from current environment
|
||||
local env = getfenv(1)
|
||||
local env = getfenv(2)
|
||||
|
||||
-- Check if context exists and has cookies
|
||||
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
|
||||
|
|
|
@ -132,6 +132,7 @@ func init() {
|
|||
GlobalRegistry.EnableDebug() // Enable debugging by default
|
||||
GlobalRegistry.Register("http", HTTPModuleInitFunc())
|
||||
GlobalRegistry.Register("cookie", CookieModuleInitFunc())
|
||||
GlobalRegistry.Register("csrf", CSRFModuleInitFunc())
|
||||
logger.Debug("[CoreModuleRegistry] Core modules registered in init()")
|
||||
}
|
||||
|
||||
|
|
230
core/runner/Csrf.go
Normal file
230
core/runner/Csrf.go
Normal file
|
@ -0,0 +1,230 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
)
|
||||
|
||||
// LuaCSRFModule provides CSRF protection functionality to Lua scripts
|
||||
const LuaCSRFModule = `
|
||||
-- CSRF protection module
|
||||
local csrf = {
|
||||
-- Session key where the token is stored
|
||||
TOKEN_KEY = "_csrf_token",
|
||||
|
||||
-- Default form field name
|
||||
DEFAULT_FIELD = "csrf",
|
||||
|
||||
-- Generate a new CSRF token and store it in the session
|
||||
generate = function(length)
|
||||
-- Default length is 32 characters
|
||||
length = length or 32
|
||||
|
||||
if length < 16 then
|
||||
-- Enforce minimum security
|
||||
length = 16
|
||||
end
|
||||
|
||||
-- Check if we have a session module
|
||||
if not session then
|
||||
error("CSRF protection requires the session module", 2)
|
||||
end
|
||||
|
||||
-- Generate a secure random token using os.time and math.random
|
||||
local token = ""
|
||||
local chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
-- Seed the random generator with current time
|
||||
math.randomseed(os.time())
|
||||
|
||||
-- Generate random string
|
||||
for i = 1, length do
|
||||
local idx = math.random(1, #chars)
|
||||
token = token .. chars:sub(idx, idx)
|
||||
end
|
||||
|
||||
-- Store in session
|
||||
session.set(csrf.TOKEN_KEY, token)
|
||||
|
||||
return token
|
||||
end,
|
||||
|
||||
-- Get the current token or generate a new one
|
||||
token = function()
|
||||
-- Get from session if exists
|
||||
local token = session.get(csrf.TOKEN_KEY)
|
||||
|
||||
-- Generate if needed
|
||||
if not token then
|
||||
token = csrf.generate()
|
||||
end
|
||||
|
||||
return token
|
||||
end,
|
||||
|
||||
-- Generate a hidden form field with the CSRF token
|
||||
field = function(field_name)
|
||||
field_name = field_name or csrf.DEFAULT_FIELD
|
||||
local token = csrf.token()
|
||||
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
|
||||
end,
|
||||
|
||||
-- Verify a given token against the session token
|
||||
verify = function(token, field_name)
|
||||
field_name = field_name or csrf.DEFAULT_FIELD
|
||||
|
||||
local env = getfenv(2)
|
||||
|
||||
local form = nil
|
||||
if env.ctx and env.ctx.form then
|
||||
form = env.ctx.form
|
||||
else
|
||||
return false
|
||||
end
|
||||
|
||||
token = token or form[field_name]
|
||||
if not token then
|
||||
return false
|
||||
end
|
||||
|
||||
local session_token = session.get(csrf.TOKEN_KEY)
|
||||
if not session_token then
|
||||
return false
|
||||
end
|
||||
|
||||
if #token ~= #session_token then
|
||||
return false
|
||||
end
|
||||
|
||||
local result = true
|
||||
for i = 1, #token do
|
||||
if token:sub(i, i) ~= session_token:sub(i, i) then
|
||||
result = false
|
||||
-- Don't break early - continue to prevent timing attacks
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
}
|
||||
|
||||
-- Install CSRF module
|
||||
_G.csrf = csrf
|
||||
|
||||
-- Make sure the CSRF module is accessible in sandbox
|
||||
if __env_system and __env_system.base_env then
|
||||
__env_system.base_env.csrf = csrf
|
||||
end
|
||||
`
|
||||
|
||||
// CSRFModuleInitFunc returns an initializer for the CSRF module
|
||||
func CSRFModuleInitFunc() StateInitFunc {
|
||||
return func(state *luajit.State) error {
|
||||
return state.DoString(LuaCSRFModule)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateCSRFToken checks if the CSRF token is valid for a request
|
||||
func ValidateCSRFToken(state *luajit.State, ctx *Context) bool {
|
||||
// Only validate for form submissions
|
||||
method, ok := ctx.Get("method").(string)
|
||||
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Get form data
|
||||
formData, ok := ctx.Get("form").(map[string]any)
|
||||
if !ok || formData == nil {
|
||||
logger.Warning("CSRF validation failed: no form data")
|
||||
return false
|
||||
}
|
||||
|
||||
// Get token from form
|
||||
formToken, ok := formData["csrf"].(string)
|
||||
if !ok || formToken == "" {
|
||||
logger.Warning("CSRF validation failed: no token in form")
|
||||
return false
|
||||
}
|
||||
|
||||
// Get session token
|
||||
state.GetGlobal("session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
logger.Warning("CSRF validation failed: session module not available")
|
||||
return false
|
||||
}
|
||||
|
||||
state.GetField(-1, "get")
|
||||
if !state.IsFunction(-1) {
|
||||
state.Pop(2)
|
||||
logger.Warning("CSRF validation failed: session.get not available")
|
||||
return false
|
||||
}
|
||||
|
||||
state.PushCopy(-1) // Duplicate function
|
||||
state.PushString("_csrf_token")
|
||||
|
||||
if err := state.Call(1, 1); err != nil {
|
||||
state.Pop(3) // Pop error, function and session table
|
||||
logger.Warning("CSRF validation failed: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(3) // Pop nil, function and session table
|
||||
logger.Warning("CSRF validation failed: no token in session")
|
||||
return false
|
||||
}
|
||||
|
||||
sessionToken := state.ToString(-1)
|
||||
state.Pop(3) // Pop token, function and session table
|
||||
|
||||
// Constant-time comparison to prevent timing attacks
|
||||
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
|
||||
}
|
||||
|
||||
// WithCSRFProtection creates a runner option to add CSRF protection
|
||||
func WithCSRFProtection() RunnerOption {
|
||||
return func(r *LuaRunner) {
|
||||
r.AddInitHook(func(state *luajit.State, ctx *Context) error {
|
||||
// Get request method
|
||||
method, ok := ctx.Get("method").(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only validate for form submissions
|
||||
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for form data
|
||||
form, ok := ctx.Get("form").(map[string]any)
|
||||
if !ok || form == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate CSRF token
|
||||
if !ValidateCSRFToken(state, ctx) {
|
||||
return ErrCSRFValidationFailed
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Error for CSRF validation failure
|
||||
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
|
||||
|
||||
// CSRFError represents a CSRF validation error
|
||||
type CSRFError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *CSRFError) Error() string {
|
||||
return e.message
|
||||
}
|
|
@ -19,6 +19,7 @@ const (
|
|||
ErrorTypeNotFound ErrorType = 404
|
||||
ErrorTypeMethodNotAllowed ErrorType = 405
|
||||
ErrorTypeInternalError ErrorType = 500
|
||||
ErrorTypeForbidden ErrorType = 403 // Added CSRF/Forbidden error type
|
||||
)
|
||||
|
||||
// ErrorPage generates an HTML error page based on the error type
|
||||
|
@ -34,6 +35,8 @@ func ErrorPage(config ErrorPageConfig, errorType ErrorType, url string, errMsg s
|
|||
filename = "405.html"
|
||||
case ErrorTypeInternalError:
|
||||
filename = "500.html"
|
||||
case ErrorTypeForbidden:
|
||||
filename = "403.html"
|
||||
}
|
||||
|
||||
if filename != "" {
|
||||
|
@ -52,6 +55,8 @@ func ErrorPage(config ErrorPageConfig, errorType ErrorType, url string, errMsg s
|
|||
return generateMethodNotAllowedHTML(url)
|
||||
case ErrorTypeInternalError:
|
||||
return generateInternalErrorHTML(config.DebugMode, url, errMsg)
|
||||
case ErrorTypeForbidden:
|
||||
return generateForbiddenHTML(config.DebugMode, url, errMsg)
|
||||
default:
|
||||
// Fallback to internal error
|
||||
return generateInternalErrorHTML(config.DebugMode, url, errMsg)
|
||||
|
@ -73,6 +78,11 @@ func InternalErrorPage(config ErrorPageConfig, url string, errMsg string) string
|
|||
return ErrorPage(config, ErrorTypeInternalError, url, errMsg)
|
||||
}
|
||||
|
||||
// ForbiddenPage generates a 403 Forbidden error page
|
||||
func ForbiddenPage(config ErrorPageConfig, url string, errMsg string) string {
|
||||
return ErrorPage(config, ErrorTypeForbidden, url, errMsg)
|
||||
}
|
||||
|
||||
// generateInternalErrorHTML creates a 500 Internal Server Error page
|
||||
func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string {
|
||||
errorMessages := []string{
|
||||
|
@ -92,6 +102,30 @@ func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string
|
|||
return generateErrorHTML("500", randomMessage, "Internal Server Error", debugMode, errMsg)
|
||||
}
|
||||
|
||||
// generateForbiddenHTML creates a 403 Forbidden error page
|
||||
func generateForbiddenHTML(debugMode bool, url string, errMsg string) string {
|
||||
errorMessages := []string{
|
||||
"Access denied",
|
||||
"You shall not pass",
|
||||
"This area is off-limits",
|
||||
"Security check failed",
|
||||
"Invalid security token",
|
||||
"Request blocked for security reasons",
|
||||
"Permission denied",
|
||||
"Security violation detected",
|
||||
"This request was rejected",
|
||||
"Security first, access second",
|
||||
}
|
||||
|
||||
defaultMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
|
||||
if errMsg == "" {
|
||||
errMsg = defaultMsg
|
||||
}
|
||||
|
||||
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
|
||||
return generateErrorHTML("403", randomMessage, "Forbidden", debugMode, errMsg)
|
||||
}
|
||||
|
||||
// generateNotFoundHTML creates a 404 Not Found error page
|
||||
func generateNotFoundHTML(url string) string {
|
||||
errorMessages := []string{
|
||||
|
|
Loading…
Reference in New Issue
Block a user