231 lines
5.7 KiB
Go
231 lines
5.7 KiB
Go
package runner
|
|
|
|
import (
|
|
"crypto/subtle"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
"git.sharkk.net/Sky/Moonshark/core/logger"
|
|
)
|
|
|
|
// LuaCSRFModule provides CSRF protection functionality to Lua scripts
|
|
const LuaCSRFModule = `
|
|
-- CSRF protection module
|
|
local csrf = {
|
|
-- Session key where the token is stored
|
|
TOKEN_KEY = "_csrf_token",
|
|
|
|
-- Default form field name
|
|
DEFAULT_FIELD = "csrf",
|
|
|
|
-- Generate a new CSRF token and store it in the session
|
|
generate = function(length)
|
|
-- Default length is 32 characters
|
|
length = length or 32
|
|
|
|
if length < 16 then
|
|
-- Enforce minimum security
|
|
length = 16
|
|
end
|
|
|
|
-- Check if we have a session module
|
|
if not session then
|
|
error("CSRF protection requires the session module", 2)
|
|
end
|
|
|
|
-- Generate a secure random token using os.time and math.random
|
|
local token = ""
|
|
local chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
|
|
-- Seed the random generator with current time
|
|
math.randomseed(os.time())
|
|
|
|
-- Generate random string
|
|
for i = 1, length do
|
|
local idx = math.random(1, #chars)
|
|
token = token .. chars:sub(idx, idx)
|
|
end
|
|
|
|
-- Store in session
|
|
session.set(csrf.TOKEN_KEY, token)
|
|
|
|
return token
|
|
end,
|
|
|
|
-- Get the current token or generate a new one
|
|
token = function()
|
|
-- Get from session if exists
|
|
local token = session.get(csrf.TOKEN_KEY)
|
|
|
|
-- Generate if needed
|
|
if not token then
|
|
token = csrf.generate()
|
|
end
|
|
|
|
return token
|
|
end,
|
|
|
|
-- Generate a hidden form field with the CSRF token
|
|
field = function(field_name)
|
|
field_name = field_name or csrf.DEFAULT_FIELD
|
|
local token = csrf.token()
|
|
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
|
|
end,
|
|
|
|
-- Verify a given token against the session token
|
|
verify = function(token, field_name)
|
|
field_name = field_name or csrf.DEFAULT_FIELD
|
|
|
|
local env = getfenv(2)
|
|
|
|
local form = nil
|
|
if env.ctx and env.ctx.form then
|
|
form = env.ctx.form
|
|
else
|
|
return false
|
|
end
|
|
|
|
token = token or form[field_name]
|
|
if not token then
|
|
return false
|
|
end
|
|
|
|
local session_token = session.get(csrf.TOKEN_KEY)
|
|
if not session_token then
|
|
return false
|
|
end
|
|
|
|
if #token ~= #session_token then
|
|
return false
|
|
end
|
|
|
|
local result = true
|
|
for i = 1, #token do
|
|
if token:sub(i, i) ~= session_token:sub(i, i) then
|
|
result = false
|
|
-- Don't break early - continue to prevent timing attacks
|
|
end
|
|
end
|
|
|
|
return result
|
|
end
|
|
}
|
|
|
|
-- Install CSRF module
|
|
_G.csrf = csrf
|
|
|
|
-- Make sure the CSRF module is accessible in sandbox
|
|
if __env_system and __env_system.base_env then
|
|
__env_system.base_env.csrf = csrf
|
|
end
|
|
`
|
|
|
|
// CSRFModuleInitFunc returns an initializer for the CSRF module
|
|
func CSRFModuleInitFunc() StateInitFunc {
|
|
return func(state *luajit.State) error {
|
|
return state.DoString(LuaCSRFModule)
|
|
}
|
|
}
|
|
|
|
// ValidateCSRFToken checks if the CSRF token is valid for a request
|
|
func ValidateCSRFToken(state *luajit.State, ctx *Context) bool {
|
|
// Only validate for form submissions
|
|
method, ok := ctx.Get("method").(string)
|
|
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
|
|
return true
|
|
}
|
|
|
|
// Get form data
|
|
formData, ok := ctx.Get("form").(map[string]any)
|
|
if !ok || formData == nil {
|
|
logger.Warning("CSRF validation failed: no form data")
|
|
return false
|
|
}
|
|
|
|
// Get token from form
|
|
formToken, ok := formData["csrf"].(string)
|
|
if !ok || formToken == "" {
|
|
logger.Warning("CSRF validation failed: no token in form")
|
|
return false
|
|
}
|
|
|
|
// Get session token
|
|
state.GetGlobal("session")
|
|
if state.IsNil(-1) {
|
|
state.Pop(1)
|
|
logger.Warning("CSRF validation failed: session module not available")
|
|
return false
|
|
}
|
|
|
|
state.GetField(-1, "get")
|
|
if !state.IsFunction(-1) {
|
|
state.Pop(2)
|
|
logger.Warning("CSRF validation failed: session.get not available")
|
|
return false
|
|
}
|
|
|
|
state.PushCopy(-1) // Duplicate function
|
|
state.PushString("_csrf_token")
|
|
|
|
if err := state.Call(1, 1); err != nil {
|
|
state.Pop(3) // Pop error, function and session table
|
|
logger.Warning("CSRF validation failed: %v", err)
|
|
return false
|
|
}
|
|
|
|
if state.IsNil(-1) {
|
|
state.Pop(3) // Pop nil, function and session table
|
|
logger.Warning("CSRF validation failed: no token in session")
|
|
return false
|
|
}
|
|
|
|
sessionToken := state.ToString(-1)
|
|
state.Pop(3) // Pop token, function and session table
|
|
|
|
// Constant-time comparison to prevent timing attacks
|
|
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
|
|
}
|
|
|
|
// WithCSRFProtection creates a runner option to add CSRF protection
|
|
func WithCSRFProtection() RunnerOption {
|
|
return func(r *LuaRunner) {
|
|
r.AddInitHook(func(state *luajit.State, ctx *Context) error {
|
|
// Get request method
|
|
method, ok := ctx.Get("method").(string)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
// Only validate for form submissions
|
|
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
|
|
return nil
|
|
}
|
|
|
|
// Check for form data
|
|
form, ok := ctx.Get("form").(map[string]any)
|
|
if !ok || form == nil {
|
|
return nil
|
|
}
|
|
|
|
// Validate CSRF token
|
|
if !ValidateCSRFToken(state, ctx) {
|
|
return ErrCSRFValidationFailed
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
}
|
|
|
|
// Error for CSRF validation failure
|
|
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
|
|
|
|
// CSRFError represents a CSRF validation error
|
|
type CSRFError struct {
|
|
message string
|
|
}
|
|
|
|
// Error implements the error interface
|
|
func (e *CSRFError) Error() string {
|
|
return e.message
|
|
}
|