Compare commits
3 Commits
5ebcd97662
...
35ce09d66e
Author | SHA1 | Date | |
---|---|---|---|
35ce09d66e | |||
ac991f40a0 | |||
85b0551e70 |
|
@ -225,13 +225,12 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save session if modified
|
// Update session if modified
|
||||||
if response.SessionModified {
|
if response.SessionModified {
|
||||||
// Update session data
|
|
||||||
for k, v := range response.SessionData {
|
for k, v := range response.SessionData {
|
||||||
session.Set(k, v)
|
session.Set(k, v)
|
||||||
}
|
}
|
||||||
s.sessionManager.SaveSession(session)
|
|
||||||
s.sessionManager.ApplySessionCookie(ctx, session)
|
s.sessionManager.ApplySessionCookie(ctx, session)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,12 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/valyala/bytebufferpool"
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
|
"maps"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -111,74 +112,48 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
|
||||||
|
|
||||||
// Execute runs a Lua script in the sandbox with the given context
|
// Execute runs a Lua script in the sandbox with the given context
|
||||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
|
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
|
||||||
s.debugLog("Executing script...")
|
|
||||||
|
|
||||||
// Create a response object
|
// Create a response object
|
||||||
response := NewResponse()
|
response := NewResponse()
|
||||||
|
|
||||||
// Get a buffer for string operations
|
|
||||||
buf := bytebufferpool.Get()
|
|
||||||
defer bytebufferpool.Put(buf)
|
|
||||||
|
|
||||||
// Load bytecode
|
// Load bytecode
|
||||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||||
ReleaseResponse(response)
|
ReleaseResponse(response)
|
||||||
s.debugLog("Failed to load bytecode: %v", err)
|
|
||||||
return nil, fmt.Errorf("failed to load script: %w", err)
|
return nil, fmt.Errorf("failed to load script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize session data in Lua
|
// Add session data to context
|
||||||
|
contextWithSession := make(map[string]any)
|
||||||
|
maps.Copy(contextWithSession, ctx.Values)
|
||||||
|
|
||||||
|
// Pass session data through context
|
||||||
if ctx.SessionID != "" {
|
if ctx.SessionID != "" {
|
||||||
// Set session ID
|
contextWithSession["session_id"] = ctx.SessionID
|
||||||
state.PushString(ctx.SessionID)
|
contextWithSession["session_data"] = ctx.SessionData
|
||||||
state.SetGlobal("__session_id")
|
|
||||||
|
|
||||||
// Set session data
|
|
||||||
if err := state.PushTable(ctx.SessionData); err != nil {
|
|
||||||
ReleaseResponse(response)
|
|
||||||
s.debugLog("Failed to push session data: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
state.SetGlobal("__session_data")
|
|
||||||
|
|
||||||
// Reset modification flag
|
|
||||||
state.PushBoolean(false)
|
|
||||||
state.SetGlobal("__session_modified")
|
|
||||||
} else {
|
|
||||||
// Initialize empty session
|
|
||||||
if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil {
|
|
||||||
s.debugLog("Failed to initialize empty session data: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up context values for execution
|
// Set up context values for execution
|
||||||
if err := state.PushTable(ctx.Values); err != nil {
|
if err := state.PushTable(contextWithSession); err != nil {
|
||||||
ReleaseResponse(response)
|
ReleaseResponse(response)
|
||||||
s.debugLog("Failed to push context values: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the execution function
|
// Get the execution function
|
||||||
state.GetGlobal("__execute_script")
|
state.GetGlobal("__execute_script")
|
||||||
if !state.IsFunction(-1) {
|
if !state.IsFunction(-1) {
|
||||||
state.Pop(1) // Pop non-function
|
state.Pop(1)
|
||||||
ReleaseResponse(response)
|
ReleaseResponse(response)
|
||||||
s.debugLog("__execute_script is not a function")
|
|
||||||
return nil, ErrSandboxNotInitialized
|
return nil, ErrSandboxNotInitialized
|
||||||
}
|
}
|
||||||
|
|
||||||
// Push function and context to stack
|
// Push function and bytecode
|
||||||
state.PushCopy(-2) // bytecode
|
state.PushCopy(-2) // Bytecode
|
||||||
state.PushCopy(-2) // context
|
state.PushCopy(-2) // Context
|
||||||
|
state.Remove(-4) // Remove bytecode duplicate
|
||||||
// Remove duplicates
|
state.Remove(-3) // Remove context duplicate
|
||||||
state.Remove(-4)
|
|
||||||
state.Remove(-3)
|
|
||||||
|
|
||||||
// Execute with 2 args, 1 result
|
// Execute with 2 args, 1 result
|
||||||
if err := state.Call(2, 1); err != nil {
|
if err := state.Call(2, 1); err != nil {
|
||||||
ReleaseResponse(response)
|
ReleaseResponse(response)
|
||||||
s.debugLog("Execution failed: %v", err)
|
|
||||||
return nil, fmt.Errorf("script execution failed: %w", err)
|
return nil, fmt.Errorf("script execution failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -189,104 +164,84 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
// Extract HTTP response data from Lua state
|
extractHTTPResponseData(state, response)
|
||||||
s.extractResponseData(state, response)
|
|
||||||
|
extractSessionData(state, response)
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractResponseData pulls response info from the Lua state
|
// extractResponseData pulls response info from the Lua state
|
||||||
func (s *Sandbox) extractResponseData(state *luajit.State, response *Response) {
|
func extractHTTPResponseData(state *luajit.State, response *Response) {
|
||||||
// Get HTTP response
|
|
||||||
state.GetGlobal("__http_responses")
|
state.GetGlobal("__http_responses")
|
||||||
if !state.IsNil(-1) && state.IsTable(-1) {
|
if !state.IsTable(-1) {
|
||||||
state.PushNumber(1)
|
|
||||||
state.GetTable(-2)
|
|
||||||
|
|
||||||
if !state.IsNil(-1) && state.IsTable(-1) {
|
|
||||||
// Extract status
|
|
||||||
state.GetField(-1, "status")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
response.Status = int(state.ToNumber(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Extract headers
|
|
||||||
state.GetField(-1, "headers")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
state.PushNil() // Start iteration
|
|
||||||
for state.Next(-2) {
|
|
||||||
if state.IsString(-2) && state.IsString(-1) {
|
|
||||||
key := state.ToString(-2)
|
|
||||||
value := state.ToString(-1)
|
|
||||||
response.Headers[key] = value
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Extract cookies
|
|
||||||
state.GetField(-1, "cookies")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
length := state.GetTableLength(-1)
|
|
||||||
for i := 1; i <= length; i++ {
|
|
||||||
state.PushNumber(float64(i))
|
|
||||||
state.GetTable(-2)
|
|
||||||
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
s.extractCookie(state, response)
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Extract metadata if present
|
|
||||||
state.GetField(-1, "metadata")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
table, err := state.ToTable(-1)
|
|
||||||
if err == nil {
|
|
||||||
for k, v := range table {
|
|
||||||
response.Metadata[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state.PushNumber(1)
|
||||||
|
state.GetTable(-2)
|
||||||
|
if !state.IsTable(-1) {
|
||||||
|
state.Pop(2)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract status
|
||||||
|
state.GetField(-1, "status")
|
||||||
|
if state.IsNumber(-1) {
|
||||||
|
response.Status = int(state.ToNumber(-1))
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
// Extract session data
|
// Extract headers
|
||||||
state.GetGlobal("__session_modified")
|
state.GetField(-1, "headers")
|
||||||
if state.IsBoolean(-1) && state.ToBoolean(-1) {
|
if state.IsTable(-1) {
|
||||||
response.SessionModified = true
|
state.PushNil() // Start iteration
|
||||||
|
for state.Next(-2) {
|
||||||
// Get session ID
|
if state.IsString(-2) && state.IsString(-1) {
|
||||||
state.GetGlobal("__session_id")
|
key := state.ToString(-2)
|
||||||
if state.IsString(-1) {
|
value := state.ToString(-1)
|
||||||
response.SessionID = state.ToString(-1)
|
response.Headers[key] = value
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get session data
|
|
||||||
state.GetGlobal("__session_data")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
sessionData, err := state.ToTable(-1)
|
|
||||||
if err == nil {
|
|
||||||
for k, v := range sessionData {
|
|
||||||
response.SessionData[k] = v
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
state.Pop(1)
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Extract cookies
|
||||||
|
state.GetField(-1, "cookies")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
length := state.GetTableLength(-1)
|
||||||
|
for i := 1; i <= length; i++ {
|
||||||
|
state.PushNumber(float64(i))
|
||||||
|
state.GetTable(-2)
|
||||||
|
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
extractCookie(state, response)
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Extract metadata
|
||||||
|
state.GetField(-1, "metadata")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
table, err := state.ToTable(-1)
|
||||||
|
if err == nil {
|
||||||
|
for k, v := range table {
|
||||||
|
response.Metadata[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
state.Pop(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractCookie pulls cookie data from the current table on the stack
|
// extractCookie pulls cookie data from the current table on the stack
|
||||||
func (s *Sandbox) extractCookie(state *luajit.State, response *Response) {
|
func extractCookie(state *luajit.State, response *Response) {
|
||||||
cookie := fasthttp.AcquireCookie()
|
cookie := fasthttp.AcquireCookie()
|
||||||
|
|
||||||
// Get name (required)
|
// Get name (required)
|
||||||
|
@ -343,3 +298,69 @@ func (s *Sandbox) extractCookie(state *luajit.State, response *Response) {
|
||||||
|
|
||||||
response.Cookies = append(response.Cookies, cookie)
|
response.Cookies = append(response.Cookies, cookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract session data if modified
|
||||||
|
func extractSessionData(state *luajit.State, response *Response) {
|
||||||
|
logger.Debug("extractSessionData: Starting extraction")
|
||||||
|
|
||||||
|
// Get HTTP response table
|
||||||
|
state.GetGlobal("__http_responses")
|
||||||
|
if !state.IsTable(-1) {
|
||||||
|
logger.Debug("extractSessionData: __http_responses is not a table")
|
||||||
|
state.Pop(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get first response
|
||||||
|
state.PushNumber(1)
|
||||||
|
state.GetTable(-2)
|
||||||
|
if !state.IsTable(-1) {
|
||||||
|
logger.Debug("extractSessionData: __http_responses[1] is not a table")
|
||||||
|
state.Pop(2)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check session_modified flag
|
||||||
|
state.GetField(-1, "session_modified")
|
||||||
|
if !state.IsBoolean(-1) || !state.ToBoolean(-1) {
|
||||||
|
logger.Debug("extractSessionData: session_modified is not true")
|
||||||
|
state.Pop(3)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Debug("extractSessionData: Found session_modified=true")
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get session ID
|
||||||
|
state.GetField(-1, "session_id")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
response.SessionID = state.ToString(-1)
|
||||||
|
logger.Debug("extractSessionData: Found session ID: %s", response.SessionID)
|
||||||
|
} else {
|
||||||
|
logger.Debug("extractSessionData: session_id not found or not a string")
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get session data
|
||||||
|
state.GetField(-1, "session_data")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
logger.Debug("extractSessionData: Found session_data table")
|
||||||
|
sessionData, err := state.ToTable(-1)
|
||||||
|
if err == nil {
|
||||||
|
logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData))
|
||||||
|
for k, v := range sessionData {
|
||||||
|
response.SessionData[k] = v
|
||||||
|
logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v)
|
||||||
|
}
|
||||||
|
response.SessionModified = true
|
||||||
|
} else {
|
||||||
|
logger.Debug("extractSessionData: Failed to convert session data: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug("extractSessionData: session_data not found or not a table")
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Clean up stack
|
||||||
|
state.Pop(2)
|
||||||
|
logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified)
|
||||||
|
}
|
||||||
|
|
|
@ -42,11 +42,25 @@ function __execute_script(fn, ctx)
|
||||||
-- Clear previous responses
|
-- Clear previous responses
|
||||||
__http_responses[1] = nil
|
__http_responses[1] = nil
|
||||||
|
|
||||||
-- Reset session modification flag
|
-- Create environment with metatable inheriting from _G
|
||||||
__session_modified = false
|
local env = setmetatable({}, {__index = _G})
|
||||||
|
|
||||||
-- Create environment
|
-- Add context if provided
|
||||||
local env = __create_env(ctx)
|
if ctx then
|
||||||
|
env.ctx = ctx
|
||||||
|
end
|
||||||
|
|
||||||
|
print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {}))
|
||||||
|
|
||||||
|
-- Initialize local session variables in the environment
|
||||||
|
env.__session_data = ctx.session_data or {}
|
||||||
|
env.__session_id = ctx.session_id
|
||||||
|
env.__session_modified = false
|
||||||
|
|
||||||
|
-- Add proper require function to this environment
|
||||||
|
if __setup_require then
|
||||||
|
__setup_require(env)
|
||||||
|
end
|
||||||
|
|
||||||
-- Set environment for function
|
-- Set environment for function
|
||||||
setfenv(fn, env)
|
setfenv(fn, env)
|
||||||
|
@ -57,6 +71,17 @@ function __execute_script(fn, ctx)
|
||||||
error(result, 0)
|
error(result, 0)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- If session was modified, add to response
|
||||||
|
if env.__session_modified then
|
||||||
|
__http_responses[1] = __http_responses[1] or {}
|
||||||
|
__http_responses[1].session_data = env.__session_data
|
||||||
|
__http_responses[1].session_id = env.__session_id
|
||||||
|
__http_responses[1].session_modified = true
|
||||||
|
end
|
||||||
|
|
||||||
|
print("SESSION MODIFIED:", env.__session_modified)
|
||||||
|
print("FINAL DATA:", util.json_encode(env.__session_data or {}))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -292,81 +317,79 @@ local cookie = {
|
||||||
-- SESSION MODULE
|
-- SESSION MODULE
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
|
|
||||||
-- Session module implementation
|
|
||||||
local session = {
|
local session = {
|
||||||
-- Get a session value
|
-- Get session value
|
||||||
get = function(key)
|
get = function(key)
|
||||||
if type(key) ~= "string" then
|
if type(key) ~= "string" then
|
||||||
error("session.get: key must be a string", 2)
|
error("session.get: key must be a string", 2)
|
||||||
end
|
end
|
||||||
|
local env = getfenv(2)
|
||||||
if __session_data and __session_data[key] ~= nil then
|
return env.__session_data and env.__session_data[key]
|
||||||
return __session_data[key]
|
|
||||||
end
|
|
||||||
|
|
||||||
return nil
|
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Set a session value
|
-- Set session value
|
||||||
set = function(key, value)
|
set = function(key, value)
|
||||||
if type(key) ~= "string" then
|
if type(key) ~= "string" then
|
||||||
error("session.set: key must be a string", 2)
|
error("session.set: key must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Ensure session data table exists
|
local env = getfenv(2)
|
||||||
__session_data = __session_data or {}
|
print("SET ENV:", tostring(env)) -- Debug the environment
|
||||||
|
|
||||||
-- Store value
|
if not env.__session_data then
|
||||||
__session_data[key] = value
|
env.__session_data = {}
|
||||||
|
print("CREATED NEW SESSION TABLE")
|
||||||
-- Mark session as modified
|
end
|
||||||
__session_modified = true
|
|
||||||
|
|
||||||
|
env.__session_data[key] = value
|
||||||
|
env.__session_modified = true
|
||||||
|
print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)
|
||||||
return true
|
return true
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Delete a session value
|
-- Delete session value
|
||||||
delete = function(key)
|
delete = function(key)
|
||||||
if type(key) ~= "string" then
|
if type(key) ~= "string" then
|
||||||
error("session.delete: key must be a string", 2)
|
error("session.delete: key must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
if __session_data then
|
local env = getfenv(2)
|
||||||
__session_data[key] = nil
|
if env.__session_data and env.__session_data[key] ~= nil then
|
||||||
__session_modified = true
|
env.__session_data[key] = nil
|
||||||
|
env.__session_modified = true
|
||||||
end
|
end
|
||||||
|
|
||||||
return true
|
return true
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Clear all session data
|
-- Clear all session data
|
||||||
clear = function()
|
clear = function()
|
||||||
__session_data = {}
|
local env = getfenv(2)
|
||||||
__session_modified = true
|
if env.__session_data and next(env.__session_data) then
|
||||||
|
env.__session_data = {}
|
||||||
|
env.__session_modified = true
|
||||||
|
end
|
||||||
return true
|
return true
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Get the session ID
|
-- Get session ID
|
||||||
get_id = function()
|
get_id = function()
|
||||||
return __session_id or nil
|
local env = getfenv(2)
|
||||||
|
return env.__session_id or ""
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Get all session data
|
-- Get all session data
|
||||||
get_all = function()
|
get_all = function()
|
||||||
local result = {}
|
local env = getfenv(2)
|
||||||
for k, v in pairs(__session_data or {}) do
|
return env.__session_data or {}
|
||||||
result[k] = v
|
|
||||||
end
|
|
||||||
return result
|
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Check if session has a key
|
-- Check if session has key
|
||||||
has = function(key)
|
has = function(key)
|
||||||
if type(key) ~= "string" then
|
if type(key) ~= "string" then
|
||||||
error("session.has: key must be a string", 2)
|
error("session.has: key must be a string", 2)
|
||||||
end
|
end
|
||||||
|
local env = getfenv(2)
|
||||||
return __session_data and __session_data[key] ~= nil
|
return env.__session_data ~= nil and env.__session_data[key] ~= nil
|
||||||
end
|
end
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,46 +1,43 @@
|
||||||
package sessions
|
package sessions
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/VictoriaMetrics/fastcache"
|
|
||||||
"github.com/goccy/go-json"
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Default settings
|
DefaultMaxSessions = 10000
|
||||||
DefaultMaxSize = 100 * 1024 * 1024 // 100MB default cache size
|
DefaultCookieName = "MoonsharkSID"
|
||||||
DefaultCookieName = "MoonsharkSID"
|
DefaultCookiePath = "/"
|
||||||
DefaultCookiePath = "/"
|
DefaultMaxAge = 86400 // 1 day in seconds
|
||||||
DefaultMaxAge = 86400 // 1 day in seconds
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionManager handles multiple sessions using fastcache for storage
|
// SessionManager handles multiple sessions
|
||||||
type SessionManager struct {
|
type SessionManager struct {
|
||||||
cache *fastcache.Cache
|
sessions map[string]*Session
|
||||||
|
maxSessions int
|
||||||
cookieName string
|
cookieName string
|
||||||
cookiePath string
|
cookiePath string
|
||||||
cookieDomain string
|
cookieDomain string
|
||||||
cookieSecure bool
|
cookieSecure bool
|
||||||
cookieHTTPOnly bool
|
cookieHTTPOnly bool
|
||||||
cookieMaxAge int
|
cookieMaxAge int
|
||||||
mu sync.RWMutex // Only for cookie settings
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSessionManager creates a new session manager with optional cache size
|
// NewSessionManager creates a new session manager
|
||||||
func NewSessionManager(maxSize ...int) *SessionManager {
|
func NewSessionManager(maxSessions int) *SessionManager {
|
||||||
size := DefaultMaxSize
|
if maxSessions <= 0 {
|
||||||
if len(maxSize) > 0 && maxSize[0] > 0 {
|
maxSessions = DefaultMaxSessions
|
||||||
size = maxSize[0]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &SessionManager{
|
return &SessionManager{
|
||||||
cache: fastcache.New(size),
|
sessions: make(map[string]*Session, maxSessions),
|
||||||
|
maxSessions: maxSessions,
|
||||||
cookieName: DefaultCookieName,
|
cookieName: DefaultCookieName,
|
||||||
cookiePath: DefaultCookiePath,
|
cookiePath: DefaultCookiePath,
|
||||||
cookieHTTPOnly: true,
|
cookieHTTPOnly: true,
|
||||||
|
@ -48,7 +45,7 @@ func NewSessionManager(maxSize ...int) *SessionManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateSessionID creates a cryptographically secure random session ID
|
// generateSessionID creates a random session ID
|
||||||
func generateSessionID() string {
|
func generateSessionID() string {
|
||||||
b := make([]byte, 32)
|
b := make([]byte, 32)
|
||||||
if _, err := rand.Read(b); err != nil {
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
@ -59,59 +56,136 @@ func generateSessionID() string {
|
||||||
|
|
||||||
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist
|
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist
|
||||||
func (sm *SessionManager) GetSession(id string) *Session {
|
func (sm *SessionManager) GetSession(id string) *Session {
|
||||||
// Check if session exists
|
// Try to get an existing session
|
||||||
data := sm.cache.Get(nil, []byte(id))
|
if id != "" {
|
||||||
|
sm.mu.RLock()
|
||||||
|
session, exists := sm.sessions[id]
|
||||||
|
sm.mu.RUnlock()
|
||||||
|
|
||||||
if len(data) > 0 {
|
if exists {
|
||||||
logger.Debug("Getting session %s", id)
|
// Check if session is expired
|
||||||
|
if session.IsExpired() {
|
||||||
// Session exists, unmarshal it
|
sm.mu.Lock()
|
||||||
session := &Session{}
|
delete(sm.sessions, id)
|
||||||
if err := json.Unmarshal(data, session); err == nil {
|
sm.mu.Unlock()
|
||||||
// Initialize mutex properly
|
} else {
|
||||||
session.mu = sync.RWMutex{}
|
// Update last used time
|
||||||
|
session.UpdateLastUsed()
|
||||||
// Update last accessed time
|
return session
|
||||||
session.UpdatedAt = time.Now()
|
}
|
||||||
|
|
||||||
// Store back with updated timestamp
|
|
||||||
updatedData, _ := json.Marshal(session)
|
|
||||||
sm.cache.Set([]byte(id), updatedData)
|
|
||||||
|
|
||||||
return session
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Session doesn't exist; creating it")
|
// Create a new session
|
||||||
|
return sm.CreateSession()
|
||||||
// Create new session
|
|
||||||
session := NewSession(id)
|
|
||||||
data, _ = json.Marshal(session)
|
|
||||||
sm.cache.Set([]byte(id), data)
|
|
||||||
|
|
||||||
return session
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateSession generates a new session with a unique ID
|
// CreateSession generates a new session with a unique ID
|
||||||
func (sm *SessionManager) CreateSession() *Session {
|
func (sm *SessionManager) CreateSession() *Session {
|
||||||
id := generateSessionID()
|
id := generateSessionID()
|
||||||
|
session := NewSession(id, sm.cookieMaxAge)
|
||||||
|
|
||||||
session := NewSession(id)
|
sm.mu.Lock()
|
||||||
data, _ := json.Marshal(session)
|
// Enforce session limit - evict LRU if needed
|
||||||
sm.cache.Set([]byte(id), data)
|
if len(sm.sessions) >= sm.maxSessions {
|
||||||
|
sm.evictLRU()
|
||||||
|
}
|
||||||
|
|
||||||
|
sm.sessions[id] = session
|
||||||
|
sm.mu.Unlock()
|
||||||
|
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveSession persists a session back to the cache
|
// evictLRU removes the least recently used session
|
||||||
func (sm *SessionManager) SaveSession(session *Session) {
|
func (sm *SessionManager) evictLRU() {
|
||||||
data, _ := json.Marshal(session)
|
// Called with mutex already held
|
||||||
sm.cache.Set([]byte(session.ID), data)
|
if len(sm.sessions) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldestID string
|
||||||
|
var oldestTime time.Time
|
||||||
|
|
||||||
|
// Find oldest session
|
||||||
|
for id, session := range sm.sessions {
|
||||||
|
if oldestID == "" || session.LastUsed.Before(oldestTime) {
|
||||||
|
oldestID = id
|
||||||
|
oldestTime = session.LastUsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldestID != "" {
|
||||||
|
delete(sm.sessions, oldestID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DestroySession removes a session
|
// DestroySession removes a session
|
||||||
func (sm *SessionManager) DestroySession(id string) {
|
func (sm *SessionManager) DestroySession(id string) {
|
||||||
sm.cache.Del([]byte(id))
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
delete(sm.sessions, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupExpired removes all expired sessions
|
||||||
|
func (sm *SessionManager) CleanupExpired() int {
|
||||||
|
removed := 0
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
for id, session := range sm.sessions {
|
||||||
|
if now.After(session.Expiry) {
|
||||||
|
delete(sm.sessions, id)
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return removed
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCookieOptions configures cookie parameters
|
||||||
|
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
sm.cookieName = name
|
||||||
|
sm.cookiePath = path
|
||||||
|
sm.cookieDomain = domain
|
||||||
|
sm.cookieSecure = secure
|
||||||
|
sm.cookieHTTPOnly = httpOnly
|
||||||
|
sm.cookieMaxAge = maxAge
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionFromRequest extracts the session from a request
|
||||||
|
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
|
||||||
|
cookie := ctx.Request.Header.Cookie(sm.cookieName)
|
||||||
|
if len(cookie) == 0 {
|
||||||
|
return sm.CreateSession()
|
||||||
|
}
|
||||||
|
|
||||||
|
return sm.GetSession(string(cookie))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplySessionCookie adds the session cookie to the response
|
||||||
|
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
|
||||||
|
cookie := fasthttp.AcquireCookie()
|
||||||
|
defer fasthttp.ReleaseCookie(cookie)
|
||||||
|
|
||||||
|
cookie.SetKey(sm.cookieName)
|
||||||
|
cookie.SetValue(session.ID)
|
||||||
|
cookie.SetPath(sm.cookiePath)
|
||||||
|
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
|
||||||
|
cookie.SetMaxAge(sm.cookieMaxAge)
|
||||||
|
|
||||||
|
if sm.cookieDomain != "" {
|
||||||
|
cookie.SetDomain(sm.cookieDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie.SetSecure(sm.cookieSecure)
|
||||||
|
|
||||||
|
ctx.Response.Header.SetCookie(cookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CookieOptions returns the cookie options for this session manager
|
// CookieOptions returns the cookie options for this session manager
|
||||||
|
@ -129,52 +203,5 @@ func (sm *SessionManager) CookieOptions() map[string]any {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCookieOptions configures cookie parameters
|
|
||||||
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
|
|
||||||
sm.mu.Lock()
|
|
||||||
defer sm.mu.Unlock()
|
|
||||||
|
|
||||||
sm.cookieName = name
|
|
||||||
sm.cookiePath = path
|
|
||||||
sm.cookieDomain = domain
|
|
||||||
sm.cookieSecure = secure
|
|
||||||
sm.cookieHTTPOnly = httpOnly
|
|
||||||
sm.cookieMaxAge = maxAge
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSessionFromRequest extracts the session from a request context
|
|
||||||
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
|
|
||||||
cookie := ctx.Request.Header.Cookie(sm.cookieName)
|
|
||||||
if len(cookie) == 0 {
|
|
||||||
// No session cookie, create a new session
|
|
||||||
return sm.CreateSession()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Session cookie exists, get the session
|
|
||||||
return sm.GetSession(string(cookie))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveSessionToResponse adds the session cookie to an HTTP response
|
|
||||||
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
|
|
||||||
cookie := fasthttp.AcquireCookie()
|
|
||||||
defer fasthttp.ReleaseCookie(cookie)
|
|
||||||
|
|
||||||
sm.mu.RLock()
|
|
||||||
cookie.SetKey(sm.cookieName)
|
|
||||||
cookie.SetValue(session.ID)
|
|
||||||
cookie.SetPath(sm.cookiePath)
|
|
||||||
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
|
|
||||||
cookie.SetMaxAge(sm.cookieMaxAge)
|
|
||||||
|
|
||||||
if sm.cookieDomain != "" {
|
|
||||||
cookie.SetDomain(sm.cookieDomain)
|
|
||||||
}
|
|
||||||
|
|
||||||
cookie.SetSecure(sm.cookieSecure)
|
|
||||||
sm.mu.RUnlock()
|
|
||||||
|
|
||||||
ctx.Response.Header.SetCookie(cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GlobalSessionManager is the default session manager instance
|
// GlobalSessionManager is the default session manager instance
|
||||||
var GlobalSessionManager = NewSessionManager()
|
var GlobalSessionManager = NewSessionManager(DefaultMaxSessions)
|
||||||
|
|
|
@ -1,41 +1,31 @@
|
||||||
package sessions
|
package sessions
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/goccy/go-json"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
DefaultMaxValueSize = 256 * 1024 // 256KB per value
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrValueTooLarge = errors.New("session value exceeds size limit")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Session stores data for a single user session
|
// Session stores data for a single user session
|
||||||
type Session struct {
|
type Session struct {
|
||||||
ID string `json:"id"`
|
ID string
|
||||||
Data map[string]any `json:"data"`
|
Data map[string]any
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time
|
||||||
mu sync.RWMutex `json:"-"`
|
LastUsed time.Time
|
||||||
maxValueSize int `json:"max_value_size"`
|
Expiry time.Time
|
||||||
totalDataSize int `json:"total_data_size"`
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSession creates a new session with the given ID
|
// NewSession creates a new session with the given ID
|
||||||
func NewSession(id string) *Session {
|
func NewSession(id string, maxAge int) *Session {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return &Session{
|
return &Session{
|
||||||
ID: id,
|
ID: id,
|
||||||
Data: make(map[string]any),
|
Data: make(map[string]any),
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
maxValueSize: DefaultMaxValueSize,
|
LastUsed: now,
|
||||||
|
Expiry: now.Add(time.Duration(maxAge) * time.Second),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,65 +37,17 @@ func (s *Session) Get(key string) any {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set stores a value in the session
|
// Set stores a value in the session
|
||||||
func (s *Session) Set(key string, value any) error {
|
func (s *Session) Set(key string, value any) {
|
||||||
// Estimate value size
|
|
||||||
size, err := estimateSize(value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check against limit
|
|
||||||
if size > s.maxValueSize {
|
|
||||||
return ErrValueTooLarge
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// If replacing, subtract old value size
|
|
||||||
if oldVal, exists := s.Data[key]; exists {
|
|
||||||
oldSize, _ := estimateSize(oldVal)
|
|
||||||
s.totalDataSize -= oldSize
|
|
||||||
}
|
|
||||||
|
|
||||||
s.Data[key] = value
|
s.Data[key] = value
|
||||||
s.totalDataSize += size
|
|
||||||
s.UpdatedAt = time.Now()
|
s.UpdatedAt = time.Now()
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMaxValueSize changes the maximum allowed value size
|
|
||||||
func (s *Session) SetMaxValueSize(bytes int) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.maxValueSize = bytes
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxValueSize returns the current max value size
|
|
||||||
func (s *Session) GetMaxValueSize() int {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
return s.maxValueSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTotalSize returns the estimated total size of all session data
|
|
||||||
func (s *Session) GetTotalSize() int {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
return s.totalDataSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a value from the session
|
// Delete removes a value from the session
|
||||||
func (s *Session) Delete(key string) {
|
func (s *Session) Delete(key string) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// Update size tracking
|
|
||||||
if oldVal, exists := s.Data[key]; exists {
|
|
||||||
oldSize, _ := estimateSize(oldVal)
|
|
||||||
s.totalDataSize -= oldSize
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(s.Data, key)
|
delete(s.Data, key)
|
||||||
s.UpdatedAt = time.Now()
|
s.UpdatedAt = time.Now()
|
||||||
}
|
}
|
||||||
|
@ -115,7 +57,6 @@ func (s *Session) Clear() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
s.Data = make(map[string]any)
|
s.Data = make(map[string]any)
|
||||||
s.totalDataSize = 0
|
|
||||||
s.UpdatedAt = time.Now()
|
s.UpdatedAt = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,7 +65,6 @@ func (s *Session) GetAll() map[string]any {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
// Create a copy to avoid concurrent map access issues
|
|
||||||
copy := make(map[string]any, len(s.Data))
|
copy := make(map[string]any, len(s.Data))
|
||||||
for k, v := range s.Data {
|
for k, v := range s.Data {
|
||||||
copy[k] = v
|
copy[k] = v
|
||||||
|
@ -133,20 +73,14 @@ func (s *Session) GetAll() map[string]any {
|
||||||
return copy
|
return copy
|
||||||
}
|
}
|
||||||
|
|
||||||
// estimateSize approximates the memory footprint of a value
|
// IsExpired checks if the session has expired
|
||||||
func estimateSize(v any) (int, error) {
|
func (s *Session) IsExpired() bool {
|
||||||
// Fast path for common types
|
return time.Now().After(s.Expiry)
|
||||||
switch val := v.(type) {
|
}
|
||||||
case string:
|
|
||||||
return len(val), nil
|
// UpdateLastUsed updates the last used time
|
||||||
case []byte:
|
func (s *Session) UpdateLastUsed() {
|
||||||
return len(val), nil
|
s.mu.Lock()
|
||||||
}
|
s.LastUsed = time.Now()
|
||||||
|
s.mu.Unlock()
|
||||||
// For other types, use JSON serialization as approximation
|
|
||||||
data, err := json.Marshal(v)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return len(data), nil
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user