work on sessions

This commit is contained in:
Sky Johnson 2025-04-10 07:51:15 -05:00
parent 35ce09d66e
commit 0abf31ed3a
5 changed files with 84 additions and 128 deletions

View File

@ -35,13 +35,21 @@ func ValidateCSRFToken(ctx *runner.Context) bool {
return false return false
} }
// Get token from session // Get session from context
sessionData := ctx.SessionData sessionMap, ok := ctx.Get("session").(map[string]any)
if sessionData == nil { if !ok || sessionMap == nil {
logger.Warning("CSRF validation failed: no session data") logger.Warning("CSRF validation failed: no session data")
return false return false
} }
// Get session data
sessionData, ok := sessionMap["data"].(map[string]any)
if !ok || sessionData == nil {
logger.Warning("CSRF validation failed: no session data map")
return false
}
// Get token from session
sessionToken, ok := sessionData["_csrf_token"].(string) sessionToken, ok := sessionData["_csrf_token"].(string)
if !ok || sessionToken == "" { if !ok || sessionToken == "" {
logger.Warning("CSRF validation failed: no token in session") logger.Warning("CSRF validation failed: no token in session")
@ -79,15 +87,41 @@ func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) {
return "", err return "", err
} }
// Get session from context
sessionMap, ok := ctx.Get("session").(map[string]any)
if !ok || sessionMap == nil {
return "", errors.New("no session found in context")
}
// Get session data
sessionData, ok := sessionMap["data"].(map[string]any)
if !ok {
// Initialize session data if it doesn't exist
sessionData = make(map[string]any)
sessionMap["data"] = sessionData
}
// Store token in session // Store token in session
ctx.SessionData["_csrf_token"] = token sessionData["_csrf_token"] = token
return token, nil return token, nil
} }
// GetCSRFToken retrieves the current CSRF token or generates a new one // GetCSRFToken retrieves the current CSRF token or generates a new one
func GetCSRFToken(ctx *runner.Context) (string, error) { func GetCSRFToken(ctx *runner.Context) (string, error) {
// Get session from context
sessionMap, ok := ctx.Get("session").(map[string]any)
if !ok || sessionMap == nil {
return "", errors.New("no session found in context")
}
// Get session data
sessionData, ok := sessionMap["data"].(map[string]any)
if !ok || sessionData == nil {
return GenerateCSRFToken(ctx, 32)
}
// Check if token already exists in session // Check if token already exists in session
if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" { if token, ok := sessionData["_csrf_token"].(string); ok && token != "" {
return token, nil return token, nil
} }

View File

@ -169,8 +169,11 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
// Initialize session // Initialize session
session := s.sessionManager.GetSessionFromRequest(ctx) session := s.sessionManager.GetSessionFromRequest(ctx)
luaCtx.SessionID = session.ID sessionMap := map[string]any{
luaCtx.SessionData = session.GetAll() "id": session.ID,
"data": session.Data,
}
luaCtx.Set("session", sessionMap)
// URL parameters // URL parameters
if params.Count > 0 { if params.Count > 0 {

View File

@ -15,10 +15,6 @@ type Context struct {
// FastHTTP context if this was created from an HTTP request // FastHTTP context if this was created from an HTTP request
RequestCtx *fasthttp.RequestCtx RequestCtx *fasthttp.RequestCtx
// Session information
SessionID string
SessionData map[string]any
// Buffer for efficient string operations // Buffer for efficient string operations
buffer *bytebufferpool.ByteBuffer buffer *bytebufferpool.ByteBuffer
} }
@ -27,8 +23,7 @@ type Context struct {
var contextPool = sync.Pool{ var contextPool = sync.Pool{
New: func() any { New: func() any {
return &Context{ return &Context{
Values: make(map[string]any, 16), Values: make(map[string]any, 32),
SessionData: make(map[string]any, 8),
} }
}, },
} }
@ -90,13 +85,6 @@ func (c *Context) Release() {
delete(c.Values, k) delete(c.Values, k)
} }
for k := range c.SessionData {
delete(c.SessionData, k)
}
// Reset session info
c.SessionID = ""
// Reset request context // Reset request context
c.RequestCtx = nil c.RequestCtx = nil
@ -126,13 +114,3 @@ func (c *Context) Set(key string, value any) {
func (c *Context) Get(key string) any { func (c *Context) Get(key string) any {
return c.Values[key] return c.Values[key]
} }
// SetSession sets a session data value
func (c *Context) SetSession(key string, value any) {
c.SessionData[key] = value
}
// GetSession retrieves a session data value
func (c *Context) GetSession(key string) any {
return c.SessionData[key]
}

View File

@ -121,18 +121,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
return nil, fmt.Errorf("failed to load script: %w", err) return nil, fmt.Errorf("failed to load script: %w", err)
} }
// Add session data to context
contextWithSession := make(map[string]any)
maps.Copy(contextWithSession, ctx.Values)
// Pass session data through context
if ctx.SessionID != "" {
contextWithSession["session_id"] = ctx.SessionID
contextWithSession["session_data"] = ctx.SessionData
}
// Set up context values for execution // Set up context values for execution
if err := state.PushTable(contextWithSession); err != nil { if err := state.PushTable(ctx.Values); err != nil {
ReleaseResponse(response) ReleaseResponse(response)
return nil, err return nil, err
} }
@ -166,8 +156,6 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
extractHTTPResponseData(state, response) extractHTTPResponseData(state, response)
extractSessionData(state, response)
return response, nil return response, nil
} }
@ -229,15 +217,37 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
if state.IsTable(-1) { if state.IsTable(-1) {
table, err := state.ToTable(-1) table, err := state.ToTable(-1)
if err == nil { if err == nil {
for k, v := range table { maps.Copy(response.Metadata, table)
response.Metadata[k] = v
}
} }
} }
state.Pop(1) state.Pop(1)
// Check session modified flag
state.GetField(-1, "session_modified")
if state.IsBoolean(-1) && state.ToBoolean(-1) {
logger.DebugCont("Found session_modified=true")
response.SessionModified = true
// Get session data (using the new structure)
state.Pop(1) // Remove session_modified
state.GetField(-1, "session_data")
if state.IsTable(-1) {
sessionData, err := state.ToTable(-1)
if err == nil {
for k, v := range sessionData {
response.SessionData[k] = v
}
}
}
state.Pop(1)
} else {
logger.DebugCont("session_modified is not set or not true")
}
state.Pop(1)
// Clean up // Clean up
state.Pop(2) state.Pop(1)
} }
// extractCookie pulls cookie data from the current table on the stack // extractCookie pulls cookie data from the current table on the stack
@ -298,69 +308,3 @@ func extractCookie(state *luajit.State, response *Response) {
response.Cookies = append(response.Cookies, cookie) response.Cookies = append(response.Cookies, cookie)
} }
// Extract session data if modified
func extractSessionData(state *luajit.State, response *Response) {
logger.Debug("extractSessionData: Starting extraction")
// Get HTTP response table
state.GetGlobal("__http_responses")
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses is not a table")
state.Pop(1)
return
}
// Get first response
state.PushNumber(1)
state.GetTable(-2)
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses[1] is not a table")
state.Pop(2)
return
}
// Check session_modified flag
state.GetField(-1, "session_modified")
if !state.IsBoolean(-1) || !state.ToBoolean(-1) {
logger.Debug("extractSessionData: session_modified is not true")
state.Pop(3)
return
}
logger.Debug("extractSessionData: Found session_modified=true")
state.Pop(1)
// Get session ID
state.GetField(-1, "session_id")
if state.IsString(-1) {
response.SessionID = state.ToString(-1)
logger.Debug("extractSessionData: Found session ID: %s", response.SessionID)
} else {
logger.Debug("extractSessionData: session_id not found or not a string")
}
state.Pop(1)
// Get session data
state.GetField(-1, "session_data")
if state.IsTable(-1) {
logger.Debug("extractSessionData: Found session_data table")
sessionData, err := state.ToTable(-1)
if err == nil {
logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData))
for k, v := range sessionData {
response.SessionData[k] = v
logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v)
}
response.SessionModified = true
} else {
logger.Debug("extractSessionData: Failed to convert session data: %v", err)
}
} else {
logger.Debug("extractSessionData: session_data not found or not a table")
}
state.Pop(1)
// Clean up stack
state.Pop(2)
logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified)
}

View File

@ -50,18 +50,19 @@ function __execute_script(fn, ctx)
env.ctx = ctx env.ctx = ctx
end end
print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {}))
-- Initialize local session variables in the environment -- Initialize local session variables in the environment
env.__session_data = ctx.session_data or {} local sessionData = {}
env.__session_id = ctx.session_id local sessionId = ""
env.__session_modified = false
-- Add proper require function to this environment if ctx.session then
if __setup_require then sessionId = ctx.session.id or ""
__setup_require(env) sessionData = ctx.session.data or {}
end end
env.__session_data = sessionData
env.__session_id = sessionId
env.__session_modified = false
-- Set environment for function -- Set environment for function
setfenv(fn, env) setfenv(fn, env)
@ -75,13 +76,9 @@ function __execute_script(fn, ctx)
if env.__session_modified then if env.__session_modified then
__http_responses[1] = __http_responses[1] or {} __http_responses[1] = __http_responses[1] or {}
__http_responses[1].session_data = env.__session_data __http_responses[1].session_data = env.__session_data
__http_responses[1].session_id = env.__session_id
__http_responses[1].session_modified = true __http_responses[1].session_modified = true
end end
print("SESSION MODIFIED:", env.__session_modified)
print("FINAL DATA:", util.json_encode(env.__session_data or {}))
return result return result
end end
@ -332,15 +329,15 @@ local session = {
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.set: key must be a string", 2) error("session.set: key must be a string", 2)
end end
local env = getfenv(2) local env = getfenv(2)
print("SET ENV:", tostring(env)) -- Debug the environment print("SET ENV:", tostring(env)) -- Debug the environment
if not env.__session_data then if not env.__session_data then
env.__session_data = {} env.__session_data = {}
print("CREATED NEW SESSION TABLE") print("CREATED NEW SESSION TABLE")
end end
env.__session_data[key] = value env.__session_data[key] = value
env.__session_modified = true env.__session_modified = true
print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified) print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)