diff --git a/core/Moonshark.go b/core/Moonshark.go index 8ccc263..be4fd89 100644 --- a/core/Moonshark.go +++ b/core/Moonshark.go @@ -184,8 +184,6 @@ func (s *Moonshark) initRunner() error { runnerOpts := []runner.RunnerOption{ runner.WithPoolSize(s.Config.Runner.PoolSize), runner.WithLibDirs(s.Config.Dirs.Libs...), - runner.WithSessionManager(sessionManager), - http.WithCSRFProtection(), } // Add debug option conditionally diff --git a/core/http/Csrf.go b/core/http/Csrf.go index 1b7d7bb..902e689 100644 --- a/core/http/Csrf.go +++ b/core/http/Csrf.go @@ -2,18 +2,19 @@ package http import ( "Moonshark/core/runner" - luaCtx "Moonshark/core/runner/context" "Moonshark/core/utils" "Moonshark/core/utils/logger" "crypto/subtle" + "errors" "github.com/valyala/fasthttp" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) +// Error for CSRF validation failure +var ErrCSRFValidationFailed = errors.New("CSRF token validation failed") + // ValidateCSRFToken checks if the CSRF token is valid for a request -func ValidateCSRFToken(state *luajit.State, ctx *luaCtx.Context) bool { +func ValidateCSRFToken(ctx *runner.Context) bool { // Only validate for form submissions method, ok := ctx.Get("method").(string) if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") { @@ -34,87 +35,23 @@ func ValidateCSRFToken(state *luajit.State, ctx *luaCtx.Context) bool { return false } - // Get session token - state.GetGlobal("session") - if state.IsNil(-1) { - state.Pop(1) - logger.Warning("CSRF validation failed: session module not available") + // Get token from session + sessionData := ctx.SessionData + if sessionData == nil { + logger.Warning("CSRF validation failed: no session data") return false } - state.GetField(-1, "get") - if !state.IsFunction(-1) { - state.Pop(2) - logger.Warning("CSRF validation failed: session.get not available") - return false - } - - state.PushCopy(-1) // Duplicate function - state.PushString("_csrf_token") - - if err := state.Call(1, 1); err != nil { - state.Pop(3) // Pop error, function and session table - logger.Warning("CSRF validation failed: %v", err) - return false - } - - if state.IsNil(-1) { - state.Pop(3) // Pop nil, function and session table + sessionToken, ok := sessionData["_csrf_token"].(string) + if !ok || sessionToken == "" { logger.Warning("CSRF validation failed: no token in session") return false } - sessionToken := state.ToString(-1) - state.Pop(3) // Pop token, function and session table - // Constant-time comparison to prevent timing attacks return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1 } -// WithCSRFProtection creates a runner option to add CSRF protection -func WithCSRFProtection() runner.RunnerOption { - return func(r *runner.Runner) { - r.AddInitHook(func(state *luajit.State, ctx *luaCtx.Context) error { - // Get request method - method, ok := ctx.Get("method").(string) - if !ok { - return nil - } - - // Only validate for form submissions - if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" { - return nil - } - - // Check for form data - form, ok := ctx.Get("form").(map[string]any) - if !ok || form == nil { - return nil - } - - // Validate CSRF token - if !ValidateCSRFToken(state, ctx) { - return ErrCSRFValidationFailed - } - - return nil - }) - } -} - -// Error for CSRF validation failure -var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"} - -// CSRFError represents a CSRF validation error -type CSRFError struct { - message string -} - -// Error implements the error interface -func (e *CSRFError) Error() string { - return e.message -} - // HandleCSRFError handles a CSRF validation error func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) { method := string(ctx.Method()) @@ -129,3 +66,39 @@ func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg) ctx.SetBody([]byte(errorHTML)) } + +// GenerateCSRFToken creates a new CSRF token and stores it in the session +func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) { + if length < 16 { + length = 16 // Minimum token length for security + } + + // Create secure random token + token, err := GenerateSecureToken(length) + if err != nil { + return "", err + } + + // Store token in session + ctx.SessionData["_csrf_token"] = token + return token, nil +} + +// GetCSRFToken retrieves the current CSRF token or generates a new one +func GetCSRFToken(ctx *runner.Context) (string, error) { + // Check if token already exists in session + if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" { + return token, nil + } + + // Generate new token + return GenerateCSRFToken(ctx, 32) +} + +// CSRFMiddleware validates CSRF tokens for state-changing requests +func CSRFMiddleware(ctx *runner.Context) error { + if !ValidateCSRFToken(ctx) { + return ErrCSRFValidationFailed + } + return nil +} diff --git a/core/http/Forms.go b/core/http/Forms.go deleted file mode 100644 index 9c2720b..0000000 --- a/core/http/Forms.go +++ /dev/null @@ -1,116 +0,0 @@ -package http - -import ( - "errors" - "mime/multipart" - "strings" - - "github.com/valyala/fasthttp" -) - -// Maximum form parse size (16MB) -const maxFormSize = 16 << 20 - -// Common errors -var ( - ErrFormSizeTooLarge = errors.New("form size too large") - ErrInvalidFormType = errors.New("invalid form content type") -) - -// ParseForm parses a POST request body into a map of values -// Supports both application/x-www-form-urlencoded and multipart/form-data content types -func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { - // Only handle POST, PUT, PATCH - method := string(ctx.Method()) - if method != "POST" && method != "PUT" && method != "PATCH" { - return make(map[string]any), nil - } - - // Check content type - contentType := string(ctx.Request.Header.ContentType()) - if contentType == "" { - return make(map[string]any), nil - } - - result := make(map[string]any) - - // Check for content length to prevent DOS - if len(ctx.Request.Body()) > maxFormSize { - return nil, ErrFormSizeTooLarge - } - - // Handle by content type - if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") { - return parseURLEncodedForm(ctx) - } else if strings.HasPrefix(contentType, "multipart/form-data") { - return parseMultipartForm(ctx) - } - - // Unrecognized content type - return result, nil -} - -// parseURLEncodedForm handles application/x-www-form-urlencoded forms -func parseURLEncodedForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { - result := make(map[string]any) - - // Process form values directly from PostArgs() - ctx.PostArgs().VisitAll(func(key, value []byte) { - keyStr := string(key) - valStr := string(value) - - // Check if we already have this key - if existing, ok := result[keyStr]; ok { - // If it's already a slice, append - if existingSlice, ok := existing.([]string); ok { - result[keyStr] = append(existingSlice, valStr) - } else { - // Convert to slice and append - result[keyStr] = []string{existing.(string), valStr} - } - } else { - // New key - result[keyStr] = valStr - } - }) - - return result, nil -} - -// parseMultipartForm handles multipart/form-data forms -func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { - result := make(map[string]any) - - // Parse multipart form - form, err := ctx.MultipartForm() - if err != nil { - if err == multipart.ErrMessageTooLarge || strings.Contains(err.Error(), "too large") { - return nil, ErrFormSizeTooLarge - } - return nil, err - } - - // Process form values - for key, values := range form.Value { - if len(values) == 1 { - // Single value - result[key] = values[0] - } else if len(values) > 1 { - // Multiple values - store as string slice - strValues := make([]string, len(values)) - copy(strValues, values) - result[key] = strValues - } - } - - // We don't handle file uploads here - could be extended in the future - // if needed to support file uploads to Lua - - return result, nil -} - -// Usage: -// After parsing the form with ParseForm, you can add it to the context with: -// ctx.Set("form", formData) -// -// This makes the form data accessible in Lua as ctx.form.field_name diff --git a/core/http/HttpLogger.go b/core/http/HttpLogger.go deleted file mode 100644 index 8705f6f..0000000 --- a/core/http/HttpLogger.go +++ /dev/null @@ -1,43 +0,0 @@ -package http - -import ( - "time" - - "Moonshark/core/utils/logger" -) - -// StatusColors for different status code ranges -const ( - colorGreen = "\033[32m" // 2xx - Success - colorCyan = "\033[36m" // 3xx - Redirection - colorYellow = "\033[33m" // 4xx - Client Errors - colorRed = "\033[31m" // 5xx - Server Errors - colorReset = "\033[0m" // Reset color - colorGray = "\033[90m" -) - -// LogRequest logs an HTTP request with custom formatting -func LogRequest(statusCode int, method, path string, duration time.Duration) { - statusColor := getStatusColor(statusCode) - - // Use the logger's raw message writer to bypass the standard format - logger.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s", - colorGray, time.Now().Format(logger.TimeFormat()), colorReset, - statusColor, statusCode, method, colorReset, path, colorGray, duration, colorReset) -} - -// getStatusColor returns the ANSI color code for a status code -func getStatusColor(code int) string { - switch { - case code >= 200 && code < 300: - return colorGreen - case code >= 300 && code < 400: - return colorCyan - case code >= 400 && code < 500: - return colorYellow - case code >= 500: - return colorRed - default: - return "" - } -} diff --git a/core/http/Queries.go b/core/http/Queries.go deleted file mode 100644 index d972138..0000000 --- a/core/http/Queries.go +++ /dev/null @@ -1,43 +0,0 @@ -package http - -import ( - "github.com/valyala/fasthttp" -) - -// QueryToLua converts HTTP query parameters to a map that can be used with LuaJIT. -// Single value parameters are stored as strings. -// Multi-value parameters are converted to []any arrays. -func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any { - result := make(map[string]any) - - // Use a map to track keys that have multiple values - multiValueKeys := make(map[string]bool) - - // Process all query args - ctx.QueryArgs().VisitAll(func(key, value []byte) { - keyStr := string(key) - valStr := string(value) - - if _, exists := result[keyStr]; exists { - // This key already exists, convert to array if not already - if !multiValueKeys[keyStr] { - // First duplicate, convert existing value to array - multiValueKeys[keyStr] = true - result[keyStr] = []any{result[keyStr], valStr} - } else { - // Already an array, append - result[keyStr] = append(result[keyStr].([]any), valStr) - } - } else { - // New key - result[keyStr] = valStr - } - }) - - // If we don't have any query parameters, return empty map - if len(result) == 0 { - return make(map[string]any) - } - - return result -} diff --git a/core/http/Server.go b/core/http/Server.go index a6c5bdb..fe32fbc 100644 --- a/core/http/Server.go +++ b/core/http/Server.go @@ -2,21 +2,17 @@ package http import ( "context" - "fmt" - "strings" + "errors" "time" "Moonshark/core/metadata" "Moonshark/core/routers" "Moonshark/core/runner" - luaCtx "Moonshark/core/runner/context" - "Moonshark/core/runner/sandbox" "Moonshark/core/sessions" "Moonshark/core/utils" "Moonshark/core/utils/config" "Moonshark/core/utils/logger" - "github.com/goccy/go-json" "github.com/valyala/fasthttp" ) @@ -29,12 +25,14 @@ type Server struct { loggingEnabled bool debugMode bool config *config.Config + sessionManager *sessions.SessionManager errorConfig utils.ErrorPageConfig } // New creates a new HTTP server with optimized connection settings -func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runner *runner.Runner, - loggingEnabled bool, debugMode bool, overrideDir string, config *config.Config) *Server { +func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, + runner *runner.Runner, loggingEnabled bool, debugMode bool, + overrideDir string, config *config.Config) *Server { server := &Server{ luaRouter: luaRouter, @@ -43,6 +41,7 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne loggingEnabled: loggingEnabled, debugMode: debugMode, config: config, + sessionManager: sessions.GlobalSessionManager, errorConfig: utils.ErrorPageConfig{ OverrideDir: overrideDir, DebugMode: debugMode, @@ -55,7 +54,7 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne Name: "Moonshark/" + metadata.Version, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, - MaxRequestBodySize: 16 << 20, // 16MB - consistent with Forms.go + MaxRequestBodySize: 16 << 20, // 16MB DisableKeepalive: false, TCPKeepalive: true, TCPKeepalivePeriod: 60 * time.Second, @@ -99,7 +98,7 @@ func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) { // Process the request s.processRequest(ctx) - // Log the request with our custom format + // Log the request if s.loggingEnabled { duration := time.Since(start) LogRequest(ctx.Response.StatusCode(), method, path, duration) @@ -153,48 +152,25 @@ func (s *Server) processRequest(ctx *fasthttp.RequestCtx) { ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path))) } -// HandleMethodNotAllowed responds with a 405 Method Not Allowed error -func HandleMethodNotAllowed(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) { - path := string(ctx.Path()) - ctx.SetContentType("text/html; charset=utf-8") - ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed) - ctx.SetBody([]byte(utils.MethodNotAllowedPage(errorConfig, path))) -} - // handleLuaRoute executes a Lua route -// Updated handleLuaRoute function to handle sessions func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params) { - luaCtx := luaCtx.NewHTTPContext(ctx) + // Create context for Lua execution + luaCtx := runner.NewHTTPContext(ctx) defer luaCtx.Release() method := string(ctx.Method()) path := string(ctx.Path()) host := string(ctx.Host()) - // Set up context + // Set up additional context values luaCtx.Set("method", method) luaCtx.Set("path", path) luaCtx.Set("host", host) - // Headers - headerMap := make(map[string]any) - ctx.Request.Header.VisitAll(func(key, value []byte) { - headerMap[string(key)] = string(value) - }) - luaCtx.Set("headers", headerMap) - - // Cookies - cookieMap := make(map[string]any) - ctx.Request.Header.VisitAllCookie(func(key, value []byte) { - cookieMap[string(key)] = string(value) - }) - if len(cookieMap) > 0 { - luaCtx.Set("cookies", cookieMap) - luaCtx.Set("_request_cookies", cookieMap) // For backward compatibility - } else { - luaCtx.Set("cookies", make(map[string]any)) - luaCtx.Set("_request_cookies", make(map[string]any)) - } + // Initialize session + session := s.sessionManager.GetSessionFromRequest(ctx) + luaCtx.SessionID = session.ID + luaCtx.SessionData = session.GetAll() // URL parameters if params.Count > 0 { @@ -207,11 +183,7 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip luaCtx.Set("params", make(map[string]any)) } - // Query parameters - queryMap := QueryToLua(ctx) - luaCtx.Set("query", queryMap) - - // Form data + // Parse form data for POST/PUT/PATCH requests if method == "POST" || method == "PUT" || method == "PATCH" { formData, err := ParseForm(ctx) if err == nil && len(formData) > 0 { @@ -226,40 +198,26 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip luaCtx.Set("form", make(map[string]any)) } - // Session handling - cookieOpts := sessions.GlobalSessionManager.CookieOptions() - cookieName := cookieOpts["name"].(string) - sessionCookie := ctx.Request.Header.Cookie(cookieName) - - var sessionID string - if sessionCookie != nil { - sessionID = string(sessionCookie) + // CSRF middleware for state-changing requests + if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { + if !ValidateCSRFToken(luaCtx) { + HandleCSRFError(ctx, s.errorConfig) + return + } } - // Get or create session - var session *sessions.Session - if sessionID != "" { - session = sessions.GlobalSessionManager.GetSession(sessionID) - } else { - session = sessions.GlobalSessionManager.CreateSession() - } - - // Set session in context - luaCtx.Session = session - // Execute Lua script - result, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath) - - // Special handling for CSRF error + response, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath) if err != nil { - if csrfErr, ok := err.(*CSRFError); ok { - logger.Warning("CSRF error executing Lua route: %v", csrfErr) + logger.Error("Error executing Lua route: %v", err) + + // Special handling for specific errors + if errors.Is(err, ErrCSRFValidationFailed) { HandleCSRFError(ctx, s.errorConfig) return } - // Normal error handling - logger.Error("Error executing Lua route: %v", err) + // General error handling ctx.SetContentType("text/html; charset=utf-8") ctx.SetStatusCode(fasthttp.StatusInternalServerError) errorHTML := utils.InternalErrorPage(s.errorConfig, path, err.Error()) @@ -267,129 +225,21 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip return } - // Handle session updates if needed - if luaCtx.SessionModified { - sessions.GlobalSessionManager.SaveSession(luaCtx.Session) - - // Set session cookie - cookie := fasthttp.AcquireCookie() - cookie.SetKey(cookieName) - cookie.SetValue(luaCtx.Session.ID) - cookie.SetPath(cookieOpts["path"].(string)) - - if domain, ok := cookieOpts["domain"].(string); ok && domain != "" { - cookie.SetDomain(domain) + // Save session if modified + if response.SessionModified { + // Update session data + for k, v := range response.SessionData { + session.Set(k, v) } - - if maxAge, ok := cookieOpts["max_age"].(int); ok { - cookie.SetMaxAge(maxAge) - } - - cookie.SetSecure(cookieOpts["secure"].(bool)) - cookie.SetHTTPOnly(cookieOpts["http_only"].(bool)) - - ctx.Response.Header.SetCookie(cookie) - fasthttp.ReleaseCookie(cookie) + s.sessionManager.SaveSession(session) + s.sessionManager.ApplySessionCookie(ctx, session) } - // If we got a non-nil result, write it to the response - if result != nil { - writeResponse(ctx, result) - } -} + // Apply response to HTTP context + runner.ApplyResponse(response, ctx) -// Content types for responses -const ( - contentTypeJSON = "application/json" - contentTypePlain = "text/plain" -) - -// writeResponse writes the Lua result to the HTTP response -func writeResponse(ctx *fasthttp.RequestCtx, result any) { - if result == nil { - ctx.SetStatusCode(fasthttp.StatusNoContent) - return - } - - // First check the raw type of the result for strong type identification - // Sometimes type assertions don't work as expected with interface values - resultType := fmt.Sprintf("%T", result) - - // Strong check for HTTP response - if strings.Contains(resultType, "HTTPResponse") || strings.Contains(resultType, "sandbox.HTTPResponse") { - httpResp, ok := result.(*sandbox.HTTPResponse) - if ok { - defer sandbox.ReleaseResponse(httpResp) - - // Set response headers - for name, value := range httpResp.Headers { - ctx.Response.Header.Set(name, value) - } - - // Set cookies - for _, cookie := range httpResp.Cookies { - ctx.Response.Header.SetCookie(cookie) - } - - // Set status code - ctx.SetStatusCode(httpResp.Status) - - // Process the body based on its type - if httpResp.Body == nil { - return - } - - // Continue with the body only - result = httpResp.Body - } else { - // We identified it as HTTPResponse but couldn't convert it - // This is a programming error - logger.Error("Found HTTPResponse type but failed to convert: %v", resultType) - ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) - return - } - } - - // Check if it's a map (table) or array - return as JSON - isJSON := false - switch result.(type) { - case map[string]any, []any, []float64, []string, []int: - isJSON = true - } - - if isJSON { - setContentTypeIfMissing(ctx, contentTypeJSON) - data, err := json.Marshal(result) - if err != nil { - logger.Error("Failed to marshal response: %v", err) - ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) - return - } - ctx.SetBody(data) - return - } - - // Handle string and byte slice cases directly - switch r := result.(type) { - case string: - setContentTypeIfMissing(ctx, contentTypePlain) - ctx.SetBodyString(r) - return - case []byte: - setContentTypeIfMissing(ctx, contentTypePlain) - ctx.SetBody(r) - return - } - - // If we reach here, it's an unexpected type - convert to string as a last resort - setContentTypeIfMissing(ctx, contentTypePlain) - ctx.SetBodyString(fmt.Sprintf("%v", result)) -} - -func setContentTypeIfMissing(ctx *fasthttp.RequestCtx, contentType string) { - if len(ctx.Response.Header.ContentType()) == 0 { - ctx.SetContentType(contentType) - } + // Release the response when done + runner.ReleaseResponse(response) } // handleDebugStats displays debug statistics @@ -399,12 +249,14 @@ func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) { // Add component stats routeCount, bytecodeBytes := s.luaRouter.GetRouteStats() - moduleCount := s.luaRunner.GetModuleCount() + //stateCount := s.luaRunner.GetStateCount() + //activeStates := s.luaRunner.GetActiveStateCount() stats.Components = utils.ComponentStats{ RouteCount: routeCount, BytecodeBytes: bytecodeBytes, - ModuleCount: moduleCount, + //StatesCount: stateCount, + //ActiveStates: activeStates, } // Generate HTML page diff --git a/core/http/Utils.go b/core/http/Utils.go new file mode 100644 index 0000000..15b0118 --- /dev/null +++ b/core/http/Utils.go @@ -0,0 +1,206 @@ +package http + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "mime/multipart" + "strings" + "time" + + "Moonshark/core/utils/logger" + + "github.com/valyala/fasthttp" +) + +// LogRequest logs an HTTP request with its status code and duration +func LogRequest(statusCode int, method, path string, duration time.Duration) { + var statusColor, resetColor, methodColor string + + // Status code colors + if statusCode >= 200 && statusCode < 300 { + statusColor = "\u001b[32m" // Green for 2xx + } else if statusCode >= 300 && statusCode < 400 { + statusColor = "\u001b[36m" // Cyan for 3xx + } else if statusCode >= 400 && statusCode < 500 { + statusColor = "\u001b[33m" // Yellow for 4xx + } else { + statusColor = "\u001b[31m" // Red for 5xx and others + } + + // Method colors + switch method { + case "GET": + methodColor = "\u001b[32m" // Green + case "POST": + methodColor = "\u001b[34m" // Blue + case "PUT": + methodColor = "\u001b[33m" // Yellow + case "DELETE": + methodColor = "\u001b[31m" // Red + default: + methodColor = "\u001b[35m" // Magenta for others + } + + resetColor = "\u001b[0m" + + // Format duration + var durationStr string + if duration.Milliseconds() < 1 { + durationStr = fmt.Sprintf("%.2fµs", float64(duration.Microseconds())) + } else if duration.Milliseconds() < 1000 { + durationStr = fmt.Sprintf("%.2fms", float64(duration.Microseconds())/1000) + } else { + durationStr = fmt.Sprintf("%.2fs", duration.Seconds()) + } + + // Log with colors + logger.Server("%s%d%s %s%s%s %s %s", + statusColor, statusCode, resetColor, + methodColor, method, resetColor, + path, durationStr) +} + +// QueryToLua converts HTTP query args to a Lua-friendly map +func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any { + queryMap := make(map[string]any) + + // Visit all query parameters + ctx.QueryArgs().VisitAll(func(key, value []byte) { + // Convert to string + k := string(key) + v := string(value) + + // Check if this key already exists as an array + if existing, ok := queryMap[k]; ok { + // If it's already an array, append to it + if arr, ok := existing.([]string); ok { + queryMap[k] = append(arr, v) + } else if str, ok := existing.(string); ok { + // Convert existing string to array and append new value + queryMap[k] = []string{str, v} + } + } else { + // New key, store as string + queryMap[k] = v + } + }) + + return queryMap +} + +// ParseForm extracts form data from a request +func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { + formData := make(map[string]any) + + // Check if multipart form + if strings.Contains(string(ctx.Request.Header.ContentType()), "multipart/form-data") { + return parseMultipartForm(ctx) + } + + // Regular form + ctx.PostArgs().VisitAll(func(key, value []byte) { + k := string(key) + v := string(value) + + // Check if this key already exists + if existing, ok := formData[k]; ok { + // If it's already an array, append to it + if arr, ok := existing.([]string); ok { + formData[k] = append(arr, v) + } else if str, ok := existing.(string); ok { + // Convert existing string to array and append new value + formData[k] = []string{str, v} + } + } else { + // New key, store as string + formData[k] = v + } + }) + + return formData, nil +} + +// parseMultipartForm handles multipart/form-data requests +func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { + formData := make(map[string]any) + + // Parse multipart form + form, err := ctx.MultipartForm() + if err != nil { + return nil, err + } + + // Process form values + for key, values := range form.Value { + if len(values) == 1 { + formData[key] = values[0] + } else if len(values) > 1 { + formData[key] = values + } + } + + // Process files (store file info, not the content) + if len(form.File) > 0 { + files := make(map[string]any) + + for fieldName, fileHeaders := range form.File { + if len(fileHeaders) == 1 { + files[fieldName] = fileInfoToMap(fileHeaders[0]) + } else if len(fileHeaders) > 1 { + fileInfos := make([]map[string]any, 0, len(fileHeaders)) + for _, fh := range fileHeaders { + fileInfos = append(fileInfos, fileInfoToMap(fh)) + } + files[fieldName] = fileInfos + } + } + + formData["_files"] = files + } + + return formData, nil +} + +// fileInfoToMap converts a FileHeader to a map for Lua +func fileInfoToMap(fh *multipart.FileHeader) map[string]any { + return map[string]any{ + "filename": fh.Filename, + "size": fh.Size, + "mimetype": getMimeType(fh), + } +} + +// getMimeType gets the mime type from a file header +func getMimeType(fh *multipart.FileHeader) string { + if fh.Header != nil { + contentType := fh.Header.Get("Content-Type") + if contentType != "" { + return contentType + } + } + + // Fallback to basic type detection from filename + if strings.HasSuffix(fh.Filename, ".pdf") { + return "application/pdf" + } else if strings.HasSuffix(fh.Filename, ".png") { + return "image/png" + } else if strings.HasSuffix(fh.Filename, ".jpg") || strings.HasSuffix(fh.Filename, ".jpeg") { + return "image/jpeg" + } else if strings.HasSuffix(fh.Filename, ".gif") { + return "image/gif" + } else if strings.HasSuffix(fh.Filename, ".svg") { + return "image/svg+xml" + } + + return "application/octet-stream" +} + +// GenerateSecureToken creates a cryptographically secure random token +func GenerateSecureToken(length int) (string, error) { + b := make([]byte, length) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b)[:length], nil +} diff --git a/core/runner/context/Context.go b/core/runner/Context.go similarity index 51% rename from core/runner/context/Context.go rename to core/runner/Context.go index b093e2f..c6dd0e1 100644 --- a/core/runner/context/Context.go +++ b/core/runner/Context.go @@ -3,8 +3,6 @@ package runner import ( "sync" - "Moonshark/core/sessions" - "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) @@ -17,9 +15,9 @@ type Context struct { // FastHTTP context if this was created from an HTTP request RequestCtx *fasthttp.RequestCtx - // Session data and management - Session *sessions.Session - SessionModified bool + // Session information + SessionID string + SessionData map[string]any // Buffer for efficient string operations buffer *bytebufferpool.ByteBuffer @@ -29,7 +27,8 @@ type Context struct { var contextPool = sync.Pool{ New: func() any { return &Context{ - Values: make(map[string]any, 16), + Values: make(map[string]any, 16), + SessionData: make(map[string]any, 8), } }, } @@ -43,6 +42,44 @@ func NewContext() *Context { func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context { ctx := NewContext() ctx.RequestCtx = requestCtx + + // Extract common HTTP values that Lua might need + if requestCtx != nil { + ctx.Values["_request_method"] = string(requestCtx.Method()) + ctx.Values["_request_path"] = string(requestCtx.Path()) + ctx.Values["_request_url"] = string(requestCtx.RequestURI()) + + // Extract cookies + cookies := make(map[string]any) + requestCtx.Request.Header.VisitAllCookie(func(key, value []byte) { + cookies[string(key)] = string(value) + }) + ctx.Values["_request_cookies"] = cookies + + // Extract query params + query := make(map[string]any) + requestCtx.QueryArgs().VisitAll(func(key, value []byte) { + query[string(key)] = string(value) + }) + ctx.Values["_request_query"] = query + + // Extract form data if present + if requestCtx.IsPost() || requestCtx.IsPut() { + form := make(map[string]any) + requestCtx.PostArgs().VisitAll(func(key, value []byte) { + form[string(key)] = string(value) + }) + ctx.Values["_request_form"] = form + } + + // Extract headers + headers := make(map[string]any) + requestCtx.Request.Header.VisitAll(func(key, value []byte) { + headers[string(key)] = string(value) + }) + ctx.Values["_request_headers"] = headers + } + return ctx } @@ -53,9 +90,12 @@ func (c *Context) Release() { delete(c.Values, k) } + for k := range c.SessionData { + delete(c.SessionData, k) + } + // Reset session info - c.Session = nil - c.SessionModified = false + c.SessionID = "" // Reset request context c.RequestCtx = nil @@ -87,13 +127,12 @@ func (c *Context) Get(key string) any { return c.Values[key] } -// Contains checks if a key exists in the context -func (c *Context) Contains(key string) bool { - _, exists := c.Values[key] - return exists +// SetSession sets a session data value +func (c *Context) SetSession(key string, value any) { + c.SessionData[key] = value } -// Delete removes a value from the context -func (c *Context) Delete(key string) { - delete(c.Values, key) +// GetSession retrieves a session data value +func (c *Context) GetSession(key string) any { + return c.SessionData[key] } diff --git a/core/runner/CoreModules.go b/core/runner/CoreModules.go deleted file mode 100644 index 21d9ed0..0000000 --- a/core/runner/CoreModules.go +++ /dev/null @@ -1,262 +0,0 @@ -package runner - -import ( - "Moonshark/core/runner/sandbox" - "Moonshark/core/utils/logger" - "fmt" - "strings" - "sync" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// CoreModuleRegistry manages the initialization and reloading of core modules -type CoreModuleRegistry struct { - modules map[string]sandbox.StateInitFunc // Module initializers - initOrder []string // Explicit initialization order - dependencies map[string][]string // Module dependencies - initializedFlag map[string]bool // Track which modules are initialized - mu sync.RWMutex - debug bool -} - -// NewCoreModuleRegistry creates a new core module registry -func NewCoreModuleRegistry() *CoreModuleRegistry { - return &CoreModuleRegistry{ - modules: make(map[string]sandbox.StateInitFunc), - initOrder: []string{}, - dependencies: make(map[string][]string), - initializedFlag: make(map[string]bool), - debug: false, - } -} - -// EnableDebug turns on debug logging -func (r *CoreModuleRegistry) EnableDebug() { - r.debug = true -} - -// debugLog prints debug messages if enabled -func (r *CoreModuleRegistry) debugLog(format string, args ...interface{}) { - if r.debug { - logger.Debug("CoreRegistry "+format, args...) - } -} - -// Register adds a module to the registry -func (r *CoreModuleRegistry) Register(name string, initFunc sandbox.StateInitFunc) { - r.mu.Lock() - defer r.mu.Unlock() - - r.modules[name] = initFunc - - // Add to initialization order if not already there - for _, n := range r.initOrder { - if n == name { - return // Already registered - } - } - - r.initOrder = append(r.initOrder, name) - r.debugLog("registered module %s", name) -} - -// RegisterWithDependencies registers a module with explicit dependencies -func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) { - r.mu.Lock() - defer r.mu.Unlock() - - r.modules[name] = initFunc - r.dependencies[name] = dependencies - - // Add to initialization order if not already there - for _, n := range r.initOrder { - if n == name { - return // Already registered - } - } - - r.initOrder = append(r.initOrder, name) - r.debugLog("registered module %s with dependencies: %v", name, dependencies) -} - -// SetInitOrder sets explicit initialization order -func (r *CoreModuleRegistry) SetInitOrder(order []string) { - r.mu.Lock() - defer r.mu.Unlock() - - // Create new init order - newOrder := make([]string, 0, len(order)) - - // First add all known modules that are in the specified order - for _, name := range order { - if _, exists := r.modules[name]; exists && !contains(newOrder, name) { - newOrder = append(newOrder, name) - } - } - - // Then add any modules not in the specified order - for name := range r.modules { - if !contains(newOrder, name) { - newOrder = append(newOrder, name) - } - } - - r.initOrder = newOrder - r.debugLog("Set initialization order: %v", r.initOrder) -} - -// Initialize initializes all registered modules -func (r *CoreModuleRegistry) Initialize(state *luajit.State, stateIndex int) error { - r.mu.RLock() - defer r.mu.RUnlock() - - verbose := stateIndex == 0 - if verbose { - r.debugLog("initializing %d modules...", len(r.initOrder)) - } - - // Clear initialization flags - r.initializedFlag = make(map[string]bool) - - // Initialize modules in order, respecting dependencies - for _, name := range r.initOrder { - if err := r.initializeModule(state, name, []string{}, verbose); err != nil { - return err - } - } - - if verbose { - r.debugLog("All modules initialized successfully") - } - return nil -} - -// initializeModule initializes a module and its dependencies -func (r *CoreModuleRegistry) initializeModule(state *luajit.State, name string, - initStack []string, verbose bool) error { - // Check if already initialized - if r.initializedFlag[name] { - return nil - } - - // Check for circular dependencies - for _, n := range initStack { - if n == name { - return fmt.Errorf("circular dependency detected: %s -> %s", - strings.Join(initStack, " -> "), name) - } - } - - // Get init function - initFunc, ok := r.modules[name] - if !ok { - return fmt.Errorf("module not found: %s", name) - } - - // Initialize dependencies first - deps := r.dependencies[name] - if len(deps) > 0 { - newStack := append(initStack, name) - for _, dep := range deps { - if err := r.initializeModule(state, dep, newStack, verbose); err != nil { - return err - } - } - } - - err := initFunc(state) - if err != nil { - // Always log failures regardless of verbose setting - r.debugLog("Initializing module %s... failure: %v", name, err) - return fmt.Errorf("failed to initialize module %s: %w", name, err) - } - - r.initializedFlag[name] = true - - if verbose { - r.debugLog("Initializing module %s... success", name) - } - - return nil -} - -// InitializeModule initializes a specific module -func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string) error { - r.mu.RLock() - defer r.mu.RUnlock() - - // Clear initialization flag for this module - r.initializedFlag[name] = false - - // Always use verbose logging for explicit module initialization - return r.initializeModule(state, name, []string{}, true) -} - -// MatchModuleName checks if a file path corresponds to a registered module -func (r *CoreModuleRegistry) MatchModuleName(modName string) (string, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - - // Exact match - if _, ok := r.modules[modName]; ok { - return modName, true - } - - // Check if the module name ends with a registered module - for name := range r.modules { - if strings.HasSuffix(modName, "."+name) { - return name, true - } - } - - return "", false -} - -// Global registry instance -var GlobalRegistry = NewCoreModuleRegistry() - -// Initialize global registry with core modules -func init() { - GlobalRegistry.EnableDebug() // Enable debugging by default - logger.Debug("[ModuleRegistry] Registering core modules...") - - // Register core modules - GlobalRegistry.Register("util", func(state *luajit.State) error { - return sandbox.UtilModuleInitFunc()(state) - }) - - GlobalRegistry.Register("http", func(state *luajit.State) error { - return sandbox.HTTPModuleInitFunc()(state) - }) - - // Set explicit initialization order - GlobalRegistry.SetInitOrder([]string{ - "util", // First: core utilities - "http", // Second: HTTP functionality - "session", // Third: Session functionality - "csrf", // Fourth: CSRF protection - }) - - logger.Debug("Core modules registered successfully") -} - -// RegisterCoreModule registers a core module with the global registry -func RegisterCoreModule(name string, initFunc sandbox.StateInitFunc) { - GlobalRegistry.Register(name, initFunc) -} - -// RegisterCoreModuleWithDependencies registers a module with dependencies -func RegisterCoreModuleWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) { - GlobalRegistry.RegisterWithDependencies(name, initFunc, dependencies) -} - -// Helper functions -func contains(slice []string, item string) bool { - for _, s := range slice { - if s == item { - return true - } - } - return false -} diff --git a/core/runner/Embed.go b/core/runner/Embed.go new file mode 100644 index 0000000..700222a --- /dev/null +++ b/core/runner/Embed.go @@ -0,0 +1,61 @@ +package runner + +import ( + _ "embed" + "sync" + "sync/atomic" + + "Moonshark/core/utils/logger" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +//go:embed sandbox.lua +var sandboxLuaCode string + +// Global bytecode cache to improve performance +var ( + sandboxBytecode atomic.Pointer[[]byte] + bytecodeOnce sync.Once +) + +// precompileSandboxCode compiles the sandbox.lua code to bytecode once +func precompileSandboxCode() { + // Create temporary state for compilation + tempState := luajit.New() + if tempState == nil { + logger.Error("Failed to create temp Lua state for bytecode compilation") + return + } + defer tempState.Close() + defer tempState.Cleanup() + + code, err := tempState.CompileBytecode(sandboxLuaCode, "sandbox.lua") + if err != nil { + logger.Error("Failed to compile sandbox code: %v", err) + return + } + + bytecode := make([]byte, len(code)) + copy(bytecode, code) + sandboxBytecode.Store(&bytecode) + + logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(code)) +} + +// loadSandboxIntoState loads the sandbox code into a Lua state +func loadSandboxIntoState(state *luajit.State) error { + // Initialize bytecode once + bytecodeOnce.Do(precompileSandboxCode) + + // Use precompiled bytecode if available + bytecode := sandboxBytecode.Load() + if bytecode != nil && len(*bytecode) > 0 { + logger.Debug("Loading sandbox.lua from precompiled bytecode") + return state.LoadAndRunBytecode(*bytecode, "sandbox.lua") + } + + // Fallback to direct execution + logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)") + return state.DoString(sandboxLuaCode) +} diff --git a/core/runner/Http.go b/core/runner/Http.go new file mode 100644 index 0000000..78a8ec9 --- /dev/null +++ b/core/runner/Http.go @@ -0,0 +1,334 @@ +package runner + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/goccy/go-json" + "github.com/valyala/bytebufferpool" + "github.com/valyala/fasthttp" + + "Moonshark/core/utils/logger" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// Default HTTP client with sensible timeout +var defaultFastClient = fasthttp.Client{ + MaxConnsPerHost: 1024, + MaxIdleConnDuration: time.Minute, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + DisableHeaderNamesNormalizing: true, +} + +// HTTPClientConfig contains client settings +type HTTPClientConfig struct { + MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit) + DefaultTimeout time.Duration // Default request timeout + MaxResponseSize int64 // Maximum response size in bytes (0 = no limit) + AllowRemote bool // Whether to allow remote connections +} + +// DefaultHTTPClientConfig provides sensible defaults +var DefaultHTTPClientConfig = HTTPClientConfig{ + MaxTimeout: 60 * time.Second, + DefaultTimeout: 30 * time.Second, + MaxResponseSize: 10 * 1024 * 1024, // 10MB + AllowRemote: true, +} + +// ApplyResponse applies a Response to a fasthttp.RequestCtx +func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) { + // Set status code + ctx.SetStatusCode(resp.Status) + + // Set headers + for name, value := range resp.Headers { + ctx.Response.Header.Set(name, value) + } + + // Set cookies + for _, cookie := range resp.Cookies { + ctx.Response.Header.SetCookie(cookie) + } + + // Process the body based on its type + if resp.Body == nil { + return + } + + // Get a buffer from the pool + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + // Set body based on type + switch body := resp.Body.(type) { + case string: + ctx.SetBodyString(body) + case []byte: + ctx.SetBody(body) + case map[string]any, []any, []float64, []string, []int: + // Marshal JSON + if err := json.NewEncoder(buf).Encode(body); err == nil { + // Set content type if not already set + if len(ctx.Response.Header.ContentType()) == 0 { + ctx.Response.Header.SetContentType("application/json") + } + ctx.SetBody(buf.Bytes()) + } else { + // Fallback + ctx.SetBodyString(fmt.Sprintf("%v", body)) + } + default: + // Default to string representation + ctx.SetBodyString(fmt.Sprintf("%v", body)) + } +} + +// httpRequest makes an HTTP request and returns the result to Lua +func httpRequest(state *luajit.State) int { + // Get method (required) + if !state.IsString(1) { + state.PushString("http.client.request: method must be a string") + return -1 + } + method := strings.ToUpper(state.ToString(1)) + + // Get URL (required) + if !state.IsString(2) { + state.PushString("http.client.request: url must be a string") + return -1 + } + urlStr := state.ToString(2) + + // Parse URL to check if it's valid + parsedURL, err := url.Parse(urlStr) + if err != nil { + state.PushString("Invalid URL: " + err.Error()) + return -1 + } + + // Get client configuration + config := DefaultHTTPClientConfig + + // Check if remote connections are allowed + if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") { + state.PushString("Remote connections are not allowed") + return -1 + } + + // Use bytebufferpool for request and response + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set up request + req.Header.SetMethod(method) + req.SetRequestURI(urlStr) + req.Header.Set("User-Agent", "Moonshark/1.0") + + // Get body (optional) + if state.GetTop() >= 3 && !state.IsNil(3) { + if state.IsString(3) { + // String body + req.SetBodyString(state.ToString(3)) + } else if state.IsTable(3) { + // Table body - convert to JSON + luaTable, err := state.ToTable(3) + if err != nil { + state.PushString("Failed to parse body table: " + err.Error()) + return -1 + } + + // Use bytebufferpool for JSON serialization + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + if err := json.NewEncoder(buf).Encode(luaTable); err != nil { + state.PushString("Failed to convert body to JSON: " + err.Error()) + return -1 + } + + req.SetBody(buf.Bytes()) + req.Header.SetContentType("application/json") + } else { + state.PushString("Body must be a string or table") + return -1 + } + } + + // Process options (headers, timeout, etc.) + timeout := config.DefaultTimeout + if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) { + // Process headers + state.GetField(4, "headers") + if state.IsTable(-1) { + // Iterate through headers + state.PushNil() // Start iteration + for state.Next(-2) { + // Stack now has key at -2 and value at -1 + if state.IsString(-2) && state.IsString(-1) { + headerName := state.ToString(-2) + headerValue := state.ToString(-1) + req.Header.Set(headerName, headerValue) + } + state.Pop(1) // Pop value, leave key for next iteration + } + } + state.Pop(1) // Pop headers table + + // Get timeout + state.GetField(4, "timeout") + if state.IsNumber(-1) { + requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second + + // Apply max timeout if configured + if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout { + timeout = config.MaxTimeout + } else { + timeout = requestTimeout + } + } + state.Pop(1) // Pop timeout + + // Process query parameters + state.GetField(4, "query") + if state.IsTable(-1) { + // Create URL args + args := req.URI().QueryArgs() + + // Iterate through query params + state.PushNil() // Start iteration + for state.Next(-2) { + if state.IsString(-2) { + paramName := state.ToString(-2) + + // Handle different value types + if state.IsString(-1) { + args.Add(paramName, state.ToString(-1)) + } else if state.IsNumber(-1) { + args.Add(paramName, strings.TrimRight(strings.TrimRight( + state.ToString(-1), "0"), ".")) + } else if state.IsBoolean(-1) { + if state.ToBoolean(-1) { + args.Add(paramName, "true") + } else { + args.Add(paramName, "false") + } + } + } + state.Pop(1) // Pop value, leave key for next iteration + } + } + state.Pop(1) // Pop query table + } + + // Create context with timeout + _, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Execute request + err = defaultFastClient.DoTimeout(req, resp, timeout) + if err != nil { + errStr := "Request failed: " + err.Error() + if errors.Is(err, fasthttp.ErrTimeout) { + errStr = "Request timed out after " + timeout.String() + } + state.PushString(errStr) + return -1 + } + + // Create response table + state.NewTable() + + // Set status code + state.PushNumber(float64(resp.StatusCode())) + state.SetField(-2, "status") + + // Set status text + statusText := fasthttp.StatusMessage(resp.StatusCode()) + state.PushString(statusText) + state.SetField(-2, "status_text") + + // Set body + var respBody []byte + + // Apply size limits to response + if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize { + // Make a limited copy + respBody = make([]byte, config.MaxResponseSize) + copy(respBody, resp.Body()) + } else { + respBody = resp.Body() + } + + state.PushString(string(respBody)) + state.SetField(-2, "body") + + // Parse body as JSON if content type is application/json + contentType := string(resp.Header.ContentType()) + if strings.Contains(contentType, "application/json") { + var jsonData any + if err := json.Unmarshal(respBody, &jsonData); err == nil { + if err := state.PushValue(jsonData); err == nil { + state.SetField(-2, "json") + } + } + } + + // Set headers + state.NewTable() + resp.Header.VisitAll(func(key, value []byte) { + state.PushString(string(value)) + state.SetField(-2, string(key)) + }) + state.SetField(-2, "headers") + + // Create ok field (true if status code is 2xx) + state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300) + state.SetField(-2, "ok") + + return 1 +} + +// generateToken creates a cryptographically secure random token +func generateToken(state *luajit.State) int { + // Get the length from the Lua arguments (default to 32) + length := 32 + if state.GetTop() >= 1 && state.IsNumber(1) { + length = int(state.ToNumber(1)) + } + + // Enforce minimum length for security + if length < 16 { + length = 16 + } + + // Generate secure random bytes + tokenBytes := make([]byte, length) + if _, err := rand.Read(tokenBytes); err != nil { + logger.Error("Failed to generate secure token: %v", err) + state.PushString("") + return 1 // Return empty string on error + } + + // Encode as base64 + token := base64.RawURLEncoding.EncodeToString(tokenBytes) + + // Trim to requested length (base64 might be longer) + if len(token) > length { + token = token[:length] + } + + // Push the token to the Lua stack + state.PushString(token) + return 1 // One return value +} diff --git a/core/runner/ModuleLoader.go b/core/runner/ModuleLoader.go index e4da43d..7d0496c 100644 --- a/core/runner/ModuleLoader.go +++ b/core/runner/ModuleLoader.go @@ -6,6 +6,8 @@ import ( "strings" "sync" + "Moonshark/core/utils/logger" + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) @@ -15,61 +17,15 @@ type ModuleConfig struct { LibDirs []string // Additional library directories } -// ModuleInfo stores information about a loaded module -type ModuleInfo struct { - Name string - Path string - IsCore bool - Bytecode []byte -} - // ModuleLoader manages module loading and caching type ModuleLoader struct { config *ModuleConfig - registry *ModuleRegistry pathCache map[string]string // Cache module paths for fast lookups bytecodeCache map[string][]byte // Cache of compiled bytecode debug bool mu sync.RWMutex } -// ModuleRegistry keeps track of Lua modules for file watching -type ModuleRegistry struct { - // Maps file paths to module names - pathToModule sync.Map - // Maps module names to file paths - moduleToPath sync.Map -} - -// NewModuleRegistry creates a new module registry -func NewModuleRegistry() *ModuleRegistry { - return &ModuleRegistry{} -} - -// Register adds a module path to the registry -func (r *ModuleRegistry) Register(path string, name string) { - r.pathToModule.Store(path, name) - r.moduleToPath.Store(name, path) -} - -// GetModuleName retrieves a module name by path -func (r *ModuleRegistry) GetModuleName(path string) (string, bool) { - value, ok := r.pathToModule.Load(path) - if !ok { - return "", false - } - return value.(string), true -} - -// GetModulePath retrieves a path by module name -func (r *ModuleRegistry) GetModulePath(name string) (string, bool) { - value, ok := r.moduleToPath.Load(name) - if !ok { - return "", false - } - return value.(string), true -} - // NewModuleLoader creates a new module loader func NewModuleLoader(config *ModuleConfig) *ModuleLoader { if config == nil { @@ -81,7 +37,6 @@ func NewModuleLoader(config *ModuleConfig) *ModuleLoader { return &ModuleLoader{ config: config, - registry: NewModuleRegistry(), pathCache: make(map[string]string), bytecodeCache: make(map[string][]byte), debug: false, @@ -100,6 +55,13 @@ func (l *ModuleLoader) SetScriptDir(dir string) { l.config.ScriptDir = dir } +// debugLog logs a message if debug mode is enabled +func (l *ModuleLoader) debugLog(format string, args ...interface{}) { + if l.debug { + logger.Debug("ModuleLoader "+format, args...) + } +} + // SetupRequire configures the require system in a Lua state func (l *ModuleLoader) SetupRequire(state *luajit.State) error { l.mu.RLock() @@ -207,6 +169,8 @@ func (l *ModuleLoader) PreloadModules(state *luajit.State) error { continue } + l.debugLog("Scanning directory: %s", absDir) + // Find all Lua files err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error { if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") { @@ -223,19 +187,22 @@ func (l *ModuleLoader) PreloadModules(state *luajit.State) error { modName := strings.TrimSuffix(relPath, ".lua") modName = strings.ReplaceAll(modName, string(filepath.Separator), ".") + l.debugLog("Found module: %s at %s", modName, path) + // Register in our caches l.pathCache[modName] = path - l.registry.Register(path, modName) // Load file content content, err := os.ReadFile(path) if err != nil { + l.debugLog("Failed to read module file: %v", err) return nil } // Compile to bytecode bytecode, err := state.CompileBytecode(string(content), path) if err != nil { + l.debugLog("Failed to compile module: %v", err) return nil } @@ -354,10 +321,11 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) { // Clean path for proper comparison path = filepath.Clean(path) - // Try direct lookup from registry - modName, found := l.registry.GetModuleName(path) - if found { - return modName, true + // Try direct lookup from cache + for modName, modPath := range l.pathCache { + if modPath == path { + return modName, true + } } // Try to find by relative path from lib dirs @@ -373,7 +341,7 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) { } if strings.HasSuffix(relPath, ".lua") { - modName = strings.TrimSuffix(relPath, ".lua") + modName := strings.TrimSuffix(relPath, ".lua") modName = strings.ReplaceAll(modName, string(filepath.Separator), ".") return modName, true } @@ -382,103 +350,6 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) { return "", false } -// ReloadModule reloads a module from disk -func (l *ModuleLoader) ReloadModule(state *luajit.State, name string) (bool, error) { - l.mu.Lock() - defer l.mu.Unlock() - - // Get module path - path, ok := l.registry.GetModulePath(name) - if !ok { - for modName, modPath := range l.pathCache { - if modName == name { - path = modPath - ok = true - break - } - } - } - - if !ok || path == "" { - return false, nil - } - - // Invalidate module in Lua - err := state.DoString(` - package.loaded["` + name + `"] = nil - __ready_modules["` + name + `"] = nil - if package.preload then - package.preload["` + name + `"] = nil - end - `) - - if err != nil { - return false, err - } - - // Check if file still exists - if _, err := os.Stat(path); os.IsNotExist(err) { - // File was deleted, just invalidate - delete(l.pathCache, name) - delete(l.bytecodeCache, name) - l.registry.moduleToPath.Delete(name) - l.registry.pathToModule.Delete(path) - return true, nil - } - - // Read updated file - content, err := os.ReadFile(path) - if err != nil { - return false, err - } - - // Compile to bytecode - bytecode, err := state.CompileBytecode(string(content), path) - if err != nil { - return false, err - } - - // Update cache - l.bytecodeCache[name] = bytecode - - // Load bytecode into state - if err := state.LoadBytecode(bytecode, path); err != nil { - return false, err - } - - // Update preload - luaCode := ` - local modname = "` + name + `" - package.loaded[modname] = nil - package.preload[modname] = ... - __ready_modules[modname] = true - ` - - if err := state.DoString(luaCode); err != nil { - state.Pop(1) // Remove chunk from stack - return false, err - } - - state.Pop(1) // Remove chunk from stack - return true, nil -} - -// ResetModules clears non-core modules from package.loaded -func (l *ModuleLoader) ResetModules(state *luajit.State) error { - return state.DoString(` - local core_modules = { - string = true, table = true, math = true, os = true, - package = true, io = true, coroutine = true, debug = true, _G = true - } - - for name in pairs(package.loaded) do - if not core_modules[name] then - package.loaded[name] = nil - end - end - `) -} - // escapeLuaString escapes special characters in a string for Lua func escapeLuaString(s string) string { replacer := strings.NewReplacer( diff --git a/core/runner/Response.go b/core/runner/Response.go new file mode 100644 index 0000000..b608ba5 --- /dev/null +++ b/core/runner/Response.go @@ -0,0 +1,76 @@ +package runner + +import ( + "sync" + + "github.com/valyala/fasthttp" +) + +// Response represents a unified response from script execution +type Response struct { + // Basic properties + Body any // Body content (any type) + Metadata map[string]any // Additional metadata + + // HTTP specific properties + Status int // HTTP status code + Headers map[string]string // HTTP headers + Cookies []*fasthttp.Cookie // HTTP cookies + + // Session information + SessionID string // Session ID + SessionData map[string]any // Session data + SessionModified bool // Whether session was modified +} + +// Response pool to reduce allocations +var responsePool = sync.Pool{ + New: func() any { + return &Response{ + Status: 200, + Headers: make(map[string]string, 8), + Metadata: make(map[string]any, 8), + Cookies: make([]*fasthttp.Cookie, 0, 4), + SessionData: make(map[string]any, 8), + } + }, +} + +// NewResponse creates a new response object from the pool +func NewResponse() *Response { + return responsePool.Get().(*Response) +} + +// Release returns a response to the pool after cleaning it +func ReleaseResponse(resp *Response) { + if resp == nil { + return + } + + // Reset fields to default values + resp.Body = nil + resp.Status = 200 + + // Clear maps + for k := range resp.Headers { + delete(resp.Headers, k) + } + + for k := range resp.Metadata { + delete(resp.Metadata, k) + } + + for k := range resp.SessionData { + delete(resp.SessionData, k) + } + + // Clear cookies + resp.Cookies = resp.Cookies[:0] + + // Reset session info + resp.SessionID = "" + resp.SessionModified = false + + // Return to pool + responsePool.Put(resp) +} diff --git a/core/runner/Runner.go b/core/runner/Runner.go index 40952a9..aee4c85 100644 --- a/core/runner/Runner.go +++ b/core/runner/Runner.go @@ -9,8 +9,6 @@ import ( "sync/atomic" "time" - luaCtx "Moonshark/core/runner/context" - "Moonshark/core/runner/sandbox" "Moonshark/core/utils/logger" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" @@ -29,30 +27,22 @@ type RunnerOption func(*Runner) // State wraps a Lua state with its sandbox type State struct { - L *luajit.State // The Lua state - sandbox *sandbox.Sandbox // Associated sandbox - index int // Index for debugging - inUse bool // Whether the state is currently in use + L *luajit.State // The Lua state + sandbox *Sandbox // Associated sandbox + index int // Index for debugging + inUse bool // Whether the state is currently in use } -// InitHook runs before executing a script -type InitHook func(*luajit.State, *luaCtx.Context) error - -// FinalizeHook runs after executing a script -type FinalizeHook func(*luajit.State, *luaCtx.Context, any) error - // Runner runs Lua scripts using a pool of Lua states type Runner struct { - states []*State // All states managed by this runner - statePool chan int // Pool of available state indexes - poolSize int // Size of the state pool - moduleLoader *ModuleLoader // Module loader - isRunning atomic.Bool // Whether the runner is active - mu sync.RWMutex // Mutex for thread safety - debug bool // Enable debug logging - initHooks []InitHook // Hooks run before script execution - finalizeHooks []FinalizeHook // Hooks run after script execution - scriptDir string // Current script directory + states []*State // All states managed by this runner + statePool chan int // Pool of available state indexes + poolSize int // Size of the state pool + moduleLoader *ModuleLoader // Module loader + isRunning atomic.Bool // Whether the runner is active + mu sync.RWMutex // Mutex for thread safety + debug bool // Enable debug logging + scriptDir string // Current script directory } // WithPoolSize sets the state pool size @@ -84,28 +74,12 @@ func WithLibDirs(dirs ...string) RunnerOption { } } -// WithInitHook adds a hook to run before script execution -func WithInitHook(hook InitHook) RunnerOption { - return func(r *Runner) { - r.initHooks = append(r.initHooks, hook) - } -} - -// WithFinalizeHook adds a hook to run after script execution -func WithFinalizeHook(hook FinalizeHook) RunnerOption { - return func(r *Runner) { - r.finalizeHooks = append(r.finalizeHooks, hook) - } -} - // NewRunner creates a new Runner with a pool of states func NewRunner(options ...RunnerOption) (*Runner, error) { // Default configuration runner := &Runner{ - poolSize: runtime.GOMAXPROCS(0), - debug: false, - initHooks: make([]InitHook, 0, 4), - finalizeHooks: make([]FinalizeHook, 0, 4), + poolSize: runtime.GOMAXPROCS(0), + debug: false, } // Apply options @@ -122,6 +96,11 @@ func NewRunner(options ...RunnerOption) (*Runner, error) { runner.moduleLoader = NewModuleLoader(config) } + // Enable debug if requested + if runner.debug { + runner.moduleLoader.EnableDebug() + } + // Initialize states and pool runner.states = make([]*State, runner.poolSize) runner.statePool = make(chan int, runner.poolSize) @@ -145,7 +124,7 @@ func (r *Runner) debugLog(format string, args ...interface{}) { // initializeStates creates and initializes all states in the pool func (r *Runner) initializeStates() error { - r.debugLog("is initializing %d states", r.poolSize) + r.debugLog("Initializing %d states", r.poolSize) // Create all states for i := 0; i < r.poolSize; i++ { @@ -175,39 +154,36 @@ func (r *Runner) createState(index int) (*State, error) { } // Create sandbox - sb := sandbox.NewSandbox() - if r.debug && verbose { + sb := NewSandbox() + if r.debug { sb.EnableDebug() } - // Set up require system + // Set up sandbox + if err := sb.Setup(L); err != nil { + L.Cleanup() + L.Close() + return nil, ErrInitFailed + } + + // Set up module loader if err := r.moduleLoader.SetupRequire(L); err != nil { L.Cleanup() L.Close() return nil, ErrInitFailed } - // Initialize all core modules from the registry - if err := GlobalRegistry.Initialize(L, index); err != nil { - L.Cleanup() - L.Close() - return nil, ErrInitFailed - } - - // Set up sandbox after core modules are initialized - if err := sb.Setup(L, index); err != nil { - L.Cleanup() - L.Close() - return nil, ErrInitFailed - } - - // Preload all modules + // Preload modules if err := r.moduleLoader.PreloadModules(L); err != nil { L.Cleanup() L.Close() return nil, errors.New("failed to preload modules") } + if verbose { + r.debugLog("Lua state %d initialized successfully", index) + } + return &State{ L: L, sandbox: sb, @@ -216,8 +192,8 @@ func (r *Runner) createState(index int) (*State, error) { }, nil } -// Execute runs a script with context -func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *luaCtx.Context, scriptPath string) (any, error) { +// Execute runs a script in a sandbox with context +func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) { if !r.isRunning.Load() { return nil, ErrRunnerClosed } @@ -264,70 +240,17 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *luaCtx.C } }() - // Run init hooks - for _, hook := range r.initHooks { - if err := hook(state.L, execCtx); err != nil { - return nil, err - } - } - - // Get context values - var ctxValues map[string]any - if execCtx != nil { - ctxValues = execCtx.Values - } - - // Execute in sandbox with optimized context handling - var result any - var err error - - if execCtx != nil && execCtx.RequestCtx != nil { - // Use OptimizedExecute directly with the full context if we have RequestCtx - result, err = state.sandbox.OptimizedExecute(state.L, bytecode, &luaCtx.Context{ - Values: ctxValues, - RequestCtx: execCtx.RequestCtx, - }) - } else { - // Otherwise use standard Execute with just values - result, err = state.sandbox.Execute(state.L, bytecode, ctxValues) - } - + // Execute in sandbox + response, err := state.sandbox.Execute(state.L, bytecode, execCtx) if err != nil { return nil, err } - // Run finalize hooks - for _, hook := range r.finalizeHooks { - if hookErr := hook(state.L, execCtx, result); hookErr != nil { - return nil, hookErr - } - } - - // Check for HTTP response if we don't have a RequestCtx or if we still have a result - if execCtx == nil || execCtx.RequestCtx == nil || result != nil { - httpResp, hasResponse := sandbox.GetHTTPResponse(state.L) - if hasResponse { - // Set result as body if not already set - if httpResp.Body == nil { - httpResp.Body = result - } - - // Apply directly to request context if available - if execCtx != nil && execCtx.RequestCtx != nil { - sandbox.ApplyHTTPResponse(httpResp, execCtx.RequestCtx) - sandbox.ReleaseResponse(httpResp) - return nil, nil - } - - return httpResp, nil - } - } - - return result, err + return response, nil } -// Run executes a Lua script (convenience wrapper) -func (r *Runner) Run(bytecode []byte, execCtx *luaCtx.Context, scriptPath string) (any, error) { +// Run executes a Lua script with immediate context +func (r *Runner) Run(bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) { return r.Execute(context.Background(), bytecode, execCtx, scriptPath) } @@ -363,6 +286,7 @@ cleanup: } } + r.debugLog("Runner closed") return nil } @@ -375,6 +299,8 @@ func (r *Runner) RefreshStates() error { return ErrRunnerClosed } + r.debugLog("Refreshing all states...") + // Drain all states from the pool for { select { @@ -408,81 +334,6 @@ cleanup: return nil } -// AddInitHook adds a hook to be called before script execution -func (r *Runner) AddInitHook(hook InitHook) { - r.mu.Lock() - defer r.mu.Unlock() - r.initHooks = append(r.initHooks, hook) -} - -// AddFinalizeHook adds a hook to be called after script execution -func (r *Runner) AddFinalizeHook(hook FinalizeHook) { - r.mu.Lock() - defer r.mu.Unlock() - r.finalizeHooks = append(r.finalizeHooks, hook) -} - -// GetStateCount returns the number of initialized states -func (r *Runner) GetStateCount() int { - r.mu.RLock() - defer r.mu.RUnlock() - - count := 0 - for _, state := range r.states { - if state != nil { - count++ - } - } - - return count -} - -// GetActiveStateCount returns the number of states currently in use -func (r *Runner) GetActiveStateCount() int { - r.mu.RLock() - defer r.mu.RUnlock() - - count := 0 - for _, state := range r.states { - if state != nil && state.inUse { - count++ - } - } - - return count -} - -// GetModuleCount returns the number of loaded modules in the first available state -func (r *Runner) GetModuleCount() int { - r.mu.RLock() - defer r.mu.RUnlock() - - if !r.isRunning.Load() { - return 0 - } - - // Find first available state - for _, state := range r.states { - if state != nil && !state.inUse { - // Execute a Lua snippet to count modules - if res, err := state.L.ExecuteWithResult(` - local count = 0 - for _ in pairs(package.loaded) do - count = count + 1 - end - return count - `); err == nil { - if num, ok := res.(float64); ok { - return int(num) - } - } - break - } - } - - return 0 -} - // NotifyFileChanged alerts the runner about file changes func (r *Runner) NotifyFileChanged(filePath string) bool { r.debugLog("File change detected: %s", filePath) @@ -514,9 +365,6 @@ func (r *Runner) RefreshModule(moduleName string) bool { r.debugLog("Refreshing module: %s", moduleName) - // Check if it's a core module - coreName, isCore := GlobalRegistry.MatchModuleName(moduleName) - success := true for _, state := range r.states { if state == nil || state.inUse { @@ -526,16 +374,39 @@ func (r *Runner) RefreshModule(moduleName string) bool { // Invalidate module in Lua if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil { success = false - continue - } - - // For core modules, reinitialize them - if isCore { - if err := GlobalRegistry.InitializeModule(state.L, coreName); err != nil { - success = false - } + r.debugLog("Failed to invalidate module %s: %v", moduleName, err) } } return success } + +// GetStateCount returns the number of initialized states +func (r *Runner) GetStateCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + + count := 0 + for _, state := range r.states { + if state != nil { + count++ + } + } + + return count +} + +// GetActiveStateCount returns the number of states currently in use +func (r *Runner) GetActiveStateCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + + count := 0 + for _, state := range r.states { + if state != nil && state.inUse { + count++ + } + } + + return count +} diff --git a/core/runner/Sandbox.go b/core/runner/Sandbox.go new file mode 100644 index 0000000..0a1d10c --- /dev/null +++ b/core/runner/Sandbox.go @@ -0,0 +1,345 @@ +package runner + +import ( + "fmt" + "sync" + + "github.com/valyala/bytebufferpool" + "github.com/valyala/fasthttp" + + "Moonshark/core/utils/logger" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// Error represents a simple error string +type Error string + +func (e Error) Error() string { + return string(e) +} + +// Error types +var ( + ErrSandboxNotInitialized = Error("sandbox not initialized") +) + +// Sandbox provides a secure execution environment for Lua scripts +type Sandbox struct { + modules map[string]any + debug bool + mu sync.RWMutex +} + +// NewSandbox creates a new sandbox environment +func NewSandbox() *Sandbox { + return &Sandbox{ + modules: make(map[string]any, 8), + debug: false, + } +} + +// EnableDebug turns on debug logging +func (s *Sandbox) EnableDebug() { + s.debug = true +} + +// debugLog logs a message if debug mode is enabled +func (s *Sandbox) debugLog(format string, args ...interface{}) { + if s.debug { + logger.Debug("Sandbox "+format, args...) + } +} + +// AddModule adds a module to the sandbox environment +func (s *Sandbox) AddModule(name string, module any) { + s.mu.Lock() + defer s.mu.Unlock() + s.modules[name] = module + s.debugLog("Added module: %s", name) +} + +// Setup initializes the sandbox in a Lua state +func (s *Sandbox) Setup(state *luajit.State) error { + s.debugLog("Setting up sandbox...") + + // Load the sandbox code + if err := loadSandboxIntoState(state); err != nil { + s.debugLog("Failed to load sandbox: %v", err) + return err + } + + // Register core functions + if err := s.registerCoreFunctions(state); err != nil { + s.debugLog("Failed to register core functions: %v", err) + return err + } + + // Register custom modules in the global environment + s.mu.RLock() + for name, module := range s.modules { + s.debugLog("Registering module: %s", name) + if err := state.PushValue(module); err != nil { + s.mu.RUnlock() + s.debugLog("Failed to register module %s: %v", name, err) + return err + } + state.SetGlobal(name) + } + s.mu.RUnlock() + + s.debugLog("Sandbox setup complete") + return nil +} + +// registerCoreFunctions registers all built-in functions in the Lua state +func (s *Sandbox) registerCoreFunctions(state *luajit.State) error { + // Register HTTP functions + if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil { + return err + } + + // Register utility functions + if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil { + return err + } + + // Additional registrations can be added here + + return nil +} + +// Execute runs a Lua script in the sandbox with the given context +func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { + s.debugLog("Executing script...") + + // Create a response object + response := NewResponse() + + // Get a buffer for string operations + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + // Load bytecode + if err := state.LoadBytecode(bytecode, "script"); err != nil { + ReleaseResponse(response) + s.debugLog("Failed to load bytecode: %v", err) + return nil, fmt.Errorf("failed to load script: %w", err) + } + + // Initialize session data in Lua + if ctx.SessionID != "" { + // Set session ID + state.PushString(ctx.SessionID) + state.SetGlobal("__session_id") + + // Set session data + if err := state.PushTable(ctx.SessionData); err != nil { + ReleaseResponse(response) + s.debugLog("Failed to push session data: %v", err) + return nil, err + } + state.SetGlobal("__session_data") + + // Reset modification flag + state.PushBoolean(false) + state.SetGlobal("__session_modified") + } else { + // Initialize empty session + if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil { + s.debugLog("Failed to initialize empty session data: %v", err) + } + } + + // Set up context values for execution + if err := state.PushTable(ctx.Values); err != nil { + ReleaseResponse(response) + s.debugLog("Failed to push context values: %v", err) + return nil, err + } + + // Get the execution function + state.GetGlobal("__execute_script") + if !state.IsFunction(-1) { + state.Pop(1) // Pop non-function + ReleaseResponse(response) + s.debugLog("__execute_script is not a function") + return nil, ErrSandboxNotInitialized + } + + // Push function and context to stack + state.PushCopy(-2) // bytecode + state.PushCopy(-2) // context + + // Remove duplicates + state.Remove(-4) + state.Remove(-3) + + // Execute with 2 args, 1 result + if err := state.Call(2, 1); err != nil { + ReleaseResponse(response) + s.debugLog("Execution failed: %v", err) + return nil, fmt.Errorf("script execution failed: %w", err) + } + + // Set response body from result + body, err := state.ToValue(-1) + if err == nil { + response.Body = body + } + state.Pop(1) + + // Extract HTTP response data from Lua state + s.extractResponseData(state, response) + + return response, nil +} + +// extractResponseData pulls response info from the Lua state +func (s *Sandbox) extractResponseData(state *luajit.State, response *Response) { + // Get HTTP response + state.GetGlobal("__http_responses") + if !state.IsNil(-1) && state.IsTable(-1) { + state.PushNumber(1) + state.GetTable(-2) + + if !state.IsNil(-1) && state.IsTable(-1) { + // Extract status + state.GetField(-1, "status") + if state.IsNumber(-1) { + response.Status = int(state.ToNumber(-1)) + } + state.Pop(1) + + // Extract headers + state.GetField(-1, "headers") + if state.IsTable(-1) { + state.PushNil() // Start iteration + for state.Next(-2) { + if state.IsString(-2) && state.IsString(-1) { + key := state.ToString(-2) + value := state.ToString(-1) + response.Headers[key] = value + } + state.Pop(1) + } + } + state.Pop(1) + + // Extract cookies + state.GetField(-1, "cookies") + if state.IsTable(-1) { + length := state.GetTableLength(-1) + for i := 1; i <= length; i++ { + state.PushNumber(float64(i)) + state.GetTable(-2) + + if state.IsTable(-1) { + s.extractCookie(state, response) + } + state.Pop(1) + } + } + state.Pop(1) + + // Extract metadata if present + state.GetField(-1, "metadata") + if state.IsTable(-1) { + table, err := state.ToTable(-1) + if err == nil { + for k, v := range table { + response.Metadata[k] = v + } + } + } + state.Pop(1) + } + state.Pop(1) + } + state.Pop(1) + + // Extract session data + state.GetGlobal("__session_modified") + if state.IsBoolean(-1) && state.ToBoolean(-1) { + response.SessionModified = true + + // Get session ID + state.GetGlobal("__session_id") + if state.IsString(-1) { + response.SessionID = state.ToString(-1) + } + state.Pop(1) + + // Get session data + state.GetGlobal("__session_data") + if state.IsTable(-1) { + sessionData, err := state.ToTable(-1) + if err == nil { + for k, v := range sessionData { + response.SessionData[k] = v + } + } + } + state.Pop(1) + } + state.Pop(1) +} + +// extractCookie pulls cookie data from the current table on the stack +func (s *Sandbox) extractCookie(state *luajit.State, response *Response) { + cookie := fasthttp.AcquireCookie() + + // Get name (required) + state.GetField(-1, "name") + if !state.IsString(-1) { + state.Pop(1) + fasthttp.ReleaseCookie(cookie) + return + } + cookie.SetKey(state.ToString(-1)) + state.Pop(1) + + // Get value + state.GetField(-1, "value") + if state.IsString(-1) { + cookie.SetValue(state.ToString(-1)) + } + state.Pop(1) + + // Get path + state.GetField(-1, "path") + if state.IsString(-1) { + cookie.SetPath(state.ToString(-1)) + } else { + cookie.SetPath("/") // Default + } + state.Pop(1) + + // Get domain + state.GetField(-1, "domain") + if state.IsString(-1) { + cookie.SetDomain(state.ToString(-1)) + } + state.Pop(1) + + // Get other parameters + state.GetField(-1, "http_only") + if state.IsBoolean(-1) { + cookie.SetHTTPOnly(state.ToBoolean(-1)) + } + state.Pop(1) + + state.GetField(-1, "secure") + if state.IsBoolean(-1) { + cookie.SetSecure(state.ToBoolean(-1)) + } + state.Pop(1) + + state.GetField(-1, "max_age") + if state.IsNumber(-1) { + cookie.SetMaxAge(int(state.ToNumber(-1))) + } + state.Pop(1) + + response.Cookies = append(response.Cookies, cookie) +} diff --git a/core/runner/Sessions.go b/core/runner/Sessions.go deleted file mode 100644 index 1077841..0000000 --- a/core/runner/Sessions.go +++ /dev/null @@ -1,241 +0,0 @@ -package runner - -import ( - luaCtx "Moonshark/core/runner/context" - "Moonshark/core/runner/sandbox" - "Moonshark/core/sessions" - "Moonshark/core/utils/logger" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" - "github.com/valyala/fasthttp" -) - -// SessionHandler handles session management for Lua scripts -type SessionHandler struct { - manager *sessions.SessionManager - debugLog bool -} - -// NewSessionHandler creates a new session handler -func NewSessionHandler(manager *sessions.SessionManager) *SessionHandler { - return &SessionHandler{ - manager: manager, - debugLog: false, - } -} - -// EnableDebug enables debug logging -func (h *SessionHandler) EnableDebug() { - h.debugLog = true -} - -// WithSessionManager creates a RunnerOption to add session support -func WithSessionManager(manager *sessions.SessionManager) RunnerOption { - return func(r *Runner) { - handler := NewSessionHandler(manager) - r.AddInitHook(handler.preRequestHook) - r.AddFinalizeHook(handler.postRequestHook) - } -} - -// preRequestHook initializes session before script execution -func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *luaCtx.Context) error { - if ctx == nil || ctx.Values["_request_cookies"] == nil { - return nil - } - - // Extract cookies from context - cookies, ok := ctx.Values["_request_cookies"].(map[string]any) - if !ok { - return nil - } - - // Get the session ID from cookies - cookieName := h.manager.CookieOptions()["name"].(string) - var sessionID string - - // Check if our session cookie exists - if cookieValue, exists := cookies[cookieName]; exists { - if strValue, ok := cookieValue.(string); ok && strValue != "" { - sessionID = strValue - } - } - - // Create new session if needed - if sessionID == "" { - session := h.manager.CreateSession() - sessionID = session.ID - } - - // Store the session ID in the context - ctx.Set("_session_id", sessionID) - - // Get session data - session := h.manager.GetSession(sessionID) - sessionData := session.GetAll() - - // Set session data in Lua state - return SetSessionData(state, sessionID, sessionData) -} - -// postRequestHook handles session after script execution -func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *luaCtx.Context, result any) error { - // Check if session was modified - modifiedID, modifiedData, modified := GetSessionData(state) - if !modified { - return nil - } - - // Get the original session ID from context - var sessionID string - if ctx != nil { - if id, ok := ctx.Values["_session_id"].(string); ok { - sessionID = id - } - } - - // Use the original session ID if the modified one is empty - if modifiedID == "" { - modifiedID = sessionID - } - - if modifiedID == "" { - return nil - } - - // Update session in manager - session := h.manager.GetSession(modifiedID) - session.Clear() // clear to sync deleted values - for k, v := range modifiedData { - session.Set(k, v) - } - - h.manager.SaveSession(session) - - // Add session cookie to result if it's an HTTP response - if httpResp, ok := result.(*sandbox.HTTPResponse); ok { - h.addSessionCookie(httpResp, modifiedID) - } else if ctx != nil && ctx.RequestCtx != nil { - // Add cookie directly to the RequestCtx when result is not an HTTP response - h.addSessionCookieToRequestCtx(ctx.RequestCtx, modifiedID) - } - - return nil -} - -// addSessionCookie adds a session cookie to an HTTP response -func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) { - // Get cookie options - opts := h.manager.CookieOptions() - - // Check if session cookie is already set - cookieName := opts["name"].(string) - for _, cookie := range resp.Cookies { - if string(cookie.Key()) == cookieName { - return - } - } - - // Create and add cookie - cookie := fasthttp.AcquireCookie() - cookie.SetKey(cookieName) - cookie.SetValue(sessionID) - cookie.SetPath(opts["path"].(string)) - cookie.SetHTTPOnly(opts["http_only"].(bool)) - cookie.SetMaxAge(opts["max_age"].(int)) - - // Optional cookie parameters - if domain, ok := opts["domain"].(string); ok && domain != "" { - cookie.SetDomain(domain) - } - - if secure, ok := opts["secure"].(bool); ok { - cookie.SetSecure(secure) - } - - resp.Cookies = append(resp.Cookies, cookie) -} - -func (h *SessionHandler) addSessionCookieToRequestCtx(ctx *fasthttp.RequestCtx, sessionID string) { - // Get cookie options - opts := h.manager.CookieOptions() - cookieName := opts["name"].(string) - - // Create cookie - cookie := fasthttp.AcquireCookie() - defer fasthttp.ReleaseCookie(cookie) - - cookie.SetKey(cookieName) - cookie.SetValue(sessionID) - cookie.SetPath(opts["path"].(string)) - cookie.SetHTTPOnly(opts["http_only"].(bool)) - cookie.SetMaxAge(opts["max_age"].(int)) - - // Optional cookie parameters - if domain, ok := opts["domain"].(string); ok && domain != "" { - cookie.SetDomain(domain) - } - - if secure, ok := opts["secure"].(bool); ok { - cookie.SetSecure(secure) - } - - ctx.Response.Header.SetCookie(cookie) -} - -// GetSessionData extracts session data from Lua state -func GetSessionData(state *luajit.State) (string, map[string]any, bool) { - // Check if session was modified - state.GetGlobal("__session_modified") - modified := state.ToBoolean(-1) - state.Pop(1) - - if !modified { - return "", nil, false - } - - // Get session ID - state.GetGlobal("__session_id") - sessionID := state.ToString(-1) - state.Pop(1) - - // Get session data - state.GetGlobal("__session_data") - if !state.IsTable(-1) { - state.Pop(1) - return sessionID, nil, false - } - - data, err := state.ToTable(-1) - state.Pop(1) - - if err != nil { - logger.Error("Failed to extract session data: %v", err) - return sessionID, nil, false - } - - return sessionID, data, true -} - -// SetSessionData sets session data in Lua state -func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error { - // Set session ID - state.PushString(sessionID) - state.SetGlobal("__session_id") - - // Set session data - if data == nil { - data = make(map[string]any) - } - - if err := state.PushTable(data); err != nil { - return err - } - state.SetGlobal("__session_data") - - // Reset modification flag - state.PushBoolean(false) - state.SetGlobal("__session_modified") - - return nil -} diff --git a/core/runner/sandbox/lua/sandbox.lua b/core/runner/sandbox.lua similarity index 71% rename from core/runner/sandbox/lua/sandbox.lua rename to core/runner/sandbox.lua index c755c27..324ea21 100644 --- a/core/runner/sandbox/lua/sandbox.lua +++ b/core/runner/sandbox.lua @@ -14,9 +14,6 @@ __ready_modules = {} __session_data = {} __session_id = nil __session_modified = false -__env_system = { - base_env = {} -} -- ====================================================================== -- CORE SANDBOX FUNCTIONALITY @@ -44,7 +41,7 @@ end function __execute_script(fn, ctx) -- Clear previous responses __http_responses[1] = nil - + -- Reset session modification flag __session_modified = false @@ -63,75 +60,6 @@ function __execute_script(fn, ctx) return result end --- ====================================================================== --- MODULE LOADING SYSTEM --- ====================================================================== - --- Setup environment-aware require function -function __setup_require(env) - -- Create require function specific to this environment - env.require = function(modname) - -- Check if already loaded - if package.loaded[modname] then - return package.loaded[modname] - end - - -- Check preloaded modules - if __ready_modules[modname] then - local loader = package.preload[modname] - if loader then - -- Set environment for loader - setfenv(loader, env) - - -- Execute and store result - local result = loader() - if result == nil then - result = true - end - - package.loaded[modname] = result - return result - end - end - - -- Direct file load as fallback - if __module_paths[modname] then - local path = __module_paths[modname] - local chunk, err = loadfile(path) - if chunk then - setfenv(chunk, env) - local result = chunk() - if result == nil then - result = true - end - package.loaded[modname] = result - return result - end - end - - -- Full path search as last resort - local errors = {} - for path in package.path:gmatch("[^;]+") do - local file_path = path:gsub("?", modname:gsub("%.", "/")) - local chunk, err = loadfile(file_path) - if chunk then - setfenv(chunk, env) - local result = chunk() - if result == nil then - result = true - end - package.loaded[modname] = result - return result - end - table.insert(errors, "\tno file '" .. file_path .. "'") - end - - error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2) - end - - return env -end - -- ====================================================================== -- HTTP MODULE -- ====================================================================== @@ -166,6 +94,18 @@ local http = { http.set_header("Content-Type", content_type) end, + -- Set metadata (arbitrary data to be returned with response) + set_metadata = function(key, value) + if type(key) ~= "string" then + error("http.set_metadata: key must be a string", 2) + end + + local resp = __http_responses[1] or {} + resp.metadata = resp.metadata or {} + resp.metadata[key] = value + __http_responses[1] = resp + end, + -- HTTP client submodule client = { -- Generic request function @@ -213,10 +153,7 @@ local http = { -- Simple HEAD request head = function(url, options) options = options or {} - local old_options = options - options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query} - local response = http.client.request("HEAD", url, nil, options) - return response + return http.client.request("HEAD", url, nil, options) end, -- Simple OPTIONS request @@ -265,13 +202,13 @@ local http = { } -- ====================================================================== --- COOKIE MODULE +-- COOKIE MODULE -- ====================================================================== -- Cookie module implementation local cookie = { -- Set a cookie - set = function(name, value, options, ...) + set = function(name, value, options) if type(name) ~= "string" then error("cookie.set: name must be a string", 2) end @@ -281,20 +218,8 @@ local cookie = { resp.cookies = resp.cookies or {} __http_responses[1] = resp - -- Handle options as table or legacy params - local opts = {} - if type(options) == "table" then - opts = options - elseif options ~= nil then - -- Legacy support: options is actually 'expires' - opts.expires = options - -- Check for other legacy params (4th-7th args) - local args = {...} - if args[1] then opts.path = args[1] end - if args[2] then opts.domain = args[2] end - if args[3] then opts.secure = args[3] end - if args[4] ~= nil then opts.http_only = args[4] end - end + -- Handle options as table + local opts = options or {} -- Create cookie table local cookie = { @@ -314,10 +239,8 @@ local cookie = { elseif opts.expires < 0 then cookie.expires = 1 cookie.max_age = 0 - else - -- opts.expires == 0: Session cookie - -- Do nothing (omitting both expires and max-age creates a session cookie) end + -- opts.expires == 0: Session cookie (omitting both expires and max-age) end end @@ -342,8 +265,13 @@ local cookie = { local env = getfenv(2) -- Check if context exists and has cookies - if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then - return tostring(env.ctx.cookies[name]) + if env.ctx and env.ctx.cookies then + return env.ctx.cookies[name] + end + + -- If context has request_cookies map + if env.ctx and env.ctx._request_cookies then + return env.ctx._request_cookies[name] end return nil @@ -361,7 +289,7 @@ local cookie = { } -- ====================================================================== --- SESSION MODULE +-- SESSION MODULE -- ====================================================================== -- Session module implementation @@ -372,7 +300,7 @@ local session = { error("session.get: key must be a string", 2) end - if __session_data and __session_data[key] then + if __session_data and __session_data[key] ~= nil then return __session_data[key] end @@ -469,7 +397,7 @@ local csrf = { error("CSRF protection requires the session module", 2) end - local token = util.generate_token(length) + local token = __generate_token(length) session.set(csrf.TOKEN_KEY, token) return token end, @@ -495,48 +423,133 @@ local csrf = { end, -- Verify a given token against the session token - verify = function(token, field_name) - field_name = field_name or csrf.DEFAULT_FIELD + verify = function(token, field_name) + field_name = field_name or csrf.DEFAULT_FIELD - local env = getfenv(2) + local env = getfenv(2) - local form = nil - if env.ctx and env.ctx.form then - form = env.ctx.form - else - return false - end + local form = nil + if env.ctx and env.ctx._request_form then + form = env.ctx._request_form + elseif env.ctx and env.ctx.form then + form = env.ctx.form + else + return false + end - token = token or form[field_name] - if not token then - return false - end + token = token or form[field_name] + if not token then + return false + end - local session_token = session.get(csrf.TOKEN_KEY) - if not session_token then - return false - end + local session_token = session.get(csrf.TOKEN_KEY) + if not session_token then + return false + end - -- Constant-time comparison to prevent timing attacks - -- This is safe since Lua strings are immutable - if #token ~= #session_token then - return false - end + -- Constant-time comparison to prevent timing attacks + if #token ~= #session_token then + return false + end - local result = true - for i = 1, #token do - if token:sub(i, i) ~= session_token:sub(i, i) then - result = false - -- Don't break early - continue to prevent timing attacks - end - end + local result = true + for i = 1, #token do + if token:sub(i, i) ~= session_token:sub(i, i) then + result = false + -- Don't break early - continue to prevent timing attacks + end + end - return result - end + return result + end } -- ====================================================================== --- REGISTER MODULES GLOBALLY +-- UTIL MODULE +-- ====================================================================== + +-- Utility module implementation +local util = { + -- Generate a token (wrapper around __generate_token) + generate_token = function(length) + return __generate_token(length or 32) + end, + + -- Simple JSON stringify (for when you just need a quick string) + json_encode = function(value) + if type(value) == "table" then + local json = "{" + local sep = "" + for k, v in pairs(value) do + json = json .. sep + if type(k) == "number" then + -- Array-like + json = json .. util.json_encode(v) + else + -- Object-like + json = json .. '"' .. k .. '":' .. util.json_encode(v) + end + sep = "," + end + return json .. "}" + elseif type(value) == "string" then + return '"' .. value:gsub('"', '\\"'):gsub('\n', '\\n') .. '"' + elseif type(value) == "number" then + return tostring(value) + elseif type(value) == "boolean" then + return value and "true" or "false" + elseif value == nil then + return "null" + end + return '"' .. tostring(value) .. '"' + end, + + -- Deep copy of tables + deep_copy = function(obj) + if type(obj) ~= 'table' then return obj end + local res = {} + for k, v in pairs(obj) do res[k] = util.deep_copy(v) end + return res + end, + + -- Merge tables + merge_tables = function(t1, t2) + if type(t1) ~= 'table' or type(t2) ~= 'table' then + error("Both arguments must be tables", 2) + end + + local result = util.deep_copy(t1) + for k, v in pairs(t2) do + if type(v) == 'table' and type(result[k]) == 'table' then + result[k] = util.merge_tables(result[k], v) + else + result[k] = v + end + end + return result + end, + + -- String utilities + string = { + -- Trim whitespace + trim = function(s) + return (s:gsub("^%s*(.-)%s*$", "%1")) + end, + + -- Split string + split = function(s, delimiter) + delimiter = delimiter or "," + local result = {} + for match in (s..delimiter):gmatch("(.-)"..delimiter) do + table.insert(result, match) + end + return result + end + } +} + +-- ====================================================================== +-- REGISTER MODULES GLOBALLY -- ====================================================================== -- Install modules in global scope @@ -544,9 +557,4 @@ _G.http = http _G.cookie = cookie _G.session = session _G.csrf = csrf - --- Register modules in sandbox base environment -__env_system.base_env.http = http -__env_system.base_env.cookie = cookie -__env_system.base_env.session = session -__env_system.base_env.csrf = csrf \ No newline at end of file +_G.util = util diff --git a/core/runner/sandbox/Embed.go b/core/runner/sandbox/Embed.go deleted file mode 100644 index 944f9d8..0000000 --- a/core/runner/sandbox/Embed.go +++ /dev/null @@ -1,98 +0,0 @@ -package sandbox - -import ( - _ "embed" - - "Moonshark/core/utils/logger" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -//go:embed lua/sandbox.lua -var sandboxLua string - -// InitializeSandbox loads the embedded Lua sandbox code into a Lua state -func InitializeSandbox(state *luajit.State) error { - // Compile once, use many times - bytecodeOnce.Do(precompileSandbox) - - if sandboxBytecode != nil { - logger.Debug("Loading sandbox.lua from precompiled bytecode") - return state.LoadAndRunBytecode(sandboxBytecode, "sandbox.lua") - } - - // Fallback if compilation failed - logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)") - return state.DoString(sandboxLua) -} - -// ModuleInitializers stores initializer functions for core modules -type ModuleInitializers struct { - HTTP func(*luajit.State) error - Util func(*luajit.State) error - Session func(*luajit.State) error - Cookie func(*luajit.State) error - CSRF func(*luajit.State) error -} - -// DefaultInitializers returns the default set of initializers -func DefaultInitializers() *ModuleInitializers { - return &ModuleInitializers{ - HTTP: func(state *luajit.State) error { - // Register the native Go function first - if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil { - logger.Error("[HTTP Module] Failed to register __http_request function: %v", err) - return err - } - return nil - }, - Util: func(state *luajit.State) error { - // Register util functions - return RegisterModule(state, "util", UtilModuleFunctions()) - }, - Session: func(state *luajit.State) error { - // Session doesn't need special initialization - return nil - }, - Cookie: func(state *luajit.State) error { - // Cookie doesn't need special initialization - return nil - }, - CSRF: func(state *luajit.State) error { - // CSRF doesn't need special initialization - return nil - }, - } -} - -// InitializeAll initializes all modules in the Lua state -func InitializeAll(state *luajit.State, initializers *ModuleInitializers) error { - // Set up dependencies first - if err := initializers.Util(state); err != nil { - return err - } - - if err := initializers.HTTP(state); err != nil { - return err - } - - // Load the embedded sandbox code - if err := InitializeSandbox(state); err != nil { - return err - } - - // Initialize the rest of the modules - if err := initializers.Session(state); err != nil { - return err - } - - if err := initializers.Cookie(state); err != nil { - return err - } - - if err := initializers.CSRF(state); err != nil { - return err - } - - return nil -} diff --git a/core/runner/sandbox/Http.go b/core/runner/sandbox/Http.go deleted file mode 100644 index 2482dd8..0000000 --- a/core/runner/sandbox/Http.go +++ /dev/null @@ -1,590 +0,0 @@ -package sandbox - -import ( - "context" - "errors" - "fmt" - "net/url" - "strings" - "sync" - "time" - - "github.com/goccy/go-json" - "github.com/valyala/bytebufferpool" - "github.com/valyala/fasthttp" - - "Moonshark/core/utils/logger" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// SessionHandler interface for session management -type SessionHandler interface { - LoadSession(ctx *fasthttp.RequestCtx) (string, map[string]any) - SaveSession(ctx *fasthttp.RequestCtx, sessionID string, data map[string]any) bool -} - -// HTTPResponse represents an HTTP response from Lua -type HTTPResponse struct { - Status int `json:"status"` - Headers map[string]string `json:"headers"` - Body any `json:"body"` - Cookies []*fasthttp.Cookie `json:"-"` - SessionModified bool `json:"-"` -} - -// Response pool to reduce allocations -var responsePool = sync.Pool{ - New: func() any { - return &HTTPResponse{ - Status: 200, - Headers: make(map[string]string, 8), - Cookies: make([]*fasthttp.Cookie, 0, 4), - } - }, -} - -// Default HTTP client with sensible timeout -var defaultFastClient fasthttp.Client = fasthttp.Client{ - MaxConnsPerHost: 1024, - MaxIdleConnDuration: time.Minute, - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - DisableHeaderNamesNormalizing: true, -} - -// HTTPClientConfig contains client settings -type HTTPClientConfig struct { - MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit) - DefaultTimeout time.Duration // Default request timeout - MaxResponseSize int64 // Maximum response size in bytes (0 = no limit) - AllowRemote bool // Whether to allow remote connections -} - -// DefaultHTTPClientConfig provides sensible defaults -var DefaultHTTPClientConfig = HTTPClientConfig{ - MaxTimeout: 60 * time.Second, - DefaultTimeout: 30 * time.Second, - MaxResponseSize: 10 * 1024 * 1024, // 10MB - AllowRemote: true, -} - -// NewHTTPResponse creates a default HTTP response from pool -func NewHTTPResponse() *HTTPResponse { - return responsePool.Get().(*HTTPResponse) -} - -// ReleaseResponse returns the response to the pool -func ReleaseResponse(resp *HTTPResponse) { - if resp == nil { - return - } - - // Clear all values to prevent data leakage - resp.Status = 200 // Reset to default - - // Clear headers - for k := range resp.Headers { - delete(resp.Headers, k) - } - - // Clear cookies - resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0 - - // Reset session flag - resp.SessionModified = false - - // Clear body - resp.Body = nil - - responsePool.Put(resp) -} - -// HTTPModuleInitFunc returns an initializer function for the HTTP module -func HTTPModuleInitFunc() func(*luajit.State) error { - return func(state *luajit.State) error { - // Register the native Go function first - if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil { - logger.Error("[HTTP Module] Failed to register __http_request function: %v", err) - return err - } - - // Set up default HTTP client configuration - setupHTTPClientConfig(state) - - return nil - } -} - -// setupHTTPClientConfig configures HTTP client in Lua -func setupHTTPClientConfig(state *luajit.State) { - state.NewTable() - - state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second)) - state.SetField(-2, "max_timeout") - - state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second)) - state.SetField(-2, "default_timeout") - - state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize)) - state.SetField(-2, "max_response_size") - - state.PushBoolean(DefaultHTTPClientConfig.AllowRemote) - state.SetField(-2, "allow_remote") - - state.SetGlobal("__http_client_config") -} - -// GetHTTPResponse extracts the HTTP response from Lua state -func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) { - response := NewHTTPResponse() - - // Get response table - state.GetGlobal("__http_responses") - if state.IsNil(-1) { - state.Pop(1) - ReleaseResponse(response) - return nil, false - } - - // Check for response at thread index - state.PushNumber(1) - state.GetTable(-2) - if state.IsNil(-1) { - state.Pop(2) - ReleaseResponse(response) - return nil, false - } - - // Get status - state.GetField(-1, "status") - if state.IsNumber(-1) { - response.Status = int(state.ToNumber(-1)) - } - state.Pop(1) - - // Get headers - state.GetField(-1, "headers") - if state.IsTable(-1) { - // Iterate through headers table - state.PushNil() // Start iteration - for state.Next(-2) { - // Stack has key at -2 and value at -1 - if state.IsString(-2) && state.IsString(-1) { - key := state.ToString(-2) - value := state.ToString(-1) - response.Headers[key] = value - } - state.Pop(1) // Pop value, leave key for next iteration - } - } - state.Pop(1) - - // Get cookies - state.GetField(-1, "cookies") - if state.IsTable(-1) { - // Iterate through cookies array - length := state.GetTableLength(-1) - for i := 1; i <= length; i++ { - state.PushNumber(float64(i)) - state.GetTable(-2) - - if state.IsTable(-1) { - cookie := extractCookie(state) - if cookie != nil { - response.Cookies = append(response.Cookies, cookie) - } - } - state.Pop(1) - } - } - state.Pop(1) - - // Check if session was modified - state.GetGlobal("__session_modified") - if state.IsBoolean(-1) && state.ToBoolean(-1) { - response.SessionModified = true - } - state.Pop(1) - - // Clean up - state.Pop(2) // Pop response table and __http_responses - - return response, true -} - -// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx -func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) { - // Set status code - ctx.SetStatusCode(httpResp.Status) - - // Set headers - for name, value := range httpResp.Headers { - ctx.Response.Header.Set(name, value) - } - - // Set cookies - for _, cookie := range httpResp.Cookies { - ctx.Response.Header.SetCookie(cookie) - } - - // Process the body based on its type - if httpResp.Body == nil { - return - } - - // Set body based on type - switch body := httpResp.Body.(type) { - case string: - ctx.SetBodyString(body) - case []byte: - ctx.SetBody(body) - case map[string]any, []any, []float64, []string, []int: - // Marshal JSON using a buffer from the pool - buf := bytebufferpool.Get() - defer bytebufferpool.Put(buf) - - if err := json.NewEncoder(buf).Encode(body); err == nil { - // Set content type if not already set - if len(ctx.Response.Header.ContentType()) == 0 { - ctx.Response.Header.SetContentType("application/json") - } - ctx.SetBody(buf.Bytes()) - } else { - // Fallback - ctx.SetBodyString(fmt.Sprintf("%v", body)) - } - default: - // Default to string representation - ctx.SetBodyString(fmt.Sprintf("%v", body)) - } -} - -// extractCookie grabs cookies from the Lua state -func extractCookie(state *luajit.State) *fasthttp.Cookie { - cookie := fasthttp.AcquireCookie() - - // Get name - state.GetField(-1, "name") - if !state.IsString(-1) { - state.Pop(1) - fasthttp.ReleaseCookie(cookie) - return nil // Name is required - } - cookie.SetKey(state.ToString(-1)) - state.Pop(1) - - // Get value - state.GetField(-1, "value") - if state.IsString(-1) { - cookie.SetValue(state.ToString(-1)) - } - state.Pop(1) - - // Get path - state.GetField(-1, "path") - if state.IsString(-1) { - cookie.SetPath(state.ToString(-1)) - } else { - cookie.SetPath("/") // Default path - } - state.Pop(1) - - // Get domain - state.GetField(-1, "domain") - if state.IsString(-1) { - cookie.SetDomain(state.ToString(-1)) - } - state.Pop(1) - - // Get expires - state.GetField(-1, "expires") - if state.IsNumber(-1) { - expiry := int64(state.ToNumber(-1)) - cookie.SetExpire(time.Unix(expiry, 0)) - } - state.Pop(1) - - // Get max age - state.GetField(-1, "max_age") - if state.IsNumber(-1) { - cookie.SetMaxAge(int(state.ToNumber(-1))) - } - state.Pop(1) - - // Get secure - state.GetField(-1, "secure") - if state.IsBoolean(-1) { - cookie.SetSecure(state.ToBoolean(-1)) - } - state.Pop(1) - - // Get http only - state.GetField(-1, "http_only") - if state.IsBoolean(-1) { - cookie.SetHTTPOnly(state.ToBoolean(-1)) - } - state.Pop(1) - - return cookie -} - -// httpRequest makes an HTTP request and returns the result to Lua -func httpRequest(state *luajit.State) int { - // Get method (required) - if !state.IsString(1) { - state.PushString("http.client.request: method must be a string") - return -1 - } - method := strings.ToUpper(state.ToString(1)) - - // Get URL (required) - if !state.IsString(2) { - state.PushString("http.client.request: url must be a string") - return -1 - } - urlStr := state.ToString(2) - - // Parse URL to check if it's valid and if it's allowed - parsedURL, err := url.Parse(urlStr) - if err != nil { - state.PushString("Invalid URL: " + err.Error()) - return -1 - } - - // Get client configuration - var config HTTPClientConfig = DefaultHTTPClientConfig - state.GetGlobal("__http_client_config") - if !state.IsNil(-1) && state.IsTable(-1) { - // Extract max timeout - state.GetField(-1, "max_timeout") - if state.IsNumber(-1) { - config.MaxTimeout = time.Duration(state.ToNumber(-1)) * time.Second - } - state.Pop(1) - - // Extract default timeout - state.GetField(-1, "default_timeout") - if state.IsNumber(-1) { - config.DefaultTimeout = time.Duration(state.ToNumber(-1)) * time.Second - } - state.Pop(1) - - // Extract max response size - state.GetField(-1, "max_response_size") - if state.IsNumber(-1) { - config.MaxResponseSize = int64(state.ToNumber(-1)) - } - state.Pop(1) - - // Extract allow remote - state.GetField(-1, "allow_remote") - if state.IsBoolean(-1) { - config.AllowRemote = state.ToBoolean(-1) - } - state.Pop(1) - } - state.Pop(1) - - // Check if remote connections are allowed - if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") { - state.PushString("Remote connections are not allowed") - return -1 - } - - // Use bytebufferpool for request and response - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set up request - req.Header.SetMethod(method) - req.SetRequestURI(urlStr) - req.Header.Set("User-Agent", "Moonshark/1.0") - - // Get body (optional) - if state.GetTop() >= 3 && !state.IsNil(3) { - if state.IsString(3) { - // String body - req.SetBodyString(state.ToString(3)) - } else if state.IsTable(3) { - // Table body - convert to JSON - luaTable, err := state.ToTable(3) - if err != nil { - state.PushString("Failed to parse body table: " + err.Error()) - return -1 - } - - // Use bytebufferpool for JSON serialization - buf := bytebufferpool.Get() - defer bytebufferpool.Put(buf) - - if err := json.NewEncoder(buf).Encode(luaTable); err != nil { - state.PushString("Failed to convert body to JSON: " + err.Error()) - return -1 - } - - req.SetBody(buf.Bytes()) - } else { - state.PushString("Body must be a string or table") - return -1 - } - } - - // Process options (headers, timeout, etc.) - timeout := config.DefaultTimeout - if state.GetTop() >= 4 && !state.IsNil(4) { - if !state.IsTable(4) { - state.PushString("Options must be a table") - return -1 - } - - // Process headers - state.GetField(4, "headers") - if state.IsTable(-1) { - // Iterate through headers - state.PushNil() // Start iteration - for state.Next(-2) { - // Stack now has key at -2 and value at -1 - if state.IsString(-2) && state.IsString(-1) { - headerName := state.ToString(-2) - headerValue := state.ToString(-1) - req.Header.Set(headerName, headerValue) - } - state.Pop(1) // Pop value, leave key for next iteration - } - } - state.Pop(1) // Pop headers table - - // Get timeout - state.GetField(4, "timeout") - if state.IsNumber(-1) { - requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second - - // Apply max timeout if configured - if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout { - timeout = config.MaxTimeout - } else { - timeout = requestTimeout - } - } - state.Pop(1) // Pop timeout - - // Set content type for POST/PUT if body is present and content-type not manually set - if (method == "POST" || method == "PUT") && req.Body() != nil && req.Header.Peek("Content-Type") == nil { - // Check if options specify content type - state.GetField(4, "content_type") - if state.IsString(-1) { - req.Header.Set("Content-Type", state.ToString(-1)) - } else { - // Default to JSON if body is a table, otherwise plain text - if state.IsTable(3) { - req.Header.Set("Content-Type", "application/json") - } else { - req.Header.Set("Content-Type", "text/plain") - } - } - state.Pop(1) // Pop content_type - } - - // Process query parameters - state.GetField(4, "query") - if state.IsTable(-1) { - // Create URL args - args := req.URI().QueryArgs() - - // Iterate through query params - state.PushNil() // Start iteration - for state.Next(-2) { - // Stack now has key at -2 and value at -1 - if state.IsString(-2) { - paramName := state.ToString(-2) - - // Handle different value types - if state.IsString(-1) { - args.Add(paramName, state.ToString(-1)) - } else if state.IsNumber(-1) { - args.Add(paramName, strings.TrimRight(strings.TrimRight( - state.ToString(-1), "0"), ".")) - } else if state.IsBoolean(-1) { - if state.ToBoolean(-1) { - args.Add(paramName, "true") - } else { - args.Add(paramName, "false") - } - } - } - state.Pop(1) // Pop value, leave key for next iteration - } - } - state.Pop(1) // Pop query table - } - - // Create context with timeout - _, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - // Execute request - err = defaultFastClient.DoTimeout(req, resp, timeout) - if err != nil { - errStr := "Request failed: " + err.Error() - if errors.Is(err, fasthttp.ErrTimeout) { - errStr = "Request timed out after " + timeout.String() - } - state.PushString(errStr) - return -1 - } - - // Create response table - state.NewTable() - - // Set status code - state.PushNumber(float64(resp.StatusCode())) - state.SetField(-2, "status") - - // Set status text - statusText := fasthttp.StatusMessage(resp.StatusCode()) - state.PushString(statusText) - state.SetField(-2, "status_text") - - // Set body - var respBody []byte - - // Apply size limits to response - if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize { - // Make a limited copy - respBody = make([]byte, config.MaxResponseSize) - copy(respBody, resp.Body()) - } else { - respBody = resp.Body() - } - - state.PushString(string(respBody)) - state.SetField(-2, "body") - - // Parse body as JSON if content type is application/json - contentType := string(resp.Header.ContentType()) - if strings.Contains(contentType, "application/json") { - var jsonData any - if err := json.Unmarshal(respBody, &jsonData); err == nil { - if err := state.PushValue(jsonData); err == nil { - state.SetField(-2, "json") - } - } - } - - // Set headers - state.NewTable() - resp.Header.VisitAll(func(key, value []byte) { - state.PushString(string(value)) - state.SetField(-2, string(key)) - }) - state.SetField(-2, "headers") - - // Create ok field (true if status code is 2xx) - state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300) - state.SetField(-2, "ok") - - return 1 -} diff --git a/core/runner/sandbox/Modules.go b/core/runner/sandbox/Modules.go deleted file mode 100644 index 1641402..0000000 --- a/core/runner/sandbox/Modules.go +++ /dev/null @@ -1,84 +0,0 @@ -package sandbox - -import ( - "Moonshark/core/utils/logger" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// ModuleFunc returns a map of module functions -type ModuleFunc func() map[string]luajit.GoFunction - -// StateInitFunc initializes a module in a Lua state -type StateInitFunc func(*luajit.State) error - -// RegisterModule registers a map of functions as a Lua module -func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error { - // Create a new table for the module - state.NewTable() - - // Add each function to the module table - for fname, f := range funcs { - state.PushString(fname) - if err := state.PushGoFunction(f); err != nil { - state.Pop(1) // Pop table - return err - } - state.SetTable(-3) - } - - // Register the module globally - state.SetGlobal(name) - return nil -} - -// ModuleInitFunc creates a state initializer that registers multiple modules -func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc { - return func(state *luajit.State) error { - for name, moduleFunc := range modules { - if err := RegisterModule(state, name, moduleFunc()); err != nil { - logger.Error("Failed to register module %s: %v", name, err) - return err - } - } - return nil - } -} - -// CombineInitFuncs combines multiple state initializer functions into one -func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc { - return func(state *luajit.State) error { - for _, f := range funcs { - if f != nil { - if err := f(state); err != nil { - return err - } - } - } - return nil - } -} - -// RegisterLuaCode registers a Lua code snippet in a state -func RegisterLuaCode(state *luajit.State, code string) error { - return state.DoString(code) -} - -// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code -func RegisterLuaCodeInitFunc(code string) StateInitFunc { - return func(state *luajit.State) error { - return RegisterLuaCode(state, code) - } -} - -// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module -func RegisterLuaModuleInitFunc(name string, code string) StateInitFunc { - return func(state *luajit.State) error { - // Create name = {} global - state.NewTable() - state.SetGlobal(name) - - // Then run the module code which will populate it - return state.DoString(code) - } -} diff --git a/core/runner/sandbox/Sandbox.go b/core/runner/sandbox/Sandbox.go deleted file mode 100644 index 654d123..0000000 --- a/core/runner/sandbox/Sandbox.go +++ /dev/null @@ -1,371 +0,0 @@ -package sandbox - -import ( - "fmt" - "sync" - - "github.com/goccy/go-json" - "github.com/valyala/bytebufferpool" - "github.com/valyala/fasthttp" - - luaCtx "Moonshark/core/runner/context" - "Moonshark/core/sessions" - "Moonshark/core/utils/logger" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// Global bytecode cache to improve performance -var ( - sandboxBytecode []byte - bytecodeOnce sync.Once -) - -// precompileSandbox compiles the sandbox.lua code to bytecode once -func precompileSandbox() { - tempState := luajit.New() - if tempState == nil { - logger.Error("Failed to create temporary Lua state for bytecode compilation") - return - } - defer tempState.Close() - defer tempState.Cleanup() - - var err error - sandboxBytecode, err = tempState.CompileBytecode(sandboxLua, "sandbox.lua") - if err != nil { - logger.Error("Failed to precompile sandbox.lua: %v", err) - } else { - logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(sandboxBytecode)) - } -} - -// Sandbox provides a secure execution environment for Lua scripts -type Sandbox struct { - modules map[string]any // Custom modules for environment - debug bool // Enable debug output - mu sync.RWMutex // Protects modules - initializers *ModuleInitializers // Module initializers -} - -// NewSandbox creates a new sandbox environment -func NewSandbox() *Sandbox { - return &Sandbox{ - modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity - debug: false, - initializers: DefaultInitializers(), - } -} - -// EnableDebug turns on debug logging -func (s *Sandbox) EnableDebug() { - s.debug = true -} - -// debugLog logs a message if debug mode is enabled -func (s *Sandbox) debugLog(format string, args ...interface{}) { - if s.debug { - logger.Debug("Sandbox "+format, args...) - } -} - -// debugLogCont logs a continuation message if debug mode is enabled -func (s *Sandbox) debugLogCont(format string, args ...interface{}) { - if s.debug { - logger.DebugCont(format, args...) - } -} - -// AddModule adds a module to the sandbox environment -func (s *Sandbox) AddModule(name string, module any) { - s.mu.Lock() - defer s.mu.Unlock() - - s.modules[name] = module - s.debugLog("Added module: %s", name) -} - -// Setup initializes the sandbox in a Lua state -func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error { - verbose := stateIndex == 0 - - if verbose { - s.debugLog("Setting up sandbox...") - } - - // Initialize modules with the embedded sandbox code - if err := InitializeAll(state, s.initializers); err != nil { - if verbose { - s.debugLog("Failed to initialize sandbox: %v", err) - } - return err - } - - // Register custom modules in the global environment - s.mu.RLock() - for name, module := range s.modules { - if verbose { - s.debugLog("Registering module: %s", name) - } - if err := state.PushValue(module); err != nil { - s.mu.RUnlock() - if verbose { - s.debugLog("Failed to register module %s: %v", name, err) - } - return err - } - state.SetGlobal(name) - } - s.mu.RUnlock() - - if verbose { - s.debugLogCont("Sandbox setup complete") - } - return nil -} - -// Execute runs bytecode in the sandbox -func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) { - // Create a temporary context if we only have a map - if ctx != nil { - tempCtx := &luaCtx.Context{ - Values: ctx, - } - return s.OptimizedExecute(state, bytecode, tempCtx) - } - - // Just pass nil through if we have no context - return s.OptimizedExecute(state, bytecode, nil) -} - -// OptimizedExecute runs bytecode with a fasthttp context if available -func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *luaCtx.Context) (any, error) { - // Use a buffer from the pool for any string operations - buf := bytebufferpool.Get() - defer bytebufferpool.Put(buf) - - // Load bytecode - if err := state.LoadBytecode(bytecode, "script"); err != nil { - s.debugLog("Failed to load bytecode: %v", err) - return nil, fmt.Errorf("failed to load script: %w", err) - } - - // Prepare context values - var ctxValues map[string]any - if ctx != nil { - ctxValues = ctx.Values - } else { - ctxValues = nil - } - - // Initialize session tracking in Lua - if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil { - s.debugLog("Failed to initialize session data: %v", err) - } - - // Load session data if available - if ctx != nil && ctx.Session != nil { - // Set session ID in Lua - sessionIDCode := fmt.Sprintf("__session_id = %q", ctx.Session.ID) - if err := state.DoString(sessionIDCode); err != nil { - s.debugLog("Failed to set session ID: %v", err) - } - - // Get session data and populate Lua table - state.GetGlobal("__session_data") - if state.IsTable(-1) { - sessionData := ctx.Session.GetAll() - for k, v := range sessionData { - state.PushString(k) - if err := state.PushValue(v); err != nil { - s.debugLog("Failed to push session value %s: %v", k, err) - continue - } - state.SetTable(-3) - } - } - state.Pop(1) // Pop __session_data - } - - // Prepare context table - if ctxValues != nil { - state.CreateTable(0, len(ctxValues)) - for k, v := range ctxValues { - state.PushString(k) - if err := state.PushValue(v); err != nil { - state.Pop(2) // Pop key and table - s.debugLog("Failed to push context value %s: %v", k, err) - return nil, fmt.Errorf("failed to prepare context: %w", err) - } - state.SetTable(-3) - } - } else { - state.PushNil() // No context - } - - // Get execution function - state.GetGlobal("__execute_script") - if !state.IsFunction(-1) { - state.Pop(2) // Pop context and non-function - s.debugLog("__execute_script is not a function") - return nil, fmt.Errorf("sandbox execution function not found") - } - - // Stack setup for call: __execute_script, bytecode function, context - state.PushCopy(-3) // bytecode function (copy from -3) - state.PushCopy(-3) // context (copy from -3) - - // Clean up duplicate references - state.Remove(-5) // Remove original bytecode function - state.Remove(-4) // Remove original context - - // Call with 2 args (function, context), 1 result - if err := state.Call(2, 1); err != nil { - s.debugLog("Execution failed: %v", err) - return nil, fmt.Errorf("script execution failed: %w", err) - } - - // Get result - result, err := state.ToValue(-1) - state.Pop(1) // Pop result - - // Extract session data if it was modified - if ctx != nil && ctx.Session != nil { - // Check if session was modified - state.GetGlobal("__session_modified") - if state.IsBoolean(-1) && state.ToBoolean(-1) { - ctx.SessionModified = true - - // Extract session data - state.GetGlobal("__session_data") - if state.IsTable(-1) { - // Clear existing data and extract new data from Lua - sessionData := make(map[string]any) - - // Extract new session data - state.PushNil() // Start iteration - for state.Next(-2) { - // Stack now has key at -2 and value at -1 - if state.IsString(-2) { - key := state.ToString(-2) - value, err := state.ToValue(-1) - if err == nil { - sessionData[key] = value - } - } - state.Pop(1) // Pop value, leave key for next iteration - } - - // Update session with the new data - for k, v := range sessionData { - if err := ctx.Session.Set(k, v); err != nil { - s.debugLog("Failed to set session value %s: %v", k, err) - } - } - } - state.Pop(1) // Pop __session_data - } - state.Pop(1) // Pop __session_modified - } - - // Check for HTTP response - httpResponse, hasResponse := GetHTTPResponse(state) - if hasResponse { - // Add the script result as the response body - httpResponse.Body = result - - // Mark session as modified if needed - if ctx != nil && ctx.SessionModified { - httpResponse.SessionModified = true - } - - // If we have a fasthttp context, apply the response directly - if ctx != nil && ctx.RequestCtx != nil { - // If session was modified, save it - if ctx.SessionModified && ctx.Session != nil { - // Save session and set cookie if needed - sessions.GlobalSessionManager.SaveSession(ctx.Session) - - // Add session cookie to the response - cookieOpts := sessions.GlobalSessionManager.CookieOptions() - cookie := fasthttp.AcquireCookie() - cookie.SetKey(cookieOpts["name"].(string)) - cookie.SetValue(ctx.Session.ID) - cookie.SetPath(cookieOpts["path"].(string)) - - if domain, ok := cookieOpts["domain"].(string); ok && domain != "" { - cookie.SetDomain(domain) - } - - if maxAge, ok := cookieOpts["max_age"].(int); ok && maxAge > 0 { - cookie.SetMaxAge(maxAge) - } - - cookie.SetSecure(cookieOpts["secure"].(bool)) - cookie.SetHTTPOnly(cookieOpts["http_only"].(bool)) - - // Add to response cookies - httpResponse.Cookies = append(httpResponse.Cookies, cookie) - } - - ApplyHTTPResponse(httpResponse, ctx.RequestCtx) - ReleaseResponse(httpResponse) - return nil, nil // No need to return response object - } - - return httpResponse, nil - } - - // If we have a fasthttp context and the result needs to be written directly - if ctx != nil && ctx.RequestCtx != nil && (result != nil) { - // For direct HTTP responses - switch r := result.(type) { - case string: - ctx.RequestCtx.SetBodyString(r) - case []byte: - ctx.RequestCtx.SetBody(r) - case map[string]any, []any: - // JSON response - ctx.RequestCtx.Response.Header.SetContentType("application/json") - if err := json.NewEncoder(buf).Encode(r); err == nil { - ctx.RequestCtx.SetBody(buf.Bytes()) - } else { - ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r)) - } - default: - // Default string conversion - ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r)) - } - - // Handle session if modified - if ctx.SessionModified && ctx.Session != nil { - // Save session - sessions.GlobalSessionManager.SaveSession(ctx.Session) - - // Add session cookie - cookieOpts := sessions.GlobalSessionManager.CookieOptions() - cookie := fasthttp.AcquireCookie() - cookie.SetKey(cookieOpts["name"].(string)) - cookie.SetValue(ctx.Session.ID) - cookie.SetPath(cookieOpts["path"].(string)) - - if domain, ok := cookieOpts["domain"].(string); ok && domain != "" { - cookie.SetDomain(domain) - } - - if maxAge, ok := cookieOpts["max_age"].(int); ok && maxAge > 0 { - cookie.SetMaxAge(maxAge) - } - - cookie.SetSecure(cookieOpts["secure"].(bool)) - cookie.SetHTTPOnly(cookieOpts["http_only"].(bool)) - - // Add to response - ctx.RequestCtx.Response.Header.SetCookie(cookie) - } - - return nil, nil - } - - return result, err -} diff --git a/core/runner/sandbox/Utils.go b/core/runner/sandbox/Utils.go deleted file mode 100644 index 75863a9..0000000 --- a/core/runner/sandbox/Utils.go +++ /dev/null @@ -1,58 +0,0 @@ -package sandbox - -import ( - "crypto/rand" - "encoding/base64" - - "Moonshark/core/utils/logger" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// UtilModuleInitFunc returns an initializer for the util module -func UtilModuleInitFunc() func(*luajit.State) error { - return func(state *luajit.State) error { - return RegisterModule(state, "util", UtilModuleFunctions()) - } -} - -// UtilModuleFunctions returns all functions for the util module -func UtilModuleFunctions() map[string]luajit.GoFunction { - return map[string]luajit.GoFunction{ - "generate_token": GenerateToken, - } -} - -// GenerateToken creates a cryptographically secure random token -func GenerateToken(s *luajit.State) int { - // Get the length from the Lua arguments (default to 32) - length := 32 - if s.GetTop() >= 1 && s.IsNumber(1) { - length = int(s.ToNumber(1)) - } - - // Enforce minimum length for security - if length < 16 { - length = 16 - } - - // Generate secure random bytes - tokenBytes := make([]byte, length) - if _, err := rand.Read(tokenBytes); err != nil { - s.PushString("") - logger.Error("Failed to generate secure token: %v", err) - return 1 // Return empty string on error - } - - // Encode as base64 - token := base64.RawURLEncoding.EncodeToString(tokenBytes) - - // Trim to requested length (base64 might be longer) - if len(token) > length { - token = token[:length] - } - - // Push the token to the Lua stack - s.PushString(token) - return 1 // One return value -} diff --git a/core/sessions/Manager.go b/core/sessions/Manager.go index 9e61e34..29ef6ae 100644 --- a/core/sessions/Manager.go +++ b/core/sessions/Manager.go @@ -9,6 +9,7 @@ import ( "github.com/VictoriaMetrics/fastcache" "github.com/goccy/go-json" + "github.com/valyala/fasthttp" ) const ( @@ -75,7 +76,7 @@ func (sm *SessionManager) GetSession(id string) *Session { // Store back with updated timestamp updatedData, _ := json.Marshal(session) - sm.cache.Set([]byte(id), updatedData) // Use updatedData, not data + sm.cache.Set([]byte(id), updatedData) return session } @@ -141,5 +142,39 @@ func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, ht sm.cookieMaxAge = maxAge } +// GetSessionFromRequest extracts the session from a request context +func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session { + cookie := ctx.Request.Header.Cookie(sm.cookieName) + if len(cookie) == 0 { + // No session cookie, create a new session + return sm.CreateSession() + } + + // Session cookie exists, get the session + return sm.GetSession(string(cookie)) +} + +// SaveSessionToResponse adds the session cookie to an HTTP response +func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) { + cookie := fasthttp.AcquireCookie() + defer fasthttp.ReleaseCookie(cookie) + + sm.mu.RLock() + cookie.SetKey(sm.cookieName) + cookie.SetValue(session.ID) + cookie.SetPath(sm.cookiePath) + cookie.SetHTTPOnly(sm.cookieHTTPOnly) + cookie.SetMaxAge(sm.cookieMaxAge) + + if sm.cookieDomain != "" { + cookie.SetDomain(sm.cookieDomain) + } + + cookie.SetSecure(sm.cookieSecure) + sm.mu.RUnlock() + + ctx.Response.Header.SetCookie(cookie) +} + // GlobalSessionManager is the default session manager instance var GlobalSessionManager = NewSessionManager()