package sessions import ( "sync" "time" "github.com/VictoriaMetrics/fastcache" gonanoid "github.com/matoous/go-nanoid/v2" "github.com/valyala/fasthttp" ) const ( DefaultMaxSessions = 10000 DefaultCookieName = "MoonsharkSID" DefaultCookiePath = "/" DefaultMaxAge = 86400 // 1 day in seconds CleanupInterval = 5 * time.Minute ) // SessionManager handles multiple sessions type SessionManager struct { cache *fastcache.Cache cookieName string cookiePath string cookieDomain string cookieSecure bool cookieHTTPOnly bool cookieMaxAge int cookieMu sync.RWMutex cleanupTicker *time.Ticker cleanupDone chan struct{} } // NewSessionManager creates a new session manager func NewSessionManager(maxSessions int) *SessionManager { if maxSessions <= 0 { maxSessions = DefaultMaxSessions } sm := &SessionManager{ cache: fastcache.New(maxSessions * 4096), cookieName: DefaultCookieName, cookiePath: DefaultCookiePath, cookieDomain: "", cookieSecure: false, cookieHTTPOnly: true, cookieMaxAge: DefaultMaxAge, cleanupDone: make(chan struct{}), } // Pre-populate session pool for range 100 { s := NewSession("", 0) s.Release() } sm.cleanupTicker = time.NewTicker(CleanupInterval) go sm.cleanupRoutine() return sm } // Stop shuts down the session manager's cleanup routine func (sm *SessionManager) Stop() { close(sm.cleanupDone) } func (sm *SessionManager) cleanupRoutine() { for { select { case <-sm.cleanupTicker.C: sm.CleanupExpired() case <-sm.cleanupDone: sm.cleanupTicker.Stop() return } } } // GetSession retrieves a session by ID, or creates a new one if it doesn't exist func (sm *SessionManager) GetSession(id string) *Session { if id != "" { if data := sm.cache.Get(nil, []byte(id)); len(data) > 0 { if s, err := Unmarshal(data); err == nil && !s.IsExpired() { s.UpdateLastUsed() s.ResetDirty() return s } sm.cache.Del([]byte(id)) } } return sm.CreateSession() } // CreateSession generates a new session with a unique ID func (sm *SessionManager) CreateSession() *Session { id, _ := gonanoid.New() // Ensure uniqueness (max 3 attempts) for i := 0; i < 3 && sm.cache.Has([]byte(id)); i++ { id, _ = gonanoid.New() } s := NewSession(id, sm.cookieMaxAge) if data, err := s.Marshal(); err == nil { sm.cache.Set([]byte(id), data) } s.ResetDirty() return s } // DestroySession removes a session func (sm *SessionManager) DestroySession(id string) { if data := sm.cache.Get(nil, []byte(id)); len(data) > 0 { if s, err := Unmarshal(data); err == nil { s.Release() } } sm.cache.Del([]byte(id)) } // CleanupExpired removes all expired sessions func (sm *SessionManager) CleanupExpired() int { // fastcache doesn't support iteration return 0 } // SetCookieOptions configures cookie parameters func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) { sm.cookieMu.Lock() sm.cookieName = name sm.cookiePath = path sm.cookieDomain = domain sm.cookieSecure = secure sm.cookieHTTPOnly = httpOnly sm.cookieMaxAge = maxAge sm.cookieMu.Unlock() } // GetSessionFromRequest extracts the session from a request func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session { sm.cookieMu.RLock() name := sm.cookieName sm.cookieMu.RUnlock() if cookie := ctx.Request.Header.Cookie(name); len(cookie) > 0 { return sm.GetSession(string(cookie)) } return sm.CreateSession() } // ApplySessionCookie adds the session cookie to the response func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) { if session.IsDirty() { if data, err := session.Marshal(); err == nil { sm.cache.Set([]byte(session.ID), data) } session.ResetDirty() } cookie := fasthttp.AcquireCookie() defer fasthttp.ReleaseCookie(cookie) sm.cookieMu.RLock() cookie.SetKey(sm.cookieName) cookie.SetPath(sm.cookiePath) cookie.SetHTTPOnly(sm.cookieHTTPOnly) cookie.SetMaxAge(sm.cookieMaxAge) if sm.cookieDomain != "" { cookie.SetDomain(sm.cookieDomain) } cookie.SetSecure(sm.cookieSecure) sm.cookieMu.RUnlock() cookie.SetValue(session.ID) ctx.Response.Header.SetCookie(cookie) } // CookieOptions returns the cookie options for this session manager func (sm *SessionManager) CookieOptions() map[string]any { sm.cookieMu.RLock() defer sm.cookieMu.RUnlock() return map[string]any{ "name": sm.cookieName, "path": sm.cookiePath, "domain": sm.cookieDomain, "secure": sm.cookieSecure, "http_only": sm.cookieHTTPOnly, "max_age": sm.cookieMaxAge, } } // GetCacheStats returns statistics about the session cache func (sm *SessionManager) GetCacheStats() map[string]uint64 { if sm == nil || sm.cache == nil { return map[string]uint64{} } var stats fastcache.Stats sm.cache.UpdateStats(&stats) return map[string]uint64{ "entries": stats.EntriesCount, "bytes": stats.BytesSize, "max_bytes": stats.MaxBytesSize, "gets": stats.GetCalls, "sets": stats.SetCalls, "misses": stats.Misses, } }