Compare commits

..

No commits in common. "35ce09d66ec42b03b04bc157d11c7e687e5f1b5e" and "5ebcd97662dad8f21f1af67825cf6c0f63835efd" have entirely different histories.

5 changed files with 349 additions and 353 deletions

View File

@ -225,12 +225,13 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
return return
} }
// Update session if modified // Save session if modified
if response.SessionModified { if response.SessionModified {
// Update session data
for k, v := range response.SessionData { for k, v := range response.SessionData {
session.Set(k, v) session.Set(k, v)
} }
s.sessionManager.SaveSession(session)
s.sessionManager.ApplySessionCookie(ctx, session) s.sessionManager.ApplySessionCookie(ctx, session)
} }

View File

@ -4,12 +4,11 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"Moonshark/core/utils/logger" "Moonshark/core/utils/logger"
"maps"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
) )
@ -112,48 +111,74 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
// Execute runs a Lua script in the sandbox with the given context // Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
s.debugLog("Executing script...")
// Create a response object // Create a response object
response := NewResponse() response := NewResponse()
// Get a buffer for string operations
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
// Load bytecode // Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil { if err := state.LoadBytecode(bytecode, "script"); err != nil {
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("Failed to load bytecode: %v", err)
return nil, fmt.Errorf("failed to load script: %w", err) return nil, fmt.Errorf("failed to load script: %w", err)
} }
// Add session data to context // Initialize session data in Lua
contextWithSession := make(map[string]any)
maps.Copy(contextWithSession, ctx.Values)
// Pass session data through context
if ctx.SessionID != "" { if ctx.SessionID != "" {
contextWithSession["session_id"] = ctx.SessionID // Set session ID
contextWithSession["session_data"] = ctx.SessionData state.PushString(ctx.SessionID)
state.SetGlobal("__session_id")
// Set session data
if err := state.PushTable(ctx.SessionData); err != nil {
ReleaseResponse(response)
s.debugLog("Failed to push session data: %v", err)
return nil, err
}
state.SetGlobal("__session_data")
// Reset modification flag
state.PushBoolean(false)
state.SetGlobal("__session_modified")
} else {
// Initialize empty session
if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil {
s.debugLog("Failed to initialize empty session data: %v", err)
}
} }
// Set up context values for execution // Set up context values for execution
if err := state.PushTable(contextWithSession); err != nil { if err := state.PushTable(ctx.Values); err != nil {
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("Failed to push context values: %v", err)
return nil, err return nil, err
} }
// Get the execution function // Get the execution function
state.GetGlobal("__execute_script") state.GetGlobal("__execute_script")
if !state.IsFunction(-1) { if !state.IsFunction(-1) {
state.Pop(1) state.Pop(1) // Pop non-function
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("__execute_script is not a function")
return nil, ErrSandboxNotInitialized return nil, ErrSandboxNotInitialized
} }
// Push function and bytecode // Push function and context to stack
state.PushCopy(-2) // Bytecode state.PushCopy(-2) // bytecode
state.PushCopy(-2) // Context state.PushCopy(-2) // context
state.Remove(-4) // Remove bytecode duplicate
state.Remove(-3) // Remove context duplicate // Remove duplicates
state.Remove(-4)
state.Remove(-3)
// Execute with 2 args, 1 result // Execute with 2 args, 1 result
if err := state.Call(2, 1); err != nil { if err := state.Call(2, 1); err != nil {
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("Execution failed: %v", err)
return nil, fmt.Errorf("script execution failed: %w", err) return nil, fmt.Errorf("script execution failed: %w", err)
} }
@ -164,28 +189,21 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
} }
state.Pop(1) state.Pop(1)
extractHTTPResponseData(state, response) // Extract HTTP response data from Lua state
s.extractResponseData(state, response)
extractSessionData(state, response)
return response, nil return response, nil
} }
// extractResponseData pulls response info from the Lua state // extractResponseData pulls response info from the Lua state
func extractHTTPResponseData(state *luajit.State, response *Response) { func (s *Sandbox) extractResponseData(state *luajit.State, response *Response) {
// Get HTTP response
state.GetGlobal("__http_responses") state.GetGlobal("__http_responses")
if !state.IsTable(-1) { if !state.IsNil(-1) && state.IsTable(-1) {
state.Pop(1)
return
}
state.PushNumber(1) state.PushNumber(1)
state.GetTable(-2) state.GetTable(-2)
if !state.IsTable(-1) {
state.Pop(2)
return
}
if !state.IsNil(-1) && state.IsTable(-1) {
// Extract status // Extract status
state.GetField(-1, "status") state.GetField(-1, "status")
if state.IsNumber(-1) { if state.IsNumber(-1) {
@ -217,14 +235,14 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
state.GetTable(-2) state.GetTable(-2)
if state.IsTable(-1) { if state.IsTable(-1) {
extractCookie(state, response) s.extractCookie(state, response)
} }
state.Pop(1) state.Pop(1)
} }
} }
state.Pop(1) state.Pop(1)
// Extract metadata // Extract metadata if present
state.GetField(-1, "metadata") state.GetField(-1, "metadata")
if state.IsTable(-1) { if state.IsTable(-1) {
table, err := state.ToTable(-1) table, err := state.ToTable(-1)
@ -235,13 +253,40 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
} }
} }
state.Pop(1) state.Pop(1)
}
state.Pop(1)
}
state.Pop(1)
// Clean up // Extract session data
state.Pop(2) state.GetGlobal("__session_modified")
if state.IsBoolean(-1) && state.ToBoolean(-1) {
response.SessionModified = true
// Get session ID
state.GetGlobal("__session_id")
if state.IsString(-1) {
response.SessionID = state.ToString(-1)
}
state.Pop(1)
// Get session data
state.GetGlobal("__session_data")
if state.IsTable(-1) {
sessionData, err := state.ToTable(-1)
if err == nil {
for k, v := range sessionData {
response.SessionData[k] = v
}
}
}
state.Pop(1)
}
state.Pop(1)
} }
// extractCookie pulls cookie data from the current table on the stack // extractCookie pulls cookie data from the current table on the stack
func extractCookie(state *luajit.State, response *Response) { func (s *Sandbox) extractCookie(state *luajit.State, response *Response) {
cookie := fasthttp.AcquireCookie() cookie := fasthttp.AcquireCookie()
// Get name (required) // Get name (required)
@ -298,69 +343,3 @@ func extractCookie(state *luajit.State, response *Response) {
response.Cookies = append(response.Cookies, cookie) response.Cookies = append(response.Cookies, cookie)
} }
// Extract session data if modified
func extractSessionData(state *luajit.State, response *Response) {
logger.Debug("extractSessionData: Starting extraction")
// Get HTTP response table
state.GetGlobal("__http_responses")
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses is not a table")
state.Pop(1)
return
}
// Get first response
state.PushNumber(1)
state.GetTable(-2)
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses[1] is not a table")
state.Pop(2)
return
}
// Check session_modified flag
state.GetField(-1, "session_modified")
if !state.IsBoolean(-1) || !state.ToBoolean(-1) {
logger.Debug("extractSessionData: session_modified is not true")
state.Pop(3)
return
}
logger.Debug("extractSessionData: Found session_modified=true")
state.Pop(1)
// Get session ID
state.GetField(-1, "session_id")
if state.IsString(-1) {
response.SessionID = state.ToString(-1)
logger.Debug("extractSessionData: Found session ID: %s", response.SessionID)
} else {
logger.Debug("extractSessionData: session_id not found or not a string")
}
state.Pop(1)
// Get session data
state.GetField(-1, "session_data")
if state.IsTable(-1) {
logger.Debug("extractSessionData: Found session_data table")
sessionData, err := state.ToTable(-1)
if err == nil {
logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData))
for k, v := range sessionData {
response.SessionData[k] = v
logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v)
}
response.SessionModified = true
} else {
logger.Debug("extractSessionData: Failed to convert session data: %v", err)
}
} else {
logger.Debug("extractSessionData: session_data not found or not a table")
}
state.Pop(1)
// Clean up stack
state.Pop(2)
logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified)
}

View File

@ -42,25 +42,11 @@ function __execute_script(fn, ctx)
-- Clear previous responses -- Clear previous responses
__http_responses[1] = nil __http_responses[1] = nil
-- Create environment with metatable inheriting from _G -- Reset session modification flag
local env = setmetatable({}, {__index = _G}) __session_modified = false
-- Add context if provided -- Create environment
if ctx then local env = __create_env(ctx)
env.ctx = ctx
end
print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {}))
-- Initialize local session variables in the environment
env.__session_data = ctx.session_data or {}
env.__session_id = ctx.session_id
env.__session_modified = false
-- Add proper require function to this environment
if __setup_require then
__setup_require(env)
end
-- Set environment for function -- Set environment for function
setfenv(fn, env) setfenv(fn, env)
@ -71,17 +57,6 @@ function __execute_script(fn, ctx)
error(result, 0) error(result, 0)
end end
-- If session was modified, add to response
if env.__session_modified then
__http_responses[1] = __http_responses[1] or {}
__http_responses[1].session_data = env.__session_data
__http_responses[1].session_id = env.__session_id
__http_responses[1].session_modified = true
end
print("SESSION MODIFIED:", env.__session_modified)
print("FINAL DATA:", util.json_encode(env.__session_data or {}))
return result return result
end end
@ -317,79 +292,81 @@ local cookie = {
-- SESSION MODULE -- SESSION MODULE
-- ====================================================================== -- ======================================================================
-- Session module implementation
local session = { local session = {
-- Get session value -- Get a session value
get = function(key) get = function(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.get: key must be a string", 2) error("session.get: key must be a string", 2)
end end
local env = getfenv(2)
return env.__session_data and env.__session_data[key] if __session_data and __session_data[key] ~= nil then
return __session_data[key]
end
return nil
end, end,
-- Set session value -- Set a session value
set = function(key, value) set = function(key, value)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.set: key must be a string", 2) error("session.set: key must be a string", 2)
end end
local env = getfenv(2) -- Ensure session data table exists
print("SET ENV:", tostring(env)) -- Debug the environment __session_data = __session_data or {}
if not env.__session_data then -- Store value
env.__session_data = {} __session_data[key] = value
print("CREATED NEW SESSION TABLE")
end -- Mark session as modified
__session_modified = true
env.__session_data[key] = value
env.__session_modified = true
print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)
return true return true
end, end,
-- Delete session value -- Delete a session value
delete = function(key) delete = function(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.delete: key must be a string", 2) error("session.delete: key must be a string", 2)
end end
local env = getfenv(2) if __session_data then
if env.__session_data and env.__session_data[key] ~= nil then __session_data[key] = nil
env.__session_data[key] = nil __session_modified = true
env.__session_modified = true
end end
return true return true
end, end,
-- Clear all session data -- Clear all session data
clear = function() clear = function()
local env = getfenv(2) __session_data = {}
if env.__session_data and next(env.__session_data) then __session_modified = true
env.__session_data = {}
env.__session_modified = true
end
return true return true
end, end,
-- Get session ID -- Get the session ID
get_id = function() get_id = function()
local env = getfenv(2) return __session_id or nil
return env.__session_id or ""
end, end,
-- Get all session data -- Get all session data
get_all = function() get_all = function()
local env = getfenv(2) local result = {}
return env.__session_data or {} for k, v in pairs(__session_data or {}) do
result[k] = v
end
return result
end, end,
-- Check if session has key -- Check if session has a key
has = function(key) has = function(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.has: key must be a string", 2) error("session.has: key must be a string", 2)
end end
local env = getfenv(2)
return env.__session_data ~= nil and env.__session_data[key] ~= nil return __session_data and __session_data[key] ~= nil
end end
} }

View File

@ -1,43 +1,46 @@
package sessions package sessions
import ( import (
"Moonshark/core/utils/logger"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"sync" "sync"
"time" "time"
"github.com/VictoriaMetrics/fastcache"
"github.com/goccy/go-json"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
const ( const (
DefaultMaxSessions = 10000 // Default settings
DefaultMaxSize = 100 * 1024 * 1024 // 100MB default cache size
DefaultCookieName = "MoonsharkSID" DefaultCookieName = "MoonsharkSID"
DefaultCookiePath = "/" DefaultCookiePath = "/"
DefaultMaxAge = 86400 // 1 day in seconds DefaultMaxAge = 86400 // 1 day in seconds
) )
// SessionManager handles multiple sessions // SessionManager handles multiple sessions using fastcache for storage
type SessionManager struct { type SessionManager struct {
sessions map[string]*Session cache *fastcache.Cache
maxSessions int
cookieName string cookieName string
cookiePath string cookiePath string
cookieDomain string cookieDomain string
cookieSecure bool cookieSecure bool
cookieHTTPOnly bool cookieHTTPOnly bool
cookieMaxAge int cookieMaxAge int
mu sync.RWMutex mu sync.RWMutex // Only for cookie settings
} }
// NewSessionManager creates a new session manager // NewSessionManager creates a new session manager with optional cache size
func NewSessionManager(maxSessions int) *SessionManager { func NewSessionManager(maxSize ...int) *SessionManager {
if maxSessions <= 0 { size := DefaultMaxSize
maxSessions = DefaultMaxSessions if len(maxSize) > 0 && maxSize[0] > 0 {
size = maxSize[0]
} }
return &SessionManager{ return &SessionManager{
sessions: make(map[string]*Session, maxSessions), cache: fastcache.New(size),
maxSessions: maxSessions,
cookieName: DefaultCookieName, cookieName: DefaultCookieName,
cookiePath: DefaultCookiePath, cookiePath: DefaultCookiePath,
cookieHTTPOnly: true, cookieHTTPOnly: true,
@ -45,7 +48,7 @@ func NewSessionManager(maxSessions int) *SessionManager {
} }
} }
// generateSessionID creates a random session ID // generateSessionID creates a cryptographically secure random session ID
func generateSessionID() string { func generateSessionID() string {
b := make([]byte, 32) b := make([]byte, 32)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(b); err != nil {
@ -56,136 +59,59 @@ func generateSessionID() string {
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist // GetSession retrieves a session by ID, or creates a new one if it doesn't exist
func (sm *SessionManager) GetSession(id string) *Session { func (sm *SessionManager) GetSession(id string) *Session {
// Try to get an existing session // Check if session exists
if id != "" { data := sm.cache.Get(nil, []byte(id))
sm.mu.RLock()
session, exists := sm.sessions[id] if len(data) > 0 {
sm.mu.RUnlock() logger.Debug("Getting session %s", id)
// Session exists, unmarshal it
session := &Session{}
if err := json.Unmarshal(data, session); err == nil {
// Initialize mutex properly
session.mu = sync.RWMutex{}
// Update last accessed time
session.UpdatedAt = time.Now()
// Store back with updated timestamp
updatedData, _ := json.Marshal(session)
sm.cache.Set([]byte(id), updatedData)
if exists {
// Check if session is expired
if session.IsExpired() {
sm.mu.Lock()
delete(sm.sessions, id)
sm.mu.Unlock()
} else {
// Update last used time
session.UpdateLastUsed()
return session return session
} }
} }
}
// Create a new session logger.Debug("Session doesn't exist; creating it")
return sm.CreateSession()
// Create new session
session := NewSession(id)
data, _ = json.Marshal(session)
sm.cache.Set([]byte(id), data)
return session
} }
// CreateSession generates a new session with a unique ID // CreateSession generates a new session with a unique ID
func (sm *SessionManager) CreateSession() *Session { func (sm *SessionManager) CreateSession() *Session {
id := generateSessionID() id := generateSessionID()
session := NewSession(id, sm.cookieMaxAge)
sm.mu.Lock() session := NewSession(id)
// Enforce session limit - evict LRU if needed data, _ := json.Marshal(session)
if len(sm.sessions) >= sm.maxSessions { sm.cache.Set([]byte(id), data)
sm.evictLRU()
}
sm.sessions[id] = session
sm.mu.Unlock()
return session return session
} }
// evictLRU removes the least recently used session // SaveSession persists a session back to the cache
func (sm *SessionManager) evictLRU() { func (sm *SessionManager) SaveSession(session *Session) {
// Called with mutex already held data, _ := json.Marshal(session)
if len(sm.sessions) == 0 { sm.cache.Set([]byte(session.ID), data)
return
}
var oldestID string
var oldestTime time.Time
// Find oldest session
for id, session := range sm.sessions {
if oldestID == "" || session.LastUsed.Before(oldestTime) {
oldestID = id
oldestTime = session.LastUsed
}
}
if oldestID != "" {
delete(sm.sessions, oldestID)
}
} }
// DestroySession removes a session // DestroySession removes a session
func (sm *SessionManager) DestroySession(id string) { func (sm *SessionManager) DestroySession(id string) {
sm.mu.Lock() sm.cache.Del([]byte(id))
defer sm.mu.Unlock()
delete(sm.sessions, id)
}
// CleanupExpired removes all expired sessions
func (sm *SessionManager) CleanupExpired() int {
removed := 0
now := time.Now()
sm.mu.Lock()
defer sm.mu.Unlock()
for id, session := range sm.sessions {
if now.After(session.Expiry) {
delete(sm.sessions, id)
removed++
}
}
return removed
}
// SetCookieOptions configures cookie parameters
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.cookieName = name
sm.cookiePath = path
sm.cookieDomain = domain
sm.cookieSecure = secure
sm.cookieHTTPOnly = httpOnly
sm.cookieMaxAge = maxAge
}
// GetSessionFromRequest extracts the session from a request
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
cookie := ctx.Request.Header.Cookie(sm.cookieName)
if len(cookie) == 0 {
return sm.CreateSession()
}
return sm.GetSession(string(cookie))
}
// ApplySessionCookie adds the session cookie to the response
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
cookie.SetKey(sm.cookieName)
cookie.SetValue(session.ID)
cookie.SetPath(sm.cookiePath)
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
cookie.SetMaxAge(sm.cookieMaxAge)
if sm.cookieDomain != "" {
cookie.SetDomain(sm.cookieDomain)
}
cookie.SetSecure(sm.cookieSecure)
ctx.Response.Header.SetCookie(cookie)
} }
// CookieOptions returns the cookie options for this session manager // CookieOptions returns the cookie options for this session manager
@ -203,5 +129,52 @@ func (sm *SessionManager) CookieOptions() map[string]any {
} }
} }
// SetCookieOptions configures cookie parameters
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.cookieName = name
sm.cookiePath = path
sm.cookieDomain = domain
sm.cookieSecure = secure
sm.cookieHTTPOnly = httpOnly
sm.cookieMaxAge = maxAge
}
// GetSessionFromRequest extracts the session from a request context
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
cookie := ctx.Request.Header.Cookie(sm.cookieName)
if len(cookie) == 0 {
// No session cookie, create a new session
return sm.CreateSession()
}
// Session cookie exists, get the session
return sm.GetSession(string(cookie))
}
// SaveSessionToResponse adds the session cookie to an HTTP response
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
sm.mu.RLock()
cookie.SetKey(sm.cookieName)
cookie.SetValue(session.ID)
cookie.SetPath(sm.cookiePath)
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
cookie.SetMaxAge(sm.cookieMaxAge)
if sm.cookieDomain != "" {
cookie.SetDomain(sm.cookieDomain)
}
cookie.SetSecure(sm.cookieSecure)
sm.mu.RUnlock()
ctx.Response.Header.SetCookie(cookie)
}
// GlobalSessionManager is the default session manager instance // GlobalSessionManager is the default session manager instance
var GlobalSessionManager = NewSessionManager(DefaultMaxSessions) var GlobalSessionManager = NewSessionManager()

View File

@ -1,31 +1,41 @@
package sessions package sessions
import ( import (
"errors"
"sync" "sync"
"time" "time"
"github.com/goccy/go-json"
)
const (
DefaultMaxValueSize = 256 * 1024 // 256KB per value
)
var (
ErrValueTooLarge = errors.New("session value exceeds size limit")
) )
// Session stores data for a single user session // Session stores data for a single user session
type Session struct { type Session struct {
ID string ID string `json:"id"`
Data map[string]any Data map[string]any `json:"data"`
CreatedAt time.Time CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time UpdatedAt time.Time `json:"updated_at"`
LastUsed time.Time mu sync.RWMutex `json:"-"`
Expiry time.Time maxValueSize int `json:"max_value_size"`
mu sync.RWMutex totalDataSize int `json:"total_data_size"`
} }
// NewSession creates a new session with the given ID // NewSession creates a new session with the given ID
func NewSession(id string, maxAge int) *Session { func NewSession(id string) *Session {
now := time.Now() now := time.Now()
return &Session{ return &Session{
ID: id, ID: id,
Data: make(map[string]any), Data: make(map[string]any),
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
LastUsed: now, maxValueSize: DefaultMaxValueSize,
Expiry: now.Add(time.Duration(maxAge) * time.Second),
} }
} }
@ -37,17 +47,65 @@ func (s *Session) Get(key string) any {
} }
// Set stores a value in the session // Set stores a value in the session
func (s *Session) Set(key string, value any) { func (s *Session) Set(key string, value any) error {
// Estimate value size
size, err := estimateSize(value)
if err != nil {
return err
}
// Check against limit
if size > s.maxValueSize {
return ErrValueTooLarge
}
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// If replacing, subtract old value size
if oldVal, exists := s.Data[key]; exists {
oldSize, _ := estimateSize(oldVal)
s.totalDataSize -= oldSize
}
s.Data[key] = value s.Data[key] = value
s.totalDataSize += size
s.UpdatedAt = time.Now() s.UpdatedAt = time.Now()
return nil
}
// SetMaxValueSize changes the maximum allowed value size
func (s *Session) SetMaxValueSize(bytes int) {
s.mu.Lock()
defer s.mu.Unlock()
s.maxValueSize = bytes
}
// GetMaxValueSize returns the current max value size
func (s *Session) GetMaxValueSize() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.maxValueSize
}
// GetTotalSize returns the estimated total size of all session data
func (s *Session) GetTotalSize() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.totalDataSize
} }
// Delete removes a value from the session // Delete removes a value from the session
func (s *Session) Delete(key string) { func (s *Session) Delete(key string) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Update size tracking
if oldVal, exists := s.Data[key]; exists {
oldSize, _ := estimateSize(oldVal)
s.totalDataSize -= oldSize
}
delete(s.Data, key) delete(s.Data, key)
s.UpdatedAt = time.Now() s.UpdatedAt = time.Now()
} }
@ -57,6 +115,7 @@ func (s *Session) Clear() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.Data = make(map[string]any) s.Data = make(map[string]any)
s.totalDataSize = 0
s.UpdatedAt = time.Now() s.UpdatedAt = time.Now()
} }
@ -65,6 +124,7 @@ func (s *Session) GetAll() map[string]any {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Create a copy to avoid concurrent map access issues
copy := make(map[string]any, len(s.Data)) copy := make(map[string]any, len(s.Data))
for k, v := range s.Data { for k, v := range s.Data {
copy[k] = v copy[k] = v
@ -73,14 +133,20 @@ func (s *Session) GetAll() map[string]any {
return copy return copy
} }
// IsExpired checks if the session has expired // estimateSize approximates the memory footprint of a value
func (s *Session) IsExpired() bool { func estimateSize(v any) (int, error) {
return time.Now().After(s.Expiry) // Fast path for common types
} switch val := v.(type) {
case string:
return len(val), nil
case []byte:
return len(val), nil
}
// UpdateLastUsed updates the last used time // For other types, use JSON serialization as approximation
func (s *Session) UpdateLastUsed() { data, err := json.Marshal(v)
s.mu.Lock() if err != nil {
s.LastUsed = time.Now() return 0, err
s.mu.Unlock() }
return len(data), nil
} }