Moonshark/core/runner/Sessions.go
2025-04-08 22:10:50 -05:00

211 lines
4.9 KiB
Go

package runner
import (
"Moonshark/core/runner/sandbox"
"Moonshark/core/sessions"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/valyala/fasthttp"
)
// SessionHandler handles session management for Lua scripts
type SessionHandler struct {
manager *sessions.SessionManager
debugLog bool
}
// NewSessionHandler creates a new session handler
func NewSessionHandler(manager *sessions.SessionManager) *SessionHandler {
return &SessionHandler{
manager: manager,
debugLog: false,
}
}
// EnableDebug enables debug logging
func (h *SessionHandler) EnableDebug() {
h.debugLog = true
}
// WithSessionManager creates a RunnerOption to add session support
func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
return func(r *Runner) {
handler := NewSessionHandler(manager)
r.AddInitHook(handler.preRequestHook)
r.AddFinalizeHook(handler.postRequestHook)
}
}
// preRequestHook initializes session before script execution
func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error {
if ctx == nil || ctx.Values["_request_cookies"] == nil {
return nil
}
// Extract cookies from context
cookies, ok := ctx.Values["_request_cookies"].(map[string]any)
if !ok {
return nil
}
// Get the session ID from cookies
cookieName := h.manager.CookieOptions()["name"].(string)
var sessionID string
// Check if our session cookie exists
if cookieValue, exists := cookies[cookieName]; exists {
if strValue, ok := cookieValue.(string); ok && strValue != "" {
sessionID = strValue
}
}
// Create new session if needed
if sessionID == "" {
session := h.manager.CreateSession()
sessionID = session.ID
}
// Store the session ID in the context
ctx.Set("_session_id", sessionID)
// Get session data
session := h.manager.GetSession(sessionID)
sessionData := session.GetAll()
// Set session data in Lua state
return SetSessionData(state, sessionID, sessionData)
}
// postRequestHook handles session after script execution
func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, result any) error {
// Check if session was modified
modifiedID, modifiedData, modified := GetSessionData(state)
if !modified {
return nil
}
// Get the original session ID from context
var sessionID string
if ctx != nil {
if id, ok := ctx.Values["_session_id"].(string); ok {
sessionID = id
}
}
// Use the original session ID if the modified one is empty
if modifiedID == "" {
modifiedID = sessionID
}
if modifiedID == "" {
return nil
}
// Update session in manager
session := h.manager.GetSession(modifiedID)
session.Clear() // clear to sync deleted values
for k, v := range modifiedData {
session.Set(k, v)
}
h.manager.SaveSession(session)
// Add session cookie to result if it's an HTTP response
if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
h.addSessionCookie(httpResp, modifiedID)
}
return nil
}
// addSessionCookie adds a session cookie to an HTTP response
func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) {
// Get cookie options
opts := h.manager.CookieOptions()
// Check if session cookie is already set
cookieName := opts["name"].(string)
for _, cookie := range resp.Cookies {
if string(cookie.Key()) == cookieName {
return
}
}
// Create and add cookie
cookie := fasthttp.AcquireCookie()
cookie.SetKey(cookieName)
cookie.SetValue(sessionID)
cookie.SetPath(opts["path"].(string))
cookie.SetHTTPOnly(opts["http_only"].(bool))
cookie.SetMaxAge(opts["max_age"].(int))
// Optional cookie parameters
if domain, ok := opts["domain"].(string); ok && domain != "" {
cookie.SetDomain(domain)
}
if secure, ok := opts["secure"].(bool); ok {
cookie.SetSecure(secure)
}
resp.Cookies = append(resp.Cookies, cookie)
}
// GetSessionData extracts session data from Lua state
func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
// Check if session was modified
state.GetGlobal("__session_modified")
modified := state.ToBoolean(-1)
state.Pop(1)
if !modified {
return "", nil, false
}
// Get session ID
state.GetGlobal("__session_id")
sessionID := state.ToString(-1)
state.Pop(1)
// Get session data
state.GetGlobal("__session_data")
if !state.IsTable(-1) {
state.Pop(1)
return sessionID, nil, false
}
data, err := state.ToTable(-1)
state.Pop(1)
if err != nil {
logger.Error("Failed to extract session data: %v", err)
return sessionID, nil, false
}
return sessionID, data, true
}
// SetSessionData sets session data in Lua state
func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
// Set session ID
state.PushString(sessionID)
state.SetGlobal("__session_id")
// Set session data
if data == nil {
data = make(map[string]any)
}
if err := state.PushTable(data); err != nil {
return err
}
state.SetGlobal("__session_data")
// Reset modification flag
state.PushBoolean(false)
state.SetGlobal("__session_modified")
return nil
}