Compare commits

..

3 Commits

Author SHA1 Message Date
ab6135e98a clean up 2025-04-10 09:48:58 -05:00
ba9a3db0a0 optimize sandbox 2025-04-10 09:26:14 -05:00
0abf31ed3a work on sessions 2025-04-10 07:51:15 -05:00
7 changed files with 158 additions and 655 deletions

View File

@ -1,104 +0,0 @@
package http
import (
"Moonshark/core/runner"
"Moonshark/core/utils"
"Moonshark/core/utils/logger"
"crypto/subtle"
"errors"
"github.com/valyala/fasthttp"
)
// Error for CSRF validation failure
var ErrCSRFValidationFailed = errors.New("CSRF token validation failed")
// ValidateCSRFToken checks if the CSRF token is valid for a request
func ValidateCSRFToken(ctx *runner.Context) bool {
// Only validate for form submissions
method, ok := ctx.Get("method").(string)
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
return true
}
// Get form data
formData, ok := ctx.Get("form").(map[string]any)
if !ok || formData == nil {
logger.Warning("CSRF validation failed: no form data")
return false
}
// Get token from form
formToken, ok := formData["csrf"].(string)
if !ok || formToken == "" {
logger.Warning("CSRF validation failed: no token in form")
return false
}
// Get token from session
sessionData := ctx.SessionData
if sessionData == nil {
logger.Warning("CSRF validation failed: no session data")
return false
}
sessionToken, ok := sessionData["_csrf_token"].(string)
if !ok || sessionToken == "" {
logger.Warning("CSRF validation failed: no token in session")
return false
}
// Constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
}
// HandleCSRFError handles a CSRF validation error
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
method := string(ctx.Method())
path := string(ctx.Path())
logger.Warning("CSRF validation failed for %s %s", method, path)
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusForbidden)
errorMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg)
ctx.SetBody([]byte(errorHTML))
}
// GenerateCSRFToken creates a new CSRF token and stores it in the session
func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) {
if length < 16 {
length = 16 // Minimum token length for security
}
// Create secure random token
token, err := GenerateSecureToken(length)
if err != nil {
return "", err
}
// Store token in session
ctx.SessionData["_csrf_token"] = token
return token, nil
}
// GetCSRFToken retrieves the current CSRF token or generates a new one
func GetCSRFToken(ctx *runner.Context) (string, error) {
// Check if token already exists in session
if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" {
return token, nil
}
// Generate new token
return GenerateCSRFToken(ctx, 32)
}
// CSRFMiddleware validates CSRF tokens for state-changing requests
func CSRFMiddleware(ctx *runner.Context) error {
if !ValidateCSRFToken(ctx) {
return ErrCSRFValidationFailed
}
return nil
}

View File

@ -2,7 +2,6 @@ package http
import ( import (
"context" "context"
"errors"
"time" "time"
"Moonshark/core/metadata" "Moonshark/core/metadata"
@ -167,11 +166,6 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
luaCtx.Set("path", path) luaCtx.Set("path", path)
luaCtx.Set("host", host) luaCtx.Set("host", host)
// Initialize session
session := s.sessionManager.GetSessionFromRequest(ctx)
luaCtx.SessionID = session.ID
luaCtx.SessionData = session.GetAll()
// URL parameters // URL parameters
if params.Count > 0 { if params.Count > 0 {
paramMap := make(map[string]any, params.Count) paramMap := make(map[string]any, params.Count)
@ -198,25 +192,11 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
luaCtx.Set("form", make(map[string]any)) luaCtx.Set("form", make(map[string]any))
} }
// CSRF middleware for state-changing requests
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
if !ValidateCSRFToken(luaCtx) {
HandleCSRFError(ctx, s.errorConfig)
return
}
}
// Execute Lua script // Execute Lua script
response, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath) response, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
if err != nil { if err != nil {
logger.Error("Error executing Lua route: %v", err) logger.Error("Error executing Lua route: %v", err)
// Special handling for specific errors
if errors.Is(err, ErrCSRFValidationFailed) {
HandleCSRFError(ctx, s.errorConfig)
return
}
// General error handling // General error handling
ctx.SetContentType("text/html; charset=utf-8") ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusInternalServerError) ctx.SetStatusCode(fasthttp.StatusInternalServerError)
@ -225,15 +205,6 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
return return
} }
// Update session if modified
if response.SessionModified {
for k, v := range response.SessionData {
session.Set(k, v)
}
s.sessionManager.ApplySessionCookie(ctx, session)
}
// Apply response to HTTP context // Apply response to HTTP context
runner.ApplyResponse(response, ctx) runner.ApplyResponse(response, ctx)

View File

@ -15,10 +15,6 @@ type Context struct {
// FastHTTP context if this was created from an HTTP request // FastHTTP context if this was created from an HTTP request
RequestCtx *fasthttp.RequestCtx RequestCtx *fasthttp.RequestCtx
// Session information
SessionID string
SessionData map[string]any
// Buffer for efficient string operations // Buffer for efficient string operations
buffer *bytebufferpool.ByteBuffer buffer *bytebufferpool.ByteBuffer
} }
@ -27,8 +23,7 @@ type Context struct {
var contextPool = sync.Pool{ var contextPool = sync.Pool{
New: func() any { New: func() any {
return &Context{ return &Context{
Values: make(map[string]any, 16), Values: make(map[string]any, 32),
SessionData: make(map[string]any, 8),
} }
}, },
} }
@ -90,13 +85,6 @@ func (c *Context) Release() {
delete(c.Values, k) delete(c.Values, k)
} }
for k := range c.SessionData {
delete(c.SessionData, k)
}
// Reset session info
c.SessionID = ""
// Reset request context // Reset request context
c.RequestCtx = nil c.RequestCtx = nil
@ -126,13 +114,3 @@ func (c *Context) Set(key string, value any) {
func (c *Context) Get(key string) any { func (c *Context) Get(key string) any {
return c.Values[key] return c.Values[key]
} }
// SetSession sets a session data value
func (c *Context) SetSession(key string, value any) {
c.SessionData[key] = value
}
// GetSession retrieves a session data value
func (c *Context) GetSession(key string) any {
return c.SessionData[key]
}

View File

@ -24,7 +24,7 @@ func precompileSandboxCode() {
// Create temporary state for compilation // Create temporary state for compilation
tempState := luajit.New() tempState := luajit.New()
if tempState == nil { if tempState == nil {
logger.Error("Failed to create temp Lua state for bytecode compilation") logger.ErrorCont("Failed to create temp Lua state for bytecode compilation")
return return
} }
defer tempState.Close() defer tempState.Close()
@ -32,7 +32,7 @@ func precompileSandboxCode() {
code, err := tempState.CompileBytecode(sandboxLuaCode, "sandbox.lua") code, err := tempState.CompileBytecode(sandboxLuaCode, "sandbox.lua")
if err != nil { if err != nil {
logger.Error("Failed to compile sandbox code: %v", err) logger.ErrorCont("Failed to compile sandbox code: %v", err)
return return
} }
@ -40,22 +40,20 @@ func precompileSandboxCode() {
copy(bytecode, code) copy(bytecode, code)
sandboxBytecode.Store(&bytecode) sandboxBytecode.Store(&bytecode)
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(code)) logger.ServerCont("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(code))
} }
// loadSandboxIntoState loads the sandbox code into a Lua state // loadSandboxIntoState loads the sandbox code into a Lua state
func loadSandboxIntoState(state *luajit.State) error { func loadSandboxIntoState(state *luajit.State) error {
// Initialize bytecode once
bytecodeOnce.Do(precompileSandboxCode) bytecodeOnce.Do(precompileSandboxCode)
// Use precompiled bytecode if available
bytecode := sandboxBytecode.Load() bytecode := sandboxBytecode.Load()
if bytecode != nil && len(*bytecode) > 0 { if bytecode != nil && len(*bytecode) > 0 {
logger.Debug("Loading sandbox.lua from precompiled bytecode") logger.ServerCont("Loading sandbox.lua from precompiled bytecode") // piggyback off Sandbox.go's Setup()
return state.LoadAndRunBytecode(*bytecode, "sandbox.lua") return state.LoadAndRunBytecode(*bytecode, "sandbox.lua")
} }
// Fallback to direct execution // Fallback to direct execution
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)") logger.WarningCont("Using non-precompiled sandbox.lua (bytecode compilation failed)")
return state.DoString(sandboxLuaCode) return state.DoString(sandboxLuaCode)
} }

View File

@ -147,17 +147,12 @@ func (r *Runner) createState(index int) (*State, error) {
r.debugLog("Creating Lua state %d", index) r.debugLog("Creating Lua state %d", index)
} }
// Create a new state
L := luajit.New() L := luajit.New()
if L == nil { if L == nil {
return nil, errors.New("failed to create Lua state") return nil, errors.New("failed to create Lua state")
} }
// Create sandbox
sb := NewSandbox() sb := NewSandbox()
if r.debug {
sb.EnableDebug()
}
// Set up sandbox // Set up sandbox
if err := sb.Setup(L); err != nil { if err := sb.Setup(L); err != nil {

View File

@ -40,67 +40,50 @@ func NewSandbox() *Sandbox {
} }
} }
// EnableDebug turns on debug logging
func (s *Sandbox) EnableDebug() {
s.debug = true
}
// debugLog logs a message if debug mode is enabled
func (s *Sandbox) debugLog(format string, args ...interface{}) {
if s.debug {
logger.Debug("Sandbox "+format, args...)
}
}
// AddModule adds a module to the sandbox environment // AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) { func (s *Sandbox) AddModule(name string, module any) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.modules[name] = module s.modules[name] = module
s.debugLog("Added module: %s", name) logger.Debug("Added module: %s", name)
} }
// Setup initializes the sandbox in a Lua state // Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State) error { func (s *Sandbox) Setup(state *luajit.State) error {
s.debugLog("Setting up sandbox...") logger.Server("Setting up sandbox...")
// Load the sandbox code
if err := loadSandboxIntoState(state); err != nil { if err := loadSandboxIntoState(state); err != nil {
s.debugLog("Failed to load sandbox: %v", err) logger.ErrorCont("Failed to load sandbox: %v", err)
return err return err
} }
// Register core functions
if err := s.registerCoreFunctions(state); err != nil { if err := s.registerCoreFunctions(state); err != nil {
s.debugLog("Failed to register core functions: %v", err) logger.ErrorCont("Failed to register core functions: %v", err)
return err return err
} }
// Register custom modules in the global environment
s.mu.RLock() s.mu.RLock()
for name, module := range s.modules { for name, module := range s.modules {
s.debugLog("Registering module: %s", name) logger.DebugCont("Registering module: %s", name)
if err := state.PushValue(module); err != nil { if err := state.PushValue(module); err != nil {
s.mu.RUnlock() s.mu.RUnlock()
s.debugLog("Failed to register module %s: %v", name, err) logger.ErrorCont("Failed to register module %s: %v", name, err)
return err return err
} }
state.SetGlobal(name) state.SetGlobal(name)
} }
s.mu.RUnlock() s.mu.RUnlock()
s.debugLog("Sandbox setup complete") logger.ServerCont("Sandbox setup complete")
return nil return nil
} }
// registerCoreFunctions registers all built-in functions in the Lua state // registerCoreFunctions registers all built-in functions in the Lua state
func (s *Sandbox) registerCoreFunctions(state *luajit.State) error { func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
// Register HTTP functions
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil { if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
return err return err
} }
// Register utility functions
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil { if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
return err return err
} }
@ -112,62 +95,41 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
// Execute runs a Lua script in the sandbox with the given context // Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
// Create a response object // Get the execution function first
response := NewResponse()
// Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil {
ReleaseResponse(response)
return nil, fmt.Errorf("failed to load script: %w", err)
}
// Add session data to context
contextWithSession := make(map[string]any)
maps.Copy(contextWithSession, ctx.Values)
// Pass session data through context
if ctx.SessionID != "" {
contextWithSession["session_id"] = ctx.SessionID
contextWithSession["session_data"] = ctx.SessionData
}
// Set up context values for execution
if err := state.PushTable(contextWithSession); err != nil {
ReleaseResponse(response)
return nil, err
}
// Get the execution function
state.GetGlobal("__execute_script") state.GetGlobal("__execute_script")
if !state.IsFunction(-1) { if !state.IsFunction(-1) {
state.Pop(1) state.Pop(1)
ReleaseResponse(response)
return nil, ErrSandboxNotInitialized return nil, ErrSandboxNotInitialized
} }
// Push function and bytecode // Load bytecode
state.PushCopy(-2) // Bytecode if err := state.LoadBytecode(bytecode, "script"); err != nil {
state.PushCopy(-2) // Context state.Pop(1) // Pop the __execute_script function
state.Remove(-4) // Remove bytecode duplicate return nil, fmt.Errorf("failed to load script: %w", err)
state.Remove(-3) // Remove context duplicate }
// Push context values
if err := state.PushTable(ctx.Values); err != nil {
state.Pop(2) // Pop bytecode and __execute_script
return nil, err
}
// Execute with 2 args, 1 result // Execute with 2 args, 1 result
if err := state.Call(2, 1); err != nil { if err := state.Call(2, 1); err != nil {
ReleaseResponse(response)
return nil, fmt.Errorf("script execution failed: %w", err) return nil, fmt.Errorf("script execution failed: %w", err)
} }
// Set response body from result // Get result value
body, err := state.ToValue(-1) body, err := state.ToValue(-1)
state.Pop(1)
response := NewResponse()
if err == nil { if err == nil {
response.Body = body response.Body = body
} }
state.Pop(1)
extractHTTPResponseData(state, response) extractHTTPResponseData(state, response)
extractSessionData(state, response)
return response, nil return response, nil
} }
@ -229,9 +191,7 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
if state.IsTable(-1) { if state.IsTable(-1) {
table, err := state.ToTable(-1) table, err := state.ToTable(-1)
if err == nil { if err == nil {
for k, v := range table { maps.Copy(response.Metadata, table)
response.Metadata[k] = v
}
} }
} }
state.Pop(1) state.Pop(1)
@ -298,69 +258,3 @@ func extractCookie(state *luajit.State, response *Response) {
response.Cookies = append(response.Cookies, cookie) response.Cookies = append(response.Cookies, cookie)
} }
// Extract session data if modified
func extractSessionData(state *luajit.State, response *Response) {
logger.Debug("extractSessionData: Starting extraction")
// Get HTTP response table
state.GetGlobal("__http_responses")
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses is not a table")
state.Pop(1)
return
}
// Get first response
state.PushNumber(1)
state.GetTable(-2)
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses[1] is not a table")
state.Pop(2)
return
}
// Check session_modified flag
state.GetField(-1, "session_modified")
if !state.IsBoolean(-1) || !state.ToBoolean(-1) {
logger.Debug("extractSessionData: session_modified is not true")
state.Pop(3)
return
}
logger.Debug("extractSessionData: Found session_modified=true")
state.Pop(1)
// Get session ID
state.GetField(-1, "session_id")
if state.IsString(-1) {
response.SessionID = state.ToString(-1)
logger.Debug("extractSessionData: Found session ID: %s", response.SessionID)
} else {
logger.Debug("extractSessionData: session_id not found or not a string")
}
state.Pop(1)
// Get session data
state.GetField(-1, "session_data")
if state.IsTable(-1) {
logger.Debug("extractSessionData: Found session_data table")
sessionData, err := state.ToTable(-1)
if err == nil {
logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData))
for k, v := range sessionData {
response.SessionData[k] = v
logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v)
}
response.SessionModified = true
} else {
logger.Debug("extractSessionData: Failed to convert session data: %v", err)
}
} else {
logger.Debug("extractSessionData: session_data not found or not a table")
}
state.Pop(1)
// Clean up stack
state.Pop(2)
logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified)
}

View File

@ -6,14 +6,10 @@ including core modules and utilities. It's designed to be embedded in the
Go binary at build time. Go binary at build time.
]]-- ]]--
-- Global tables for execution context __http_response = {}
__http_responses = {}
__module_paths = {} __module_paths = {}
__module_bytecode = {} __module_bytecode = {}
__ready_modules = {} __ready_modules = {}
__session_data = {}
__session_id = nil
__session_modified = false
-- ====================================================================== -- ======================================================================
-- CORE SANDBOX FUNCTIONALITY -- CORE SANDBOX FUNCTIONALITY
@ -21,15 +17,12 @@ __session_modified = false
-- Create environment inheriting from _G -- Create environment inheriting from _G
function __create_env(ctx) function __create_env(ctx)
-- Create environment with metatable inheriting from _G
local env = setmetatable({}, {__index = _G}) local env = setmetatable({}, {__index = _G})
-- Add context if provided
if ctx then if ctx then
env.ctx = ctx env.ctx = ctx
end end
-- Add proper require function to this environment
if __setup_require then if __setup_require then
__setup_require(env) __setup_require(env)
end end
@ -39,57 +32,31 @@ end
-- Execute script with clean environment -- Execute script with clean environment
function __execute_script(fn, ctx) function __execute_script(fn, ctx)
-- Clear previous responses __http_response = nil
__http_responses[1] = nil
-- Create environment with metatable inheriting from _G local env = __create_env(ctx)
local env = setmetatable({}, {__index = _G})
-- Add context if provided
if ctx then
env.ctx = ctx
end
print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {}))
-- Initialize local session variables in the environment
env.__session_data = ctx.session_data or {}
env.__session_id = ctx.session_id
env.__session_modified = false
-- Add proper require function to this environment
if __setup_require then
__setup_require(env)
end
-- Set environment for function
setfenv(fn, env) setfenv(fn, env)
-- Execute with protected call
local ok, result = pcall(fn) local ok, result = pcall(fn)
if not ok then if not ok then
error(result, 0) error(result, 0)
end end
-- If session was modified, add to response return result
if env.__session_modified then
__http_responses[1] = __http_responses[1] or {}
__http_responses[1].session_data = env.__session_data
__http_responses[1].session_id = env.__session_id
__http_responses[1].session_modified = true
end end
print("SESSION MODIFIED:", env.__session_modified) -- Ensure __http_response exists, then return it
print("FINAL DATA:", util.json_encode(env.__session_data or {})) function __ensure_response()
if not __http_response then
return result __http_response = {}
end
return __http_response
end end
-- ====================================================================== -- ======================================================================
-- HTTP MODULE -- HTTP MODULE
-- ====================================================================== -- ======================================================================
-- HTTP module implementation
local http = { local http = {
-- Set HTTP status code -- Set HTTP status code
set_status = function(code) set_status = function(code)
@ -97,9 +64,8 @@ local http = {
error("http.set_status: status code must be a number", 2) error("http.set_status: status code must be a number", 2)
end end
local resp = __http_responses[1] or {} local resp = __ensure_response()
resp.status = code resp.status = code
__http_responses[1] = resp
end, end,
-- Set HTTP header -- Set HTTP header
@ -108,10 +74,9 @@ local http = {
error("http.set_header: name and value must be strings", 2) error("http.set_header: name and value must be strings", 2)
end end
local resp = __http_responses[1] or {} local resp = __ensure_response()
resp.headers = resp.headers or {} resp.headers = resp.headers or {}
resp.headers[name] = value resp.headers[name] = value
__http_responses[1] = resp
end, end,
-- Set content type; set_header helper -- Set content type; set_header helper
@ -125,10 +90,9 @@ local http = {
error("http.set_metadata: key must be a string", 2) error("http.set_metadata: key must be a string", 2)
end end
local resp = __http_responses[1] or {} local resp = __ensure_response()
resp.metadata = resp.metadata or {} resp.metadata = resp.metadata or {}
resp.metadata[key] = value resp.metadata[key] = value
__http_responses[1] = resp
end, end,
-- HTTP client submodule -- HTTP client submodule
@ -147,45 +111,6 @@ local http = {
return result return result
end, end,
-- Simple GET request
get = function(url, options)
return http.client.request("GET", url, nil, options)
end,
-- Simple POST request with automatic content-type
post = function(url, body, options)
options = options or {}
return http.client.request("POST", url, body, options)
end,
-- Simple PUT request with automatic content-type
put = function(url, body, options)
options = options or {}
return http.client.request("PUT", url, body, options)
end,
-- Simple DELETE request
delete = function(url, options)
return http.client.request("DELETE", url, nil, options)
end,
-- Simple PATCH request
patch = function(url, body, options)
options = options or {}
return http.client.request("PATCH", url, body, options)
end,
-- Simple HEAD request
head = function(url, options)
options = options or {}
return http.client.request("HEAD", url, nil, options)
end,
-- Simple OPTIONS request
options = function(url, options)
return http.client.request("OPTIONS", url, nil, options)
end,
-- Shorthand function to directly get JSON -- Shorthand function to directly get JSON
get_json = function(url, options) get_json = function(url, options)
options = options or {} options = options or {}
@ -226,11 +151,30 @@ local http = {
} }
} }
local function make_method(method, needs_body)
return function(url, body_or_options, options)
if needs_body then
options = options or {}
return http.client.request(method, url, body_or_options, options)
else
body_or_options = body_or_options or {}
return http.client.request(method, url, nil, body_or_options)
end
end
end
http.client.get = make_method("GET", false)
http.client.delete = make_method("DELETE", false)
http.client.head = make_method("HEAD", false)
http.client.options = make_method("OPTIONS", false)
http.client.post = make_method("POST", true)
http.client.put = make_method("PUT", true)
http.client.patch = make_method("PATCH", true)
-- ====================================================================== -- ======================================================================
-- COOKIE MODULE -- COOKIE MODULE
-- ====================================================================== -- ======================================================================
-- Cookie module implementation
local cookie = { local cookie = {
-- Set a cookie -- Set a cookie
set = function(name, value, options) set = function(name, value, options)
@ -238,15 +182,10 @@ local cookie = {
error("cookie.set: name must be a string", 2) error("cookie.set: name must be a string", 2)
end end
-- Get or create response local resp = __ensure_response()
local resp = __http_responses[1] or {}
resp.cookies = resp.cookies or {} resp.cookies = resp.cookies or {}
__http_responses[1] = resp
-- Handle options as table
local opts = options or {} local opts = options or {}
-- Create cookie table
local cookie = { local cookie = {
name = name, name = name,
value = value or "", value = value or "",
@ -254,7 +193,6 @@ local cookie = {
domain = opts.domain domain = opts.domain
} }
-- Handle expiry
if opts.expires then if opts.expires then
if type(opts.expires) == "number" then if type(opts.expires) == "number" then
if opts.expires > 0 then if opts.expires > 0 then
@ -269,14 +207,28 @@ local cookie = {
end end
end end
-- Security flags
cookie.secure = (opts.secure ~= false) cookie.secure = (opts.secure ~= false)
cookie.http_only = (opts.http_only ~= false) cookie.http_only = (opts.http_only ~= false)
-- Store in cookies table if opts.same_site then
local n = #resp.cookies + 1 local valid_values = {none = true, lax = true, strict = true}
resp.cookies[n] = cookie local same_site = string.lower(opts.same_site)
if not valid_values[same_site] then
error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
end
-- If SameSite=None, the cookie must be secure
if same_site == "none" and not cookie.secure then
cookie.secure = true
end
cookie.same_site = opts.same_site
else
cookie.same_site = "Lax"
end
table.insert(resp.cookies, cookie)
return true return true
end, end,
@ -286,15 +238,12 @@ local cookie = {
error("cookie.get: name must be a string", 2) error("cookie.get: name must be a string", 2)
end end
-- Access values directly from current environment
local env = getfenv(2) local env = getfenv(2)
-- Check if context exists and has cookies
if env.ctx and env.ctx.cookies then if env.ctx and env.ctx.cookies then
return env.ctx.cookies[name] return env.ctx.cookies[name]
end end
-- If context has request_cookies map
if env.ctx and env.ctx._request_cookies then if env.ctx and env.ctx._request_cookies then
return env.ctx._request_cookies[name] return env.ctx._request_cookies[name]
end end
@ -308,185 +257,10 @@ local cookie = {
error("cookie.remove: name must be a string", 2) error("cookie.remove: name must be a string", 2)
end end
-- Create an expired cookie
return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain}) return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
end end
} }
-- ======================================================================
-- SESSION MODULE
-- ======================================================================
local session = {
-- Get session value
get = function(key)
if type(key) ~= "string" then
error("session.get: key must be a string", 2)
end
local env = getfenv(2)
return env.__session_data and env.__session_data[key]
end,
-- Set session value
set = function(key, value)
if type(key) ~= "string" then
error("session.set: key must be a string", 2)
end
local env = getfenv(2)
print("SET ENV:", tostring(env)) -- Debug the environment
if not env.__session_data then
env.__session_data = {}
print("CREATED NEW SESSION TABLE")
end
env.__session_data[key] = value
env.__session_modified = true
print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)
return true
end,
-- Delete session value
delete = function(key)
if type(key) ~= "string" then
error("session.delete: key must be a string", 2)
end
local env = getfenv(2)
if env.__session_data and env.__session_data[key] ~= nil then
env.__session_data[key] = nil
env.__session_modified = true
end
return true
end,
-- Clear all session data
clear = function()
local env = getfenv(2)
if env.__session_data and next(env.__session_data) then
env.__session_data = {}
env.__session_modified = true
end
return true
end,
-- Get session ID
get_id = function()
local env = getfenv(2)
return env.__session_id or ""
end,
-- Get all session data
get_all = function()
local env = getfenv(2)
return env.__session_data or {}
end,
-- Check if session has key
has = function(key)
if type(key) ~= "string" then
error("session.has: key must be a string", 2)
end
local env = getfenv(2)
return env.__session_data ~= nil and env.__session_data[key] ~= nil
end
}
-- ======================================================================
-- CSRF MODULE
-- ======================================================================
-- CSRF protection module
local csrf = {
-- Session key where the token is stored
TOKEN_KEY = "_csrf_token",
-- Default form field name
DEFAULT_FIELD = "csrf",
-- Generate a new CSRF token and store it in the session
generate = function(length)
-- Default length is 32 characters
length = length or 32
if length < 16 then
-- Enforce minimum security
length = 16
end
-- Check if we have a session module
if not session then
error("CSRF protection requires the session module", 2)
end
local token = __generate_token(length)
session.set(csrf.TOKEN_KEY, token)
return token
end,
-- Get the current token or generate a new one
token = function()
-- Get from session if exists
local token = session.get(csrf.TOKEN_KEY)
-- Generate if needed
if not token then
token = csrf.generate()
end
return token
end,
-- Generate a hidden form field with the CSRF token
field = function(field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local token = csrf.token()
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
end,
-- Verify a given token against the session token
verify = function(token, field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local env = getfenv(2)
local form = nil
if env.ctx and env.ctx._request_form then
form = env.ctx._request_form
elseif env.ctx and env.ctx.form then
form = env.ctx.form
else
return false
end
token = token or form[field_name]
if not token then
return false
end
local session_token = session.get(csrf.TOKEN_KEY)
if not session_token then
return false
end
-- Constant-time comparison to prevent timing attacks
if #token ~= #session_token then
return false
end
local result = true
for i = 1, #token do
if token:sub(i, i) ~= session_token:sub(i, i) then
result = false
-- Don't break early - continue to prevent timing attacks
end
end
return result
end
}
-- ====================================================================== -- ======================================================================
-- UTIL MODULE -- UTIL MODULE
-- ====================================================================== -- ======================================================================
@ -575,9 +349,6 @@ local util = {
-- REGISTER MODULES GLOBALLY -- REGISTER MODULES GLOBALLY
-- ====================================================================== -- ======================================================================
-- Install modules in global scope
_G.http = http _G.http = http
_G.cookie = cookie _G.cookie = cookie
_G.session = session
_G.csrf = csrf
_G.util = util _G.util = util