244 lines
6.0 KiB
Go
244 lines
6.0 KiB
Go
package runner
|
|
|
|
import (
|
|
"github.com/valyala/fasthttp"
|
|
|
|
"Moonshark/core/runner/sandbox"
|
|
"Moonshark/core/sessions"
|
|
"Moonshark/core/utils/logger"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
// SessionHandler handles session management for Lua scripts
|
|
type SessionHandler struct {
|
|
manager *sessions.SessionManager
|
|
debugLog bool
|
|
}
|
|
|
|
// NewSessionHandler creates a new session handler
|
|
func NewSessionHandler(manager *sessions.SessionManager) *SessionHandler {
|
|
return &SessionHandler{
|
|
manager: manager,
|
|
debugLog: false,
|
|
}
|
|
}
|
|
|
|
// EnableDebug enables debug logging
|
|
func (h *SessionHandler) EnableDebug() {
|
|
h.debugLog = true
|
|
}
|
|
|
|
// debug logs a message if debug is enabled
|
|
func (h *SessionHandler) debug(format string, args ...interface{}) {
|
|
if h.debugLog {
|
|
logger.Debug("[SessionHandler] "+format, args...)
|
|
}
|
|
}
|
|
|
|
// WithSessionManager creates a RunnerOption to add session support
|
|
func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
|
|
return func(r *Runner) {
|
|
handler := NewSessionHandler(manager)
|
|
|
|
// Add hooks to the runner
|
|
r.AddInitHook(handler.preRequestHook)
|
|
r.AddFinalizeHook(handler.postRequestHook)
|
|
}
|
|
}
|
|
|
|
// preRequestHook is called before executing a request
|
|
func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error {
|
|
h.debug("Running pre-request session hook")
|
|
|
|
// Check if we have cookie information in context
|
|
// Instead of raw request, we now look for the cookie map
|
|
if ctx == nil || ctx.Values["_request_cookies"] == nil {
|
|
return nil
|
|
}
|
|
|
|
// Extract cookies from context
|
|
cookies, ok := ctx.Values["_request_cookies"].(map[string]any)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
// Get the session ID from cookies
|
|
cookieName := h.manager.CookieOptions()["name"].(string)
|
|
var sessionID string
|
|
|
|
// Check if our session cookie exists
|
|
if cookieValue, exists := cookies[cookieName]; exists {
|
|
if strValue, ok := cookieValue.(string); ok && strValue != "" {
|
|
sessionID = strValue
|
|
h.debug("Found existing session ID: %s", sessionID)
|
|
}
|
|
}
|
|
|
|
// If no session ID found, create new session
|
|
if sessionID == "" {
|
|
// Create a new session
|
|
session := h.manager.CreateSession()
|
|
sessionID = session.ID
|
|
h.debug("Created new session with ID: %s", sessionID)
|
|
}
|
|
|
|
// Store the session ID in the context for later use
|
|
ctx.Set("_session_id", sessionID)
|
|
|
|
// Get the session data
|
|
session := h.manager.GetSession(sessionID)
|
|
sessionData := session.GetAll()
|
|
|
|
// Set session data in Lua state
|
|
if err := SetSessionData(state, sessionID, sessionData); err != nil {
|
|
h.debug("Failed to set session data: %v", err)
|
|
return err
|
|
}
|
|
|
|
h.debug("Session data initialized successfully")
|
|
return nil
|
|
}
|
|
|
|
// postRequestHook is called after executing a request
|
|
func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, result any) error {
|
|
h.debug("Running post-request session hook")
|
|
|
|
// Check if session was modified
|
|
modifiedID, modifiedData, modified := GetSessionData(state)
|
|
if !modified {
|
|
h.debug("Session not modified, skipping")
|
|
return nil
|
|
}
|
|
|
|
// Get the original session ID from context
|
|
var sessionID string
|
|
if ctx != nil {
|
|
if id, ok := ctx.Values["_session_id"].(string); ok {
|
|
sessionID = id
|
|
}
|
|
}
|
|
|
|
// Use the original session ID if the modified one is empty
|
|
if modifiedID == "" {
|
|
modifiedID = sessionID
|
|
}
|
|
|
|
if modifiedID == "" {
|
|
h.debug("No session ID found, cannot persist session data")
|
|
return nil
|
|
}
|
|
|
|
h.debug("Persisting modified session data for ID: %s", modifiedID)
|
|
|
|
// Update session in manager
|
|
session := h.manager.GetSession(modifiedID)
|
|
session.Clear() // clear to sync deleted values
|
|
for k, v := range modifiedData {
|
|
session.Set(k, v)
|
|
}
|
|
|
|
h.manager.SaveSession(session)
|
|
|
|
// Add session cookie to result if it's an HTTP response
|
|
if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
|
|
h.addSessionCookie(httpResp, modifiedID)
|
|
}
|
|
|
|
h.debug("Session data persisted successfully")
|
|
return nil
|
|
}
|
|
|
|
// addSessionCookie adds a session cookie to an HTTP response
|
|
func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) {
|
|
// Get cookie options
|
|
opts := h.manager.CookieOptions()
|
|
|
|
// Check if session cookie is already set
|
|
cookieName := opts["name"].(string)
|
|
for _, cookie := range resp.Cookies {
|
|
if string(cookie.Key()) == cookieName {
|
|
h.debug("Session cookie already set in response")
|
|
return
|
|
}
|
|
}
|
|
|
|
h.debug("Adding session cookie to response")
|
|
|
|
// Create and add cookie
|
|
cookie := fasthttp.AcquireCookie()
|
|
cookie.SetKey(cookieName)
|
|
cookie.SetValue(sessionID)
|
|
cookie.SetPath(opts["path"].(string))
|
|
cookie.SetHTTPOnly(opts["http_only"].(bool))
|
|
cookie.SetMaxAge(opts["max_age"].(int))
|
|
|
|
// Optional cookie parameters
|
|
if domain, ok := opts["domain"].(string); ok && domain != "" {
|
|
cookie.SetDomain(domain)
|
|
}
|
|
|
|
if secure, ok := opts["secure"].(bool); ok {
|
|
cookie.SetSecure(secure)
|
|
}
|
|
|
|
resp.Cookies = append(resp.Cookies, cookie)
|
|
}
|
|
|
|
// GetSessionData extracts session data from Lua state
|
|
func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
|
|
// Check if session was modified
|
|
state.GetGlobal("__session_modified")
|
|
modified := state.ToBoolean(-1)
|
|
state.Pop(1)
|
|
|
|
if !modified {
|
|
return "", nil, false
|
|
}
|
|
|
|
// Get session ID
|
|
state.GetGlobal("__session_id")
|
|
sessionID := state.ToString(-1)
|
|
state.Pop(1)
|
|
|
|
// Get session data
|
|
state.GetGlobal("__session_data")
|
|
if !state.IsTable(-1) {
|
|
state.Pop(1)
|
|
return sessionID, nil, false
|
|
}
|
|
|
|
data, err := state.ToTable(-1)
|
|
state.Pop(1)
|
|
|
|
if err != nil {
|
|
logger.Error("Failed to extract session data: %v", err)
|
|
return sessionID, nil, false
|
|
}
|
|
|
|
return sessionID, data, true
|
|
}
|
|
|
|
// SetSessionData sets session data in Lua state
|
|
func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
|
|
// Set session ID
|
|
state.PushString(sessionID)
|
|
state.SetGlobal("__session_id")
|
|
|
|
// Set session data
|
|
if data == nil {
|
|
data = make(map[string]any)
|
|
}
|
|
|
|
if err := state.PushTable(data); err != nil {
|
|
return err
|
|
}
|
|
state.SetGlobal("__session_data")
|
|
|
|
// Reset modification flag
|
|
state.PushBoolean(false)
|
|
state.SetGlobal("__session_modified")
|
|
|
|
return nil
|
|
}
|