249 lines
6.1 KiB
Go
249 lines
6.1 KiB
Go
package sessions
|
||
|
||
import (
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/VictoriaMetrics/fastcache"
|
||
gonanoid "github.com/matoous/go-nanoid/v2"
|
||
"github.com/valyala/fasthttp"
|
||
)
|
||
|
||
const (
|
||
DefaultMaxSessions = 10000
|
||
DefaultCookieName = "MoonsharkSID"
|
||
DefaultCookiePath = "/"
|
||
DefaultMaxAge = 86400 // 1 day in seconds
|
||
CleanupInterval = 5 * time.Minute
|
||
)
|
||
|
||
// SessionManager handles multiple sessions
|
||
type SessionManager struct {
|
||
cache *fastcache.Cache
|
||
maxSessions int
|
||
cookieName string
|
||
cookiePath string
|
||
cookieDomain string
|
||
cookieSecure bool
|
||
cookieHTTPOnly bool
|
||
cookieMaxAge int
|
||
cookieMu sync.RWMutex // Only cookie options need a mutex
|
||
cleanupTicker *time.Ticker
|
||
cleanupDone chan struct{}
|
||
}
|
||
|
||
// InitializeSessionPool pre-allocates session objects
|
||
func InitializeSessionPool(size int) {
|
||
for range size {
|
||
session := &Session{
|
||
Data: make(map[string]any, 8),
|
||
}
|
||
ReturnToPool(session)
|
||
}
|
||
}
|
||
|
||
// NewSessionManager creates a new session manager
|
||
func NewSessionManager(maxSessions int) *SessionManager {
|
||
if maxSessions <= 0 {
|
||
maxSessions = DefaultMaxSessions
|
||
}
|
||
|
||
// Estimate max memory: ~4KB per session × maxSessions
|
||
maxBytes := maxSessions * 4096
|
||
|
||
sm := &SessionManager{
|
||
cache: fastcache.New(maxBytes),
|
||
maxSessions: maxSessions,
|
||
cookieName: DefaultCookieName,
|
||
cookiePath: DefaultCookiePath,
|
||
cookieHTTPOnly: true,
|
||
cookieMaxAge: DefaultMaxAge,
|
||
cleanupDone: make(chan struct{}),
|
||
}
|
||
|
||
// Pre-allocate session objects for common pool size
|
||
InitializeSessionPool(100) // Adjust based on expected concurrent requests
|
||
|
||
// Start periodic cleanup
|
||
sm.cleanupTicker = time.NewTicker(CleanupInterval)
|
||
go sm.cleanupRoutine()
|
||
|
||
return sm
|
||
}
|
||
|
||
// Stop shuts down the session manager's cleanup routine
|
||
func (sm *SessionManager) Stop() {
|
||
close(sm.cleanupDone)
|
||
}
|
||
|
||
// cleanupRoutine periodically removes expired sessions
|
||
func (sm *SessionManager) cleanupRoutine() {
|
||
for {
|
||
select {
|
||
case <-sm.cleanupTicker.C:
|
||
sm.CleanupExpired()
|
||
case <-sm.cleanupDone:
|
||
sm.cleanupTicker.Stop()
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// generateSessionID creates a random session ID
|
||
func generateSessionID() string {
|
||
id, _ := gonanoid.New()
|
||
return id
|
||
}
|
||
|
||
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist
|
||
func (sm *SessionManager) GetSession(id string) *Session {
|
||
// Try to get an existing session
|
||
if id != "" {
|
||
data := sm.cache.Get(nil, []byte(id))
|
||
if len(data) > 0 {
|
||
session, err := Unmarshal(data)
|
||
if err == nil && !session.IsExpired() {
|
||
session.UpdateLastUsed()
|
||
session.ResetDirty() // Start clean
|
||
return session
|
||
}
|
||
// Session expired or corrupt, remove it
|
||
sm.cache.Del([]byte(id))
|
||
}
|
||
}
|
||
|
||
// Create a new session
|
||
return sm.CreateSession()
|
||
}
|
||
|
||
// CreateSession generates a new session with a unique ID
|
||
func (sm *SessionManager) CreateSession() *Session {
|
||
id := generateSessionID()
|
||
|
||
// Ensure ID uniqueness
|
||
attempts := 0
|
||
for attempts < 3 {
|
||
if sm.cache.Has([]byte(id)) {
|
||
id = generateSessionID()
|
||
attempts++
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
|
||
session := NewSession(id, sm.cookieMaxAge)
|
||
|
||
// Serialize and store the session
|
||
if data, err := session.Marshal(); err == nil {
|
||
sm.cache.Set([]byte(id), data)
|
||
}
|
||
|
||
session.ResetDirty() // Start clean
|
||
return session
|
||
}
|
||
|
||
// DestroySession removes a session
|
||
func (sm *SessionManager) DestroySession(id string) {
|
||
// Get and clean session from cache before deleting
|
||
data := sm.cache.Get(nil, []byte(id))
|
||
if len(data) > 0 {
|
||
if session, err := Unmarshal(data); err == nil {
|
||
ReturnToPool(session)
|
||
}
|
||
}
|
||
|
||
sm.cache.Del([]byte(id))
|
||
}
|
||
|
||
// CleanupExpired removes all expired sessions
|
||
// Note: fastcache doesn't provide iteration, so we can't clean all expired sessions
|
||
// This is a limitation of this implementation
|
||
func (sm *SessionManager) CleanupExpired() int {
|
||
// No way to iterate through all keys in fastcache
|
||
// We'd need to track expiring sessions separately
|
||
return 0
|
||
}
|
||
|
||
// SetCookieOptions configures cookie parameters
|
||
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
|
||
sm.cookieMu.Lock()
|
||
defer sm.cookieMu.Unlock()
|
||
|
||
sm.cookieName = name
|
||
sm.cookiePath = path
|
||
sm.cookieDomain = domain
|
||
sm.cookieSecure = secure
|
||
sm.cookieHTTPOnly = httpOnly
|
||
sm.cookieMaxAge = maxAge
|
||
}
|
||
|
||
// GetSessionFromRequest extracts the session from a request
|
||
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
|
||
sm.cookieMu.RLock()
|
||
cookieName := sm.cookieName
|
||
sm.cookieMu.RUnlock()
|
||
|
||
cookie := ctx.Request.Header.Cookie(cookieName)
|
||
if len(cookie) == 0 {
|
||
return sm.CreateSession()
|
||
}
|
||
|
||
return sm.GetSession(string(cookie))
|
||
}
|
||
|
||
// ApplySessionCookie adds the session cookie to the response
|
||
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
|
||
cookie := fasthttp.AcquireCookie()
|
||
defer fasthttp.ReleaseCookie(cookie)
|
||
|
||
// Get cookie options with minimal lock time
|
||
sm.cookieMu.RLock()
|
||
cookieName := sm.cookieName
|
||
cookiePath := sm.cookiePath
|
||
cookieDomain := sm.cookieDomain
|
||
cookieSecure := sm.cookieSecure
|
||
cookieHTTPOnly := sm.cookieHTTPOnly
|
||
cookieMaxAge := sm.cookieMaxAge
|
||
sm.cookieMu.RUnlock()
|
||
|
||
// Store updated session only if it has changes
|
||
if session.IsDirty() {
|
||
if data, err := session.Marshal(); err == nil {
|
||
sm.cache.Set([]byte(session.ID), data)
|
||
}
|
||
session.ResetDirty()
|
||
}
|
||
|
||
cookie.SetKey(cookieName)
|
||
cookie.SetValue(session.ID)
|
||
cookie.SetPath(cookiePath)
|
||
cookie.SetHTTPOnly(cookieHTTPOnly)
|
||
cookie.SetMaxAge(cookieMaxAge)
|
||
|
||
if cookieDomain != "" {
|
||
cookie.SetDomain(cookieDomain)
|
||
}
|
||
|
||
cookie.SetSecure(cookieSecure)
|
||
|
||
ctx.Response.Header.SetCookie(cookie)
|
||
}
|
||
|
||
// CookieOptions returns the cookie options for this session manager
|
||
func (sm *SessionManager) CookieOptions() map[string]any {
|
||
sm.cookieMu.RLock()
|
||
defer sm.cookieMu.RUnlock()
|
||
|
||
return map[string]any{
|
||
"name": sm.cookieName,
|
||
"path": sm.cookiePath,
|
||
"domain": sm.cookieDomain,
|
||
"secure": sm.cookieSecure,
|
||
"http_only": sm.cookieHTTPOnly,
|
||
"max_age": sm.cookieMaxAge,
|
||
}
|
||
}
|
||
|
||
// GlobalSessionManager is the default session manager instance
|
||
var GlobalSessionManager = NewSessionManager(DefaultMaxSessions)
|