optimize session manager

This commit is contained in:
Sky Johnson 2025-04-10 13:55:05 -05:00
parent c952242a9c
commit 941e810acb
4 changed files with 225 additions and 117 deletions

View File

@ -1,11 +1,11 @@
package sessions
import (
"crypto/rand"
"encoding/base64"
"sync"
"time"
"github.com/VictoriaMetrics/fastcache"
gonanoid "github.com/matoous/go-nanoid/v2"
"github.com/valyala/fasthttp"
)
@ -14,11 +14,12 @@ const (
DefaultCookieName = "MoonsharkSID"
DefaultCookiePath = "/"
DefaultMaxAge = 86400 // 1 day in seconds
CleanupInterval = 5 * time.Minute
)
// SessionManager handles multiple sessions
type SessionManager struct {
sessions map[string]*Session
cache *fastcache.Cache
maxSessions int
cookieName string
cookiePath string
@ -26,7 +27,19 @@ type SessionManager struct {
cookieSecure bool
cookieHTTPOnly bool
cookieMaxAge int
mu sync.RWMutex
cookieMu sync.RWMutex // Only cookie options need a mutex
cleanupTicker *time.Ticker
cleanupDone chan struct{}
}
// InitializeSessionPool pre-allocates session objects
func InitializeSessionPool(size int) {
for range size {
session := &Session{
Data: make(map[string]any, 8),
}
ReturnToPool(session)
}
}
// NewSessionManager creates a new session manager
@ -35,44 +48,67 @@ func NewSessionManager(maxSessions int) *SessionManager {
maxSessions = DefaultMaxSessions
}
return &SessionManager{
sessions: make(map[string]*Session, maxSessions),
// Estimate max memory: ~4KB per session × maxSessions
maxBytes := maxSessions * 4096
sm := &SessionManager{
cache: fastcache.New(maxBytes),
maxSessions: maxSessions,
cookieName: DefaultCookieName,
cookiePath: DefaultCookiePath,
cookieHTTPOnly: true,
cookieMaxAge: DefaultMaxAge,
cleanupDone: make(chan struct{}),
}
// Pre-allocate session objects for common pool size
InitializeSessionPool(100) // Adjust based on expected concurrent requests
// Start periodic cleanup
sm.cleanupTicker = time.NewTicker(CleanupInterval)
go sm.cleanupRoutine()
return sm
}
// Stop shuts down the session manager's cleanup routine
func (sm *SessionManager) Stop() {
close(sm.cleanupDone)
}
// cleanupRoutine periodically removes expired sessions
func (sm *SessionManager) cleanupRoutine() {
for {
select {
case <-sm.cleanupTicker.C:
sm.CleanupExpired()
case <-sm.cleanupDone:
sm.cleanupTicker.Stop()
return
}
}
}
// generateSessionID creates a random session ID
func generateSessionID() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return time.Now().String() // Fallback
}
return base64.URLEncoding.EncodeToString(b)
id, _ := gonanoid.New()
return id
}
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist
func (sm *SessionManager) GetSession(id string) *Session {
// Try to get an existing session
if id != "" {
sm.mu.RLock()
session, exists := sm.sessions[id]
sm.mu.RUnlock()
if exists {
// Check if session is expired
if session.IsExpired() {
sm.mu.Lock()
delete(sm.sessions, id)
sm.mu.Unlock()
} else {
// Update last used time
data := sm.cache.Get(nil, []byte(id))
if len(data) > 0 {
session, err := Unmarshal(data)
if err == nil && !session.IsExpired() {
session.UpdateLastUsed()
session.ResetDirty() // Start clean
return session
}
// Session expired or corrupt, remove it
sm.cache.Del([]byte(id))
}
}
@ -83,72 +119,55 @@ func (sm *SessionManager) GetSession(id string) *Session {
// CreateSession generates a new session with a unique ID
func (sm *SessionManager) CreateSession() *Session {
id := generateSessionID()
// Ensure ID uniqueness
attempts := 0
for attempts < 3 {
if sm.cache.Has([]byte(id)) {
id = generateSessionID()
attempts++
} else {
break
}
}
session := NewSession(id, sm.cookieMaxAge)
sm.mu.Lock()
// Enforce session limit - evict LRU if needed
if len(sm.sessions) >= sm.maxSessions {
sm.evictLRU()
// Serialize and store the session
if data, err := session.Marshal(); err == nil {
sm.cache.Set([]byte(id), data)
}
sm.sessions[id] = session
sm.mu.Unlock()
session.ResetDirty() // Start clean
return session
}
// evictLRU removes the least recently used session
func (sm *SessionManager) evictLRU() {
// Called with mutex already held
if len(sm.sessions) == 0 {
return
}
var oldestID string
var oldestTime time.Time
// Find oldest session
for id, session := range sm.sessions {
if oldestID == "" || session.LastUsed.Before(oldestTime) {
oldestID = id
oldestTime = session.LastUsed
}
}
if oldestID != "" {
delete(sm.sessions, oldestID)
}
}
// DestroySession removes a session
func (sm *SessionManager) DestroySession(id string) {
sm.mu.Lock()
defer sm.mu.Unlock()
delete(sm.sessions, id)
// Get and clean session from cache before deleting
data := sm.cache.Get(nil, []byte(id))
if len(data) > 0 {
if session, err := Unmarshal(data); err == nil {
ReturnToPool(session)
}
}
sm.cache.Del([]byte(id))
}
// CleanupExpired removes all expired sessions
// Note: fastcache doesn't provide iteration, so we can't clean all expired sessions
// This is a limitation of this implementation
func (sm *SessionManager) CleanupExpired() int {
removed := 0
now := time.Now()
sm.mu.Lock()
defer sm.mu.Unlock()
for id, session := range sm.sessions {
if now.After(session.Expiry) {
delete(sm.sessions, id)
removed++
}
}
return removed
// No way to iterate through all keys in fastcache
// We'd need to track expiring sessions separately
return 0
}
// SetCookieOptions configures cookie parameters
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.cookieMu.Lock()
defer sm.cookieMu.Unlock()
sm.cookieName = name
sm.cookiePath = path
@ -160,7 +179,11 @@ func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, ht
// GetSessionFromRequest extracts the session from a request
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
cookie := ctx.Request.Header.Cookie(sm.cookieName)
sm.cookieMu.RLock()
cookieName := sm.cookieName
sm.cookieMu.RUnlock()
cookie := ctx.Request.Header.Cookie(cookieName)
if len(cookie) == 0 {
return sm.CreateSession()
}
@ -173,25 +196,43 @@ func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
cookie.SetKey(sm.cookieName)
cookie.SetValue(session.ID)
cookie.SetPath(sm.cookiePath)
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
cookie.SetMaxAge(sm.cookieMaxAge)
// Get cookie options with minimal lock time
sm.cookieMu.RLock()
cookieName := sm.cookieName
cookiePath := sm.cookiePath
cookieDomain := sm.cookieDomain
cookieSecure := sm.cookieSecure
cookieHTTPOnly := sm.cookieHTTPOnly
cookieMaxAge := sm.cookieMaxAge
sm.cookieMu.RUnlock()
if sm.cookieDomain != "" {
cookie.SetDomain(sm.cookieDomain)
// Store updated session only if it has changes
if session.IsDirty() {
if data, err := session.Marshal(); err == nil {
sm.cache.Set([]byte(session.ID), data)
}
session.ResetDirty()
}
cookie.SetSecure(sm.cookieSecure)
cookie.SetKey(cookieName)
cookie.SetValue(session.ID)
cookie.SetPath(cookiePath)
cookie.SetHTTPOnly(cookieHTTPOnly)
cookie.SetMaxAge(cookieMaxAge)
if cookieDomain != "" {
cookie.SetDomain(cookieDomain)
}
cookie.SetSecure(cookieSecure)
ctx.Response.Header.SetCookie(cookie)
}
// CookieOptions returns the cookie options for this session manager
func (sm *SessionManager) CookieOptions() map[string]any {
sm.mu.RLock()
defer sm.mu.RUnlock()
sm.cookieMu.RLock()
defer sm.cookieMu.RUnlock()
return map[string]any{
"name": sm.cookieName,

View File

@ -1,74 +1,107 @@
package sessions
import (
"maps"
"sync"
"time"
"github.com/goccy/go-json"
)
// Session stores data for a single user session
type Session struct {
ID string
Data map[string]any
CreatedAt time.Time
UpdatedAt time.Time
LastUsed time.Time
Expiry time.Time
mu sync.RWMutex
ID string `json:"id"`
Data map[string]any `json:"data"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
LastUsed time.Time `json:"last_used"`
Expiry time.Time `json:"expiry"`
dirty bool // Tracks if session has changes, not serialized
}
// Session pool to reduce allocations
var sessionPool = sync.Pool{
New: func() any {
return &Session{
Data: make(map[string]any, 8),
}
},
}
// GetFromPool retrieves a session from the pool
func GetFromPool() *Session {
return sessionPool.Get().(*Session)
}
// ReturnToPool returns a session to the pool after cleaning it
func ReturnToPool(s *Session) {
if s == nil {
return
}
// Clean the session for reuse
s.ID = ""
for k := range s.Data {
delete(s.Data, k)
}
s.CreatedAt = time.Time{}
s.UpdatedAt = time.Time{}
s.LastUsed = time.Time{}
s.Expiry = time.Time{}
s.dirty = false
sessionPool.Put(s)
}
// NewSession creates a new session with the given ID
func NewSession(id string, maxAge int) *Session {
now := time.Now()
return &Session{
ID: id,
Data: make(map[string]any),
CreatedAt: now,
UpdatedAt: now,
LastUsed: now,
Expiry: now.Add(time.Duration(maxAge) * time.Second),
}
// Get from pool or create new
session := GetFromPool()
// Initialize
session.ID = id
session.CreatedAt = now
session.UpdatedAt = now
session.LastUsed = now
session.Expiry = now.Add(time.Duration(maxAge) * time.Second)
session.dirty = false
return session
}
// Get retrieves a value from the session
func (s *Session) Get(key string) any {
s.mu.RLock()
defer s.mu.RUnlock()
return s.Data[key]
}
// Set stores a value in the session
func (s *Session) Set(key string, value any) {
s.mu.Lock()
defer s.mu.Unlock()
s.Data[key] = value
s.UpdatedAt = time.Now()
s.dirty = true
}
// Delete removes a value from the session
func (s *Session) Delete(key string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.Data, key)
s.UpdatedAt = time.Now()
s.dirty = true
}
// Clear removes all data from the session
func (s *Session) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.Data = make(map[string]any)
s.Data = make(map[string]any, 8)
s.UpdatedAt = time.Now()
s.dirty = true
}
// GetAll returns a copy of all session data
func (s *Session) GetAll() map[string]any {
s.mu.RLock()
defer s.mu.RUnlock()
copy := make(map[string]any, len(s.Data))
maps.Copy(copy, s.Data)
for k, v := range s.Data {
copy[k] = v
}
return copy
}
@ -78,8 +111,37 @@ func (s *Session) IsExpired() bool {
}
// UpdateLastUsed updates the last used time
// Only updates if at least 5 seconds have passed since last update
func (s *Session) UpdateLastUsed() {
s.mu.Lock()
s.LastUsed = time.Now()
s.mu.Unlock()
now := time.Now()
if now.Sub(s.LastUsed) > 5*time.Second {
s.LastUsed = now
// Not marking dirty for LastUsed updates to reduce writes
}
}
// IsDirty returns if the session has unsaved changes
func (s *Session) IsDirty() bool {
return s.dirty
}
// ResetDirty marks the session as clean after saving
func (s *Session) ResetDirty() {
s.dirty = false
}
// Marshal serializes the session to JSON
func (s *Session) Marshal() ([]byte, error) {
return json.Marshal(s)
}
// Unmarshal deserializes a session from JSON
func Unmarshal(data []byte) (*Session, error) {
session := GetFromPool()
err := json.Unmarshal(data, session)
if err != nil {
ReturnToPool(session)
return nil, err
}
return session, nil
}

1
go.mod
View File

@ -15,6 +15,7 @@ require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/matoous/go-nanoid/v2 v2.1.0 // indirect
golang.org/x/sys v0.31.0 // indirect
)

4
go.sum
View File

@ -14,7 +14,11 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE=
github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sony/sonyflake v1.2.0 h1:Pfr3A+ejSg+0SPqpoAmQgEtNDAhc2G1SUYk205qVMLQ=
github.com/sony/sonyflake v1.2.0/go.mod h1:LORtCywH/cq10ZbyfhKrHYgAUGH7mOBa76enV9txy/Y=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=