From 0abf31ed3a0feeb1daece497eb95b45d22cf201e Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 10 Apr 2025 07:51:15 -0500 Subject: [PATCH] work on sessions --- core/http/Csrf.go | 44 ++++++++++++++-- core/http/Server.go | 7 ++- core/runner/Context.go | 24 +-------- core/runner/Sandbox.go | 110 ++++++++++------------------------------ core/runner/sandbox.lua | 27 +++++----- 5 files changed, 84 insertions(+), 128 deletions(-) diff --git a/core/http/Csrf.go b/core/http/Csrf.go index 902e689..f526793 100644 --- a/core/http/Csrf.go +++ b/core/http/Csrf.go @@ -35,13 +35,21 @@ func ValidateCSRFToken(ctx *runner.Context) bool { return false } - // Get token from session - sessionData := ctx.SessionData - if sessionData == nil { + // Get session from context + sessionMap, ok := ctx.Get("session").(map[string]any) + if !ok || sessionMap == nil { logger.Warning("CSRF validation failed: no session data") return false } + // Get session data + sessionData, ok := sessionMap["data"].(map[string]any) + if !ok || sessionData == nil { + logger.Warning("CSRF validation failed: no session data map") + return false + } + + // Get token from session sessionToken, ok := sessionData["_csrf_token"].(string) if !ok || sessionToken == "" { logger.Warning("CSRF validation failed: no token in session") @@ -79,15 +87,41 @@ func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) { return "", err } + // Get session from context + sessionMap, ok := ctx.Get("session").(map[string]any) + if !ok || sessionMap == nil { + return "", errors.New("no session found in context") + } + + // Get session data + sessionData, ok := sessionMap["data"].(map[string]any) + if !ok { + // Initialize session data if it doesn't exist + sessionData = make(map[string]any) + sessionMap["data"] = sessionData + } + // Store token in session - ctx.SessionData["_csrf_token"] = token + sessionData["_csrf_token"] = token return token, nil } // GetCSRFToken retrieves the current CSRF token or generates a new one func GetCSRFToken(ctx *runner.Context) (string, error) { + // Get session from context + sessionMap, ok := ctx.Get("session").(map[string]any) + if !ok || sessionMap == nil { + return "", errors.New("no session found in context") + } + + // Get session data + sessionData, ok := sessionMap["data"].(map[string]any) + if !ok || sessionData == nil { + return GenerateCSRFToken(ctx, 32) + } + // Check if token already exists in session - if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" { + if token, ok := sessionData["_csrf_token"].(string); ok && token != "" { return token, nil } diff --git a/core/http/Server.go b/core/http/Server.go index 1968bbe..a357b2b 100644 --- a/core/http/Server.go +++ b/core/http/Server.go @@ -169,8 +169,11 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip // Initialize session session := s.sessionManager.GetSessionFromRequest(ctx) - luaCtx.SessionID = session.ID - luaCtx.SessionData = session.GetAll() + sessionMap := map[string]any{ + "id": session.ID, + "data": session.Data, + } + luaCtx.Set("session", sessionMap) // URL parameters if params.Count > 0 { diff --git a/core/runner/Context.go b/core/runner/Context.go index c6dd0e1..5324645 100644 --- a/core/runner/Context.go +++ b/core/runner/Context.go @@ -15,10 +15,6 @@ type Context struct { // FastHTTP context if this was created from an HTTP request RequestCtx *fasthttp.RequestCtx - // Session information - SessionID string - SessionData map[string]any - // Buffer for efficient string operations buffer *bytebufferpool.ByteBuffer } @@ -27,8 +23,7 @@ type Context struct { var contextPool = sync.Pool{ New: func() any { return &Context{ - Values: make(map[string]any, 16), - SessionData: make(map[string]any, 8), + Values: make(map[string]any, 32), } }, } @@ -90,13 +85,6 @@ func (c *Context) Release() { delete(c.Values, k) } - for k := range c.SessionData { - delete(c.SessionData, k) - } - - // Reset session info - c.SessionID = "" - // Reset request context c.RequestCtx = nil @@ -126,13 +114,3 @@ func (c *Context) Set(key string, value any) { func (c *Context) Get(key string) any { return c.Values[key] } - -// SetSession sets a session data value -func (c *Context) SetSession(key string, value any) { - c.SessionData[key] = value -} - -// GetSession retrieves a session data value -func (c *Context) GetSession(key string) any { - return c.SessionData[key] -} diff --git a/core/runner/Sandbox.go b/core/runner/Sandbox.go index ddefd60..d6dcf93 100644 --- a/core/runner/Sandbox.go +++ b/core/runner/Sandbox.go @@ -121,18 +121,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (* return nil, fmt.Errorf("failed to load script: %w", err) } - // Add session data to context - contextWithSession := make(map[string]any) - maps.Copy(contextWithSession, ctx.Values) - - // Pass session data through context - if ctx.SessionID != "" { - contextWithSession["session_id"] = ctx.SessionID - contextWithSession["session_data"] = ctx.SessionData - } - // Set up context values for execution - if err := state.PushTable(contextWithSession); err != nil { + if err := state.PushTable(ctx.Values); err != nil { ReleaseResponse(response) return nil, err } @@ -166,8 +156,6 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (* extractHTTPResponseData(state, response) - extractSessionData(state, response) - return response, nil } @@ -229,15 +217,37 @@ func extractHTTPResponseData(state *luajit.State, response *Response) { if state.IsTable(-1) { table, err := state.ToTable(-1) if err == nil { - for k, v := range table { - response.Metadata[k] = v - } + maps.Copy(response.Metadata, table) } } state.Pop(1) + // Check session modified flag + state.GetField(-1, "session_modified") + if state.IsBoolean(-1) && state.ToBoolean(-1) { + logger.DebugCont("Found session_modified=true") + response.SessionModified = true + + // Get session data (using the new structure) + state.Pop(1) // Remove session_modified + + state.GetField(-1, "session_data") + if state.IsTable(-1) { + sessionData, err := state.ToTable(-1) + if err == nil { + for k, v := range sessionData { + response.SessionData[k] = v + } + } + } + state.Pop(1) + } else { + logger.DebugCont("session_modified is not set or not true") + } + state.Pop(1) + // Clean up - state.Pop(2) + state.Pop(1) } // extractCookie pulls cookie data from the current table on the stack @@ -298,69 +308,3 @@ func extractCookie(state *luajit.State, response *Response) { response.Cookies = append(response.Cookies, cookie) } - -// Extract session data if modified -func extractSessionData(state *luajit.State, response *Response) { - logger.Debug("extractSessionData: Starting extraction") - - // Get HTTP response table - state.GetGlobal("__http_responses") - if !state.IsTable(-1) { - logger.Debug("extractSessionData: __http_responses is not a table") - state.Pop(1) - return - } - - // Get first response - state.PushNumber(1) - state.GetTable(-2) - if !state.IsTable(-1) { - logger.Debug("extractSessionData: __http_responses[1] is not a table") - state.Pop(2) - return - } - - // Check session_modified flag - state.GetField(-1, "session_modified") - if !state.IsBoolean(-1) || !state.ToBoolean(-1) { - logger.Debug("extractSessionData: session_modified is not true") - state.Pop(3) - return - } - logger.Debug("extractSessionData: Found session_modified=true") - state.Pop(1) - - // Get session ID - state.GetField(-1, "session_id") - if state.IsString(-1) { - response.SessionID = state.ToString(-1) - logger.Debug("extractSessionData: Found session ID: %s", response.SessionID) - } else { - logger.Debug("extractSessionData: session_id not found or not a string") - } - state.Pop(1) - - // Get session data - state.GetField(-1, "session_data") - if state.IsTable(-1) { - logger.Debug("extractSessionData: Found session_data table") - sessionData, err := state.ToTable(-1) - if err == nil { - logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData)) - for k, v := range sessionData { - response.SessionData[k] = v - logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v) - } - response.SessionModified = true - } else { - logger.Debug("extractSessionData: Failed to convert session data: %v", err) - } - } else { - logger.Debug("extractSessionData: session_data not found or not a table") - } - state.Pop(1) - - // Clean up stack - state.Pop(2) - logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified) -} diff --git a/core/runner/sandbox.lua b/core/runner/sandbox.lua index b96e9ba..c7c6e09 100644 --- a/core/runner/sandbox.lua +++ b/core/runner/sandbox.lua @@ -50,18 +50,19 @@ function __execute_script(fn, ctx) env.ctx = ctx end - print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {})) - -- Initialize local session variables in the environment - env.__session_data = ctx.session_data or {} - env.__session_id = ctx.session_id - env.__session_modified = false + local sessionData = {} + local sessionId = "" - -- Add proper require function to this environment - if __setup_require then - __setup_require(env) + if ctx.session then + sessionId = ctx.session.id or "" + sessionData = ctx.session.data or {} end + env.__session_data = sessionData + env.__session_id = sessionId + env.__session_modified = false + -- Set environment for function setfenv(fn, env) @@ -75,13 +76,9 @@ function __execute_script(fn, ctx) if env.__session_modified then __http_responses[1] = __http_responses[1] or {} __http_responses[1].session_data = env.__session_data - __http_responses[1].session_id = env.__session_id __http_responses[1].session_modified = true end - print("SESSION MODIFIED:", env.__session_modified) - print("FINAL DATA:", util.json_encode(env.__session_data or {})) - return result end @@ -332,15 +329,15 @@ local session = { if type(key) ~= "string" then error("session.set: key must be a string", 2) end - + local env = getfenv(2) print("SET ENV:", tostring(env)) -- Debug the environment - + if not env.__session_data then env.__session_data = {} print("CREATED NEW SESSION TABLE") end - + env.__session_data[key] = value env.__session_modified = true print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)