optimize session manager

This commit is contained in:
Sky Johnson 2025-04-10 13:55:05 -05:00
parent c952242a9c
commit 941e810acb
4 changed files with 225 additions and 117 deletions

View File

@ -1,11 +1,11 @@
package sessions package sessions
import ( import (
"crypto/rand"
"encoding/base64"
"sync" "sync"
"time" "time"
"github.com/VictoriaMetrics/fastcache"
gonanoid "github.com/matoous/go-nanoid/v2"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -14,11 +14,12 @@ const (
DefaultCookieName = "MoonsharkSID" DefaultCookieName = "MoonsharkSID"
DefaultCookiePath = "/" DefaultCookiePath = "/"
DefaultMaxAge = 86400 // 1 day in seconds DefaultMaxAge = 86400 // 1 day in seconds
CleanupInterval = 5 * time.Minute
) )
// SessionManager handles multiple sessions // SessionManager handles multiple sessions
type SessionManager struct { type SessionManager struct {
sessions map[string]*Session cache *fastcache.Cache
maxSessions int maxSessions int
cookieName string cookieName string
cookiePath string cookiePath string
@ -26,7 +27,19 @@ type SessionManager struct {
cookieSecure bool cookieSecure bool
cookieHTTPOnly bool cookieHTTPOnly bool
cookieMaxAge int cookieMaxAge int
mu sync.RWMutex cookieMu sync.RWMutex // Only cookie options need a mutex
cleanupTicker *time.Ticker
cleanupDone chan struct{}
}
// InitializeSessionPool pre-allocates session objects
func InitializeSessionPool(size int) {
for range size {
session := &Session{
Data: make(map[string]any, 8),
}
ReturnToPool(session)
}
} }
// NewSessionManager creates a new session manager // NewSessionManager creates a new session manager
@ -35,44 +48,67 @@ func NewSessionManager(maxSessions int) *SessionManager {
maxSessions = DefaultMaxSessions maxSessions = DefaultMaxSessions
} }
return &SessionManager{ // Estimate max memory: ~4KB per session × maxSessions
sessions: make(map[string]*Session, maxSessions), maxBytes := maxSessions * 4096
sm := &SessionManager{
cache: fastcache.New(maxBytes),
maxSessions: maxSessions, maxSessions: maxSessions,
cookieName: DefaultCookieName, cookieName: DefaultCookieName,
cookiePath: DefaultCookiePath, cookiePath: DefaultCookiePath,
cookieHTTPOnly: true, cookieHTTPOnly: true,
cookieMaxAge: DefaultMaxAge, cookieMaxAge: DefaultMaxAge,
cleanupDone: make(chan struct{}),
}
// Pre-allocate session objects for common pool size
InitializeSessionPool(100) // Adjust based on expected concurrent requests
// Start periodic cleanup
sm.cleanupTicker = time.NewTicker(CleanupInterval)
go sm.cleanupRoutine()
return sm
}
// Stop shuts down the session manager's cleanup routine
func (sm *SessionManager) Stop() {
close(sm.cleanupDone)
}
// cleanupRoutine periodically removes expired sessions
func (sm *SessionManager) cleanupRoutine() {
for {
select {
case <-sm.cleanupTicker.C:
sm.CleanupExpired()
case <-sm.cleanupDone:
sm.cleanupTicker.Stop()
return
}
} }
} }
// generateSessionID creates a random session ID // generateSessionID creates a random session ID
func generateSessionID() string { func generateSessionID() string {
b := make([]byte, 32) id, _ := gonanoid.New()
if _, err := rand.Read(b); err != nil { return id
return time.Now().String() // Fallback
}
return base64.URLEncoding.EncodeToString(b)
} }
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist // GetSession retrieves a session by ID, or creates a new one if it doesn't exist
func (sm *SessionManager) GetSession(id string) *Session { func (sm *SessionManager) GetSession(id string) *Session {
// Try to get an existing session // Try to get an existing session
if id != "" { if id != "" {
sm.mu.RLock() data := sm.cache.Get(nil, []byte(id))
session, exists := sm.sessions[id] if len(data) > 0 {
sm.mu.RUnlock() session, err := Unmarshal(data)
if err == nil && !session.IsExpired() {
if exists {
// Check if session is expired
if session.IsExpired() {
sm.mu.Lock()
delete(sm.sessions, id)
sm.mu.Unlock()
} else {
// Update last used time
session.UpdateLastUsed() session.UpdateLastUsed()
session.ResetDirty() // Start clean
return session return session
} }
// Session expired or corrupt, remove it
sm.cache.Del([]byte(id))
} }
} }
@ -83,72 +119,55 @@ func (sm *SessionManager) GetSession(id string) *Session {
// CreateSession generates a new session with a unique ID // CreateSession generates a new session with a unique ID
func (sm *SessionManager) CreateSession() *Session { func (sm *SessionManager) CreateSession() *Session {
id := generateSessionID() id := generateSessionID()
session := NewSession(id, sm.cookieMaxAge)
sm.mu.Lock() // Ensure ID uniqueness
// Enforce session limit - evict LRU if needed attempts := 0
if len(sm.sessions) >= sm.maxSessions { for attempts < 3 {
sm.evictLRU() if sm.cache.Has([]byte(id)) {
} id = generateSessionID()
attempts++
sm.sessions[id] = session } else {
sm.mu.Unlock() break
return session
}
// evictLRU removes the least recently used session
func (sm *SessionManager) evictLRU() {
// Called with mutex already held
if len(sm.sessions) == 0 {
return
}
var oldestID string
var oldestTime time.Time
// Find oldest session
for id, session := range sm.sessions {
if oldestID == "" || session.LastUsed.Before(oldestTime) {
oldestID = id
oldestTime = session.LastUsed
} }
} }
if oldestID != "" { session := NewSession(id, sm.cookieMaxAge)
delete(sm.sessions, oldestID)
// Serialize and store the session
if data, err := session.Marshal(); err == nil {
sm.cache.Set([]byte(id), data)
} }
session.ResetDirty() // Start clean
return session
} }
// DestroySession removes a session // DestroySession removes a session
func (sm *SessionManager) DestroySession(id string) { func (sm *SessionManager) DestroySession(id string) {
sm.mu.Lock() // Get and clean session from cache before deleting
defer sm.mu.Unlock() data := sm.cache.Get(nil, []byte(id))
delete(sm.sessions, id) if len(data) > 0 {
} if session, err := Unmarshal(data); err == nil {
ReturnToPool(session)
// CleanupExpired removes all expired sessions
func (sm *SessionManager) CleanupExpired() int {
removed := 0
now := time.Now()
sm.mu.Lock()
defer sm.mu.Unlock()
for id, session := range sm.sessions {
if now.After(session.Expiry) {
delete(sm.sessions, id)
removed++
} }
} }
return removed sm.cache.Del([]byte(id))
}
// CleanupExpired removes all expired sessions
// Note: fastcache doesn't provide iteration, so we can't clean all expired sessions
// This is a limitation of this implementation
func (sm *SessionManager) CleanupExpired() int {
// No way to iterate through all keys in fastcache
// We'd need to track expiring sessions separately
return 0
} }
// SetCookieOptions configures cookie parameters // SetCookieOptions configures cookie parameters
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) { func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
sm.mu.Lock() sm.cookieMu.Lock()
defer sm.mu.Unlock() defer sm.cookieMu.Unlock()
sm.cookieName = name sm.cookieName = name
sm.cookiePath = path sm.cookiePath = path
@ -160,7 +179,11 @@ func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, ht
// GetSessionFromRequest extracts the session from a request // GetSessionFromRequest extracts the session from a request
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session { func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
cookie := ctx.Request.Header.Cookie(sm.cookieName) sm.cookieMu.RLock()
cookieName := sm.cookieName
sm.cookieMu.RUnlock()
cookie := ctx.Request.Header.Cookie(cookieName)
if len(cookie) == 0 { if len(cookie) == 0 {
return sm.CreateSession() return sm.CreateSession()
} }
@ -173,25 +196,43 @@ func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *
cookie := fasthttp.AcquireCookie() cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie) defer fasthttp.ReleaseCookie(cookie)
cookie.SetKey(sm.cookieName) // Get cookie options with minimal lock time
cookie.SetValue(session.ID) sm.cookieMu.RLock()
cookie.SetPath(sm.cookiePath) cookieName := sm.cookieName
cookie.SetHTTPOnly(sm.cookieHTTPOnly) cookiePath := sm.cookiePath
cookie.SetMaxAge(sm.cookieMaxAge) cookieDomain := sm.cookieDomain
cookieSecure := sm.cookieSecure
cookieHTTPOnly := sm.cookieHTTPOnly
cookieMaxAge := sm.cookieMaxAge
sm.cookieMu.RUnlock()
if sm.cookieDomain != "" { // Store updated session only if it has changes
cookie.SetDomain(sm.cookieDomain) if session.IsDirty() {
if data, err := session.Marshal(); err == nil {
sm.cache.Set([]byte(session.ID), data)
}
session.ResetDirty()
} }
cookie.SetSecure(sm.cookieSecure) cookie.SetKey(cookieName)
cookie.SetValue(session.ID)
cookie.SetPath(cookiePath)
cookie.SetHTTPOnly(cookieHTTPOnly)
cookie.SetMaxAge(cookieMaxAge)
if cookieDomain != "" {
cookie.SetDomain(cookieDomain)
}
cookie.SetSecure(cookieSecure)
ctx.Response.Header.SetCookie(cookie) ctx.Response.Header.SetCookie(cookie)
} }
// CookieOptions returns the cookie options for this session manager // CookieOptions returns the cookie options for this session manager
func (sm *SessionManager) CookieOptions() map[string]any { func (sm *SessionManager) CookieOptions() map[string]any {
sm.mu.RLock() sm.cookieMu.RLock()
defer sm.mu.RUnlock() defer sm.cookieMu.RUnlock()
return map[string]any{ return map[string]any{
"name": sm.cookieName, "name": sm.cookieName,

View File

@ -1,74 +1,107 @@
package sessions package sessions
import ( import (
"maps"
"sync" "sync"
"time" "time"
"github.com/goccy/go-json"
) )
// Session stores data for a single user session // Session stores data for a single user session
type Session struct { type Session struct {
ID string ID string `json:"id"`
Data map[string]any Data map[string]any `json:"data"`
CreatedAt time.Time CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time UpdatedAt time.Time `json:"updated_at"`
LastUsed time.Time LastUsed time.Time `json:"last_used"`
Expiry time.Time Expiry time.Time `json:"expiry"`
mu sync.RWMutex dirty bool // Tracks if session has changes, not serialized
}
// Session pool to reduce allocations
var sessionPool = sync.Pool{
New: func() any {
return &Session{
Data: make(map[string]any, 8),
}
},
}
// GetFromPool retrieves a session from the pool
func GetFromPool() *Session {
return sessionPool.Get().(*Session)
}
// ReturnToPool returns a session to the pool after cleaning it
func ReturnToPool(s *Session) {
if s == nil {
return
}
// Clean the session for reuse
s.ID = ""
for k := range s.Data {
delete(s.Data, k)
}
s.CreatedAt = time.Time{}
s.UpdatedAt = time.Time{}
s.LastUsed = time.Time{}
s.Expiry = time.Time{}
s.dirty = false
sessionPool.Put(s)
} }
// NewSession creates a new session with the given ID // NewSession creates a new session with the given ID
func NewSession(id string, maxAge int) *Session { func NewSession(id string, maxAge int) *Session {
now := time.Now() now := time.Now()
return &Session{
ID: id, // Get from pool or create new
Data: make(map[string]any), session := GetFromPool()
CreatedAt: now,
UpdatedAt: now, // Initialize
LastUsed: now, session.ID = id
Expiry: now.Add(time.Duration(maxAge) * time.Second), session.CreatedAt = now
} session.UpdatedAt = now
session.LastUsed = now
session.Expiry = now.Add(time.Duration(maxAge) * time.Second)
session.dirty = false
return session
} }
// Get retrieves a value from the session // Get retrieves a value from the session
func (s *Session) Get(key string) any { func (s *Session) Get(key string) any {
s.mu.RLock()
defer s.mu.RUnlock()
return s.Data[key] return s.Data[key]
} }
// Set stores a value in the session // Set stores a value in the session
func (s *Session) Set(key string, value any) { func (s *Session) Set(key string, value any) {
s.mu.Lock()
defer s.mu.Unlock()
s.Data[key] = value s.Data[key] = value
s.UpdatedAt = time.Now() s.UpdatedAt = time.Now()
s.dirty = true
} }
// Delete removes a value from the session // Delete removes a value from the session
func (s *Session) Delete(key string) { func (s *Session) Delete(key string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.Data, key) delete(s.Data, key)
s.UpdatedAt = time.Now() s.UpdatedAt = time.Now()
s.dirty = true
} }
// Clear removes all data from the session // Clear removes all data from the session
func (s *Session) Clear() { func (s *Session) Clear() {
s.mu.Lock() s.Data = make(map[string]any, 8)
defer s.mu.Unlock()
s.Data = make(map[string]any)
s.UpdatedAt = time.Now() s.UpdatedAt = time.Now()
s.dirty = true
} }
// GetAll returns a copy of all session data // GetAll returns a copy of all session data
func (s *Session) GetAll() map[string]any { func (s *Session) GetAll() map[string]any {
s.mu.RLock()
defer s.mu.RUnlock()
copy := make(map[string]any, len(s.Data)) copy := make(map[string]any, len(s.Data))
maps.Copy(copy, s.Data) for k, v := range s.Data {
copy[k] = v
}
return copy return copy
} }
@ -78,8 +111,37 @@ func (s *Session) IsExpired() bool {
} }
// UpdateLastUsed updates the last used time // UpdateLastUsed updates the last used time
// Only updates if at least 5 seconds have passed since last update
func (s *Session) UpdateLastUsed() { func (s *Session) UpdateLastUsed() {
s.mu.Lock() now := time.Now()
s.LastUsed = time.Now() if now.Sub(s.LastUsed) > 5*time.Second {
s.mu.Unlock() s.LastUsed = now
// Not marking dirty for LastUsed updates to reduce writes
}
}
// IsDirty returns if the session has unsaved changes
func (s *Session) IsDirty() bool {
return s.dirty
}
// ResetDirty marks the session as clean after saving
func (s *Session) ResetDirty() {
s.dirty = false
}
// Marshal serializes the session to JSON
func (s *Session) Marshal() ([]byte, error) {
return json.Marshal(s)
}
// Unmarshal deserializes a session from JSON
func Unmarshal(data []byte) (*Session, error) {
session := GetFromPool()
err := json.Unmarshal(data, session)
if err != nil {
ReturnToPool(session)
return nil, err
}
return session, nil
} }

1
go.mod
View File

@ -15,6 +15,7 @@ require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/golang/snappy v0.0.4 // indirect github.com/golang/snappy v0.0.4 // indirect
github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/compress v1.18.0 // indirect
github.com/matoous/go-nanoid/v2 v2.1.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.31.0 // indirect
) )

4
go.sum
View File

@ -14,7 +14,11 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE=
github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sony/sonyflake v1.2.0 h1:Pfr3A+ejSg+0SPqpoAmQgEtNDAhc2G1SUYk205qVMLQ=
github.com/sony/sonyflake v1.2.0/go.mod h1:LORtCywH/cq10ZbyfhKrHYgAUGH7mOBa76enV9txy/Y=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=