work on sessions
This commit is contained in:
parent
35ce09d66e
commit
0abf31ed3a
|
@ -35,13 +35,21 @@ func ValidateCSRFToken(ctx *runner.Context) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// Get token from session
|
||||
sessionData := ctx.SessionData
|
||||
if sessionData == nil {
|
||||
// Get session from context
|
||||
sessionMap, ok := ctx.Get("session").(map[string]any)
|
||||
if !ok || sessionMap == nil {
|
||||
logger.Warning("CSRF validation failed: no session data")
|
||||
return false
|
||||
}
|
||||
|
||||
// Get session data
|
||||
sessionData, ok := sessionMap["data"].(map[string]any)
|
||||
if !ok || sessionData == nil {
|
||||
logger.Warning("CSRF validation failed: no session data map")
|
||||
return false
|
||||
}
|
||||
|
||||
// Get token from session
|
||||
sessionToken, ok := sessionData["_csrf_token"].(string)
|
||||
if !ok || sessionToken == "" {
|
||||
logger.Warning("CSRF validation failed: no token in session")
|
||||
|
@ -79,15 +87,41 @@ func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
// Get session from context
|
||||
sessionMap, ok := ctx.Get("session").(map[string]any)
|
||||
if !ok || sessionMap == nil {
|
||||
return "", errors.New("no session found in context")
|
||||
}
|
||||
|
||||
// Get session data
|
||||
sessionData, ok := sessionMap["data"].(map[string]any)
|
||||
if !ok {
|
||||
// Initialize session data if it doesn't exist
|
||||
sessionData = make(map[string]any)
|
||||
sessionMap["data"] = sessionData
|
||||
}
|
||||
|
||||
// Store token in session
|
||||
ctx.SessionData["_csrf_token"] = token
|
||||
sessionData["_csrf_token"] = token
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetCSRFToken retrieves the current CSRF token or generates a new one
|
||||
func GetCSRFToken(ctx *runner.Context) (string, error) {
|
||||
// Get session from context
|
||||
sessionMap, ok := ctx.Get("session").(map[string]any)
|
||||
if !ok || sessionMap == nil {
|
||||
return "", errors.New("no session found in context")
|
||||
}
|
||||
|
||||
// Get session data
|
||||
sessionData, ok := sessionMap["data"].(map[string]any)
|
||||
if !ok || sessionData == nil {
|
||||
return GenerateCSRFToken(ctx, 32)
|
||||
}
|
||||
|
||||
// Check if token already exists in session
|
||||
if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" {
|
||||
if token, ok := sessionData["_csrf_token"].(string); ok && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -169,8 +169,11 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
|||
|
||||
// Initialize session
|
||||
session := s.sessionManager.GetSessionFromRequest(ctx)
|
||||
luaCtx.SessionID = session.ID
|
||||
luaCtx.SessionData = session.GetAll()
|
||||
sessionMap := map[string]any{
|
||||
"id": session.ID,
|
||||
"data": session.Data,
|
||||
}
|
||||
luaCtx.Set("session", sessionMap)
|
||||
|
||||
// URL parameters
|
||||
if params.Count > 0 {
|
||||
|
|
|
@ -15,10 +15,6 @@ type Context struct {
|
|||
// FastHTTP context if this was created from an HTTP request
|
||||
RequestCtx *fasthttp.RequestCtx
|
||||
|
||||
// Session information
|
||||
SessionID string
|
||||
SessionData map[string]any
|
||||
|
||||
// Buffer for efficient string operations
|
||||
buffer *bytebufferpool.ByteBuffer
|
||||
}
|
||||
|
@ -27,8 +23,7 @@ type Context struct {
|
|||
var contextPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Context{
|
||||
Values: make(map[string]any, 16),
|
||||
SessionData: make(map[string]any, 8),
|
||||
Values: make(map[string]any, 32),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -90,13 +85,6 @@ func (c *Context) Release() {
|
|||
delete(c.Values, k)
|
||||
}
|
||||
|
||||
for k := range c.SessionData {
|
||||
delete(c.SessionData, k)
|
||||
}
|
||||
|
||||
// Reset session info
|
||||
c.SessionID = ""
|
||||
|
||||
// Reset request context
|
||||
c.RequestCtx = nil
|
||||
|
||||
|
@ -126,13 +114,3 @@ func (c *Context) Set(key string, value any) {
|
|||
func (c *Context) Get(key string) any {
|
||||
return c.Values[key]
|
||||
}
|
||||
|
||||
// SetSession sets a session data value
|
||||
func (c *Context) SetSession(key string, value any) {
|
||||
c.SessionData[key] = value
|
||||
}
|
||||
|
||||
// GetSession retrieves a session data value
|
||||
func (c *Context) GetSession(key string) any {
|
||||
return c.SessionData[key]
|
||||
}
|
||||
|
|
|
@ -121,18 +121,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
|
|||
return nil, fmt.Errorf("failed to load script: %w", err)
|
||||
}
|
||||
|
||||
// Add session data to context
|
||||
contextWithSession := make(map[string]any)
|
||||
maps.Copy(contextWithSession, ctx.Values)
|
||||
|
||||
// Pass session data through context
|
||||
if ctx.SessionID != "" {
|
||||
contextWithSession["session_id"] = ctx.SessionID
|
||||
contextWithSession["session_data"] = ctx.SessionData
|
||||
}
|
||||
|
||||
// Set up context values for execution
|
||||
if err := state.PushTable(contextWithSession); err != nil {
|
||||
if err := state.PushTable(ctx.Values); err != nil {
|
||||
ReleaseResponse(response)
|
||||
return nil, err
|
||||
}
|
||||
|
@ -166,8 +156,6 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
|
|||
|
||||
extractHTTPResponseData(state, response)
|
||||
|
||||
extractSessionData(state, response)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
|
@ -229,15 +217,37 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
|
|||
if state.IsTable(-1) {
|
||||
table, err := state.ToTable(-1)
|
||||
if err == nil {
|
||||
for k, v := range table {
|
||||
response.Metadata[k] = v
|
||||
}
|
||||
maps.Copy(response.Metadata, table)
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Check session modified flag
|
||||
state.GetField(-1, "session_modified")
|
||||
if state.IsBoolean(-1) && state.ToBoolean(-1) {
|
||||
logger.DebugCont("Found session_modified=true")
|
||||
response.SessionModified = true
|
||||
|
||||
// Get session data (using the new structure)
|
||||
state.Pop(1) // Remove session_modified
|
||||
|
||||
state.GetField(-1, "session_data")
|
||||
if state.IsTable(-1) {
|
||||
sessionData, err := state.ToTable(-1)
|
||||
if err == nil {
|
||||
for k, v := range sessionData {
|
||||
response.SessionData[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
} else {
|
||||
logger.DebugCont("session_modified is not set or not true")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Clean up
|
||||
state.Pop(2)
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
// extractCookie pulls cookie data from the current table on the stack
|
||||
|
@ -298,69 +308,3 @@ func extractCookie(state *luajit.State, response *Response) {
|
|||
|
||||
response.Cookies = append(response.Cookies, cookie)
|
||||
}
|
||||
|
||||
// Extract session data if modified
|
||||
func extractSessionData(state *luajit.State, response *Response) {
|
||||
logger.Debug("extractSessionData: Starting extraction")
|
||||
|
||||
// Get HTTP response table
|
||||
state.GetGlobal("__http_responses")
|
||||
if !state.IsTable(-1) {
|
||||
logger.Debug("extractSessionData: __http_responses is not a table")
|
||||
state.Pop(1)
|
||||
return
|
||||
}
|
||||
|
||||
// Get first response
|
||||
state.PushNumber(1)
|
||||
state.GetTable(-2)
|
||||
if !state.IsTable(-1) {
|
||||
logger.Debug("extractSessionData: __http_responses[1] is not a table")
|
||||
state.Pop(2)
|
||||
return
|
||||
}
|
||||
|
||||
// Check session_modified flag
|
||||
state.GetField(-1, "session_modified")
|
||||
if !state.IsBoolean(-1) || !state.ToBoolean(-1) {
|
||||
logger.Debug("extractSessionData: session_modified is not true")
|
||||
state.Pop(3)
|
||||
return
|
||||
}
|
||||
logger.Debug("extractSessionData: Found session_modified=true")
|
||||
state.Pop(1)
|
||||
|
||||
// Get session ID
|
||||
state.GetField(-1, "session_id")
|
||||
if state.IsString(-1) {
|
||||
response.SessionID = state.ToString(-1)
|
||||
logger.Debug("extractSessionData: Found session ID: %s", response.SessionID)
|
||||
} else {
|
||||
logger.Debug("extractSessionData: session_id not found or not a string")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get session data
|
||||
state.GetField(-1, "session_data")
|
||||
if state.IsTable(-1) {
|
||||
logger.Debug("extractSessionData: Found session_data table")
|
||||
sessionData, err := state.ToTable(-1)
|
||||
if err == nil {
|
||||
logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData))
|
||||
for k, v := range sessionData {
|
||||
response.SessionData[k] = v
|
||||
logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v)
|
||||
}
|
||||
response.SessionModified = true
|
||||
} else {
|
||||
logger.Debug("extractSessionData: Failed to convert session data: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("extractSessionData: session_data not found or not a table")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Clean up stack
|
||||
state.Pop(2)
|
||||
logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified)
|
||||
}
|
||||
|
|
|
@ -50,18 +50,19 @@ function __execute_script(fn, ctx)
|
|||
env.ctx = ctx
|
||||
end
|
||||
|
||||
print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {}))
|
||||
|
||||
-- Initialize local session variables in the environment
|
||||
env.__session_data = ctx.session_data or {}
|
||||
env.__session_id = ctx.session_id
|
||||
env.__session_modified = false
|
||||
local sessionData = {}
|
||||
local sessionId = ""
|
||||
|
||||
-- Add proper require function to this environment
|
||||
if __setup_require then
|
||||
__setup_require(env)
|
||||
if ctx.session then
|
||||
sessionId = ctx.session.id or ""
|
||||
sessionData = ctx.session.data or {}
|
||||
end
|
||||
|
||||
env.__session_data = sessionData
|
||||
env.__session_id = sessionId
|
||||
env.__session_modified = false
|
||||
|
||||
-- Set environment for function
|
||||
setfenv(fn, env)
|
||||
|
||||
|
@ -75,13 +76,9 @@ function __execute_script(fn, ctx)
|
|||
if env.__session_modified then
|
||||
__http_responses[1] = __http_responses[1] or {}
|
||||
__http_responses[1].session_data = env.__session_data
|
||||
__http_responses[1].session_id = env.__session_id
|
||||
__http_responses[1].session_modified = true
|
||||
end
|
||||
|
||||
print("SESSION MODIFIED:", env.__session_modified)
|
||||
print("FINAL DATA:", util.json_encode(env.__session_data or {}))
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
|
@ -332,15 +329,15 @@ local session = {
|
|||
if type(key) ~= "string" then
|
||||
error("session.set: key must be a string", 2)
|
||||
end
|
||||
|
||||
|
||||
local env = getfenv(2)
|
||||
print("SET ENV:", tostring(env)) -- Debug the environment
|
||||
|
||||
|
||||
if not env.__session_data then
|
||||
env.__session_data = {}
|
||||
print("CREATED NEW SESSION TABLE")
|
||||
end
|
||||
|
||||
|
||||
env.__session_data[key] = value
|
||||
env.__session_modified = true
|
||||
print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)
|
||||
|
|
Loading…
Reference in New Issue
Block a user