Compare commits

...

3 Commits

Author SHA1 Message Date
35ce09d66e work on sessions 2025-04-09 23:12:23 -05:00
ac991f40a0 optimize sessions 1 2025-04-09 20:54:49 -05:00
85b0551e70 session rewrite 2025-04-09 20:47:22 -05:00
5 changed files with 361 additions and 357 deletions

View File

@ -225,13 +225,12 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
return return
} }
// Save session if modified // Update session if modified
if response.SessionModified { if response.SessionModified {
// Update session data
for k, v := range response.SessionData { for k, v := range response.SessionData {
session.Set(k, v) session.Set(k, v)
} }
s.sessionManager.SaveSession(session)
s.sessionManager.ApplySessionCookie(ctx, session) s.sessionManager.ApplySessionCookie(ctx, session)
} }

View File

@ -4,11 +4,12 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"Moonshark/core/utils/logger" "Moonshark/core/utils/logger"
"maps"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
) )
@ -111,74 +112,48 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
// Execute runs a Lua script in the sandbox with the given context // Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
s.debugLog("Executing script...")
// Create a response object // Create a response object
response := NewResponse() response := NewResponse()
// Get a buffer for string operations
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
// Load bytecode // Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil { if err := state.LoadBytecode(bytecode, "script"); err != nil {
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("Failed to load bytecode: %v", err)
return nil, fmt.Errorf("failed to load script: %w", err) return nil, fmt.Errorf("failed to load script: %w", err)
} }
// Initialize session data in Lua // Add session data to context
contextWithSession := make(map[string]any)
maps.Copy(contextWithSession, ctx.Values)
// Pass session data through context
if ctx.SessionID != "" { if ctx.SessionID != "" {
// Set session ID contextWithSession["session_id"] = ctx.SessionID
state.PushString(ctx.SessionID) contextWithSession["session_data"] = ctx.SessionData
state.SetGlobal("__session_id")
// Set session data
if err := state.PushTable(ctx.SessionData); err != nil {
ReleaseResponse(response)
s.debugLog("Failed to push session data: %v", err)
return nil, err
}
state.SetGlobal("__session_data")
// Reset modification flag
state.PushBoolean(false)
state.SetGlobal("__session_modified")
} else {
// Initialize empty session
if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil {
s.debugLog("Failed to initialize empty session data: %v", err)
}
} }
// Set up context values for execution // Set up context values for execution
if err := state.PushTable(ctx.Values); err != nil { if err := state.PushTable(contextWithSession); err != nil {
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("Failed to push context values: %v", err)
return nil, err return nil, err
} }
// Get the execution function // Get the execution function
state.GetGlobal("__execute_script") state.GetGlobal("__execute_script")
if !state.IsFunction(-1) { if !state.IsFunction(-1) {
state.Pop(1) // Pop non-function state.Pop(1)
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("__execute_script is not a function")
return nil, ErrSandboxNotInitialized return nil, ErrSandboxNotInitialized
} }
// Push function and context to stack // Push function and bytecode
state.PushCopy(-2) // bytecode state.PushCopy(-2) // Bytecode
state.PushCopy(-2) // context state.PushCopy(-2) // Context
state.Remove(-4) // Remove bytecode duplicate
// Remove duplicates state.Remove(-3) // Remove context duplicate
state.Remove(-4)
state.Remove(-3)
// Execute with 2 args, 1 result // Execute with 2 args, 1 result
if err := state.Call(2, 1); err != nil { if err := state.Call(2, 1); err != nil {
ReleaseResponse(response) ReleaseResponse(response)
s.debugLog("Execution failed: %v", err)
return nil, fmt.Errorf("script execution failed: %w", err) return nil, fmt.Errorf("script execution failed: %w", err)
} }
@ -189,21 +164,28 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
} }
state.Pop(1) state.Pop(1)
// Extract HTTP response data from Lua state extractHTTPResponseData(state, response)
s.extractResponseData(state, response)
extractSessionData(state, response)
return response, nil return response, nil
} }
// extractResponseData pulls response info from the Lua state // extractResponseData pulls response info from the Lua state
func (s *Sandbox) extractResponseData(state *luajit.State, response *Response) { func extractHTTPResponseData(state *luajit.State, response *Response) {
// Get HTTP response
state.GetGlobal("__http_responses") state.GetGlobal("__http_responses")
if !state.IsNil(-1) && state.IsTable(-1) { if !state.IsTable(-1) {
state.Pop(1)
return
}
state.PushNumber(1) state.PushNumber(1)
state.GetTable(-2) state.GetTable(-2)
if !state.IsTable(-1) {
state.Pop(2)
return
}
if !state.IsNil(-1) && state.IsTable(-1) {
// Extract status // Extract status
state.GetField(-1, "status") state.GetField(-1, "status")
if state.IsNumber(-1) { if state.IsNumber(-1) {
@ -235,14 +217,14 @@ func (s *Sandbox) extractResponseData(state *luajit.State, response *Response) {
state.GetTable(-2) state.GetTable(-2)
if state.IsTable(-1) { if state.IsTable(-1) {
s.extractCookie(state, response) extractCookie(state, response)
} }
state.Pop(1) state.Pop(1)
} }
} }
state.Pop(1) state.Pop(1)
// Extract metadata if present // Extract metadata
state.GetField(-1, "metadata") state.GetField(-1, "metadata")
if state.IsTable(-1) { if state.IsTable(-1) {
table, err := state.ToTable(-1) table, err := state.ToTable(-1)
@ -253,40 +235,13 @@ func (s *Sandbox) extractResponseData(state *luajit.State, response *Response) {
} }
} }
state.Pop(1) state.Pop(1)
}
state.Pop(1)
}
state.Pop(1)
// Extract session data // Clean up
state.GetGlobal("__session_modified") state.Pop(2)
if state.IsBoolean(-1) && state.ToBoolean(-1) {
response.SessionModified = true
// Get session ID
state.GetGlobal("__session_id")
if state.IsString(-1) {
response.SessionID = state.ToString(-1)
}
state.Pop(1)
// Get session data
state.GetGlobal("__session_data")
if state.IsTable(-1) {
sessionData, err := state.ToTable(-1)
if err == nil {
for k, v := range sessionData {
response.SessionData[k] = v
}
}
}
state.Pop(1)
}
state.Pop(1)
} }
// extractCookie pulls cookie data from the current table on the stack // extractCookie pulls cookie data from the current table on the stack
func (s *Sandbox) extractCookie(state *luajit.State, response *Response) { func extractCookie(state *luajit.State, response *Response) {
cookie := fasthttp.AcquireCookie() cookie := fasthttp.AcquireCookie()
// Get name (required) // Get name (required)
@ -343,3 +298,69 @@ func (s *Sandbox) extractCookie(state *luajit.State, response *Response) {
response.Cookies = append(response.Cookies, cookie) response.Cookies = append(response.Cookies, cookie)
} }
// Extract session data if modified
func extractSessionData(state *luajit.State, response *Response) {
logger.Debug("extractSessionData: Starting extraction")
// Get HTTP response table
state.GetGlobal("__http_responses")
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses is not a table")
state.Pop(1)
return
}
// Get first response
state.PushNumber(1)
state.GetTable(-2)
if !state.IsTable(-1) {
logger.Debug("extractSessionData: __http_responses[1] is not a table")
state.Pop(2)
return
}
// Check session_modified flag
state.GetField(-1, "session_modified")
if !state.IsBoolean(-1) || !state.ToBoolean(-1) {
logger.Debug("extractSessionData: session_modified is not true")
state.Pop(3)
return
}
logger.Debug("extractSessionData: Found session_modified=true")
state.Pop(1)
// Get session ID
state.GetField(-1, "session_id")
if state.IsString(-1) {
response.SessionID = state.ToString(-1)
logger.Debug("extractSessionData: Found session ID: %s", response.SessionID)
} else {
logger.Debug("extractSessionData: session_id not found or not a string")
}
state.Pop(1)
// Get session data
state.GetField(-1, "session_data")
if state.IsTable(-1) {
logger.Debug("extractSessionData: Found session_data table")
sessionData, err := state.ToTable(-1)
if err == nil {
logger.Debug("extractSessionData: Converted session data, size=%d", len(sessionData))
for k, v := range sessionData {
response.SessionData[k] = v
logger.Debug("extractSessionData: Added session key=%s, value=%v", k, v)
}
response.SessionModified = true
} else {
logger.Debug("extractSessionData: Failed to convert session data: %v", err)
}
} else {
logger.Debug("extractSessionData: session_data not found or not a table")
}
state.Pop(1)
// Clean up stack
state.Pop(2)
logger.Debug("extractSessionData: Finished extraction, modified=%v", response.SessionModified)
}

View File

@ -42,11 +42,25 @@ function __execute_script(fn, ctx)
-- Clear previous responses -- Clear previous responses
__http_responses[1] = nil __http_responses[1] = nil
-- Reset session modification flag -- Create environment with metatable inheriting from _G
__session_modified = false local env = setmetatable({}, {__index = _G})
-- Create environment -- Add context if provided
local env = __create_env(ctx) if ctx then
env.ctx = ctx
end
print("INIT SESSION DATA:", util.json_encode(ctx.session_data or {}))
-- Initialize local session variables in the environment
env.__session_data = ctx.session_data or {}
env.__session_id = ctx.session_id
env.__session_modified = false
-- Add proper require function to this environment
if __setup_require then
__setup_require(env)
end
-- Set environment for function -- Set environment for function
setfenv(fn, env) setfenv(fn, env)
@ -57,6 +71,17 @@ function __execute_script(fn, ctx)
error(result, 0) error(result, 0)
end end
-- If session was modified, add to response
if env.__session_modified then
__http_responses[1] = __http_responses[1] or {}
__http_responses[1].session_data = env.__session_data
__http_responses[1].session_id = env.__session_id
__http_responses[1].session_modified = true
end
print("SESSION MODIFIED:", env.__session_modified)
print("FINAL DATA:", util.json_encode(env.__session_data or {}))
return result return result
end end
@ -292,81 +317,79 @@ local cookie = {
-- SESSION MODULE -- SESSION MODULE
-- ====================================================================== -- ======================================================================
-- Session module implementation
local session = { local session = {
-- Get a session value -- Get session value
get = function(key) get = function(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.get: key must be a string", 2) error("session.get: key must be a string", 2)
end end
local env = getfenv(2)
if __session_data and __session_data[key] ~= nil then return env.__session_data and env.__session_data[key]
return __session_data[key]
end
return nil
end, end,
-- Set a session value -- Set session value
set = function(key, value) set = function(key, value)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.set: key must be a string", 2) error("session.set: key must be a string", 2)
end end
-- Ensure session data table exists local env = getfenv(2)
__session_data = __session_data or {} print("SET ENV:", tostring(env)) -- Debug the environment
-- Store value if not env.__session_data then
__session_data[key] = value env.__session_data = {}
print("CREATED NEW SESSION TABLE")
-- Mark session as modified end
__session_modified = true
env.__session_data[key] = value
env.__session_modified = true
print("SET:", key, "=", tostring(value), "MODIFIED:", env.__session_modified)
return true return true
end, end,
-- Delete a session value -- Delete session value
delete = function(key) delete = function(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.delete: key must be a string", 2) error("session.delete: key must be a string", 2)
end end
if __session_data then local env = getfenv(2)
__session_data[key] = nil if env.__session_data and env.__session_data[key] ~= nil then
__session_modified = true env.__session_data[key] = nil
env.__session_modified = true
end end
return true return true
end, end,
-- Clear all session data -- Clear all session data
clear = function() clear = function()
__session_data = {} local env = getfenv(2)
__session_modified = true if env.__session_data and next(env.__session_data) then
env.__session_data = {}
env.__session_modified = true
end
return true return true
end, end,
-- Get the session ID -- Get session ID
get_id = function() get_id = function()
return __session_id or nil local env = getfenv(2)
return env.__session_id or ""
end, end,
-- Get all session data -- Get all session data
get_all = function() get_all = function()
local result = {} local env = getfenv(2)
for k, v in pairs(__session_data or {}) do return env.__session_data or {}
result[k] = v
end
return result
end, end,
-- Check if session has a key -- Check if session has key
has = function(key) has = function(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.has: key must be a string", 2) error("session.has: key must be a string", 2)
end end
local env = getfenv(2)
return __session_data and __session_data[key] ~= nil return env.__session_data ~= nil and env.__session_data[key] ~= nil
end end
} }

View File

@ -1,46 +1,43 @@
package sessions package sessions
import ( import (
"Moonshark/core/utils/logger"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"sync" "sync"
"time" "time"
"github.com/VictoriaMetrics/fastcache"
"github.com/goccy/go-json"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
const ( const (
// Default settings DefaultMaxSessions = 10000
DefaultMaxSize = 100 * 1024 * 1024 // 100MB default cache size
DefaultCookieName = "MoonsharkSID" DefaultCookieName = "MoonsharkSID"
DefaultCookiePath = "/" DefaultCookiePath = "/"
DefaultMaxAge = 86400 // 1 day in seconds DefaultMaxAge = 86400 // 1 day in seconds
) )
// SessionManager handles multiple sessions using fastcache for storage // SessionManager handles multiple sessions
type SessionManager struct { type SessionManager struct {
cache *fastcache.Cache sessions map[string]*Session
maxSessions int
cookieName string cookieName string
cookiePath string cookiePath string
cookieDomain string cookieDomain string
cookieSecure bool cookieSecure bool
cookieHTTPOnly bool cookieHTTPOnly bool
cookieMaxAge int cookieMaxAge int
mu sync.RWMutex // Only for cookie settings mu sync.RWMutex
} }
// NewSessionManager creates a new session manager with optional cache size // NewSessionManager creates a new session manager
func NewSessionManager(maxSize ...int) *SessionManager { func NewSessionManager(maxSessions int) *SessionManager {
size := DefaultMaxSize if maxSessions <= 0 {
if len(maxSize) > 0 && maxSize[0] > 0 { maxSessions = DefaultMaxSessions
size = maxSize[0]
} }
return &SessionManager{ return &SessionManager{
cache: fastcache.New(size), sessions: make(map[string]*Session, maxSessions),
maxSessions: maxSessions,
cookieName: DefaultCookieName, cookieName: DefaultCookieName,
cookiePath: DefaultCookiePath, cookiePath: DefaultCookiePath,
cookieHTTPOnly: true, cookieHTTPOnly: true,
@ -48,7 +45,7 @@ func NewSessionManager(maxSize ...int) *SessionManager {
} }
} }
// generateSessionID creates a cryptographically secure random session ID // generateSessionID creates a random session ID
func generateSessionID() string { func generateSessionID() string {
b := make([]byte, 32) b := make([]byte, 32)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(b); err != nil {
@ -59,59 +56,136 @@ func generateSessionID() string {
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist // GetSession retrieves a session by ID, or creates a new one if it doesn't exist
func (sm *SessionManager) GetSession(id string) *Session { func (sm *SessionManager) GetSession(id string) *Session {
// Check if session exists // Try to get an existing session
data := sm.cache.Get(nil, []byte(id)) if id != "" {
sm.mu.RLock()
if len(data) > 0 { session, exists := sm.sessions[id]
logger.Debug("Getting session %s", id) sm.mu.RUnlock()
// Session exists, unmarshal it
session := &Session{}
if err := json.Unmarshal(data, session); err == nil {
// Initialize mutex properly
session.mu = sync.RWMutex{}
// Update last accessed time
session.UpdatedAt = time.Now()
// Store back with updated timestamp
updatedData, _ := json.Marshal(session)
sm.cache.Set([]byte(id), updatedData)
if exists {
// Check if session is expired
if session.IsExpired() {
sm.mu.Lock()
delete(sm.sessions, id)
sm.mu.Unlock()
} else {
// Update last used time
session.UpdateLastUsed()
return session return session
} }
} }
}
logger.Debug("Session doesn't exist; creating it") // Create a new session
return sm.CreateSession()
// Create new session
session := NewSession(id)
data, _ = json.Marshal(session)
sm.cache.Set([]byte(id), data)
return session
} }
// CreateSession generates a new session with a unique ID // CreateSession generates a new session with a unique ID
func (sm *SessionManager) CreateSession() *Session { func (sm *SessionManager) CreateSession() *Session {
id := generateSessionID() id := generateSessionID()
session := NewSession(id, sm.cookieMaxAge)
session := NewSession(id) sm.mu.Lock()
data, _ := json.Marshal(session) // Enforce session limit - evict LRU if needed
sm.cache.Set([]byte(id), data) if len(sm.sessions) >= sm.maxSessions {
sm.evictLRU()
}
sm.sessions[id] = session
sm.mu.Unlock()
return session return session
} }
// SaveSession persists a session back to the cache // evictLRU removes the least recently used session
func (sm *SessionManager) SaveSession(session *Session) { func (sm *SessionManager) evictLRU() {
data, _ := json.Marshal(session) // Called with mutex already held
sm.cache.Set([]byte(session.ID), data) if len(sm.sessions) == 0 {
return
}
var oldestID string
var oldestTime time.Time
// Find oldest session
for id, session := range sm.sessions {
if oldestID == "" || session.LastUsed.Before(oldestTime) {
oldestID = id
oldestTime = session.LastUsed
}
}
if oldestID != "" {
delete(sm.sessions, oldestID)
}
} }
// DestroySession removes a session // DestroySession removes a session
func (sm *SessionManager) DestroySession(id string) { func (sm *SessionManager) DestroySession(id string) {
sm.cache.Del([]byte(id)) sm.mu.Lock()
defer sm.mu.Unlock()
delete(sm.sessions, id)
}
// CleanupExpired removes all expired sessions
func (sm *SessionManager) CleanupExpired() int {
removed := 0
now := time.Now()
sm.mu.Lock()
defer sm.mu.Unlock()
for id, session := range sm.sessions {
if now.After(session.Expiry) {
delete(sm.sessions, id)
removed++
}
}
return removed
}
// SetCookieOptions configures cookie parameters
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.cookieName = name
sm.cookiePath = path
sm.cookieDomain = domain
sm.cookieSecure = secure
sm.cookieHTTPOnly = httpOnly
sm.cookieMaxAge = maxAge
}
// GetSessionFromRequest extracts the session from a request
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
cookie := ctx.Request.Header.Cookie(sm.cookieName)
if len(cookie) == 0 {
return sm.CreateSession()
}
return sm.GetSession(string(cookie))
}
// ApplySessionCookie adds the session cookie to the response
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
cookie.SetKey(sm.cookieName)
cookie.SetValue(session.ID)
cookie.SetPath(sm.cookiePath)
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
cookie.SetMaxAge(sm.cookieMaxAge)
if sm.cookieDomain != "" {
cookie.SetDomain(sm.cookieDomain)
}
cookie.SetSecure(sm.cookieSecure)
ctx.Response.Header.SetCookie(cookie)
} }
// CookieOptions returns the cookie options for this session manager // CookieOptions returns the cookie options for this session manager
@ -129,52 +203,5 @@ func (sm *SessionManager) CookieOptions() map[string]any {
} }
} }
// SetCookieOptions configures cookie parameters
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.cookieName = name
sm.cookiePath = path
sm.cookieDomain = domain
sm.cookieSecure = secure
sm.cookieHTTPOnly = httpOnly
sm.cookieMaxAge = maxAge
}
// GetSessionFromRequest extracts the session from a request context
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
cookie := ctx.Request.Header.Cookie(sm.cookieName)
if len(cookie) == 0 {
// No session cookie, create a new session
return sm.CreateSession()
}
// Session cookie exists, get the session
return sm.GetSession(string(cookie))
}
// SaveSessionToResponse adds the session cookie to an HTTP response
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
sm.mu.RLock()
cookie.SetKey(sm.cookieName)
cookie.SetValue(session.ID)
cookie.SetPath(sm.cookiePath)
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
cookie.SetMaxAge(sm.cookieMaxAge)
if sm.cookieDomain != "" {
cookie.SetDomain(sm.cookieDomain)
}
cookie.SetSecure(sm.cookieSecure)
sm.mu.RUnlock()
ctx.Response.Header.SetCookie(cookie)
}
// GlobalSessionManager is the default session manager instance // GlobalSessionManager is the default session manager instance
var GlobalSessionManager = NewSessionManager() var GlobalSessionManager = NewSessionManager(DefaultMaxSessions)

View File

@ -1,41 +1,31 @@
package sessions package sessions
import ( import (
"errors"
"sync" "sync"
"time" "time"
"github.com/goccy/go-json"
)
const (
DefaultMaxValueSize = 256 * 1024 // 256KB per value
)
var (
ErrValueTooLarge = errors.New("session value exceeds size limit")
) )
// Session stores data for a single user session // Session stores data for a single user session
type Session struct { type Session struct {
ID string `json:"id"` ID string
Data map[string]any `json:"data"` Data map[string]any
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time
mu sync.RWMutex `json:"-"` LastUsed time.Time
maxValueSize int `json:"max_value_size"` Expiry time.Time
totalDataSize int `json:"total_data_size"` mu sync.RWMutex
} }
// NewSession creates a new session with the given ID // NewSession creates a new session with the given ID
func NewSession(id string) *Session { func NewSession(id string, maxAge int) *Session {
now := time.Now() now := time.Now()
return &Session{ return &Session{
ID: id, ID: id,
Data: make(map[string]any), Data: make(map[string]any),
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
maxValueSize: DefaultMaxValueSize, LastUsed: now,
Expiry: now.Add(time.Duration(maxAge) * time.Second),
} }
} }
@ -47,65 +37,17 @@ func (s *Session) Get(key string) any {
} }
// Set stores a value in the session // Set stores a value in the session
func (s *Session) Set(key string, value any) error { func (s *Session) Set(key string, value any) {
// Estimate value size
size, err := estimateSize(value)
if err != nil {
return err
}
// Check against limit
if size > s.maxValueSize {
return ErrValueTooLarge
}
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// If replacing, subtract old value size
if oldVal, exists := s.Data[key]; exists {
oldSize, _ := estimateSize(oldVal)
s.totalDataSize -= oldSize
}
s.Data[key] = value s.Data[key] = value
s.totalDataSize += size
s.UpdatedAt = time.Now() s.UpdatedAt = time.Now()
return nil
}
// SetMaxValueSize changes the maximum allowed value size
func (s *Session) SetMaxValueSize(bytes int) {
s.mu.Lock()
defer s.mu.Unlock()
s.maxValueSize = bytes
}
// GetMaxValueSize returns the current max value size
func (s *Session) GetMaxValueSize() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.maxValueSize
}
// GetTotalSize returns the estimated total size of all session data
func (s *Session) GetTotalSize() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.totalDataSize
} }
// Delete removes a value from the session // Delete removes a value from the session
func (s *Session) Delete(key string) { func (s *Session) Delete(key string) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Update size tracking
if oldVal, exists := s.Data[key]; exists {
oldSize, _ := estimateSize(oldVal)
s.totalDataSize -= oldSize
}
delete(s.Data, key) delete(s.Data, key)
s.UpdatedAt = time.Now() s.UpdatedAt = time.Now()
} }
@ -115,7 +57,6 @@ func (s *Session) Clear() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.Data = make(map[string]any) s.Data = make(map[string]any)
s.totalDataSize = 0
s.UpdatedAt = time.Now() s.UpdatedAt = time.Now()
} }
@ -124,7 +65,6 @@ func (s *Session) GetAll() map[string]any {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Create a copy to avoid concurrent map access issues
copy := make(map[string]any, len(s.Data)) copy := make(map[string]any, len(s.Data))
for k, v := range s.Data { for k, v := range s.Data {
copy[k] = v copy[k] = v
@ -133,20 +73,14 @@ func (s *Session) GetAll() map[string]any {
return copy return copy
} }
// estimateSize approximates the memory footprint of a value // IsExpired checks if the session has expired
func estimateSize(v any) (int, error) { func (s *Session) IsExpired() bool {
// Fast path for common types return time.Now().After(s.Expiry)
switch val := v.(type) {
case string:
return len(val), nil
case []byte:
return len(val), nil
} }
// For other types, use JSON serialization as approximation // UpdateLastUsed updates the last used time
data, err := json.Marshal(v) func (s *Session) UpdateLastUsed() {
if err != nil { s.mu.Lock()
return 0, err s.LastUsed = time.Now()
} s.mu.Unlock()
return len(data), nil
} }