work on sessions

This commit is contained in:
Sky Johnson 2025-04-09 23:12:23 -05:00
parent ac991f40a0
commit 35ce09d66e
2 changed files with 203 additions and 160 deletions

View File

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"Moonshark/core/utils/logger" "Moonshark/core/utils/logger"
@ -113,78 +112,48 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
// Execute runs a Lua script in the sandbox with the given context // Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
s.debugLog("Executing script...")
// Create a response object // Create a response object
response := NewResponse() response := NewResponse()
// Get a buffer for string operations
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
// Load bytecode // Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil { if err := state.LoadBytecode(bytecode, "script"); err != nil {
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("Failed to load bytecode: %v", err)
return nil, fmt.Errorf("failed to load script: %w", err) return nil, fmt.Errorf("failed to load script: %w", err)
} }
// Initialize session data in Lua // Add session data to context
contextWithSession := make(map[string]any)
maps.Copy(contextWithSession, ctx.Values)
// Pass session data through context
if ctx.SessionID != "" { if ctx.SessionID != "" {
// Set session ID contextWithSession["session_id"] = ctx.SessionID
state.PushString(ctx.SessionID) contextWithSession["session_data"] = ctx.SessionData
state.SetGlobal("__session_id")
// Set session data
if err := state.PushTable(ctx.SessionData); err != nil {
ReleaseResponse(response)
s.debugLog("Failed to push session data: %v", err)
return nil, err
}
state.SetGlobal("__session_data")
// Reset modification flag and tracking
state.PushBoolean(false)
state.SetGlobal("__session_modified")
// Create empty modified keys table
state.NewTable()
state.SetGlobal("__session_modified_keys")
} else {
// Initialize empty session
if err := state.DoString("__session_data = {}; __session_modified = false; __session_modified_keys = {}"); err != nil {
s.debugLog("Failed to initialize empty session data: %v", err)
}
} }
// Set up context values for execution // Set up context values for execution
if err := state.PushTable(ctx.Values); err != nil { if err := state.PushTable(contextWithSession); err != nil {
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("Failed to push context values: %v", err)
return nil, err return nil, err
} }
// Get the execution function // Get the execution function
state.GetGlobal("__execute_script") state.GetGlobal("__execute_script")
if !state.IsFunction(-1) { if !state.IsFunction(-1) {
state.Pop(1) // Pop non-function state.Pop(1)
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("__execute_script is not a function")
return nil, ErrSandboxNotInitialized return nil, ErrSandboxNotInitialized
} }
// Push function and context to stack // Push function and bytecode
state.PushCopy(-2) // bytecode state.PushCopy(-2) // Bytecode
state.PushCopy(-2) // context state.PushCopy(-2) // Context
state.Remove(-4) // Remove bytecode duplicate
// Remove duplicates state.Remove(-3) // Remove context duplicate
state.Remove(-4)
state.Remove(-3)
// Execute with 2 args, 1 result // Execute with 2 args, 1 result
if err := state.Call(2, 1); err != nil { if err := state.Call(2, 1); err != nil {
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("Execution failed: %v", err)
return nil, fmt.Errorf("script execution failed: %w", err) return nil, fmt.Errorf("script execution failed: %w", err)
} }
@ -195,102 +164,84 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
} }
state.Pop(1) state.Pop(1)
// Extract HTTP response data from Lua state extractHTTPResponseData(state, response)
s.extractResponseData(state, response)
extractSessionData(state, response)
return response, nil return response, nil
} }
// extractResponseData pulls response info from the Lua state // extractResponseData pulls response info from the Lua state
func (s *Sandbox) extractResponseData(state *luajit.State, response *Response) { func extractHTTPResponseData(state *luajit.State, response *Response) {
// Get HTTP response
state.GetGlobal("__http_responses") state.GetGlobal("__http_responses")
if !state.IsNil(-1) && state.IsTable(-1) { if !state.IsTable(-1) {
state.PushNumber(1)
state.GetTable(-2)
if !state.IsNil(-1) && state.IsTable(-1) {
// Extract status
state.GetField(-1, "status")
if state.IsNumber(-1) {
response.Status = int(state.ToNumber(-1))
}
state.Pop(1)
// Extract headers
state.GetField(-1, "headers")
if state.IsTable(-1) {
state.PushNil() // Start iteration
for state.Next(-2) {
if state.IsString(-2) && state.IsString(-1) {
key := state.ToString(-2)
value := state.ToString(-1)
response.Headers[key] = value
}
state.Pop(1)
}
}
state.Pop(1)
// Extract cookies
state.GetField(-1, "cookies")
if state.IsTable(-1) {
length := state.GetTableLength(-1)
for i := 1; i <= length; i++ {
state.PushNumber(float64(i))
state.GetTable(-2)
if state.IsTable(-1) {
s.extractCookie(state, response)
}
state.Pop(1)
}
}
state.Pop(1)
// Extract metadata if present
state.GetField(-1, "metadata")
if state.IsTable(-1) {
table, err := state.ToTable(-1)
if err == nil {
for k, v := range table {
response.Metadata[k] = v
}
}
}
state.Pop(1)
}
state.Pop(1) state.Pop(1)
return
}
state.PushNumber(1)
state.GetTable(-2)
if !state.IsTable(-1) {
state.Pop(2)
return
}
// Extract status
state.GetField(-1, "status")
if state.IsNumber(-1) {
response.Status = int(state.ToNumber(-1))
} }
state.Pop(1) state.Pop(1)
// Extract session data // Extract headers
state.GetGlobal("__session_modified") state.GetField(-1, "headers")
if state.IsBoolean(-1) && state.ToBoolean(-1) { if state.IsTable(-1) {
response.SessionModified = true state.PushNil() // Start iteration
for state.Next(-2) {
// Get session ID if state.IsString(-2) && state.IsString(-1) {
state.GetGlobal("__session_id") key := state.ToString(-2)
if state.IsString(-1) { value := state.ToString(-1)
response.SessionID = state.ToString(-1) response.Headers[key] = value
}
state.Pop(1)
// Get session data
state.GetGlobal("__session_data")
if state.IsTable(-1) {
sessionData, err := state.ToTable(-1)
if err == nil {
maps.Copy(response.SessionData, sessionData)
} }
state.Pop(1)
} }
state.Pop(1)
} }
state.Pop(1) state.Pop(1)
// Extract cookies
state.GetField(-1, "cookies")
if state.IsTable(-1) {
length := state.GetTableLength(-1)
for i := 1; i <= length; i++ {
state.PushNumber(float64(i))
state.GetTable(-2)
if state.IsTable(-1) {
extractCookie(state, response)
}
state.Pop(1)
}
}
state.Pop(1)
// Extract metadata
state.GetField(-1, "metadata")
if state.IsTable(-1) {
table, err := state.ToTable(-1)
if err == nil {
for k, v := range table {
response.Metadata[k] = v
}
}
}
state.Pop(1)
// Clean up
state.Pop(2)
} }
// extractCookie pulls cookie data from the current table on the stack // extractCookie pulls cookie data from the current table on the stack
func (s *Sandbox) extractCookie(state *luajit.State, response *Response) { func extractCookie(state *luajit.State, response *Response) {
cookie := fasthttp.AcquireCookie() cookie := fasthttp.AcquireCookie()
// Get name (required) // Get name (required)
@ -347,3 +298,69 @@ func (s *Sandbox) extractCookie(state *luajit.State, response *Response) {
response.Cookies = append(response.Cookies, cookie) response.Cookies = append(response.Cookies, cookie)
} }
// Extract session data if modified
func extractSessionData(state *luajit.State, response *Response) {
logger.Debug("extractSessionData: Starting extraction")
// Get HTTP response table
state.GetGlobal("__http_responses")
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses is not a table")
state.Pop(1)
return
}
// Get first response
state.PushNumber(1)
state.GetTable(-2)
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses[1] is not a table")
state.Pop(2)
return
}
// Check session_modified flag
state.GetField(-1, "session_modified")
if !state.IsBoolean(-1) || !state.ToBoolean(-1) {
logger.Debug("extractSessionData: session_modified is not true")
state.Pop(3)
return
}
logger.Debug("extractSessionData: Found session_modified=true")
state.Pop(1)
// Get session ID
state.GetField(-1, "session_id")
if state.IsString(-1) {
response.SessionID = state.ToString(-1)
logger.Debug("extractSessionData: Found session ID: %s", response.SessionID)
} else {
logger.Debug("extractSessionData: session_id not found or not a string")
}
state.Pop(1)
// Get session data
state.GetField(-1, "session_data")
if state.IsTable(-1) {
logger.Debug("extractSessionData: Found session_data table")
sessionData, err := state.ToTable(-1)
if err == nil {
logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData))
for k, v := range sessionData {
response.SessionData[k] = v
logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v)
}
response.SessionModified = true
} else {
logger.Debug("extractSessionData: Failed to convert session data: %v", err)
}
} else {
logger.Debug("extractSessionData: session_data not found or not a table")
}
state.Pop(1)
// Clean up stack
state.Pop(2)
logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified)
}

View File

@ -14,7 +14,6 @@ __ready_modules = {}
__session_data = {} __session_data = {}
__session_id = nil __session_id = nil
__session_modified = false __session_modified = false
__session_modified_keys = {}
-- ====================================================================== -- ======================================================================
-- CORE SANDBOX FUNCTIONALITY -- CORE SANDBOX FUNCTIONALITY
@ -43,11 +42,25 @@ function __execute_script(fn, ctx)
-- Clear previous responses -- Clear previous responses
__http_responses[1] = nil __http_responses[1] = nil
-- Reset session modification flag -- Create environment with metatable inheriting from _G
__session_modified = false local env = setmetatable({}, {__index = _G})
-- Create environment -- Add context if provided
local env = __create_env(ctx) if ctx then
env.ctx = ctx
end
print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {}))
-- Initialize local session variables in the environment
env.__session_data = ctx.session_data or {}
env.__session_id = ctx.session_id
env.__session_modified = false
-- Add proper require function to this environment
if __setup_require then
__setup_require(env)
end
-- Set environment for function -- Set environment for function
setfenv(fn, env) setfenv(fn, env)
@ -58,6 +71,17 @@ function __execute_script(fn, ctx)
error(result, 0) error(result, 0)
end end
-- If session was modified, add to response
if env.__session_modified then
__http_responses[1] = __http_responses[1] or {}
__http_responses[1].session_data = env.__session_data
__http_responses[1].session_id = env.__session_id
__http_responses[1].session_modified = true
end
print("SESSION MODIFIED:", env.__session_modified)
print("FINAL DATA:", util.json_encode(env.__session_data or {}))
return result return result
end end
@ -293,77 +317,79 @@ local cookie = {
-- SESSION MODULE -- SESSION MODULE
-- ====================================================================== -- ======================================================================
-- Session module implementation
local session = { local session = {
-- Get a session value -- Get session value
get = function(key) get = function(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.get: key must be a string", 2) error("session.get: key must be a string", 2)
end end
return __session_data and __session_data[key] local env = getfenv(2)
return env.__session_data and env.__session_data[key]
end, end,
-- Set a session value -- Set session value
set = function(key, value) set = function(key, value)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.set: key must be a string", 2) error("session.set: key must be a string", 2)
end end
__session_data = __session_data or {}
__session_data[key] = value local env = getfenv(2)
__session_modified = true print("SET ENV:", tostring(env)) -- Debug the environment
if not env.__session_data then
env.__session_data = {}
print("CREATED NEW SESSION TABLE")
end
env.__session_data[key] = value
env.__session_modified = true
print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)
return true return true
end, end,
-- Delete a session value -- Delete session value
delete = function(key) delete = function(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.delete: key must be a string", 2) error("session.delete: key must be a string", 2)
end end
if __session_data and __session_data[key] ~= nil then local env = getfenv(2)
__session_data[key] = nil if env.__session_data and env.__session_data[key] ~= nil then
__session_modified = true env.__session_data[key] = nil
__session_modified_keys[key] = true env.__session_modified = true
end end
return true return true
end, end,
-- Clear all session data -- Clear all session data
clear = function() clear = function()
if __session_data and next(__session_data) then local env = getfenv(2)
-- Track all keys as modified if env.__session_data and next(env.__session_data) then
for k in pairs(__session_data) do env.__session_data = {}
__session_modified_keys[k] = true env.__session_modified = true
end
__session_data = {}
__session_modified = true
end end
return true return true
end, end,
-- Get the session ID -- Get session ID
get_id = function() get_id = function()
return __session_id or nil local env = getfenv(2)
return env.__session_id or ""
end, end,
-- Get all session data -- Get all session data
get_all = function() get_all = function()
local result = {} local env = getfenv(2)
for k, v in pairs(__session_data or {}) do return env.__session_data or {}
result[k] = v
end
return result
end, end,
-- Check if session has a key -- Check if session has key
has = function(key) has = function(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.has: key must be a string", 2) error("session.has: key must be a string", 2)
end end
local env = getfenv(2)
return __session_data and __session_data[key] ~= nil return env.__session_data ~= nil and env.__session_data[key] ~= nil
end end
} }