package runner import ( "Moonshark/core/runner/sandbox" "Moonshark/core/sessions" "Moonshark/core/utils/logger" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" "github.com/valyala/fasthttp" ) // SessionHandler handles session management for Lua scripts type SessionHandler struct { manager *sessions.SessionManager debugLog bool } // NewSessionHandler creates a new session handler func NewSessionHandler(manager *sessions.SessionManager) *SessionHandler { return &SessionHandler{ manager: manager, debugLog: false, } } // EnableDebug enables debug logging func (h *SessionHandler) EnableDebug() { h.debugLog = true } // WithSessionManager creates a RunnerOption to add session support func WithSessionManager(manager *sessions.SessionManager) RunnerOption { return func(r *Runner) { handler := NewSessionHandler(manager) r.AddInitHook(handler.preRequestHook) r.AddFinalizeHook(handler.postRequestHook) } } // preRequestHook initializes session before script execution func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error { if ctx == nil || ctx.Values["_request_cookies"] == nil { return nil } // Extract cookies from context cookies, ok := ctx.Values["_request_cookies"].(map[string]any) if !ok { return nil } // Get the session ID from cookies cookieName := h.manager.CookieOptions()["name"].(string) var sessionID string // Check if our session cookie exists if cookieValue, exists := cookies[cookieName]; exists { if strValue, ok := cookieValue.(string); ok && strValue != "" { sessionID = strValue } } // Create new session if needed if sessionID == "" { session := h.manager.CreateSession() sessionID = session.ID } // Store the session ID in the context ctx.Set("_session_id", sessionID) // Get session data session := h.manager.GetSession(sessionID) sessionData := session.GetAll() // Set session data in Lua state return SetSessionData(state, sessionID, sessionData) } // postRequestHook handles session after script execution func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, result any) error { // Check if session was modified modifiedID, modifiedData, modified := GetSessionData(state) if !modified { return nil } // Get the original session ID from context var sessionID string if ctx != nil { if id, ok := ctx.Values["_session_id"].(string); ok { sessionID = id } } // Use the original session ID if the modified one is empty if modifiedID == "" { modifiedID = sessionID } if modifiedID == "" { return nil } // Update session in manager session := h.manager.GetSession(modifiedID) session.Clear() // clear to sync deleted values for k, v := range modifiedData { session.Set(k, v) } h.manager.SaveSession(session) // Add session cookie to result if it's an HTTP response if httpResp, ok := result.(*sandbox.HTTPResponse); ok { h.addSessionCookie(httpResp, modifiedID) } return nil } // addSessionCookie adds a session cookie to an HTTP response func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) { // Get cookie options opts := h.manager.CookieOptions() // Check if session cookie is already set cookieName := opts["name"].(string) for _, cookie := range resp.Cookies { if string(cookie.Key()) == cookieName { return } } // Create and add cookie cookie := fasthttp.AcquireCookie() cookie.SetKey(cookieName) cookie.SetValue(sessionID) cookie.SetPath(opts["path"].(string)) cookie.SetHTTPOnly(opts["http_only"].(bool)) cookie.SetMaxAge(opts["max_age"].(int)) // Optional cookie parameters if domain, ok := opts["domain"].(string); ok && domain != "" { cookie.SetDomain(domain) } if secure, ok := opts["secure"].(bool); ok { cookie.SetSecure(secure) } resp.Cookies = append(resp.Cookies, cookie) } // GetSessionData extracts session data from Lua state func GetSessionData(state *luajit.State) (string, map[string]any, bool) { // Check if session was modified state.GetGlobal("__session_modified") modified := state.ToBoolean(-1) state.Pop(1) if !modified { return "", nil, false } // Get session ID state.GetGlobal("__session_id") sessionID := state.ToString(-1) state.Pop(1) // Get session data state.GetGlobal("__session_data") if !state.IsTable(-1) { state.Pop(1) return sessionID, nil, false } data, err := state.ToTable(-1) state.Pop(1) if err != nil { logger.Error("Failed to extract session data: %v", err) return sessionID, nil, false } return sessionID, data, true } // SetSessionData sets session data in Lua state func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error { // Set session ID state.PushString(sessionID) state.SetGlobal("__session_id") // Set session data if data == nil { data = make(map[string]any) } if err := state.PushTable(data); err != nil { return err } state.SetGlobal("__session_data") // Reset modification flag state.PushBoolean(false) state.SetGlobal("__session_modified") return nil }