work on sessions

This commit is contained in:
Sky Johnson 2025-04-10 07:51:15 -05:00
parent 35ce09d66e
commit 0abf31ed3a
5 changed files with 84 additions and 128 deletions

View File

@ -35,13 +35,21 @@ func ValidateCSRFToken(ctx *runner.Context) bool {
return false
}
// Get token from session
sessionData := ctx.SessionData
if sessionData == nil {
// Get session from context
sessionMap, ok := ctx.Get("session").(map[string]any)
if !ok || sessionMap == nil {
logger.Warning("CSRF validation failed: no session data")
return false
}
// Get session data
sessionData, ok := sessionMap["data"].(map[string]any)
if !ok || sessionData == nil {
logger.Warning("CSRF validation failed: no session data map")
return false
}
// Get token from session
sessionToken, ok := sessionData["_csrf_token"].(string)
if !ok || sessionToken == "" {
logger.Warning("CSRF validation failed: no token in session")
@ -79,15 +87,41 @@ func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) {
return "", err
}
// Get session from context
sessionMap, ok := ctx.Get("session").(map[string]any)
if !ok || sessionMap == nil {
return "", errors.New("no session found in context")
}
// Get session data
sessionData, ok := sessionMap["data"].(map[string]any)
if !ok {
// Initialize session data if it doesn't exist
sessionData = make(map[string]any)
sessionMap["data"] = sessionData
}
// Store token in session
ctx.SessionData["_csrf_token"] = token
sessionData["_csrf_token"] = token
return token, nil
}
// GetCSRFToken retrieves the current CSRF token or generates a new one
func GetCSRFToken(ctx *runner.Context) (string, error) {
// Get session from context
sessionMap, ok := ctx.Get("session").(map[string]any)
if !ok || sessionMap == nil {
return "", errors.New("no session found in context")
}
// Get session data
sessionData, ok := sessionMap["data"].(map[string]any)
if !ok || sessionData == nil {
return GenerateCSRFToken(ctx, 32)
}
// Check if token already exists in session
if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" {
if token, ok := sessionData["_csrf_token"].(string); ok && token != "" {
return token, nil
}

View File

@ -169,8 +169,11 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
// Initialize session
session := s.sessionManager.GetSessionFromRequest(ctx)
luaCtx.SessionID = session.ID
luaCtx.SessionData = session.GetAll()
sessionMap := map[string]any{
"id": session.ID,
"data": session.Data,
}
luaCtx.Set("session", sessionMap)
// URL parameters
if params.Count > 0 {

View File

@ -15,10 +15,6 @@ type Context struct {
// FastHTTP context if this was created from an HTTP request
RequestCtx *fasthttp.RequestCtx
// Session information
SessionID string
SessionData map[string]any
// Buffer for efficient string operations
buffer *bytebufferpool.ByteBuffer
}
@ -27,8 +23,7 @@ type Context struct {
var contextPool = sync.Pool{
New: func() any {
return &Context{
Values: make(map[string]any, 16),
SessionData: make(map[string]any, 8),
Values: make(map[string]any, 32),
}
},
}
@ -90,13 +85,6 @@ func (c *Context) Release() {
delete(c.Values, k)
}
for k := range c.SessionData {
delete(c.SessionData, k)
}
// Reset session info
c.SessionID = ""
// Reset request context
c.RequestCtx = nil
@ -126,13 +114,3 @@ func (c *Context) Set(key string, value any) {
func (c *Context) Get(key string) any {
return c.Values[key]
}
// SetSession sets a session data value
func (c *Context) SetSession(key string, value any) {
c.SessionData[key] = value
}
// GetSession retrieves a session data value
func (c *Context) GetSession(key string) any {
return c.SessionData[key]
}

View File

@ -121,18 +121,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
return nil, fmt.Errorf("failed to load script: %w", err)
}
// Add session data to context
contextWithSession := make(map[string]any)
maps.Copy(contextWithSession, ctx.Values)
// Pass session data through context
if ctx.SessionID != "" {
contextWithSession["session_id"] = ctx.SessionID
contextWithSession["session_data"] = ctx.SessionData
}
// Set up context values for execution
if err := state.PushTable(contextWithSession); err != nil {
if err := state.PushTable(ctx.Values); err != nil {
ReleaseResponse(response)
return nil, err
}
@ -166,8 +156,6 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
extractHTTPResponseData(state, response)
extractSessionData(state, response)
return response, nil
}
@ -229,15 +217,37 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
if state.IsTable(-1) {
table, err := state.ToTable(-1)
if err == nil {
for k, v := range table {
response.Metadata[k] = v
}
maps.Copy(response.Metadata, table)
}
}
state.Pop(1)
// Check session modified flag
state.GetField(-1, "session_modified")
if state.IsBoolean(-1) && state.ToBoolean(-1) {
logger.DebugCont("Found session_modified=true")
response.SessionModified = true
// Get session data (using the new structure)
state.Pop(1) // Remove session_modified
state.GetField(-1, "session_data")
if state.IsTable(-1) {
sessionData, err := state.ToTable(-1)
if err == nil {
for k, v := range sessionData {
response.SessionData[k] = v
}
}
}
state.Pop(1)
} else {
logger.DebugCont("session_modified is not set or not true")
}
state.Pop(1)
// Clean up
state.Pop(2)
state.Pop(1)
}
// extractCookie pulls cookie data from the current table on the stack
@ -298,69 +308,3 @@ func extractCookie(state *luajit.State, response *Response) {
response.Cookies = append(response.Cookies, cookie)
}
// Extract session data if modified
func extractSessionData(state *luajit.State, response *Response) {
logger.Debug("extractSessionData: Starting extraction")
// Get HTTP response table
state.GetGlobal("__http_responses")
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses is not a table")
state.Pop(1)
return
}
// Get first response
state.PushNumber(1)
state.GetTable(-2)
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses[1] is not a table")
state.Pop(2)
return
}
// Check session_modified flag
state.GetField(-1, "session_modified")
if !state.IsBoolean(-1) || !state.ToBoolean(-1) {
logger.Debug("extractSessionData: session_modified is not true")
state.Pop(3)
return
}
logger.Debug("extractSessionData: Found session_modified=true")
state.Pop(1)
// Get session ID
state.GetField(-1, "session_id")
if state.IsString(-1) {
response.SessionID = state.ToString(-1)
logger.Debug("extractSessionData: Found session ID: %s", response.SessionID)
} else {
logger.Debug("extractSessionData: session_id not found or not a string")
}
state.Pop(1)
// Get session data
state.GetField(-1, "session_data")
if state.IsTable(-1) {
logger.Debug("extractSessionData: Found session_data table")
sessionData, err := state.ToTable(-1)
if err == nil {
logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData))
for k, v := range sessionData {
response.SessionData[k] = v
logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v)
}
response.SessionModified = true
} else {
logger.Debug("extractSessionData: Failed to convert session data: %v", err)
}
} else {
logger.Debug("extractSessionData: session_data not found or not a table")
}
state.Pop(1)
// Clean up stack
state.Pop(2)
logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified)
}

View File

@ -50,18 +50,19 @@ function __execute_script(fn, ctx)
env.ctx = ctx
end
print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {}))
-- Initialize local session variables in the environment
env.__session_data = ctx.session_data or {}
env.__session_id = ctx.session_id
env.__session_modified = false
local sessionData = {}
local sessionId = ""
-- Add proper require function to this environment
if __setup_require then
__setup_require(env)
if ctx.session then
sessionId = ctx.session.id or ""
sessionData = ctx.session.data or {}
end
env.__session_data = sessionData
env.__session_id = sessionId
env.__session_modified = false
-- Set environment for function
setfenv(fn, env)
@ -75,13 +76,9 @@ function __execute_script(fn, ctx)
if env.__session_modified then
__http_responses[1] = __http_responses[1] or {}
__http_responses[1].session_data = env.__session_data
__http_responses[1].session_id = env.__session_id
__http_responses[1].session_modified = true
end
print("SESSION MODIFIED:", env.__session_modified)
print("FINAL DATA:", util.json_encode(env.__session_data or {}))
return result
end