work on sessions
This commit is contained in:
parent
35ce09d66e
commit
0abf31ed3a
|
@ -35,13 +35,21 @@ func ValidateCSRFToken(ctx *runner.Context) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get token from session
|
// Get session from context
|
||||||
sessionData := ctx.SessionData
|
sessionMap, ok := ctx.Get("session").(map[string]any)
|
||||||
if sessionData == nil {
|
if !ok || sessionMap == nil {
|
||||||
logger.Warning("CSRF validation failed: no session data")
|
logger.Warning("CSRF validation failed: no session data")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get session data
|
||||||
|
sessionData, ok := sessionMap["data"].(map[string]any)
|
||||||
|
if !ok || sessionData == nil {
|
||||||
|
logger.Warning("CSRF validation failed: no session data map")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get token from session
|
||||||
sessionToken, ok := sessionData["_csrf_token"].(string)
|
sessionToken, ok := sessionData["_csrf_token"].(string)
|
||||||
if !ok || sessionToken == "" {
|
if !ok || sessionToken == "" {
|
||||||
logger.Warning("CSRF validation failed: no token in session")
|
logger.Warning("CSRF validation failed: no token in session")
|
||||||
|
@ -79,15 +87,41 @@ func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get session from context
|
||||||
|
sessionMap, ok := ctx.Get("session").(map[string]any)
|
||||||
|
if !ok || sessionMap == nil {
|
||||||
|
return "", errors.New("no session found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get session data
|
||||||
|
sessionData, ok := sessionMap["data"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
// Initialize session data if it doesn't exist
|
||||||
|
sessionData = make(map[string]any)
|
||||||
|
sessionMap["data"] = sessionData
|
||||||
|
}
|
||||||
|
|
||||||
// Store token in session
|
// Store token in session
|
||||||
ctx.SessionData["_csrf_token"] = token
|
sessionData["_csrf_token"] = token
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCSRFToken retrieves the current CSRF token or generates a new one
|
// GetCSRFToken retrieves the current CSRF token or generates a new one
|
||||||
func GetCSRFToken(ctx *runner.Context) (string, error) {
|
func GetCSRFToken(ctx *runner.Context) (string, error) {
|
||||||
|
// Get session from context
|
||||||
|
sessionMap, ok := ctx.Get("session").(map[string]any)
|
||||||
|
if !ok || sessionMap == nil {
|
||||||
|
return "", errors.New("no session found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get session data
|
||||||
|
sessionData, ok := sessionMap["data"].(map[string]any)
|
||||||
|
if !ok || sessionData == nil {
|
||||||
|
return GenerateCSRFToken(ctx, 32)
|
||||||
|
}
|
||||||
|
|
||||||
// Check if token already exists in session
|
// Check if token already exists in session
|
||||||
if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" {
|
if token, ok := sessionData["_csrf_token"].(string); ok && token != "" {
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -169,8 +169,11 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
||||||
|
|
||||||
// Initialize session
|
// Initialize session
|
||||||
session := s.sessionManager.GetSessionFromRequest(ctx)
|
session := s.sessionManager.GetSessionFromRequest(ctx)
|
||||||
luaCtx.SessionID = session.ID
|
sessionMap := map[string]any{
|
||||||
luaCtx.SessionData = session.GetAll()
|
"id": session.ID,
|
||||||
|
"data": session.Data,
|
||||||
|
}
|
||||||
|
luaCtx.Set("session", sessionMap)
|
||||||
|
|
||||||
// URL parameters
|
// URL parameters
|
||||||
if params.Count > 0 {
|
if params.Count > 0 {
|
||||||
|
|
|
@ -15,10 +15,6 @@ type Context struct {
|
||||||
// FastHTTP context if this was created from an HTTP request
|
// FastHTTP context if this was created from an HTTP request
|
||||||
RequestCtx *fasthttp.RequestCtx
|
RequestCtx *fasthttp.RequestCtx
|
||||||
|
|
||||||
// Session information
|
|
||||||
SessionID string
|
|
||||||
SessionData map[string]any
|
|
||||||
|
|
||||||
// Buffer for efficient string operations
|
// Buffer for efficient string operations
|
||||||
buffer *bytebufferpool.ByteBuffer
|
buffer *bytebufferpool.ByteBuffer
|
||||||
}
|
}
|
||||||
|
@ -27,8 +23,7 @@ type Context struct {
|
||||||
var contextPool = sync.Pool{
|
var contextPool = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
return &Context{
|
return &Context{
|
||||||
Values: make(map[string]any, 16),
|
Values: make(map[string]any, 32),
|
||||||
SessionData: make(map[string]any, 8),
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -90,13 +85,6 @@ func (c *Context) Release() {
|
||||||
delete(c.Values, k)
|
delete(c.Values, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
for k := range c.SessionData {
|
|
||||||
delete(c.SessionData, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset session info
|
|
||||||
c.SessionID = ""
|
|
||||||
|
|
||||||
// Reset request context
|
// Reset request context
|
||||||
c.RequestCtx = nil
|
c.RequestCtx = nil
|
||||||
|
|
||||||
|
@ -126,13 +114,3 @@ func (c *Context) Set(key string, value any) {
|
||||||
func (c *Context) Get(key string) any {
|
func (c *Context) Get(key string) any {
|
||||||
return c.Values[key]
|
return c.Values[key]
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSession sets a session data value
|
|
||||||
func (c *Context) SetSession(key string, value any) {
|
|
||||||
c.SessionData[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSession retrieves a session data value
|
|
||||||
func (c *Context) GetSession(key string) any {
|
|
||||||
return c.SessionData[key]
|
|
||||||
}
|
|
||||||
|
|
|
@ -121,18 +121,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
|
||||||
return nil, fmt.Errorf("failed to load script: %w", err)
|
return nil, fmt.Errorf("failed to load script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add session data to context
|
|
||||||
contextWithSession := make(map[string]any)
|
|
||||||
maps.Copy(contextWithSession, ctx.Values)
|
|
||||||
|
|
||||||
// Pass session data through context
|
|
||||||
if ctx.SessionID != "" {
|
|
||||||
contextWithSession["session_id"] = ctx.SessionID
|
|
||||||
contextWithSession["session_data"] = ctx.SessionData
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up context values for execution
|
// Set up context values for execution
|
||||||
if err := state.PushTable(contextWithSession); err != nil {
|
if err := state.PushTable(ctx.Values); err != nil {
|
||||||
ReleaseResponse(response)
|
ReleaseResponse(response)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -166,8 +156,6 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
|
||||||
|
|
||||||
extractHTTPResponseData(state, response)
|
extractHTTPResponseData(state, response)
|
||||||
|
|
||||||
extractSessionData(state, response)
|
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,15 +217,37 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
|
||||||
if state.IsTable(-1) {
|
if state.IsTable(-1) {
|
||||||
table, err := state.ToTable(-1)
|
table, err := state.ToTable(-1)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for k, v := range table {
|
maps.Copy(response.Metadata, table)
|
||||||
response.Metadata[k] = v
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Check session modified flag
|
||||||
|
state.GetField(-1, "session_modified")
|
||||||
|
if state.IsBoolean(-1) && state.ToBoolean(-1) {
|
||||||
|
logger.DebugCont("Found session_modified=true")
|
||||||
|
response.SessionModified = true
|
||||||
|
|
||||||
|
// Get session data (using the new structure)
|
||||||
|
state.Pop(1) // Remove session_modified
|
||||||
|
|
||||||
|
state.GetField(-1, "session_data")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
sessionData, err := state.ToTable(-1)
|
||||||
|
if err == nil {
|
||||||
|
for k, v := range sessionData {
|
||||||
|
response.SessionData[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
} else {
|
||||||
|
logger.DebugCont("session_modified is not set or not true")
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
// Clean up
|
// Clean up
|
||||||
state.Pop(2)
|
state.Pop(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractCookie pulls cookie data from the current table on the stack
|
// extractCookie pulls cookie data from the current table on the stack
|
||||||
|
@ -298,69 +308,3 @@ func extractCookie(state *luajit.State, response *Response) {
|
||||||
|
|
||||||
response.Cookies = append(response.Cookies, cookie)
|
response.Cookies = append(response.Cookies, cookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract session data if modified
|
|
||||||
func extractSessionData(state *luajit.State, response *Response) {
|
|
||||||
logger.Debug("extractSessionData: Starting extraction")
|
|
||||||
|
|
||||||
// Get HTTP response table
|
|
||||||
state.GetGlobal("__http_responses")
|
|
||||||
if !state.IsTable(-1) {
|
|
||||||
logger.Debug("extractSessionData: __http_responses is not a table")
|
|
||||||
state.Pop(1)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get first response
|
|
||||||
state.PushNumber(1)
|
|
||||||
state.GetTable(-2)
|
|
||||||
if !state.IsTable(-1) {
|
|
||||||
logger.Debug("extractSessionData: __http_responses[1] is not a table")
|
|
||||||
state.Pop(2)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check session_modified flag
|
|
||||||
state.GetField(-1, "session_modified")
|
|
||||||
if !state.IsBoolean(-1) || !state.ToBoolean(-1) {
|
|
||||||
logger.Debug("extractSessionData: session_modified is not true")
|
|
||||||
state.Pop(3)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Debug("extractSessionData: Found session_modified=true")
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get session ID
|
|
||||||
state.GetField(-1, "session_id")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
response.SessionID = state.ToString(-1)
|
|
||||||
logger.Debug("extractSessionData: Found session ID: %s", response.SessionID)
|
|
||||||
} else {
|
|
||||||
logger.Debug("extractSessionData: session_id not found or not a string")
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get session data
|
|
||||||
state.GetField(-1, "session_data")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
logger.Debug("extractSessionData: Found session_data table")
|
|
||||||
sessionData, err := state.ToTable(-1)
|
|
||||||
if err == nil {
|
|
||||||
logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData))
|
|
||||||
for k, v := range sessionData {
|
|
||||||
response.SessionData[k] = v
|
|
||||||
logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v)
|
|
||||||
}
|
|
||||||
response.SessionModified = true
|
|
||||||
} else {
|
|
||||||
logger.Debug("extractSessionData: Failed to convert session data: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Debug("extractSessionData: session_data not found or not a table")
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Clean up stack
|
|
||||||
state.Pop(2)
|
|
||||||
logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified)
|
|
||||||
}
|
|
||||||
|
|
|
@ -50,18 +50,19 @@ function __execute_script(fn, ctx)
|
||||||
env.ctx = ctx
|
env.ctx = ctx
|
||||||
end
|
end
|
||||||
|
|
||||||
print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {}))
|
|
||||||
|
|
||||||
-- Initialize local session variables in the environment
|
-- Initialize local session variables in the environment
|
||||||
env.__session_data = ctx.session_data or {}
|
local sessionData = {}
|
||||||
env.__session_id = ctx.session_id
|
local sessionId = ""
|
||||||
env.__session_modified = false
|
|
||||||
|
|
||||||
-- Add proper require function to this environment
|
if ctx.session then
|
||||||
if __setup_require then
|
sessionId = ctx.session.id or ""
|
||||||
__setup_require(env)
|
sessionData = ctx.session.data or {}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
env.__session_data = sessionData
|
||||||
|
env.__session_id = sessionId
|
||||||
|
env.__session_modified = false
|
||||||
|
|
||||||
-- Set environment for function
|
-- Set environment for function
|
||||||
setfenv(fn, env)
|
setfenv(fn, env)
|
||||||
|
|
||||||
|
@ -75,13 +76,9 @@ function __execute_script(fn, ctx)
|
||||||
if env.__session_modified then
|
if env.__session_modified then
|
||||||
__http_responses[1] = __http_responses[1] or {}
|
__http_responses[1] = __http_responses[1] or {}
|
||||||
__http_responses[1].session_data = env.__session_data
|
__http_responses[1].session_data = env.__session_data
|
||||||
__http_responses[1].session_id = env.__session_id
|
|
||||||
__http_responses[1].session_modified = true
|
__http_responses[1].session_modified = true
|
||||||
end
|
end
|
||||||
|
|
||||||
print("SESSION MODIFIED:", env.__session_modified)
|
|
||||||
print("FINAL DATA:", util.json_encode(env.__session_data or {}))
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -332,15 +329,15 @@ local session = {
|
||||||
if type(key) ~= "string" then
|
if type(key) ~= "string" then
|
||||||
error("session.set: key must be a string", 2)
|
error("session.set: key must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local env = getfenv(2)
|
local env = getfenv(2)
|
||||||
print("SET ENV:", tostring(env)) -- Debug the environment
|
print("SET ENV:", tostring(env)) -- Debug the environment
|
||||||
|
|
||||||
if not env.__session_data then
|
if not env.__session_data then
|
||||||
env.__session_data = {}
|
env.__session_data = {}
|
||||||
print("CREATED NEW SESSION TABLE")
|
print("CREATED NEW SESSION TABLE")
|
||||||
end
|
end
|
||||||
|
|
||||||
env.__session_data[key] = value
|
env.__session_data[key] = value
|
||||||
env.__session_modified = true
|
env.__session_modified = true
|
||||||
print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)
|
print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user