diff --git a/core/Moonshark.go b/core/Moonshark.go index 752c238..d4a103f 100644 --- a/core/Moonshark.go +++ b/core/Moonshark.go @@ -185,7 +185,7 @@ func (s *Moonshark) initRunner() error { runner.WithPoolSize(s.Config.Runner.PoolSize), runner.WithLibDirs(s.Config.Dirs.Libs...), runner.WithSessionManager(sessionManager), - runner.WithCSRFProtection(), + http.WithCSRFProtection(), } // Add debug option conditionally diff --git a/core/http/Csrf.go b/core/http/Csrf.go index 1830b78..7347ad4 100644 --- a/core/http/Csrf.go +++ b/core/http/Csrf.go @@ -1,12 +1,119 @@ package http import ( + "Moonshark/core/runner" "Moonshark/core/utils" "Moonshark/core/utils/logger" + "crypto/subtle" "github.com/valyala/fasthttp" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) +// ValidateCSRFToken checks if the CSRF token is valid for a request +func ValidateCSRFToken(state *luajit.State, ctx *runner.Context) bool { + // Only validate for form submissions + method, ok := ctx.Get("method").(string) + if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") { + return true + } + + // Get form data + formData, ok := ctx.Get("form").(map[string]any) + if !ok || formData == nil { + logger.Warning("CSRF validation failed: no form data") + return false + } + + // Get token from form + formToken, ok := formData["csrf"].(string) + if !ok || formToken == "" { + logger.Warning("CSRF validation failed: no token in form") + return false + } + + // Get session token + state.GetGlobal("session") + if state.IsNil(-1) { + state.Pop(1) + logger.Warning("CSRF validation failed: session module not available") + return false + } + + state.GetField(-1, "get") + if !state.IsFunction(-1) { + state.Pop(2) + logger.Warning("CSRF validation failed: session.get not available") + return false + } + + state.PushCopy(-1) // Duplicate function + state.PushString("_csrf_token") + + if err := state.Call(1, 1); err != nil { + state.Pop(3) // Pop error, function and session table + logger.Warning("CSRF validation failed: %v", err) + return false + } + + if state.IsNil(-1) { + state.Pop(3) // Pop nil, function and session table + logger.Warning("CSRF validation failed: no token in session") + return false + } + + sessionToken := state.ToString(-1) + state.Pop(3) // Pop token, function and session table + + // Constant-time comparison to prevent timing attacks + return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1 +} + +// WithCSRFProtection creates a runner option to add CSRF protection +func WithCSRFProtection() runner.RunnerOption { + return func(r *runner.Runner) { + r.AddInitHook(func(state *luajit.State, ctx *runner.Context) error { + // Get request method + method, ok := ctx.Get("method").(string) + if !ok { + return nil + } + + // Only validate for form submissions + if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" { + return nil + } + + // Check for form data + form, ok := ctx.Get("form").(map[string]any) + if !ok || form == nil { + return nil + } + + // Validate CSRF token + if !ValidateCSRFToken(state, ctx) { + return ErrCSRFValidationFailed + } + + return nil + }) + } +} + +// Error for CSRF validation failure +var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"} + +// CSRFError represents a CSRF validation error +type CSRFError struct { + message string +} + +// Error implements the error interface +func (e *CSRFError) Error() string { + return e.message +} + // HandleCSRFError handles a CSRF validation error func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) { method := string(ctx.Method()) diff --git a/core/http/Server.go b/core/http/Server.go index 6b87c0a..fa23a49 100644 --- a/core/http/Server.go +++ b/core/http/Server.go @@ -8,6 +8,7 @@ import ( "Moonshark/core/metadata" "Moonshark/core/routers" "Moonshark/core/runner" + "Moonshark/core/runner/sandbox" "Moonshark/core/utils" "Moonshark/core/utils/config" "Moonshark/core/utils/logger" @@ -226,7 +227,7 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip // Special handling for CSRF error if err != nil { - if csrfErr, ok := err.(*runner.CSRFError); ok { + if csrfErr, ok := err.(*CSRFError); ok { logger.Warning("CSRF error executing Lua route: %v", csrfErr) HandleCSRFError(ctx, s.errorConfig) return @@ -258,8 +259,8 @@ func writeResponse(ctx *fasthttp.RequestCtx, result any) { } // Check for HTTPResponse type - if httpResp, ok := result.(*runner.HTTPResponse); ok { - defer runner.ReleaseResponse(httpResp) + if httpResp, ok := result.(*sandbox.HTTPResponse); ok { + defer sandbox.ReleaseResponse(httpResp) // Set response headers for name, value := range httpResp.Headers { diff --git a/core/runner/Context.go b/core/runner/Context.go index 673c097..9586030 100644 --- a/core/runner/Context.go +++ b/core/runner/Context.go @@ -3,6 +3,8 @@ package runner import ( "sync" + "maps" + "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) @@ -24,7 +26,7 @@ type Context struct { // Context pool to reduce allocations var contextPool = sync.Pool{ - New: func() interface{} { + New: func() any { return &Context{ Values: make(map[string]any, 16), // Pre-allocate with reasonable capacity } @@ -115,9 +117,7 @@ func (c *Context) All() map[string]any { defer c.mu.RUnlock() result := make(map[string]any, len(c.Values)) - for k, v := range c.Values { - result[k] = v - } + maps.Copy(result, c.Values) return result } diff --git a/core/runner/CookieModule.go b/core/runner/CookieModule.go deleted file mode 100644 index 356217b..0000000 --- a/core/runner/CookieModule.go +++ /dev/null @@ -1,187 +0,0 @@ -package runner - -import ( - "time" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" - "github.com/valyala/fasthttp" -) - -// extractCookie grabs cookies from the Lua state -func extractCookie(state *luajit.State) *fasthttp.Cookie { - cookie := fasthttp.AcquireCookie() - - // Get name - state.GetField(-1, "name") - if !state.IsString(-1) { - state.Pop(1) - fasthttp.ReleaseCookie(cookie) - return nil // Name is required - } - cookie.SetKey(state.ToString(-1)) - state.Pop(1) - - // Get value - state.GetField(-1, "value") - if state.IsString(-1) { - cookie.SetValue(state.ToString(-1)) - } - state.Pop(1) - - // Get path - state.GetField(-1, "path") - if state.IsString(-1) { - cookie.SetPath(state.ToString(-1)) - } else { - cookie.SetPath("/") // Default path - } - state.Pop(1) - - // Get domain - state.GetField(-1, "domain") - if state.IsString(-1) { - cookie.SetDomain(state.ToString(-1)) - } - state.Pop(1) - - // Get expires - state.GetField(-1, "expires") - if state.IsNumber(-1) { - expiry := int64(state.ToNumber(-1)) - cookie.SetExpire(time.Unix(expiry, 0)) - } - state.Pop(1) - - // Get max age - state.GetField(-1, "max_age") - if state.IsNumber(-1) { - cookie.SetMaxAge(int(state.ToNumber(-1))) - } - state.Pop(1) - - // Get secure - state.GetField(-1, "secure") - if state.IsBoolean(-1) { - cookie.SetSecure(state.ToBoolean(-1)) - } - state.Pop(1) - - // Get http only - state.GetField(-1, "http_only") - if state.IsBoolean(-1) { - cookie.SetHTTPOnly(state.ToBoolean(-1)) - } - state.Pop(1) - - return cookie -} - -// LuaCookieModule provides cookie functionality to Lua scripts -const LuaCookieModule = ` --- Cookie module implementation -local cookie = { - -- Set a cookie - set = function(name, value, options, ...) - if type(name) ~= "string" then - error("cookie.set: name must be a string", 2) - end - - -- Get or create response - local resp = __http_responses[1] or {} - resp.cookies = resp.cookies or {} - __http_responses[1] = resp - - -- Handle options as table or legacy params - local opts = {} - if type(options) == "table" then - opts = options - elseif options ~= nil then - -- Legacy support: options is actually 'expires' - opts.expires = options - -- Check for other legacy params (4th-7th args) - local args = {...} - if args[1] then opts.path = args[1] end - if args[2] then opts.domain = args[2] end - if args[3] then opts.secure = args[3] end - if args[4] ~= nil then opts.http_only = args[4] end - end - - -- Create cookie table - local cookie = { - name = name, - value = value or "", - path = opts.path or "/", - domain = opts.domain - } - - -- Handle expiry - if opts.expires then - if type(opts.expires) == "number" then - if opts.expires > 0 then - cookie.max_age = opts.expires - local now = os.time() - cookie.expires = now + opts.expires - elseif opts.expires < 0 then - cookie.expires = 1 - cookie.max_age = 0 - else - -- opts.expires == 0: Session cookie - -- Do nothing (omitting both expires and max-age creates a session cookie) - end - end - end - - -- Security flags - cookie.secure = (opts.secure ~= false) - cookie.http_only = (opts.http_only ~= false) - - -- Store in cookies table - local n = #resp.cookies + 1 - resp.cookies[n] = cookie - - return true - end, - - -- Get a cookie value - get = function(name) - if type(name) ~= "string" then - error("cookie.get: name must be a string", 2) - end - - -- Access values directly from current environment - local env = getfenv(2) - - -- Check if context exists and has cookies - if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then - return tostring(env.ctx.cookies[name]) - end - - return nil - end, - - -- Remove a cookie - remove = function(name, path, domain) - if type(name) ~= "string" then - error("cookie.remove: name must be a string", 2) - end - - -- Create an expired cookie - return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain}) - end -} - --- Install cookie module -_G.cookie = cookie - --- Make sure the cookie module is accessible in sandbox -if __env_system and __env_system.base_env then - __env_system.base_env.cookie = cookie -end -` - -// CookieModuleInitFunc returns an initializer for the cookie module -func CookieModuleInitFunc() StateInitFunc { - return func(state *luajit.State) error { - return state.DoString(LuaCookieModule) - } -} diff --git a/core/runner/Cookies.go b/core/runner/Cookies.go new file mode 100644 index 0000000..b1258e4 --- /dev/null +++ b/core/runner/Cookies.go @@ -0,0 +1,77 @@ +package runner + +import ( + "time" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "github.com/valyala/fasthttp" +) + +// extractCookie grabs cookies from the Lua state +func extractCookie(state *luajit.State) *fasthttp.Cookie { + cookie := fasthttp.AcquireCookie() + + // Get name + state.GetField(-1, "name") + if !state.IsString(-1) { + state.Pop(1) + fasthttp.ReleaseCookie(cookie) + return nil // Name is required + } + cookie.SetKey(state.ToString(-1)) + state.Pop(1) + + // Get value + state.GetField(-1, "value") + if state.IsString(-1) { + cookie.SetValue(state.ToString(-1)) + } + state.Pop(1) + + // Get path + state.GetField(-1, "path") + if state.IsString(-1) { + cookie.SetPath(state.ToString(-1)) + } else { + cookie.SetPath("/") // Default path + } + state.Pop(1) + + // Get domain + state.GetField(-1, "domain") + if state.IsString(-1) { + cookie.SetDomain(state.ToString(-1)) + } + state.Pop(1) + + // Get expires + state.GetField(-1, "expires") + if state.IsNumber(-1) { + expiry := int64(state.ToNumber(-1)) + cookie.SetExpire(time.Unix(expiry, 0)) + } + state.Pop(1) + + // Get max age + state.GetField(-1, "max_age") + if state.IsNumber(-1) { + cookie.SetMaxAge(int(state.ToNumber(-1))) + } + state.Pop(1) + + // Get secure + state.GetField(-1, "secure") + if state.IsBoolean(-1) { + cookie.SetSecure(state.ToBoolean(-1)) + } + state.Pop(1) + + // Get http only + state.GetField(-1, "http_only") + if state.IsBoolean(-1) { + cookie.SetHTTPOnly(state.ToBoolean(-1)) + } + state.Pop(1) + + return cookie +} diff --git a/core/runner/CoreModules.go b/core/runner/CoreModules.go index bf25778..98f18dc 100644 --- a/core/runner/CoreModules.go +++ b/core/runner/CoreModules.go @@ -1,6 +1,7 @@ package runner import ( + "Moonshark/core/runner/sandbox" "Moonshark/core/utils/logger" "fmt" "strings" @@ -265,17 +266,21 @@ func init() { GlobalRegistry.EnableDebug() // Enable debugging by default logger.Debug("[ModuleRegistry] Registering core modules...") - GlobalRegistry.Register("util", UtilModuleInitFunc()) - GlobalRegistry.Register("http", HTTPModuleInitFunc()) - GlobalRegistry.RegisterWithDependencies("cookie", CookieModuleInitFunc(), []string{"http"}) - GlobalRegistry.RegisterWithDependencies("csrf", CSRFModuleInitFunc(), []string{"util"}) + // Register core modules - these now point to the sandbox implementations + GlobalRegistry.Register("util", func(state *luajit.State) error { + return sandbox.UtilModuleInitFunc()(state) + }) + + GlobalRegistry.Register("http", func(state *luajit.State) error { + return sandbox.HTTPModuleInitFunc()(state) + }) // Set explicit initialization order GlobalRegistry.SetInitOrder([]string{ - "util", // First: core utilities - "http", // Second: HTTP functionality - "cookie", // Third: Cookie functionality (uses HTTP) - "csrf", // Fourth: CSRF protection (uses go and possibly session) + "util", // First: core utilities + "http", // Second: HTTP functionality + "session", // Third: Session functionality + "csrf", // Fourth: CSRF protection }) logger.DebugCont("Core modules registered successfully") diff --git a/core/runner/CsrfModule.go b/core/runner/CsrfModule.go deleted file mode 100644 index 683f35f..0000000 --- a/core/runner/CsrfModule.go +++ /dev/null @@ -1,219 +0,0 @@ -package runner - -import ( - "crypto/subtle" - - "Moonshark/core/utils/logger" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// LuaCSRFModule provides CSRF protection functionality to Lua scripts -const LuaCSRFModule = ` --- CSRF protection module -local csrf = { - -- Session key where the token is stored - TOKEN_KEY = "_csrf_token", - - -- Default form field name - DEFAULT_FIELD = "csrf", - - -- Generate a new CSRF token and store it in the session - generate = function(length) - -- Default length is 32 characters - length = length or 32 - - if length < 16 then - -- Enforce minimum security - length = 16 - end - - -- Check if we have a session module - if not session then - error("CSRF protection requires the session module", 2) - end - - local token = util.generate_token(length) - session.set(csrf.TOKEN_KEY, token) - return token - end, - - -- Get the current token or generate a new one - token = function() - -- Get from session if exists - local token = session.get(csrf.TOKEN_KEY) - - -- Generate if needed - if not token then - token = csrf.generate() - end - - return token - end, - - -- Generate a hidden form field with the CSRF token - field = function(field_name) - field_name = field_name or csrf.DEFAULT_FIELD - local token = csrf.token() - return string.format('', field_name, token) - end, - - -- Verify a given token against the session token - verify = function(token, field_name) - field_name = field_name or csrf.DEFAULT_FIELD - - local env = getfenv(2) - - local form = nil - if env.ctx and env.ctx.form then - form = env.ctx.form - else - return false - end - - token = token or form[field_name] - if not token then - return false - end - - local session_token = session.get(csrf.TOKEN_KEY) - if not session_token then - return false - end - - -- Constant-time comparison to prevent timing attacks - -- This is safe since Lua strings are immutable - if #token ~= #session_token then - return false - end - - local result = true - for i = 1, #token do - if token:sub(i, i) ~= session_token:sub(i, i) then - result = false - -- Don't break early - continue to prevent timing attacks - end - end - - return result - end -} - --- Install CSRF module -_G.csrf = csrf - --- Make sure the CSRF module is accessible in sandbox -if __env_system and __env_system.base_env then - __env_system.base_env.csrf = csrf -end -` - -// CSRFModuleInitFunc returns an initializer for the CSRF module -func CSRFModuleInitFunc() StateInitFunc { - return func(state *luajit.State) error { - return state.DoString(LuaCSRFModule) - } -} - -// ValidateCSRFToken checks if the CSRF token is valid for a request -func ValidateCSRFToken(state *luajit.State, ctx *Context) bool { - // Only validate for form submissions - method, ok := ctx.Get("method").(string) - if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") { - return true - } - - // Get form data - formData, ok := ctx.Get("form").(map[string]any) - if !ok || formData == nil { - logger.Warning("CSRF validation failed: no form data") - return false - } - - // Get token from form - formToken, ok := formData["csrf"].(string) - if !ok || formToken == "" { - logger.Warning("CSRF validation failed: no token in form") - return false - } - - // Get session token - state.GetGlobal("session") - if state.IsNil(-1) { - state.Pop(1) - logger.Warning("CSRF validation failed: session module not available") - return false - } - - state.GetField(-1, "get") - if !state.IsFunction(-1) { - state.Pop(2) - logger.Warning("CSRF validation failed: session.get not available") - return false - } - - state.PushCopy(-1) // Duplicate function - state.PushString("_csrf_token") - - if err := state.Call(1, 1); err != nil { - state.Pop(3) // Pop error, function and session table - logger.Warning("CSRF validation failed: %v", err) - return false - } - - if state.IsNil(-1) { - state.Pop(3) // Pop nil, function and session table - logger.Warning("CSRF validation failed: no token in session") - return false - } - - sessionToken := state.ToString(-1) - state.Pop(3) // Pop token, function and session table - - // Constant-time comparison to prevent timing attacks - return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1 -} - -// WithCSRFProtection creates a runner option to add CSRF protection -func WithCSRFProtection() RunnerOption { - return func(r *Runner) { - r.AddInitHook(func(state *luajit.State, ctx *Context) error { - // Get request method - method, ok := ctx.Get("method").(string) - if !ok { - return nil - } - - // Only validate for form submissions - if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" { - return nil - } - - // Check for form data - form, ok := ctx.Get("form").(map[string]any) - if !ok || form == nil { - return nil - } - - // Validate CSRF token - if !ValidateCSRFToken(state, ctx) { - return ErrCSRFValidationFailed - } - - return nil - }) - } -} - -// Error for CSRF validation failure -var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"} - -// CSRFError represents a CSRF validation error -type CSRFError struct { - message string -} - -// Error implements the error interface -func (e *CSRFError) Error() string { - return e.message -} diff --git a/core/runner/Runner.go b/core/runner/Runner.go index 6455a4b..de6a647 100644 --- a/core/runner/Runner.go +++ b/core/runner/Runner.go @@ -12,6 +12,7 @@ import ( "github.com/panjf2000/ants/v2" "github.com/valyala/bytebufferpool" + "Moonshark/core/runner/sandbox" "Moonshark/core/utils/logger" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" @@ -30,11 +31,11 @@ type RunnerOption func(*Runner) // State wraps a Lua state with its sandbox type State struct { - L *luajit.State // The Lua state - sandbox *Sandbox // Associated sandbox - index int // Index for debugging - inUse bool // Whether the state is currently in use - initTime time.Time // When this state was initialized + L *luajit.State // The Lua state + sandbox *sandbox.Sandbox // Associated sandbox + index int // Index for debugging + inUse bool // Whether the state is currently in use + initTime time.Time // When this state was initialized } // InitHook runs before executing a script @@ -217,7 +218,7 @@ func (r *Runner) createState(index int) (*State, error) { } // Create sandbox - sandbox := NewSandbox() + sandbox := sandbox.NewSandbox() if r.debug && verbose { sandbox.EnableDebug() } diff --git a/core/runner/SessionModule.go b/core/runner/SessionModule.go deleted file mode 100644 index bdea3a7..0000000 --- a/core/runner/SessionModule.go +++ /dev/null @@ -1,177 +0,0 @@ -package runner - -import ( - "Moonshark/core/utils/logger" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// LuaSessionModule provides session functionality to Lua scripts -const LuaSessionModule = ` --- Global table to store session data -__session_data = __session_data or {} -__session_id = __session_id or nil -__session_modified = false - --- Session module implementation -local session = { - -- Get a session value - get = function(key) - if type(key) ~= "string" then - error("session.get: key must be a string", 2) - end - - if __session_data and __session_data[key] then - return __session_data[key] - end - - return nil - end, - - -- Set a session value - set = function(key, value) - if type(key) ~= "string" then - error("session.set: key must be a string", 2) - end - - -- Ensure session data table exists - __session_data = __session_data or {} - - -- Store value - __session_data[key] = value - - -- Mark session as modified - __session_modified = true - - return true - end, - - -- Delete a session value - delete = function(key) - if type(key) ~= "string" then - error("session.delete: key must be a string", 2) - end - - if __session_data then - __session_data[key] = nil - __session_modified = true - end - - return true - end, - - -- Clear all session data - clear = function() - __session_data = {} - __session_modified = true - return true - end, - - -- Get the session ID - get_id = function() - return __session_id or nil - end, - - -- Get all session data - get_all = function() - local result = {} - for k, v in pairs(__session_data or {}) do - result[k] = v - end - return result - end, - - -- Check if session has a key - has = function(key) - if type(key) ~= "string" then - error("session.has: key must be a string", 2) - end - - return __session_data and __session_data[key] ~= nil - end -} - --- Install session module -_G.session = session - --- Make sure the session module is accessible in sandbox -if __env_system and __env_system.base_env then - __env_system.base_env.session = session -end - --- Hook into script execution to preserve session state -local old_execute_script = __execute_script -if old_execute_script then - __execute_script = function(fn, ctx) - -- Reset modification flag at the start of request - __session_modified = false - - -- Execute original function - return old_execute_script(fn, ctx) - end -end -` - -// GetSessionData extracts session data from Lua state -func GetSessionData(state *luajit.State) (string, map[string]any, bool) { - // Check if session was modified - state.GetGlobal("__session_modified") - modified := state.ToBoolean(-1) - state.Pop(1) - - if !modified { - return "", nil, false - } - - // Get session ID - state.GetGlobal("__session_id") - sessionID := state.ToString(-1) - state.Pop(1) - - // Get session data - state.GetGlobal("__session_data") - if !state.IsTable(-1) { - state.Pop(1) - return sessionID, nil, false - } - - data, err := state.ToTable(-1) - state.Pop(1) - - if err != nil { - logger.Error("Failed to extract session data: %v", err) - return sessionID, nil, false - } - - return sessionID, data, true -} - -// SetSessionData sets session data in Lua state -func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error { - // Set session ID - state.PushString(sessionID) - state.SetGlobal("__session_id") - - // Set session data - if data == nil { - data = make(map[string]any) - } - - if err := state.PushTable(data); err != nil { - return err - } - state.SetGlobal("__session_data") - - // Reset modification flag - state.PushBoolean(false) - state.SetGlobal("__session_modified") - - return nil -} - -// SessionModuleInitFunc returns an initializer for the session module -func SessionModuleInitFunc() StateInitFunc { - return func(state *luajit.State) error { - return state.DoString(LuaSessionModule) - } -} diff --git a/core/runner/SessionHandler.go b/core/runner/Sessions.go similarity index 76% rename from core/runner/SessionHandler.go rename to core/runner/Sessions.go index b423b30..e73de0e 100644 --- a/core/runner/SessionHandler.go +++ b/core/runner/Sessions.go @@ -3,6 +3,7 @@ package runner import ( "github.com/valyala/fasthttp" + "Moonshark/core/runner/sandbox" "Moonshark/core/sessions" "Moonshark/core/utils/logger" @@ -40,9 +41,6 @@ func WithSessionManager(manager *sessions.SessionManager) RunnerOption { return func(r *Runner) { handler := NewSessionHandler(manager) - // Register the session module - RegisterCoreModule("session", SessionModuleInitFunc()) - // Add hooks to the runner r.AddInitHook(handler.preRequestHook) r.AddFinalizeHook(handler.postRequestHook) @@ -140,8 +138,10 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu session.Set(k, v) } + h.manager.SaveSession(session) + // Add session cookie to result if it's an HTTP response - if httpResp, ok := result.(*HTTPResponse); ok { + if httpResp, ok := result.(*sandbox.HTTPResponse); ok { h.addSessionCookie(httpResp, modifiedID) } @@ -150,7 +150,7 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu } // addSessionCookie adds a session cookie to an HTTP response -func (h *SessionHandler) addSessionCookie(resp *HTTPResponse, sessionID string) { +func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) { // Get cookie options opts := h.manager.CookieOptions() @@ -184,3 +184,60 @@ func (h *SessionHandler) addSessionCookie(resp *HTTPResponse, sessionID string) resp.Cookies = append(resp.Cookies, cookie) } + +// GetSessionData extracts session data from Lua state +func GetSessionData(state *luajit.State) (string, map[string]any, bool) { + // Check if session was modified + state.GetGlobal("__session_modified") + modified := state.ToBoolean(-1) + state.Pop(1) + + if !modified { + return "", nil, false + } + + // Get session ID + state.GetGlobal("__session_id") + sessionID := state.ToString(-1) + state.Pop(1) + + // Get session data + state.GetGlobal("__session_data") + if !state.IsTable(-1) { + state.Pop(1) + return sessionID, nil, false + } + + data, err := state.ToTable(-1) + state.Pop(1) + + if err != nil { + logger.Error("Failed to extract session data: %v", err) + return sessionID, nil, false + } + + return sessionID, data, true +} + +// SetSessionData sets session data in Lua state +func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error { + // Set session ID + state.PushString(sessionID) + state.SetGlobal("__session_id") + + // Set session data + if data == nil { + data = make(map[string]any) + } + + if err := state.PushTable(data); err != nil { + return err + } + state.SetGlobal("__session_data") + + // Reset modification flag + state.PushBoolean(false) + state.SetGlobal("__session_modified") + + return nil +} diff --git a/core/runner/sandbox/Embed.go b/core/runner/sandbox/Embed.go new file mode 100644 index 0000000..944f9d8 --- /dev/null +++ b/core/runner/sandbox/Embed.go @@ -0,0 +1,98 @@ +package sandbox + +import ( + _ "embed" + + "Moonshark/core/utils/logger" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +//go:embed lua/sandbox.lua +var sandboxLua string + +// InitializeSandbox loads the embedded Lua sandbox code into a Lua state +func InitializeSandbox(state *luajit.State) error { + // Compile once, use many times + bytecodeOnce.Do(precompileSandbox) + + if sandboxBytecode != nil { + logger.Debug("Loading sandbox.lua from precompiled bytecode") + return state.LoadAndRunBytecode(sandboxBytecode, "sandbox.lua") + } + + // Fallback if compilation failed + logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)") + return state.DoString(sandboxLua) +} + +// ModuleInitializers stores initializer functions for core modules +type ModuleInitializers struct { + HTTP func(*luajit.State) error + Util func(*luajit.State) error + Session func(*luajit.State) error + Cookie func(*luajit.State) error + CSRF func(*luajit.State) error +} + +// DefaultInitializers returns the default set of initializers +func DefaultInitializers() *ModuleInitializers { + return &ModuleInitializers{ + HTTP: func(state *luajit.State) error { + // Register the native Go function first + if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil { + logger.Error("[HTTP Module] Failed to register __http_request function: %v", err) + return err + } + return nil + }, + Util: func(state *luajit.State) error { + // Register util functions + return RegisterModule(state, "util", UtilModuleFunctions()) + }, + Session: func(state *luajit.State) error { + // Session doesn't need special initialization + return nil + }, + Cookie: func(state *luajit.State) error { + // Cookie doesn't need special initialization + return nil + }, + CSRF: func(state *luajit.State) error { + // CSRF doesn't need special initialization + return nil + }, + } +} + +// InitializeAll initializes all modules in the Lua state +func InitializeAll(state *luajit.State, initializers *ModuleInitializers) error { + // Set up dependencies first + if err := initializers.Util(state); err != nil { + return err + } + + if err := initializers.HTTP(state); err != nil { + return err + } + + // Load the embedded sandbox code + if err := InitializeSandbox(state); err != nil { + return err + } + + // Initialize the rest of the modules + if err := initializers.Session(state); err != nil { + return err + } + + if err := initializers.Cookie(state); err != nil { + return err + } + + if err := initializers.CSRF(state); err != nil { + return err + } + + return nil +} diff --git a/core/runner/HttpModule.go b/core/runner/sandbox/Http.go similarity index 66% rename from core/runner/HttpModule.go rename to core/runner/sandbox/Http.go index ea28e9f..c207f70 100644 --- a/core/runner/HttpModule.go +++ b/core/runner/sandbox/Http.go @@ -1,4 +1,4 @@ -package runner +package sandbox import ( "context" @@ -28,7 +28,7 @@ type HTTPResponse struct { // Response pool to reduce allocations var responsePool = sync.Pool{ - New: func() interface{} { + New: func() any { return &HTTPResponse{ Status: 200, Headers: make(map[string]string, 8), // Pre-allocate with reasonable capacity @@ -37,36 +37,6 @@ var responsePool = sync.Pool{ }, } -// NewHTTPResponse creates a default HTTP response, potentially reusing one from the pool -func NewHTTPResponse() *HTTPResponse { - return responsePool.Get().(*HTTPResponse) -} - -// ReleaseResponse returns the response to the pool after clearing its values -func ReleaseResponse(resp *HTTPResponse) { - if resp == nil { - return - } - - // Clear all values to prevent data leakage - resp.Status = 200 // Reset to default - - // Clear headers - for k := range resp.Headers { - delete(resp.Headers, k) - } - - // Clear cookies - resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0 - - // Clear body - resp.Body = nil - - responsePool.Put(resp) -} - -// ---------- HTTP CLIENT FUNCTIONALITY ---------- - // Default HTTP client with sensible timeout var defaultFastClient fasthttp.Client = fasthttp.Client{ MaxConnsPerHost: 1024, @@ -96,8 +66,256 @@ var DefaultHTTPClientConfig = HTTPClientConfig{ AllowRemote: true, } -// Function name constant to ensure consistency -const httpRequestFuncName = "__http_request" +// NewHTTPResponse creates a default HTTP response, potentially reusing one from the pool +func NewHTTPResponse() *HTTPResponse { + return responsePool.Get().(*HTTPResponse) +} + +// ReleaseResponse returns the response to the pool after clearing its values +func ReleaseResponse(resp *HTTPResponse) { + if resp == nil { + return + } + + // Clear all values to prevent data leakage + resp.Status = 200 // Reset to default + + // Clear headers + for k := range resp.Headers { + delete(resp.Headers, k) + } + + // Clear cookies + resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0 + + // Clear body + resp.Body = nil + + responsePool.Put(resp) +} + +// HTTPModuleInitFunc returns an initializer function for the HTTP module +func HTTPModuleInitFunc() func(*luajit.State) error { + return func(state *luajit.State) error { + // Register the native Go function first + if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil { + logger.Error("[HTTP Module] Failed to register __http_request function") + logger.ErrorCont("%v", err) + return err + } + + // Set up default HTTP client configuration + setupHTTPClientConfig(state) + + return nil + } +} + +// Helper to set up HTTP client config +func setupHTTPClientConfig(state *luajit.State) { + state.NewTable() + + state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second)) + state.SetField(-2, "max_timeout") + + state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second)) + state.SetField(-2, "default_timeout") + + state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize)) + state.SetField(-2, "max_response_size") + + state.PushBoolean(DefaultHTTPClientConfig.AllowRemote) + state.SetField(-2, "allow_remote") + + state.SetGlobal("__http_client_config") +} + +// GetHTTPResponse extracts the HTTP response from Lua state +func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) { + response := NewHTTPResponse() + + // Get response table + state.GetGlobal("__http_responses") + if state.IsNil(-1) { + state.Pop(1) + ReleaseResponse(response) // Return unused response to pool + return nil, false + } + + // Check for response at thread index + state.PushNumber(1) + state.GetTable(-2) + if state.IsNil(-1) { + state.Pop(2) + ReleaseResponse(response) // Return unused response to pool + return nil, false + } + + // Get status + state.GetField(-1, "status") + if state.IsNumber(-1) { + response.Status = int(state.ToNumber(-1)) + } + state.Pop(1) + + // Get headers + state.GetField(-1, "headers") + if state.IsTable(-1) { + // Iterate through headers table + state.PushNil() // Start iteration + for state.Next(-2) { + // Stack has key at -2 and value at -1 + if state.IsString(-2) && state.IsString(-1) { + key := state.ToString(-2) + value := state.ToString(-1) + response.Headers[key] = value + } + state.Pop(1) // Pop value, leave key for next iteration + } + } + state.Pop(1) + + // Get cookies + state.GetField(-1, "cookies") + if state.IsTable(-1) { + // Iterate through cookies array + length := state.GetTableLength(-1) + for i := 1; i <= length; i++ { + state.PushNumber(float64(i)) + state.GetTable(-2) + + if state.IsTable(-1) { + cookie := extractCookie(state) + if cookie != nil { + response.Cookies = append(response.Cookies, cookie) + } + } + state.Pop(1) + } + } + state.Pop(1) + + // Clean up + state.Pop(2) // Pop response table and __http_responses + + return response, true +} + +// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx +func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) { + // Set status code + ctx.SetStatusCode(httpResp.Status) + + // Set headers + for name, value := range httpResp.Headers { + ctx.Response.Header.Set(name, value) + } + + // Set cookies + for _, cookie := range httpResp.Cookies { + ctx.Response.Header.SetCookie(cookie) + } + + // Process the body based on its type + if httpResp.Body == nil { + return + } + + // Set body based on type + switch body := httpResp.Body.(type) { + case string: + ctx.SetBodyString(body) + case []byte: + ctx.SetBody(body) + case map[string]any, []any, []float64, []string, []int: + // Marshal JSON using a buffer from the pool + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + if err := json.NewEncoder(buf).Encode(body); err == nil { + // Set content type if not already set + if len(ctx.Response.Header.ContentType()) == 0 { + ctx.Response.Header.SetContentType("application/json") + } + ctx.SetBody(buf.Bytes()) + } else { + // Fallback + ctx.SetBodyString(fmt.Sprintf("%v", body)) + } + default: + // Default to string representation + ctx.SetBodyString(fmt.Sprintf("%v", body)) + } +} + +// extractCookie grabs cookies from the Lua state +func extractCookie(state *luajit.State) *fasthttp.Cookie { + cookie := fasthttp.AcquireCookie() + + // Get name + state.GetField(-1, "name") + if !state.IsString(-1) { + state.Pop(1) + fasthttp.ReleaseCookie(cookie) + return nil // Name is required + } + cookie.SetKey(state.ToString(-1)) + state.Pop(1) + + // Get value + state.GetField(-1, "value") + if state.IsString(-1) { + cookie.SetValue(state.ToString(-1)) + } + state.Pop(1) + + // Get path + state.GetField(-1, "path") + if state.IsString(-1) { + cookie.SetPath(state.ToString(-1)) + } else { + cookie.SetPath("/") // Default path + } + state.Pop(1) + + // Get domain + state.GetField(-1, "domain") + if state.IsString(-1) { + cookie.SetDomain(state.ToString(-1)) + } + state.Pop(1) + + // Get expires + state.GetField(-1, "expires") + if state.IsNumber(-1) { + expiry := int64(state.ToNumber(-1)) + cookie.SetExpire(time.Unix(expiry, 0)) + } + state.Pop(1) + + // Get max age + state.GetField(-1, "max_age") + if state.IsNumber(-1) { + cookie.SetMaxAge(int(state.ToNumber(-1))) + } + state.Pop(1) + + // Get secure + state.GetField(-1, "secure") + if state.IsBoolean(-1) { + cookie.SetSecure(state.ToBoolean(-1)) + } + state.Pop(1) + + // Get http only + state.GetField(-1, "http_only") + if state.IsBoolean(-1) { + cookie.SetHTTPOnly(state.ToBoolean(-1)) + } + state.Pop(1) + + return cookie +} // httpRequest makes an HTTP request and returns the result to Lua func httpRequest(state *luajit.State) int { @@ -360,372 +578,3 @@ func httpRequest(state *luajit.State) int { return 1 } - -// HTTPModuleInitFunc returns an initializer function for the HTTP module -func HTTPModuleInitFunc() StateInitFunc { - return func(state *luajit.State) error { - // CRITICAL: Register the native Go function first - // This must be done BEFORE any Lua code that references it - if err := state.RegisterGoFunction(httpRequestFuncName, httpRequest); err != nil { - logger.Error("[HTTP Module] Failed to register __http_request function") - logger.ErrorCont("%v", err) - return err - } - - // Set up default HTTP client configuration - setupHTTPClientConfig(state) - - // Initialize Lua HTTP module - if err := state.DoString(LuaHTTPModule); err != nil { - logger.Error("[HTTP Module] Failed to initialize HTTP module Lua code") - logger.ErrorCont("%v", err) - return err - } - - // Verify HTTP client functions are available - verifyHTTPClient(state) - - return nil - } -} - -// Helper to set up HTTP client config -func setupHTTPClientConfig(state *luajit.State) { - state.NewTable() - - state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second)) - state.SetField(-2, "max_timeout") - - state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second)) - state.SetField(-2, "default_timeout") - - state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize)) - state.SetField(-2, "max_response_size") - - state.PushBoolean(DefaultHTTPClientConfig.AllowRemote) - state.SetField(-2, "allow_remote") - - state.SetGlobal("__http_client_config") -} - -// GetHTTPResponse extracts the HTTP response from Lua state -func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) { - response := NewHTTPResponse() - - // Get response table - state.GetGlobal("__http_responses") - if state.IsNil(-1) { - state.Pop(1) - ReleaseResponse(response) // Return unused response to pool - return nil, false - } - - // Check for response at thread index - state.PushNumber(1) - state.GetTable(-2) - if state.IsNil(-1) { - state.Pop(2) - ReleaseResponse(response) // Return unused response to pool - return nil, false - } - - // Get status - state.GetField(-1, "status") - if state.IsNumber(-1) { - response.Status = int(state.ToNumber(-1)) - } - state.Pop(1) - - // Get headers - state.GetField(-1, "headers") - if state.IsTable(-1) { - // Iterate through headers table - state.PushNil() // Start iteration - for state.Next(-2) { - // Stack has key at -2 and value at -1 - if state.IsString(-2) && state.IsString(-1) { - key := state.ToString(-2) - value := state.ToString(-1) - response.Headers[key] = value - } - state.Pop(1) // Pop value, leave key for next iteration - } - } - state.Pop(1) - - // Get cookies - state.GetField(-1, "cookies") - if state.IsTable(-1) { - // Iterate through cookies array - length := state.GetTableLength(-1) - for i := 1; i <= length; i++ { - state.PushNumber(float64(i)) - state.GetTable(-2) - - if state.IsTable(-1) { - cookie := extractCookie(state) - if cookie != nil { - response.Cookies = append(response.Cookies, cookie) - } - } - state.Pop(1) - } - } - state.Pop(1) - - // Clean up - state.Pop(2) // Pop response table and __http_responses - - return response, true -} - -// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx -func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) { - // Set status code - ctx.SetStatusCode(httpResp.Status) - - // Set headers - for name, value := range httpResp.Headers { - ctx.Response.Header.Set(name, value) - } - - // Set cookies - for _, cookie := range httpResp.Cookies { - ctx.Response.Header.SetCookie(cookie) - } - - // Process the body based on its type - if httpResp.Body == nil { - return - } - - // Set body based on type - switch body := httpResp.Body.(type) { - case string: - ctx.SetBodyString(body) - case []byte: - ctx.SetBody(body) - case map[string]any, []any, []float64, []string, []int: - // Marshal JSON using a buffer from the pool - buf := bytebufferpool.Get() - defer bytebufferpool.Put(buf) - - if err := json.NewEncoder(buf).Encode(body); err == nil { - // Set content type if not already set - if len(ctx.Response.Header.ContentType()) == 0 { - ctx.Response.Header.SetContentType("application/json") - } - ctx.SetBody(buf.Bytes()) - } else { - // Fallback - ctx.SetBodyString(fmt.Sprintf("%v", body)) - } - default: - // Default to string representation - ctx.SetBodyString(fmt.Sprintf("%v", body)) - } -} - -// WithHTTPClientConfig creates a runner option to configure the HTTP client -func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption { - return func(r *Runner) { - // Store the config to be applied during initialization - r.AddModule("__http_client_config", map[string]any{ - "max_timeout": float64(config.MaxTimeout / time.Second), - "default_timeout": float64(config.DefaultTimeout / time.Second), - "max_response_size": float64(config.MaxResponseSize), - "allow_remote": config.AllowRemote, - }) - } -} - -// RestrictHTTPToLocalhost is a convenience function to restrict HTTP client -// to localhost connections only -func RestrictHTTPToLocalhost() RunnerOption { - return WithHTTPClientConfig(HTTPClientConfig{ - MaxTimeout: DefaultHTTPClientConfig.MaxTimeout, - DefaultTimeout: DefaultHTTPClientConfig.DefaultTimeout, - MaxResponseSize: DefaultHTTPClientConfig.MaxResponseSize, - AllowRemote: false, - }) -} - -// Verify that HTTP client is properly set up -func verifyHTTPClient(state *luajit.State) { - // Get the client table - state.GetGlobal("http") - if !state.IsTable(-1) { - logger.Warning("[HTTP Module] 'http' is not a table") - state.Pop(1) - return - } - - state.GetField(-1, "client") - if !state.IsTable(-1) { - logger.Warning("[HTTP Module] 'http.client' is not a table") - state.Pop(2) - return - } - - // Check for get function - state.GetField(-1, "get") - if !state.IsFunction(-1) { - logger.Warning("[HTTP Module] 'http.client.get' is not a function") - } - state.Pop(1) - - // Check for the request function - state.GetField(-1, "request") - if !state.IsFunction(-1) { - logger.Warning("[HTTP Module] 'http.client.request' is not a function") - } - state.Pop(3) // Pop request, client, http -} - -const LuaHTTPModule = ` --- Table to store response data -__http_responses = {} - --- HTTP module implementation -local http = { - -- Set HTTP status code - set_status = function(code) - if type(code) ~= "number" then - error("http.set_status: status code must be a number", 2) - end - - local resp = __http_responses[1] or {} - resp.status = code - __http_responses[1] = resp - end, - - -- Set HTTP header - set_header = function(name, value) - if type(name) ~= "string" or type(value) ~= "string" then - error("http.set_header: name and value must be strings", 2) - end - - local resp = __http_responses[1] or {} - resp.headers = resp.headers or {} - resp.headers[name] = value - __http_responses[1] = resp - end, - - -- Set content type; set_header helper - set_content_type = function(content_type) - http.set_header("Content-Type", content_type) - end, - - -- HTTP client submodule - client = { - -- Generic request function - request = function(method, url, body, options) - if type(method) ~= "string" then - error("http.client.request: method must be a string", 2) - end - if type(url) ~= "string" then - error("http.client.request: url must be a string", 2) - end - - -- Call native implementation (this is the critical part) - local result = __http_request(method, url, body, options) - return result - end, - - -- Simple GET request - get = function(url, options) - return http.client.request("GET", url, nil, options) - end, - - -- Simple POST request with automatic content-type - post = function(url, body, options) - options = options or {} - return http.client.request("POST", url, body, options) - end, - - -- Simple PUT request with automatic content-type - put = function(url, body, options) - options = options or {} - return http.client.request("PUT", url, body, options) - end, - - -- Simple DELETE request - delete = function(url, options) - return http.client.request("DELETE", url, nil, options) - end, - - -- Simple PATCH request - patch = function(url, body, options) - options = options or {} - return http.client.request("PATCH", url, body, options) - end, - - -- Simple HEAD request - head = function(url, options) - options = options or {} - local old_options = options - options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query} - local response = http.client.request("HEAD", url, nil, options) - return response - end, - - -- Simple OPTIONS request - options = function(url, options) - return http.client.request("OPTIONS", url, nil, options) - end, - - -- Shorthand function to directly get JSON - get_json = function(url, options) - options = options or {} - local response = http.client.get(url, options) - if response.ok and response.json then - return response.json - end - return nil, response - end, - - -- Utility to build a URL with query parameters - build_url = function(base_url, params) - if not params or type(params) ~= "table" then - return base_url - end - - local query = {} - for k, v in pairs(params) do - if type(v) == "table" then - for _, item in ipairs(v) do - table.insert(query, k .. "=" .. tostring(item)) - end - else - table.insert(query, k .. "=" .. tostring(v)) - end - end - - if #query > 0 then - if base_url:find("?") then - return base_url .. "&" .. table.concat(query, "&") - else - return base_url .. "?" .. table.concat(query, "&") - end - end - - return base_url - end - } -} - --- Install HTTP module -_G.http = http - --- Clear previous responses when executing scripts -local old_execute_script = __execute_script -if old_execute_script then - __execute_script = function(fn, ctx) - -- Clear previous response - __http_responses[1] = nil - - -- Execute original function - return old_execute_script(fn, ctx) - end -end -` diff --git a/core/runner/sandbox/Modules.go b/core/runner/sandbox/Modules.go new file mode 100644 index 0000000..c9a9764 --- /dev/null +++ b/core/runner/sandbox/Modules.go @@ -0,0 +1,86 @@ +package sandbox + +import ( + "Moonshark/core/utils/logger" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// ModuleFunc is a function that returns a map of module functions +type ModuleFunc func() map[string]luajit.GoFunction + +// RegisterModule registers a map of functions as a Lua module +func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error { + // Create a new table for the module + state.NewTable() + + // Add each function to the module table + for fname, f := range funcs { + // Push function name + state.PushString(fname) + + // Push function + if err := state.PushGoFunction(f); err != nil { + state.Pop(1) // Pop table + return err + } + + // Set table[fname] = f + state.SetTable(-3) + } + + // Register the module globally + state.SetGlobal(name) + return nil +} + +// CombineInitFuncs combines multiple state initializer functions into one +func CombineInitFuncs(funcs ...func(*luajit.State) error) func(*luajit.State) error { + return func(state *luajit.State) error { + for _, f := range funcs { + if f != nil { + if err := f(state); err != nil { + return err + } + } + } + return nil + } +} + +// ModuleInitFunc creates a state initializer that registers multiple modules +func ModuleInitFunc(modules map[string]ModuleFunc) func(*luajit.State) error { + return func(state *luajit.State) error { + for name, moduleFunc := range modules { + if err := RegisterModule(state, name, moduleFunc()); err != nil { + logger.Error("Failed to register module %s: %v", name, err) + return err + } + } + return nil + } +} + +// RegisterLuaCode registers a Lua code snippet as a module +func RegisterLuaCode(state *luajit.State, code string) error { + return state.DoString(code) +} + +// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code +func RegisterLuaCodeInitFunc(code string) func(*luajit.State) error { + return func(state *luajit.State) error { + return RegisterLuaCode(state, code) + } +} + +// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module +func RegisterLuaModuleInitFunc(name string, code string) func(*luajit.State) error { + return func(state *luajit.State) error { + // Create name = {} global + state.NewTable() + state.SetGlobal(name) + + // Then run the module code which will populate it + return state.DoString(code) + } +} diff --git a/core/runner/Sandbox.go b/core/runner/sandbox/Sandbox.go similarity index 73% rename from core/runner/Sandbox.go rename to core/runner/sandbox/Sandbox.go index 3f01b47..b502df0 100644 --- a/core/runner/Sandbox.go +++ b/core/runner/sandbox/Sandbox.go @@ -1,4 +1,4 @@ -package runner +package sandbox import ( "fmt" @@ -6,24 +6,52 @@ import ( "github.com/goccy/go-json" "github.com/valyala/bytebufferpool" + "github.com/valyala/fasthttp" "Moonshark/core/utils/logger" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) +// Global bytecode cache to improve performance +var ( + sandboxBytecode []byte + bytecodeOnce sync.Once +) + +// precompileSandbox compiles the sandbox.lua code to bytecode once +func precompileSandbox() { + tempState := luajit.New() + if tempState == nil { + logger.Error("Failed to create temporary Lua state for bytecode compilation") + return + } + defer tempState.Close() + defer tempState.Cleanup() + + var err error + sandboxBytecode, err = tempState.CompileBytecode(sandboxLua, "sandbox.lua") + if err != nil { + logger.Error("Failed to precompile sandbox.lua: %v", err) + } else { + logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(sandboxBytecode)) + } +} + // Sandbox provides a secure execution environment for Lua scripts type Sandbox struct { - modules map[string]any // Custom modules for environment - debug bool // Enable debug output - mu sync.RWMutex // Protects modules + modules map[string]any // Custom modules for environment + debug bool // Enable debug output + mu sync.RWMutex // Protects modules + initializers *ModuleInitializers // Module initializers } // NewSandbox creates a new sandbox environment func NewSandbox() *Sandbox { return &Sandbox{ - modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity - debug: false, + modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity + debug: false, + initializers: DefaultInitializers(), } } @@ -39,7 +67,7 @@ func (s *Sandbox) debugLog(format string, args ...interface{}) { } } -// debugLog logs a message if debug mode is enabled +// debugLogCont logs a continuation message if debug mode is enabled func (s *Sandbox) debugLogCont(format string, args ...interface{}) { if s.debug { logger.DebugCont(format, args...) @@ -60,19 +88,27 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error { verbose := stateIndex == 0 if verbose { - s.debugLog("is setting up...") + s.debugLog("Setting up sandbox...") } - // Register modules in the global environment + // Initialize modules with the embedded sandbox code + if err := InitializeAll(state, s.initializers); err != nil { + if verbose { + s.debugLog("Failed to initialize sandbox: %v", err) + } + return err + } + + // Register custom modules in the global environment s.mu.RLock() for name, module := range s.modules { if verbose { - s.debugLog("is registering module: %s", name) + s.debugLog("Registering module: %s", name) } if err := state.PushValue(module); err != nil { s.mu.RUnlock() if verbose { - s.debugLog("failed to register module %s: %v", name, err) + s.debugLog("Failed to register module %s: %v", name, err) } return err } @@ -80,60 +116,8 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error { } s.mu.RUnlock() - // Initialize environment setup - err := state.DoString(` - -- Global tables for response handling - __http_responses = __http_responses or {} - - -- Create environment inheriting from _G - function __create_env(ctx) - -- Create environment with metatable inheriting from _G - local env = setmetatable({}, {__index = _G}) - - -- Add context if provided - if ctx then - env.ctx = ctx - end - - -- Add proper require function to this environment - if __setup_require then - __setup_require(env) - end - - return env - end - - -- Execute script with clean environment - function __execute_script(fn, ctx) - -- Clear previous responses - __http_responses[1] = nil - - -- Create environment - local env = __create_env(ctx) - - -- Set environment for function - setfenv(fn, env) - - -- Execute with protected call - local ok, result = pcall(fn) - if not ok then - error(result, 0) - end - - return result - end - `) - - if err != nil { - if verbose { - s.debugLog("failed to set up...") - s.debugLogCont("%v", err) - } - return err - } - if verbose { - s.debugLogCont("Complete") + s.debugLogCont("Sandbox setup complete") } return nil } @@ -152,6 +136,14 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a return s.OptimizedExecute(state, bytecode, nil) } +// Context represents execution context for a Lua script +type Context struct { + // Values stores any context values (route params, HTTP request info, etc.) + Values map[string]any + // RequestCtx for HTTP requests + RequestCtx *fasthttp.RequestCtx +} + // OptimizedExecute runs bytecode with a fasthttp context if available func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Context) (any, error) { // Use a buffer from the pool for any string operations diff --git a/core/runner/UtilModule.go b/core/runner/sandbox/Utils.go similarity index 85% rename from core/runner/UtilModule.go rename to core/runner/sandbox/Utils.go index 9141af5..75863a9 100644 --- a/core/runner/UtilModule.go +++ b/core/runner/sandbox/Utils.go @@ -1,4 +1,4 @@ -package runner +package sandbox import ( "crypto/rand" @@ -9,6 +9,20 @@ import ( luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) +// UtilModuleInitFunc returns an initializer for the util module +func UtilModuleInitFunc() func(*luajit.State) error { + return func(state *luajit.State) error { + return RegisterModule(state, "util", UtilModuleFunctions()) + } +} + +// UtilModuleFunctions returns all functions for the util module +func UtilModuleFunctions() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "generate_token": GenerateToken, + } +} + // GenerateToken creates a cryptographically secure random token func GenerateToken(s *luajit.State) int { // Get the length from the Lua arguments (default to 32) @@ -42,17 +56,3 @@ func GenerateToken(s *luajit.State) int { s.PushString(token) return 1 // One return value } - -// UtilModuleFunctions returns all functions for the go module -func UtilModuleFunctions() map[string]luajit.GoFunction { - return map[string]luajit.GoFunction{ - "generate_token": GenerateToken, - } -} - -// UtilModuleInitFunc returns an initializer for the go module -func UtilModuleInitFunc() StateInitFunc { - return func(state *luajit.State) error { - return RegisterModule(state, "util", UtilModuleFunctions()) - } -} diff --git a/core/runner/sandbox/lua/sandbox.lua b/core/runner/sandbox/lua/sandbox.lua new file mode 100644 index 0000000..c755c27 --- /dev/null +++ b/core/runner/sandbox/lua/sandbox.lua @@ -0,0 +1,552 @@ +--[[ +Moonshark Lua Sandbox Environment + +This file contains all the Lua code needed for the sandbox environment, +including core modules and utilities. It's designed to be embedded in the +Go binary at build time. +]]-- + +-- Global tables for execution context +__http_responses = {} +__module_paths = {} +__module_bytecode = {} +__ready_modules = {} +__session_data = {} +__session_id = nil +__session_modified = false +__env_system = { + base_env = {} +} + +-- ====================================================================== +-- CORE SANDBOX FUNCTIONALITY +-- ====================================================================== + +-- Create environment inheriting from _G +function __create_env(ctx) + -- Create environment with metatable inheriting from _G + local env = setmetatable({}, {__index = _G}) + + -- Add context if provided + if ctx then + env.ctx = ctx + end + + -- Add proper require function to this environment + if __setup_require then + __setup_require(env) + end + + return env +end + +-- Execute script with clean environment +function __execute_script(fn, ctx) + -- Clear previous responses + __http_responses[1] = nil + + -- Reset session modification flag + __session_modified = false + + -- Create environment + local env = __create_env(ctx) + + -- Set environment for function + setfenv(fn, env) + + -- Execute with protected call + local ok, result = pcall(fn) + if not ok then + error(result, 0) + end + + return result +end + +-- ====================================================================== +-- MODULE LOADING SYSTEM +-- ====================================================================== + +-- Setup environment-aware require function +function __setup_require(env) + -- Create require function specific to this environment + env.require = function(modname) + -- Check if already loaded + if package.loaded[modname] then + return package.loaded[modname] + end + + -- Check preloaded modules + if __ready_modules[modname] then + local loader = package.preload[modname] + if loader then + -- Set environment for loader + setfenv(loader, env) + + -- Execute and store result + local result = loader() + if result == nil then + result = true + end + + package.loaded[modname] = result + return result + end + end + + -- Direct file load as fallback + if __module_paths[modname] then + local path = __module_paths[modname] + local chunk, err = loadfile(path) + if chunk then + setfenv(chunk, env) + local result = chunk() + if result == nil then + result = true + end + package.loaded[modname] = result + return result + end + end + + -- Full path search as last resort + local errors = {} + for path in package.path:gmatch("[^;]+") do + local file_path = path:gsub("?", modname:gsub("%.", "/")) + local chunk, err = loadfile(file_path) + if chunk then + setfenv(chunk, env) + local result = chunk() + if result == nil then + result = true + end + package.loaded[modname] = result + return result + end + table.insert(errors, "\tno file '" .. file_path .. "'") + end + + error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2) + end + + return env +end + +-- ====================================================================== +-- HTTP MODULE +-- ====================================================================== + +-- HTTP module implementation +local http = { + -- Set HTTP status code + set_status = function(code) + if type(code) ~= "number" then + error("http.set_status: status code must be a number", 2) + end + + local resp = __http_responses[1] or {} + resp.status = code + __http_responses[1] = resp + end, + + -- Set HTTP header + set_header = function(name, value) + if type(name) ~= "string" or type(value) ~= "string" then + error("http.set_header: name and value must be strings", 2) + end + + local resp = __http_responses[1] or {} + resp.headers = resp.headers or {} + resp.headers[name] = value + __http_responses[1] = resp + end, + + -- Set content type; set_header helper + set_content_type = function(content_type) + http.set_header("Content-Type", content_type) + end, + + -- HTTP client submodule + client = { + -- Generic request function + request = function(method, url, body, options) + if type(method) ~= "string" then + error("http.client.request: method must be a string", 2) + end + if type(url) ~= "string" then + error("http.client.request: url must be a string", 2) + end + + -- Call native implementation + local result = __http_request(method, url, body, options) + return result + end, + + -- Simple GET request + get = function(url, options) + return http.client.request("GET", url, nil, options) + end, + + -- Simple POST request with automatic content-type + post = function(url, body, options) + options = options or {} + return http.client.request("POST", url, body, options) + end, + + -- Simple PUT request with automatic content-type + put = function(url, body, options) + options = options or {} + return http.client.request("PUT", url, body, options) + end, + + -- Simple DELETE request + delete = function(url, options) + return http.client.request("DELETE", url, nil, options) + end, + + -- Simple PATCH request + patch = function(url, body, options) + options = options or {} + return http.client.request("PATCH", url, body, options) + end, + + -- Simple HEAD request + head = function(url, options) + options = options or {} + local old_options = options + options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query} + local response = http.client.request("HEAD", url, nil, options) + return response + end, + + -- Simple OPTIONS request + options = function(url, options) + return http.client.request("OPTIONS", url, nil, options) + end, + + -- Shorthand function to directly get JSON + get_json = function(url, options) + options = options or {} + local response = http.client.get(url, options) + if response.ok and response.json then + return response.json + end + return nil, response + end, + + -- Utility to build a URL with query parameters + build_url = function(base_url, params) + if not params or type(params) ~= "table" then + return base_url + end + + local query = {} + for k, v in pairs(params) do + if type(v) == "table" then + for _, item in ipairs(v) do + table.insert(query, k .. "=" .. tostring(item)) + end + else + table.insert(query, k .. "=" .. tostring(v)) + end + end + + if #query > 0 then + if base_url:find("?") then + return base_url .. "&" .. table.concat(query, "&") + else + return base_url .. "?" .. table.concat(query, "&") + end + end + + return base_url + end + } +} + +-- ====================================================================== +-- COOKIE MODULE +-- ====================================================================== + +-- Cookie module implementation +local cookie = { + -- Set a cookie + set = function(name, value, options, ...) + if type(name) ~= "string" then + error("cookie.set: name must be a string", 2) + end + + -- Get or create response + local resp = __http_responses[1] or {} + resp.cookies = resp.cookies or {} + __http_responses[1] = resp + + -- Handle options as table or legacy params + local opts = {} + if type(options) == "table" then + opts = options + elseif options ~= nil then + -- Legacy support: options is actually 'expires' + opts.expires = options + -- Check for other legacy params (4th-7th args) + local args = {...} + if args[1] then opts.path = args[1] end + if args[2] then opts.domain = args[2] end + if args[3] then opts.secure = args[3] end + if args[4] ~= nil then opts.http_only = args[4] end + end + + -- Create cookie table + local cookie = { + name = name, + value = value or "", + path = opts.path or "/", + domain = opts.domain + } + + -- Handle expiry + if opts.expires then + if type(opts.expires) == "number" then + if opts.expires > 0 then + cookie.max_age = opts.expires + local now = os.time() + cookie.expires = now + opts.expires + elseif opts.expires < 0 then + cookie.expires = 1 + cookie.max_age = 0 + else + -- opts.expires == 0: Session cookie + -- Do nothing (omitting both expires and max-age creates a session cookie) + end + end + end + + -- Security flags + cookie.secure = (opts.secure ~= false) + cookie.http_only = (opts.http_only ~= false) + + -- Store in cookies table + local n = #resp.cookies + 1 + resp.cookies[n] = cookie + + return true + end, + + -- Get a cookie value + get = function(name) + if type(name) ~= "string" then + error("cookie.get: name must be a string", 2) + end + + -- Access values directly from current environment + local env = getfenv(2) + + -- Check if context exists and has cookies + if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then + return tostring(env.ctx.cookies[name]) + end + + return nil + end, + + -- Remove a cookie + remove = function(name, path, domain) + if type(name) ~= "string" then + error("cookie.remove: name must be a string", 2) + end + + -- Create an expired cookie + return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain}) + end +} + +-- ====================================================================== +-- SESSION MODULE +-- ====================================================================== + +-- Session module implementation +local session = { + -- Get a session value + get = function(key) + if type(key) ~= "string" then + error("session.get: key must be a string", 2) + end + + if __session_data and __session_data[key] then + return __session_data[key] + end + + return nil + end, + + -- Set a session value + set = function(key, value) + if type(key) ~= "string" then + error("session.set: key must be a string", 2) + end + + -- Ensure session data table exists + __session_data = __session_data or {} + + -- Store value + __session_data[key] = value + + -- Mark session as modified + __session_modified = true + + return true + end, + + -- Delete a session value + delete = function(key) + if type(key) ~= "string" then + error("session.delete: key must be a string", 2) + end + + if __session_data then + __session_data[key] = nil + __session_modified = true + end + + return true + end, + + -- Clear all session data + clear = function() + __session_data = {} + __session_modified = true + return true + end, + + -- Get the session ID + get_id = function() + return __session_id or nil + end, + + -- Get all session data + get_all = function() + local result = {} + for k, v in pairs(__session_data or {}) do + result[k] = v + end + return result + end, + + -- Check if session has a key + has = function(key) + if type(key) ~= "string" then + error("session.has: key must be a string", 2) + end + + return __session_data and __session_data[key] ~= nil + end +} + +-- ====================================================================== +-- CSRF MODULE +-- ====================================================================== + +-- CSRF protection module +local csrf = { + -- Session key where the token is stored + TOKEN_KEY = "_csrf_token", + + -- Default form field name + DEFAULT_FIELD = "csrf", + + -- Generate a new CSRF token and store it in the session + generate = function(length) + -- Default length is 32 characters + length = length or 32 + + if length < 16 then + -- Enforce minimum security + length = 16 + end + + -- Check if we have a session module + if not session then + error("CSRF protection requires the session module", 2) + end + + local token = util.generate_token(length) + session.set(csrf.TOKEN_KEY, token) + return token + end, + + -- Get the current token or generate a new one + token = function() + -- Get from session if exists + local token = session.get(csrf.TOKEN_KEY) + + -- Generate if needed + if not token then + token = csrf.generate() + end + + return token + end, + + -- Generate a hidden form field with the CSRF token + field = function(field_name) + field_name = field_name or csrf.DEFAULT_FIELD + local token = csrf.token() + return string.format('', field_name, token) + end, + + -- Verify a given token against the session token + verify = function(token, field_name) + field_name = field_name or csrf.DEFAULT_FIELD + + local env = getfenv(2) + + local form = nil + if env.ctx and env.ctx.form then + form = env.ctx.form + else + return false + end + + token = token or form[field_name] + if not token then + return false + end + + local session_token = session.get(csrf.TOKEN_KEY) + if not session_token then + return false + end + + -- Constant-time comparison to prevent timing attacks + -- This is safe since Lua strings are immutable + if #token ~= #session_token then + return false + end + + local result = true + for i = 1, #token do + if token:sub(i, i) ~= session_token:sub(i, i) then + result = false + -- Don't break early - continue to prevent timing attacks + end + end + + return result + end +} + +-- ====================================================================== +-- REGISTER MODULES GLOBALLY +-- ====================================================================== + +-- Install modules in global scope +_G.http = http +_G.cookie = cookie +_G.session = session +_G.csrf = csrf + +-- Register modules in sandbox base environment +__env_system.base_env.http = http +__env_system.base_env.cookie = cookie +__env_system.base_env.session = session +__env_system.base_env.csrf = csrf \ No newline at end of file diff --git a/core/sessions/Manager.go b/core/sessions/Manager.go index 8be2a11..7513b91 100644 --- a/core/sessions/Manager.go +++ b/core/sessions/Manager.go @@ -87,7 +87,6 @@ func (sm *SessionManager) GetSession(id string) *Session { func (sm *SessionManager) CreateSession() *Session { id := generateSessionID() - // Create new session session := NewSession(id) data, _ := json.Marshal(session) sm.cache.Set([]byte(id), data) @@ -95,6 +94,12 @@ func (sm *SessionManager) CreateSession() *Session { return session } +// SaveSession persists a session back to the cache +func (sm *SessionManager) SaveSession(session *Session) { + data, _ := json.Marshal(session) + sm.cache.Set([]byte(session.ID), data) +} + // DestroySession removes a session func (sm *SessionManager) DestroySession(id string) { sm.cache.Del([]byte(id)) diff --git a/core/sessions/Session.go b/core/sessions/Session.go index fa18609..ec75f7f 100644 --- a/core/sessions/Session.go +++ b/core/sessions/Session.go @@ -18,13 +18,13 @@ var ( // Session stores data for a single user session type Session struct { - ID string - Data map[string]any - CreatedAt time.Time - UpdatedAt time.Time - mu sync.RWMutex // Protect concurrent access to Data - maxValueSize int // Maximum size of individual values in bytes - totalDataSize int // Track total size of all data + ID string `json:"id"` + Data map[string]any `json:"data"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + mu sync.RWMutex `json:"-"` + maxValueSize int `json:"max_value_size"` + totalDataSize int `json:"total_data_size"` } // NewSession creates a new session with the given ID