Compare commits
No commits in common. "ab6135e98adbc1498e6ec95e94e281a62f8b5da4" and "35ce09d66ec42b03b04bc157d11c7e687e5f1b5e" have entirely different histories.
ab6135e98a
...
35ce09d66e
104
core/http/Csrf.go
Normal file
104
core/http/Csrf.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"Moonshark/core/runner"
|
||||
"Moonshark/core/utils"
|
||||
"Moonshark/core/utils/logger"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Error for CSRF validation failure
|
||||
var ErrCSRFValidationFailed = errors.New("CSRF token validation failed")
|
||||
|
||||
// ValidateCSRFToken checks if the CSRF token is valid for a request
|
||||
func ValidateCSRFToken(ctx *runner.Context) bool {
|
||||
// Only validate for form submissions
|
||||
method, ok := ctx.Get("method").(string)
|
||||
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Get form data
|
||||
formData, ok := ctx.Get("form").(map[string]any)
|
||||
if !ok || formData == nil {
|
||||
logger.Warning("CSRF validation failed: no form data")
|
||||
return false
|
||||
}
|
||||
|
||||
// Get token from form
|
||||
formToken, ok := formData["csrf"].(string)
|
||||
if !ok || formToken == "" {
|
||||
logger.Warning("CSRF validation failed: no token in form")
|
||||
return false
|
||||
}
|
||||
|
||||
// Get token from session
|
||||
sessionData := ctx.SessionData
|
||||
if sessionData == nil {
|
||||
logger.Warning("CSRF validation failed: no session data")
|
||||
return false
|
||||
}
|
||||
|
||||
sessionToken, ok := sessionData["_csrf_token"].(string)
|
||||
if !ok || sessionToken == "" {
|
||||
logger.Warning("CSRF validation failed: no token in session")
|
||||
return false
|
||||
}
|
||||
|
||||
// Constant-time comparison to prevent timing attacks
|
||||
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
|
||||
}
|
||||
|
||||
// HandleCSRFError handles a CSRF validation error
|
||||
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
|
||||
method := string(ctx.Method())
|
||||
path := string(ctx.Path())
|
||||
|
||||
logger.Warning("CSRF validation failed for %s %s", method, path)
|
||||
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
|
||||
errorMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
|
||||
errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg)
|
||||
ctx.SetBody([]byte(errorHTML))
|
||||
}
|
||||
|
||||
// GenerateCSRFToken creates a new CSRF token and stores it in the session
|
||||
func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) {
|
||||
if length < 16 {
|
||||
length = 16 // Minimum token length for security
|
||||
}
|
||||
|
||||
// Create secure random token
|
||||
token, err := GenerateSecureToken(length)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Store token in session
|
||||
ctx.SessionData["_csrf_token"] = token
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetCSRFToken retrieves the current CSRF token or generates a new one
|
||||
func GetCSRFToken(ctx *runner.Context) (string, error) {
|
||||
// Check if token already exists in session
|
||||
if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Generate new token
|
||||
return GenerateCSRFToken(ctx, 32)
|
||||
}
|
||||
|
||||
// CSRFMiddleware validates CSRF tokens for state-changing requests
|
||||
func CSRFMiddleware(ctx *runner.Context) error {
|
||||
if !ValidateCSRFToken(ctx) {
|
||||
return ErrCSRFValidationFailed
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -2,6 +2,7 @@ package http
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"Moonshark/core/metadata"
|
||||
|
@ -166,6 +167,11 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
|||
luaCtx.Set("path", path)
|
||||
luaCtx.Set("host", host)
|
||||
|
||||
// Initialize session
|
||||
session := s.sessionManager.GetSessionFromRequest(ctx)
|
||||
luaCtx.SessionID = session.ID
|
||||
luaCtx.SessionData = session.GetAll()
|
||||
|
||||
// URL parameters
|
||||
if params.Count > 0 {
|
||||
paramMap := make(map[string]any, params.Count)
|
||||
|
@ -192,11 +198,25 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
|||
luaCtx.Set("form", make(map[string]any))
|
||||
}
|
||||
|
||||
// CSRF middleware for state-changing requests
|
||||
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
|
||||
if !ValidateCSRFToken(luaCtx) {
|
||||
HandleCSRFError(ctx, s.errorConfig)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Execute Lua script
|
||||
response, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
|
||||
if err != nil {
|
||||
logger.Error("Error executing Lua route: %v", err)
|
||||
|
||||
// Special handling for specific errors
|
||||
if errors.Is(err, ErrCSRFValidationFailed) {
|
||||
HandleCSRFError(ctx, s.errorConfig)
|
||||
return
|
||||
}
|
||||
|
||||
// General error handling
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
|
@ -205,6 +225,15 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
|||
return
|
||||
}
|
||||
|
||||
// Update session if modified
|
||||
if response.SessionModified {
|
||||
for k, v := range response.SessionData {
|
||||
session.Set(k, v)
|
||||
}
|
||||
|
||||
s.sessionManager.ApplySessionCookie(ctx, session)
|
||||
}
|
||||
|
||||
// Apply response to HTTP context
|
||||
runner.ApplyResponse(response, ctx)
|
||||
|
||||
|
|
|
@ -15,6 +15,10 @@ type Context struct {
|
|||
// FastHTTP context if this was created from an HTTP request
|
||||
RequestCtx *fasthttp.RequestCtx
|
||||
|
||||
// Session information
|
||||
SessionID string
|
||||
SessionData map[string]any
|
||||
|
||||
// Buffer for efficient string operations
|
||||
buffer *bytebufferpool.ByteBuffer
|
||||
}
|
||||
|
@ -23,7 +27,8 @@ type Context struct {
|
|||
var contextPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Context{
|
||||
Values: make(map[string]any, 32),
|
||||
Values: make(map[string]any, 16),
|
||||
SessionData: make(map[string]any, 8),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -85,6 +90,13 @@ func (c *Context) Release() {
|
|||
delete(c.Values, k)
|
||||
}
|
||||
|
||||
for k := range c.SessionData {
|
||||
delete(c.SessionData, k)
|
||||
}
|
||||
|
||||
// Reset session info
|
||||
c.SessionID = ""
|
||||
|
||||
// Reset request context
|
||||
c.RequestCtx = nil
|
||||
|
||||
|
@ -114,3 +126,13 @@ func (c *Context) Set(key string, value any) {
|
|||
func (c *Context) Get(key string) any {
|
||||
return c.Values[key]
|
||||
}
|
||||
|
||||
// SetSession sets a session data value
|
||||
func (c *Context) SetSession(key string, value any) {
|
||||
c.SessionData[key] = value
|
||||
}
|
||||
|
||||
// GetSession retrieves a session data value
|
||||
func (c *Context) GetSession(key string) any {
|
||||
return c.SessionData[key]
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ func precompileSandboxCode() {
|
|||
// Create temporary state for compilation
|
||||
tempState := luajit.New()
|
||||
if tempState == nil {
|
||||
logger.ErrorCont("Failed to create temp Lua state for bytecode compilation")
|
||||
logger.Error("Failed to create temp Lua state for bytecode compilation")
|
||||
return
|
||||
}
|
||||
defer tempState.Close()
|
||||
|
@ -32,7 +32,7 @@ func precompileSandboxCode() {
|
|||
|
||||
code, err := tempState.CompileBytecode(sandboxLuaCode, "sandbox.lua")
|
||||
if err != nil {
|
||||
logger.ErrorCont("Failed to compile sandbox code: %v", err)
|
||||
logger.Error("Failed to compile sandbox code: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -40,20 +40,22 @@ func precompileSandboxCode() {
|
|||
copy(bytecode, code)
|
||||
sandboxBytecode.Store(&bytecode)
|
||||
|
||||
logger.ServerCont("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(code))
|
||||
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(code))
|
||||
}
|
||||
|
||||
// loadSandboxIntoState loads the sandbox code into a Lua state
|
||||
func loadSandboxIntoState(state *luajit.State) error {
|
||||
// Initialize bytecode once
|
||||
bytecodeOnce.Do(precompileSandboxCode)
|
||||
|
||||
// Use precompiled bytecode if available
|
||||
bytecode := sandboxBytecode.Load()
|
||||
if bytecode != nil && len(*bytecode) > 0 {
|
||||
logger.ServerCont("Loading sandbox.lua from precompiled bytecode") // piggyback off Sandbox.go's Setup()
|
||||
logger.Debug("Loading sandbox.lua from precompiled bytecode")
|
||||
return state.LoadAndRunBytecode(*bytecode, "sandbox.lua")
|
||||
}
|
||||
|
||||
// Fallback to direct execution
|
||||
logger.WarningCont("Using non-precompiled sandbox.lua (bytecode compilation failed)")
|
||||
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)")
|
||||
return state.DoString(sandboxLuaCode)
|
||||
}
|
||||
|
|
|
@ -147,12 +147,17 @@ func (r *Runner) createState(index int) (*State, error) {
|
|||
r.debugLog("Creating Lua state %d", index)
|
||||
}
|
||||
|
||||
// Create a new state
|
||||
L := luajit.New()
|
||||
if L == nil {
|
||||
return nil, errors.New("failed to create Lua state")
|
||||
}
|
||||
|
||||
// Create sandbox
|
||||
sb := NewSandbox()
|
||||
if r.debug {
|
||||
sb.EnableDebug()
|
||||
}
|
||||
|
||||
// Set up sandbox
|
||||
if err := sb.Setup(L); err != nil {
|
||||
|
|
|
@ -40,50 +40,67 @@ func NewSandbox() *Sandbox {
|
|||
}
|
||||
}
|
||||
|
||||
// EnableDebug turns on debug logging
|
||||
func (s *Sandbox) EnableDebug() {
|
||||
s.debug = true
|
||||
}
|
||||
|
||||
// debugLog logs a message if debug mode is enabled
|
||||
func (s *Sandbox) debugLog(format string, args ...interface{}) {
|
||||
if s.debug {
|
||||
logger.Debug("Sandbox "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// AddModule adds a module to the sandbox environment
|
||||
func (s *Sandbox) AddModule(name string, module any) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.modules[name] = module
|
||||
logger.Debug("Added module: %s", name)
|
||||
s.debugLog("Added module: %s", name)
|
||||
}
|
||||
|
||||
// Setup initializes the sandbox in a Lua state
|
||||
func (s *Sandbox) Setup(state *luajit.State) error {
|
||||
logger.Server("Setting up sandbox...")
|
||||
s.debugLog("Setting up sandbox...")
|
||||
|
||||
// Load the sandbox code
|
||||
if err := loadSandboxIntoState(state); err != nil {
|
||||
logger.ErrorCont("Failed to load sandbox: %v", err)
|
||||
s.debugLog("Failed to load sandbox: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Register core functions
|
||||
if err := s.registerCoreFunctions(state); err != nil {
|
||||
logger.ErrorCont("Failed to register core functions: %v", err)
|
||||
s.debugLog("Failed to register core functions: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Register custom modules in the global environment
|
||||
s.mu.RLock()
|
||||
for name, module := range s.modules {
|
||||
logger.DebugCont("Registering module: %s", name)
|
||||
s.debugLog("Registering module: %s", name)
|
||||
if err := state.PushValue(module); err != nil {
|
||||
s.mu.RUnlock()
|
||||
logger.ErrorCont("Failed to register module %s: %v", name, err)
|
||||
s.debugLog("Failed to register module %s: %v", name, err)
|
||||
return err
|
||||
}
|
||||
state.SetGlobal(name)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
logger.ServerCont("Sandbox setup complete")
|
||||
s.debugLog("Sandbox setup complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerCoreFunctions registers all built-in functions in the Lua state
|
||||
func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
|
||||
// Register HTTP functions
|
||||
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register utility functions
|
||||
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -95,41 +112,62 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
|
|||
|
||||
// Execute runs a Lua script in the sandbox with the given context
|
||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
|
||||
// Get the execution function first
|
||||
state.GetGlobal("__execute_script")
|
||||
if !state.IsFunction(-1) {
|
||||
state.Pop(1)
|
||||
return nil, ErrSandboxNotInitialized
|
||||
}
|
||||
// Create a response object
|
||||
response := NewResponse()
|
||||
|
||||
// Load bytecode
|
||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||
state.Pop(1) // Pop the __execute_script function
|
||||
ReleaseResponse(response)
|
||||
return nil, fmt.Errorf("failed to load script: %w", err)
|
||||
}
|
||||
|
||||
// Push context values
|
||||
if err := state.PushTable(ctx.Values); err != nil {
|
||||
state.Pop(2) // Pop bytecode and __execute_script
|
||||
// Add session data to context
|
||||
contextWithSession := make(map[string]any)
|
||||
maps.Copy(contextWithSession, ctx.Values)
|
||||
|
||||
// Pass session data through context
|
||||
if ctx.SessionID != "" {
|
||||
contextWithSession["session_id"] = ctx.SessionID
|
||||
contextWithSession["session_data"] = ctx.SessionData
|
||||
}
|
||||
|
||||
// Set up context values for execution
|
||||
if err := state.PushTable(contextWithSession); err != nil {
|
||||
ReleaseResponse(response)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the execution function
|
||||
state.GetGlobal("__execute_script")
|
||||
if !state.IsFunction(-1) {
|
||||
state.Pop(1)
|
||||
ReleaseResponse(response)
|
||||
return nil, ErrSandboxNotInitialized
|
||||
}
|
||||
|
||||
// Push function and bytecode
|
||||
state.PushCopy(-2) // Bytecode
|
||||
state.PushCopy(-2) // Context
|
||||
state.Remove(-4) // Remove bytecode duplicate
|
||||
state.Remove(-3) // Remove context duplicate
|
||||
|
||||
// Execute with 2 args, 1 result
|
||||
if err := state.Call(2, 1); err != nil {
|
||||
ReleaseResponse(response)
|
||||
return nil, fmt.Errorf("script execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Get result value
|
||||
// Set response body from result
|
||||
body, err := state.ToValue(-1)
|
||||
state.Pop(1)
|
||||
|
||||
response := NewResponse()
|
||||
if err == nil {
|
||||
response.Body = body
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
extractHTTPResponseData(state, response)
|
||||
|
||||
extractSessionData(state, response)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
|
@ -191,7 +229,9 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
|
|||
if state.IsTable(-1) {
|
||||
table, err := state.ToTable(-1)
|
||||
if err == nil {
|
||||
maps.Copy(response.Metadata, table)
|
||||
for k, v := range table {
|
||||
response.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
|
@ -258,3 +298,69 @@ func extractCookie(state *luajit.State, response *Response) {
|
|||
|
||||
response.Cookies = append(response.Cookies, cookie)
|
||||
}
|
||||
|
||||
// Extract session data if modified
|
||||
func extractSessionData(state *luajit.State, response *Response) {
|
||||
logger.Debug("extractSessionData: Starting extraction")
|
||||
|
||||
// Get HTTP response table
|
||||
state.GetGlobal("__http_responses")
|
||||
if !state.IsTable(-1) {
|
||||
logger.Debug("extractSessionData: __http_responses is not a table")
|
||||
state.Pop(1)
|
||||
return
|
||||
}
|
||||
|
||||
// Get first response
|
||||
state.PushNumber(1)
|
||||
state.GetTable(-2)
|
||||
if !state.IsTable(-1) {
|
||||
logger.Debug("extractSessionData: __http_responses[1] is not a table")
|
||||
state.Pop(2)
|
||||
return
|
||||
}
|
||||
|
||||
// Check session_modified flag
|
||||
state.GetField(-1, "session_modified")
|
||||
if !state.IsBoolean(-1) || !state.ToBoolean(-1) {
|
||||
logger.Debug("extractSessionData: session_modified is not true")
|
||||
state.Pop(3)
|
||||
return
|
||||
}
|
||||
logger.Debug("extractSessionData: Found session_modified=true")
|
||||
state.Pop(1)
|
||||
|
||||
// Get session ID
|
||||
state.GetField(-1, "session_id")
|
||||
if state.IsString(-1) {
|
||||
response.SessionID = state.ToString(-1)
|
||||
logger.Debug("extractSessionData: Found session ID: %s", response.SessionID)
|
||||
} else {
|
||||
logger.Debug("extractSessionData: session_id not found or not a string")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get session data
|
||||
state.GetField(-1, "session_data")
|
||||
if state.IsTable(-1) {
|
||||
logger.Debug("extractSessionData: Found session_data table")
|
||||
sessionData, err := state.ToTable(-1)
|
||||
if err == nil {
|
||||
logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData))
|
||||
for k, v := range sessionData {
|
||||
response.SessionData[k] = v
|
||||
logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v)
|
||||
}
|
||||
response.SessionModified = true
|
||||
} else {
|
||||
logger.Debug("extractSessionData: Failed to convert session data: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("extractSessionData: session_data not found or not a table")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Clean up stack
|
||||
state.Pop(2)
|
||||
logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified)
|
||||
}
|
||||
|
|
|
@ -6,10 +6,14 @@ including core modules and utilities. It's designed to be embedded in the
|
|||
Go binary at build time.
|
||||
]]--
|
||||
|
||||
__http_response = {}
|
||||
-- Global tables for execution context
|
||||
__http_responses = {}
|
||||
__module_paths = {}
|
||||
__module_bytecode = {}
|
||||
__ready_modules = {}
|
||||
__session_data = {}
|
||||
__session_id = nil
|
||||
__session_modified = false
|
||||
|
||||
-- ======================================================================
|
||||
-- CORE SANDBOX FUNCTIONALITY
|
||||
|
@ -17,12 +21,15 @@ __ready_modules = {}
|
|||
|
||||
-- Create environment inheriting from _G
|
||||
function __create_env(ctx)
|
||||
-- Create environment with metatable inheriting from _G
|
||||
local env = setmetatable({}, {__index = _G})
|
||||
|
||||
-- Add context if provided
|
||||
if ctx then
|
||||
env.ctx = ctx
|
||||
end
|
||||
|
||||
-- Add proper require function to this environment
|
||||
if __setup_require then
|
||||
__setup_require(env)
|
||||
end
|
||||
|
@ -32,149 +39,198 @@ end
|
|||
|
||||
-- Execute script with clean environment
|
||||
function __execute_script(fn, ctx)
|
||||
__http_response = nil
|
||||
-- Clear previous responses
|
||||
__http_responses[1] = nil
|
||||
|
||||
local env = __create_env(ctx)
|
||||
-- Create environment with metatable inheriting from _G
|
||||
local env = setmetatable({}, {__index = _G})
|
||||
|
||||
-- Add context if provided
|
||||
if ctx then
|
||||
env.ctx = ctx
|
||||
end
|
||||
|
||||
print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {}))
|
||||
|
||||
-- Initialize local session variables in the environment
|
||||
env.__session_data = ctx.session_data or {}
|
||||
env.__session_id = ctx.session_id
|
||||
env.__session_modified = false
|
||||
|
||||
-- Add proper require function to this environment
|
||||
if __setup_require then
|
||||
__setup_require(env)
|
||||
end
|
||||
|
||||
-- Set environment for function
|
||||
setfenv(fn, env)
|
||||
|
||||
-- Execute with protected call
|
||||
local ok, result = pcall(fn)
|
||||
if not ok then
|
||||
error(result, 0)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- Ensure __http_response exists, then return it
|
||||
function __ensure_response()
|
||||
if not __http_response then
|
||||
__http_response = {}
|
||||
-- If session was modified, add to response
|
||||
if env.__session_modified then
|
||||
__http_responses[1] = __http_responses[1] or {}
|
||||
__http_responses[1].session_data = env.__session_data
|
||||
__http_responses[1].session_id = env.__session_id
|
||||
__http_responses[1].session_modified = true
|
||||
end
|
||||
return __http_response
|
||||
|
||||
print("SESSION MODIFIED:", env.__session_modified)
|
||||
print("FINAL DATA:", util.json_encode(env.__session_data or {}))
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- HTTP MODULE
|
||||
-- ======================================================================
|
||||
|
||||
-- HTTP module implementation
|
||||
local http = {
|
||||
-- Set HTTP status code
|
||||
set_status = function(code)
|
||||
if type(code) ~= "number" then
|
||||
error("http.set_status: status code must be a number", 2)
|
||||
end
|
||||
-- Set HTTP status code
|
||||
set_status = function(code)
|
||||
if type(code) ~= "number" then
|
||||
error("http.set_status: status code must be a number", 2)
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.status = code
|
||||
end,
|
||||
local resp = __http_responses[1] or {}
|
||||
resp.status = code
|
||||
__http_responses[1] = resp
|
||||
end,
|
||||
|
||||
-- Set HTTP header
|
||||
set_header = function(name, value)
|
||||
if type(name) ~= "string" or type(value) ~= "string" then
|
||||
error("http.set_header: name and value must be strings", 2)
|
||||
end
|
||||
-- Set HTTP header
|
||||
set_header = function(name, value)
|
||||
if type(name) ~= "string" or type(value) ~= "string" then
|
||||
error("http.set_header: name and value must be strings", 2)
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.headers = resp.headers or {}
|
||||
resp.headers[name] = value
|
||||
end,
|
||||
local resp = __http_responses[1] or {}
|
||||
resp.headers = resp.headers or {}
|
||||
resp.headers[name] = value
|
||||
__http_responses[1] = resp
|
||||
end,
|
||||
|
||||
-- Set content type; set_header helper
|
||||
set_content_type = function(content_type)
|
||||
http.set_header("Content-Type", content_type)
|
||||
end,
|
||||
-- Set content type; set_header helper
|
||||
set_content_type = function(content_type)
|
||||
http.set_header("Content-Type", content_type)
|
||||
end,
|
||||
|
||||
-- Set metadata (arbitrary data to be returned with response)
|
||||
set_metadata = function(key, value)
|
||||
if type(key) ~= "string" then
|
||||
error("http.set_metadata: key must be a string", 2)
|
||||
end
|
||||
-- Set metadata (arbitrary data to be returned with response)
|
||||
set_metadata = function(key, value)
|
||||
if type(key) ~= "string" then
|
||||
error("http.set_metadata: key must be a string", 2)
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.metadata = resp.metadata or {}
|
||||
resp.metadata[key] = value
|
||||
end,
|
||||
local resp = __http_responses[1] or {}
|
||||
resp.metadata = resp.metadata or {}
|
||||
resp.metadata[key] = value
|
||||
__http_responses[1] = resp
|
||||
end,
|
||||
|
||||
-- HTTP client submodule
|
||||
client = {
|
||||
-- Generic request function
|
||||
request = function(method, url, body, options)
|
||||
if type(method) ~= "string" then
|
||||
error("http.client.request: method must be a string", 2)
|
||||
end
|
||||
if type(url) ~= "string" then
|
||||
error("http.client.request: url must be a string", 2)
|
||||
end
|
||||
-- HTTP client submodule
|
||||
client = {
|
||||
-- Generic request function
|
||||
request = function(method, url, body, options)
|
||||
if type(method) ~= "string" then
|
||||
error("http.client.request: method must be a string", 2)
|
||||
end
|
||||
if type(url) ~= "string" then
|
||||
error("http.client.request: url must be a string", 2)
|
||||
end
|
||||
|
||||
-- Call native implementation
|
||||
local result = __http_request(method, url, body, options)
|
||||
return result
|
||||
end,
|
||||
-- Call native implementation
|
||||
local result = __http_request(method, url, body, options)
|
||||
return result
|
||||
end,
|
||||
|
||||
-- Shorthand function to directly get JSON
|
||||
get_json = function(url, options)
|
||||
options = options or {}
|
||||
local response = http.client.get(url, options)
|
||||
if response.ok and response.json then
|
||||
return response.json
|
||||
end
|
||||
return nil, response
|
||||
end,
|
||||
-- Simple GET request
|
||||
get = function(url, options)
|
||||
return http.client.request("GET", url, nil, options)
|
||||
end,
|
||||
|
||||
-- Utility to build a URL with query parameters
|
||||
build_url = function(base_url, params)
|
||||
if not params or type(params) ~= "table" then
|
||||
return base_url
|
||||
end
|
||||
-- Simple POST request with automatic content-type
|
||||
post = function(url, body, options)
|
||||
options = options or {}
|
||||
return http.client.request("POST", url, body, options)
|
||||
end,
|
||||
|
||||
local query = {}
|
||||
for k, v in pairs(params) do
|
||||
if type(v) == "table" then
|
||||
for _, item in ipairs(v) do
|
||||
table.insert(query, k .. "=" .. tostring(item))
|
||||
end
|
||||
else
|
||||
table.insert(query, k .. "=" .. tostring(v))
|
||||
end
|
||||
end
|
||||
-- Simple PUT request with automatic content-type
|
||||
put = function(url, body, options)
|
||||
options = options or {}
|
||||
return http.client.request("PUT", url, body, options)
|
||||
end,
|
||||
|
||||
if #query > 0 then
|
||||
if base_url:find("?") then
|
||||
return base_url .. "&" .. table.concat(query, "&")
|
||||
else
|
||||
return base_url .. "?" .. table.concat(query, "&")
|
||||
end
|
||||
end
|
||||
-- Simple DELETE request
|
||||
delete = function(url, options)
|
||||
return http.client.request("DELETE", url, nil, options)
|
||||
end,
|
||||
|
||||
return base_url
|
||||
end
|
||||
}
|
||||
-- Simple PATCH request
|
||||
patch = function(url, body, options)
|
||||
options = options or {}
|
||||
return http.client.request("PATCH", url, body, options)
|
||||
end,
|
||||
|
||||
-- Simple HEAD request
|
||||
head = function(url, options)
|
||||
options = options or {}
|
||||
return http.client.request("HEAD", url, nil, options)
|
||||
end,
|
||||
|
||||
-- Simple OPTIONS request
|
||||
options = function(url, options)
|
||||
return http.client.request("OPTIONS", url, nil, options)
|
||||
end,
|
||||
|
||||
-- Shorthand function to directly get JSON
|
||||
get_json = function(url, options)
|
||||
options = options or {}
|
||||
local response = http.client.get(url, options)
|
||||
if response.ok and response.json then
|
||||
return response.json
|
||||
end
|
||||
return nil, response
|
||||
end,
|
||||
|
||||
-- Utility to build a URL with query parameters
|
||||
build_url = function(base_url, params)
|
||||
if not params or type(params) ~= "table" then
|
||||
return base_url
|
||||
end
|
||||
|
||||
local query = {}
|
||||
for k, v in pairs(params) do
|
||||
if type(v) == "table" then
|
||||
for _, item in ipairs(v) do
|
||||
table.insert(query, k .. "=" .. tostring(item))
|
||||
end
|
||||
else
|
||||
table.insert(query, k .. "=" .. tostring(v))
|
||||
end
|
||||
end
|
||||
|
||||
if #query > 0 then
|
||||
if base_url:find("?") then
|
||||
return base_url .. "&" .. table.concat(query, "&")
|
||||
else
|
||||
return base_url .. "?" .. table.concat(query, "&")
|
||||
end
|
||||
end
|
||||
|
||||
return base_url
|
||||
end
|
||||
}
|
||||
}
|
||||
|
||||
local function make_method(method, needs_body)
|
||||
return function(url, body_or_options, options)
|
||||
if needs_body then
|
||||
options = options or {}
|
||||
return http.client.request(method, url, body_or_options, options)
|
||||
else
|
||||
body_or_options = body_or_options or {}
|
||||
return http.client.request(method, url, nil, body_or_options)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
http.client.get = make_method("GET", false)
|
||||
http.client.delete = make_method("DELETE", false)
|
||||
http.client.head = make_method("HEAD", false)
|
||||
http.client.options = make_method("OPTIONS", false)
|
||||
http.client.post = make_method("POST", true)
|
||||
http.client.put = make_method("PUT", true)
|
||||
http.client.patch = make_method("PATCH", true)
|
||||
|
||||
-- ======================================================================
|
||||
-- COOKIE MODULE
|
||||
-- ======================================================================
|
||||
|
||||
-- Cookie module implementation
|
||||
local cookie = {
|
||||
-- Set a cookie
|
||||
set = function(name, value, options)
|
||||
|
@ -182,10 +238,15 @@ local cookie = {
|
|||
error("cookie.set: name must be a string", 2)
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
-- Get or create response
|
||||
local resp = __http_responses[1] or {}
|
||||
resp.cookies = resp.cookies or {}
|
||||
__http_responses[1] = resp
|
||||
|
||||
-- Handle options as table
|
||||
local opts = options or {}
|
||||
|
||||
-- Create cookie table
|
||||
local cookie = {
|
||||
name = name,
|
||||
value = value or "",
|
||||
|
@ -193,6 +254,7 @@ local cookie = {
|
|||
domain = opts.domain
|
||||
}
|
||||
|
||||
-- Handle expiry
|
||||
if opts.expires then
|
||||
if type(opts.expires) == "number" then
|
||||
if opts.expires > 0 then
|
||||
|
@ -207,28 +269,14 @@ local cookie = {
|
|||
end
|
||||
end
|
||||
|
||||
-- Security flags
|
||||
cookie.secure = (opts.secure ~= false)
|
||||
cookie.http_only = (opts.http_only ~= false)
|
||||
|
||||
if opts.same_site then
|
||||
local valid_values = {none = true, lax = true, strict = true}
|
||||
local same_site = string.lower(opts.same_site)
|
||||
-- Store in cookies table
|
||||
local n = #resp.cookies + 1
|
||||
resp.cookies[n] = cookie
|
||||
|
||||
if not valid_values[same_site] then
|
||||
error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
|
||||
end
|
||||
|
||||
-- If SameSite=None, the cookie must be secure
|
||||
if same_site == "none" and not cookie.secure then
|
||||
cookie.secure = true
|
||||
end
|
||||
|
||||
cookie.same_site = opts.same_site
|
||||
else
|
||||
cookie.same_site = "Lax"
|
||||
end
|
||||
|
||||
table.insert(resp.cookies, cookie)
|
||||
return true
|
||||
end,
|
||||
|
||||
|
@ -238,12 +286,15 @@ local cookie = {
|
|||
error("cookie.get: name must be a string", 2)
|
||||
end
|
||||
|
||||
-- Access values directly from current environment
|
||||
local env = getfenv(2)
|
||||
|
||||
-- Check if context exists and has cookies
|
||||
if env.ctx and env.ctx.cookies then
|
||||
return env.ctx.cookies[name]
|
||||
end
|
||||
|
||||
-- If context has request_cookies map
|
||||
if env.ctx and env.ctx._request_cookies then
|
||||
return env.ctx._request_cookies[name]
|
||||
end
|
||||
|
@ -257,10 +308,185 @@ local cookie = {
|
|||
error("cookie.remove: name must be a string", 2)
|
||||
end
|
||||
|
||||
-- Create an expired cookie
|
||||
return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- SESSION MODULE
|
||||
-- ======================================================================
|
||||
|
||||
local session = {
|
||||
-- Get session value
|
||||
get = function(key)
|
||||
if type(key) ~= "string" then
|
||||
error("session.get: key must be a string", 2)
|
||||
end
|
||||
local env = getfenv(2)
|
||||
return env.__session_data and env.__session_data[key]
|
||||
end,
|
||||
|
||||
-- Set session value
|
||||
set = function(key, value)
|
||||
if type(key) ~= "string" then
|
||||
error("session.set: key must be a string", 2)
|
||||
end
|
||||
|
||||
local env = getfenv(2)
|
||||
print("SET ENV:", tostring(env)) -- Debug the environment
|
||||
|
||||
if not env.__session_data then
|
||||
env.__session_data = {}
|
||||
print("CREATED NEW SESSION TABLE")
|
||||
end
|
||||
|
||||
env.__session_data[key] = value
|
||||
env.__session_modified = true
|
||||
print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)
|
||||
return true
|
||||
end,
|
||||
|
||||
-- Delete session value
|
||||
delete = function(key)
|
||||
if type(key) ~= "string" then
|
||||
error("session.delete: key must be a string", 2)
|
||||
end
|
||||
|
||||
local env = getfenv(2)
|
||||
if env.__session_data and env.__session_data[key] ~= nil then
|
||||
env.__session_data[key] = nil
|
||||
env.__session_modified = true
|
||||
end
|
||||
return true
|
||||
end,
|
||||
|
||||
-- Clear all session data
|
||||
clear = function()
|
||||
local env = getfenv(2)
|
||||
if env.__session_data and next(env.__session_data) then
|
||||
env.__session_data = {}
|
||||
env.__session_modified = true
|
||||
end
|
||||
return true
|
||||
end,
|
||||
|
||||
-- Get session ID
|
||||
get_id = function()
|
||||
local env = getfenv(2)
|
||||
return env.__session_id or ""
|
||||
end,
|
||||
|
||||
-- Get all session data
|
||||
get_all = function()
|
||||
local env = getfenv(2)
|
||||
return env.__session_data or {}
|
||||
end,
|
||||
|
||||
-- Check if session has key
|
||||
has = function(key)
|
||||
if type(key) ~= "string" then
|
||||
error("session.has: key must be a string", 2)
|
||||
end
|
||||
local env = getfenv(2)
|
||||
return env.__session_data ~= nil and env.__session_data[key] ~= nil
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- CSRF MODULE
|
||||
-- ======================================================================
|
||||
|
||||
-- CSRF protection module
|
||||
local csrf = {
|
||||
-- Session key where the token is stored
|
||||
TOKEN_KEY = "_csrf_token",
|
||||
|
||||
-- Default form field name
|
||||
DEFAULT_FIELD = "csrf",
|
||||
|
||||
-- Generate a new CSRF token and store it in the session
|
||||
generate = function(length)
|
||||
-- Default length is 32 characters
|
||||
length = length or 32
|
||||
|
||||
if length < 16 then
|
||||
-- Enforce minimum security
|
||||
length = 16
|
||||
end
|
||||
|
||||
-- Check if we have a session module
|
||||
if not session then
|
||||
error("CSRF protection requires the session module", 2)
|
||||
end
|
||||
|
||||
local token = __generate_token(length)
|
||||
session.set(csrf.TOKEN_KEY, token)
|
||||
return token
|
||||
end,
|
||||
|
||||
-- Get the current token or generate a new one
|
||||
token = function()
|
||||
-- Get from session if exists
|
||||
local token = session.get(csrf.TOKEN_KEY)
|
||||
|
||||
-- Generate if needed
|
||||
if not token then
|
||||
token = csrf.generate()
|
||||
end
|
||||
|
||||
return token
|
||||
end,
|
||||
|
||||
-- Generate a hidden form field with the CSRF token
|
||||
field = function(field_name)
|
||||
field_name = field_name or csrf.DEFAULT_FIELD
|
||||
local token = csrf.token()
|
||||
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
|
||||
end,
|
||||
|
||||
-- Verify a given token against the session token
|
||||
verify = function(token, field_name)
|
||||
field_name = field_name or csrf.DEFAULT_FIELD
|
||||
|
||||
local env = getfenv(2)
|
||||
|
||||
local form = nil
|
||||
if env.ctx and env.ctx._request_form then
|
||||
form = env.ctx._request_form
|
||||
elseif env.ctx and env.ctx.form then
|
||||
form = env.ctx.form
|
||||
else
|
||||
return false
|
||||
end
|
||||
|
||||
token = token or form[field_name]
|
||||
if not token then
|
||||
return false
|
||||
end
|
||||
|
||||
local session_token = session.get(csrf.TOKEN_KEY)
|
||||
if not session_token then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Constant-time comparison to prevent timing attacks
|
||||
if #token ~= #session_token then
|
||||
return false
|
||||
end
|
||||
|
||||
local result = true
|
||||
for i = 1, #token do
|
||||
if token:sub(i, i) ~= session_token:sub(i, i) then
|
||||
result = false
|
||||
-- Don't break early - continue to prevent timing attacks
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- UTIL MODULE
|
||||
-- ======================================================================
|
||||
|
@ -349,6 +575,9 @@ local util = {
|
|||
-- REGISTER MODULES GLOBALLY
|
||||
-- ======================================================================
|
||||
|
||||
-- Install modules in global scope
|
||||
_G.http = http
|
||||
_G.cookie = cookie
|
||||
_G.session = session
|
||||
_G.csrf = csrf
|
||||
_G.util = util
|
||||
|
|
Loading…
Reference in New Issue
Block a user