diff --git a/core/http/Csrf.go b/core/http/Csrf.go index 7347ad4..1b7d7bb 100644 --- a/core/http/Csrf.go +++ b/core/http/Csrf.go @@ -2,6 +2,7 @@ package http import ( "Moonshark/core/runner" + luaCtx "Moonshark/core/runner/context" "Moonshark/core/utils" "Moonshark/core/utils/logger" "crypto/subtle" @@ -12,7 +13,7 @@ import ( ) // ValidateCSRFToken checks if the CSRF token is valid for a request -func ValidateCSRFToken(state *luajit.State, ctx *runner.Context) bool { +func ValidateCSRFToken(state *luajit.State, ctx *luaCtx.Context) bool { // Only validate for form submissions method, ok := ctx.Get("method").(string) if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") { @@ -73,7 +74,7 @@ func ValidateCSRFToken(state *luajit.State, ctx *runner.Context) bool { // WithCSRFProtection creates a runner option to add CSRF protection func WithCSRFProtection() runner.RunnerOption { return func(r *runner.Runner) { - r.AddInitHook(func(state *luajit.State, ctx *runner.Context) error { + r.AddInitHook(func(state *luajit.State, ctx *luaCtx.Context) error { // Get request method method, ok := ctx.Get("method").(string) if !ok { diff --git a/core/http/Server.go b/core/http/Server.go index 4337e83..a6c5bdb 100644 --- a/core/http/Server.go +++ b/core/http/Server.go @@ -9,7 +9,9 @@ import ( "Moonshark/core/metadata" "Moonshark/core/routers" "Moonshark/core/runner" + luaCtx "Moonshark/core/runner/context" "Moonshark/core/runner/sandbox" + "Moonshark/core/sessions" "Moonshark/core/utils" "Moonshark/core/utils/config" "Moonshark/core/utils/logger" @@ -160,8 +162,9 @@ func HandleMethodNotAllowed(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPag } // handleLuaRoute executes a Lua route +// Updated handleLuaRoute function to handle sessions func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params) { - luaCtx := runner.NewHTTPContext(ctx) // Use NewHTTPContext instead of NewContext + luaCtx := luaCtx.NewHTTPContext(ctx) defer luaCtx.Release() method := string(ctx.Method()) @@ -223,6 +226,27 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip luaCtx.Set("form", make(map[string]any)) } + // Session handling + cookieOpts := sessions.GlobalSessionManager.CookieOptions() + cookieName := cookieOpts["name"].(string) + sessionCookie := ctx.Request.Header.Cookie(cookieName) + + var sessionID string + if sessionCookie != nil { + sessionID = string(sessionCookie) + } + + // Get or create session + var session *sessions.Session + if sessionID != "" { + session = sessions.GlobalSessionManager.GetSession(sessionID) + } else { + session = sessions.GlobalSessionManager.CreateSession() + } + + // Set session in context + luaCtx.Session = session + // Execute Lua script result, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath) @@ -243,6 +267,31 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip return } + // Handle session updates if needed + if luaCtx.SessionModified { + sessions.GlobalSessionManager.SaveSession(luaCtx.Session) + + // Set session cookie + cookie := fasthttp.AcquireCookie() + cookie.SetKey(cookieName) + cookie.SetValue(luaCtx.Session.ID) + cookie.SetPath(cookieOpts["path"].(string)) + + if domain, ok := cookieOpts["domain"].(string); ok && domain != "" { + cookie.SetDomain(domain) + } + + if maxAge, ok := cookieOpts["max_age"].(int); ok { + cookie.SetMaxAge(maxAge) + } + + cookie.SetSecure(cookieOpts["secure"].(bool)) + cookie.SetHTTPOnly(cookieOpts["http_only"].(bool)) + + ctx.Response.Header.SetCookie(cookie) + fasthttp.ReleaseCookie(cookie) + } + // If we got a non-nil result, write it to the response if result != nil { writeResponse(ctx, result) diff --git a/core/runner/Runner.go b/core/runner/Runner.go index e6eb010..40952a9 100644 --- a/core/runner/Runner.go +++ b/core/runner/Runner.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + luaCtx "Moonshark/core/runner/context" "Moonshark/core/runner/sandbox" "Moonshark/core/utils/logger" @@ -35,10 +36,10 @@ type State struct { } // InitHook runs before executing a script -type InitHook func(*luajit.State, *Context) error +type InitHook func(*luajit.State, *luaCtx.Context) error // FinalizeHook runs after executing a script -type FinalizeHook func(*luajit.State, *Context, any) error +type FinalizeHook func(*luajit.State, *luaCtx.Context, any) error // Runner runs Lua scripts using a pool of Lua states type Runner struct { @@ -216,7 +217,7 @@ func (r *Runner) createState(index int) (*State, error) { } // Execute runs a script with context -func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) { +func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *luaCtx.Context, scriptPath string) (any, error) { if !r.isRunning.Load() { return nil, ErrRunnerClosed } @@ -282,7 +283,7 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, if execCtx != nil && execCtx.RequestCtx != nil { // Use OptimizedExecute directly with the full context if we have RequestCtx - result, err = state.sandbox.OptimizedExecute(state.L, bytecode, &sandbox.Context{ + result, err = state.sandbox.OptimizedExecute(state.L, bytecode, &luaCtx.Context{ Values: ctxValues, RequestCtx: execCtx.RequestCtx, }) @@ -326,7 +327,7 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, } // Run executes a Lua script (convenience wrapper) -func (r *Runner) Run(bytecode []byte, execCtx *Context, scriptPath string) (any, error) { +func (r *Runner) Run(bytecode []byte, execCtx *luaCtx.Context, scriptPath string) (any, error) { return r.Execute(context.Background(), bytecode, execCtx, scriptPath) } diff --git a/core/runner/Sessions.go b/core/runner/Sessions.go index 6a474e0..1077841 100644 --- a/core/runner/Sessions.go +++ b/core/runner/Sessions.go @@ -1,6 +1,7 @@ package runner import ( + luaCtx "Moonshark/core/runner/context" "Moonshark/core/runner/sandbox" "Moonshark/core/sessions" "Moonshark/core/utils/logger" @@ -38,7 +39,7 @@ func WithSessionManager(manager *sessions.SessionManager) RunnerOption { } // preRequestHook initializes session before script execution -func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error { +func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *luaCtx.Context) error { if ctx == nil || ctx.Values["_request_cookies"] == nil { return nil } @@ -78,7 +79,7 @@ func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error } // postRequestHook handles session after script execution -func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, result any) error { +func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *luaCtx.Context, result any) error { // Check if session was modified modifiedID, modifiedData, modified := GetSessionData(state) if !modified { diff --git a/core/runner/Context.go b/core/runner/context/Context.go similarity index 91% rename from core/runner/Context.go rename to core/runner/context/Context.go index e916761..b093e2f 100644 --- a/core/runner/Context.go +++ b/core/runner/context/Context.go @@ -3,6 +3,8 @@ package runner import ( "sync" + "Moonshark/core/sessions" + "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) @@ -15,6 +17,10 @@ type Context struct { // FastHTTP context if this was created from an HTTP request RequestCtx *fasthttp.RequestCtx + // Session data and management + Session *sessions.Session + SessionModified bool + // Buffer for efficient string operations buffer *bytebufferpool.ByteBuffer } @@ -47,6 +53,10 @@ func (c *Context) Release() { delete(c.Values, k) } + // Reset session info + c.Session = nil + c.SessionModified = false + // Reset request context c.RequestCtx = nil diff --git a/core/runner/sandbox/Http.go b/core/runner/sandbox/Http.go index 18bc767..2482dd8 100644 --- a/core/runner/sandbox/Http.go +++ b/core/runner/sandbox/Http.go @@ -18,12 +18,19 @@ import ( luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) +// SessionHandler interface for session management +type SessionHandler interface { + LoadSession(ctx *fasthttp.RequestCtx) (string, map[string]any) + SaveSession(ctx *fasthttp.RequestCtx, sessionID string, data map[string]any) bool +} + // HTTPResponse represents an HTTP response from Lua type HTTPResponse struct { - Status int `json:"status"` - Headers map[string]string `json:"headers"` - Body any `json:"body"` - Cookies []*fasthttp.Cookie `json:"-"` + Status int `json:"status"` + Headers map[string]string `json:"headers"` + Body any `json:"body"` + Cookies []*fasthttp.Cookie `json:"-"` + SessionModified bool `json:"-"` } // Response pool to reduce allocations @@ -84,6 +91,9 @@ func ReleaseResponse(resp *HTTPResponse) { // Clear cookies resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0 + // Reset session flag + resp.SessionModified = false + // Clear body resp.Body = nil @@ -190,6 +200,13 @@ func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) { } state.Pop(1) + // Check if session was modified + state.GetGlobal("__session_modified") + if state.IsBoolean(-1) && state.ToBoolean(-1) { + response.SessionModified = true + } + state.Pop(1) + // Clean up state.Pop(2) // Pop response table and __http_responses diff --git a/core/runner/sandbox/Sandbox.go b/core/runner/sandbox/Sandbox.go index b502df0..654d123 100644 --- a/core/runner/sandbox/Sandbox.go +++ b/core/runner/sandbox/Sandbox.go @@ -8,6 +8,8 @@ import ( "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" + luaCtx "Moonshark/core/runner/context" + "Moonshark/core/sessions" "Moonshark/core/utils/logger" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" @@ -126,7 +128,7 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error { func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) { // Create a temporary context if we only have a map if ctx != nil { - tempCtx := &Context{ + tempCtx := &luaCtx.Context{ Values: ctx, } return s.OptimizedExecute(state, bytecode, tempCtx) @@ -136,16 +138,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a return s.OptimizedExecute(state, bytecode, nil) } -// Context represents execution context for a Lua script -type Context struct { - // Values stores any context values (route params, HTTP request info, etc.) - Values map[string]any - // RequestCtx for HTTP requests - RequestCtx *fasthttp.RequestCtx -} - // OptimizedExecute runs bytecode with a fasthttp context if available -func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Context) (any, error) { +func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *luaCtx.Context) (any, error) { // Use a buffer from the pool for any string operations buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) @@ -164,6 +158,35 @@ func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Co ctxValues = nil } + // Initialize session tracking in Lua + if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil { + s.debugLog("Failed to initialize session data: %v", err) + } + + // Load session data if available + if ctx != nil && ctx.Session != nil { + // Set session ID in Lua + sessionIDCode := fmt.Sprintf("__session_id = %q", ctx.Session.ID) + if err := state.DoString(sessionIDCode); err != nil { + s.debugLog("Failed to set session ID: %v", err) + } + + // Get session data and populate Lua table + state.GetGlobal("__session_data") + if state.IsTable(-1) { + sessionData := ctx.Session.GetAll() + for k, v := range sessionData { + state.PushString(k) + if err := state.PushValue(v); err != nil { + s.debugLog("Failed to push session value %s: %v", k, err) + continue + } + state.SetTable(-3) + } + } + state.Pop(1) // Pop __session_data + } + // Prepare context table if ctxValues != nil { state.CreateTable(0, len(ctxValues)) @@ -206,14 +229,85 @@ func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Co result, err := state.ToValue(-1) state.Pop(1) // Pop result + // Extract session data if it was modified + if ctx != nil && ctx.Session != nil { + // Check if session was modified + state.GetGlobal("__session_modified") + if state.IsBoolean(-1) && state.ToBoolean(-1) { + ctx.SessionModified = true + + // Extract session data + state.GetGlobal("__session_data") + if state.IsTable(-1) { + // Clear existing data and extract new data from Lua + sessionData := make(map[string]any) + + // Extract new session data + state.PushNil() // Start iteration + for state.Next(-2) { + // Stack now has key at -2 and value at -1 + if state.IsString(-2) { + key := state.ToString(-2) + value, err := state.ToValue(-1) + if err == nil { + sessionData[key] = value + } + } + state.Pop(1) // Pop value, leave key for next iteration + } + + // Update session with the new data + for k, v := range sessionData { + if err := ctx.Session.Set(k, v); err != nil { + s.debugLog("Failed to set session value %s: %v", k, err) + } + } + } + state.Pop(1) // Pop __session_data + } + state.Pop(1) // Pop __session_modified + } + // Check for HTTP response httpResponse, hasResponse := GetHTTPResponse(state) if hasResponse { // Add the script result as the response body httpResponse.Body = result + // Mark session as modified if needed + if ctx != nil && ctx.SessionModified { + httpResponse.SessionModified = true + } + // If we have a fasthttp context, apply the response directly if ctx != nil && ctx.RequestCtx != nil { + // If session was modified, save it + if ctx.SessionModified && ctx.Session != nil { + // Save session and set cookie if needed + sessions.GlobalSessionManager.SaveSession(ctx.Session) + + // Add session cookie to the response + cookieOpts := sessions.GlobalSessionManager.CookieOptions() + cookie := fasthttp.AcquireCookie() + cookie.SetKey(cookieOpts["name"].(string)) + cookie.SetValue(ctx.Session.ID) + cookie.SetPath(cookieOpts["path"].(string)) + + if domain, ok := cookieOpts["domain"].(string); ok && domain != "" { + cookie.SetDomain(domain) + } + + if maxAge, ok := cookieOpts["max_age"].(int); ok && maxAge > 0 { + cookie.SetMaxAge(maxAge) + } + + cookie.SetSecure(cookieOpts["secure"].(bool)) + cookie.SetHTTPOnly(cookieOpts["http_only"].(bool)) + + // Add to response cookies + httpResponse.Cookies = append(httpResponse.Cookies, cookie) + } + ApplyHTTPResponse(httpResponse, ctx.RequestCtx) ReleaseResponse(httpResponse) return nil, nil // No need to return response object @@ -242,6 +336,34 @@ func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Co // Default string conversion ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r)) } + + // Handle session if modified + if ctx.SessionModified && ctx.Session != nil { + // Save session + sessions.GlobalSessionManager.SaveSession(ctx.Session) + + // Add session cookie + cookieOpts := sessions.GlobalSessionManager.CookieOptions() + cookie := fasthttp.AcquireCookie() + cookie.SetKey(cookieOpts["name"].(string)) + cookie.SetValue(ctx.Session.ID) + cookie.SetPath(cookieOpts["path"].(string)) + + if domain, ok := cookieOpts["domain"].(string); ok && domain != "" { + cookie.SetDomain(domain) + } + + if maxAge, ok := cookieOpts["max_age"].(int); ok && maxAge > 0 { + cookie.SetMaxAge(maxAge) + } + + cookie.SetSecure(cookieOpts["secure"].(bool)) + cookie.SetHTTPOnly(cookieOpts["http_only"].(bool)) + + // Add to response + ctx.RequestCtx.Response.Header.SetCookie(cookie) + } + return nil, nil }