session rewrite
This commit is contained in:
parent
5ebcd97662
commit
85b0551e70
|
@ -225,13 +225,12 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
|||
return
|
||||
}
|
||||
|
||||
// Save session if modified
|
||||
// Update session if modified
|
||||
if response.SessionModified {
|
||||
// Update session data
|
||||
for k, v := range response.SessionData {
|
||||
session.Set(k, v)
|
||||
}
|
||||
s.sessionManager.SaveSession(session)
|
||||
|
||||
s.sessionManager.ApplySessionCookie(ctx, session)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,46 +1,43 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"Moonshark/core/utils/logger"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/VictoriaMetrics/fastcache"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
// Default settings
|
||||
DefaultMaxSize = 100 * 1024 * 1024 // 100MB default cache size
|
||||
DefaultMaxSessions = 10000
|
||||
DefaultCookieName = "MoonsharkSID"
|
||||
DefaultCookiePath = "/"
|
||||
DefaultMaxAge = 86400 // 1 day in seconds
|
||||
)
|
||||
|
||||
// SessionManager handles multiple sessions using fastcache for storage
|
||||
// SessionManager handles multiple sessions
|
||||
type SessionManager struct {
|
||||
cache *fastcache.Cache
|
||||
sessions map[string]*Session
|
||||
maxSessions int
|
||||
cookieName string
|
||||
cookiePath string
|
||||
cookieDomain string
|
||||
cookieSecure bool
|
||||
cookieHTTPOnly bool
|
||||
cookieMaxAge int
|
||||
mu sync.RWMutex // Only for cookie settings
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager with optional cache size
|
||||
func NewSessionManager(maxSize ...int) *SessionManager {
|
||||
size := DefaultMaxSize
|
||||
if len(maxSize) > 0 && maxSize[0] > 0 {
|
||||
size = maxSize[0]
|
||||
// NewSessionManager creates a new session manager
|
||||
func NewSessionManager(maxSessions int) *SessionManager {
|
||||
if maxSessions <= 0 {
|
||||
maxSessions = DefaultMaxSessions
|
||||
}
|
||||
|
||||
return &SessionManager{
|
||||
cache: fastcache.New(size),
|
||||
sessions: make(map[string]*Session, maxSessions),
|
||||
maxSessions: maxSessions,
|
||||
cookieName: DefaultCookieName,
|
||||
cookiePath: DefaultCookiePath,
|
||||
cookieHTTPOnly: true,
|
||||
|
@ -48,7 +45,7 @@ func NewSessionManager(maxSize ...int) *SessionManager {
|
|||
}
|
||||
}
|
||||
|
||||
// generateSessionID creates a cryptographically secure random session ID
|
||||
// generateSessionID creates a random session ID
|
||||
func generateSessionID() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
|
@ -59,59 +56,136 @@ func generateSessionID() string {
|
|||
|
||||
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist
|
||||
func (sm *SessionManager) GetSession(id string) *Session {
|
||||
// Check if session exists
|
||||
data := sm.cache.Get(nil, []byte(id))
|
||||
|
||||
if len(data) > 0 {
|
||||
logger.Debug("Getting session %s", id)
|
||||
|
||||
// Session exists, unmarshal it
|
||||
session := &Session{}
|
||||
if err := json.Unmarshal(data, session); err == nil {
|
||||
// Initialize mutex properly
|
||||
session.mu = sync.RWMutex{}
|
||||
|
||||
// Update last accessed time
|
||||
session.UpdatedAt = time.Now()
|
||||
|
||||
// Store back with updated timestamp
|
||||
updatedData, _ := json.Marshal(session)
|
||||
sm.cache.Set([]byte(id), updatedData)
|
||||
// Try to get an existing session
|
||||
if id != "" {
|
||||
sm.mu.RLock()
|
||||
session, exists := sm.sessions[id]
|
||||
sm.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
// Check if session is expired
|
||||
if session.IsExpired() {
|
||||
sm.mu.Lock()
|
||||
delete(sm.sessions, id)
|
||||
sm.mu.Unlock()
|
||||
} else {
|
||||
// Update last used time
|
||||
session.UpdateLastUsed()
|
||||
return session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Session doesn't exist; creating it")
|
||||
|
||||
// Create new session
|
||||
session := NewSession(id)
|
||||
data, _ = json.Marshal(session)
|
||||
sm.cache.Set([]byte(id), data)
|
||||
|
||||
return session
|
||||
// Create a new session
|
||||
return sm.CreateSession()
|
||||
}
|
||||
|
||||
// CreateSession generates a new session with a unique ID
|
||||
func (sm *SessionManager) CreateSession() *Session {
|
||||
id := generateSessionID()
|
||||
session := NewSession(id, sm.cookieMaxAge)
|
||||
|
||||
session := NewSession(id)
|
||||
data, _ := json.Marshal(session)
|
||||
sm.cache.Set([]byte(id), data)
|
||||
sm.mu.Lock()
|
||||
// Enforce session limit - evict LRU if needed
|
||||
if len(sm.sessions) >= sm.maxSessions {
|
||||
sm.evictLRU()
|
||||
}
|
||||
|
||||
sm.sessions[id] = session
|
||||
sm.mu.Unlock()
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// SaveSession persists a session back to the cache
|
||||
func (sm *SessionManager) SaveSession(session *Session) {
|
||||
data, _ := json.Marshal(session)
|
||||
sm.cache.Set([]byte(session.ID), data)
|
||||
// evictLRU removes the least recently used session
|
||||
func (sm *SessionManager) evictLRU() {
|
||||
// Called with mutex already held
|
||||
if len(sm.sessions) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var oldestID string
|
||||
var oldestTime time.Time
|
||||
|
||||
// Find oldest session
|
||||
for id, session := range sm.sessions {
|
||||
if oldestID == "" || session.LastUsed.Before(oldestTime) {
|
||||
oldestID = id
|
||||
oldestTime = session.LastUsed
|
||||
}
|
||||
}
|
||||
|
||||
if oldestID != "" {
|
||||
delete(sm.sessions, oldestID)
|
||||
}
|
||||
}
|
||||
|
||||
// DestroySession removes a session
|
||||
func (sm *SessionManager) DestroySession(id string) {
|
||||
sm.cache.Del([]byte(id))
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
delete(sm.sessions, id)
|
||||
}
|
||||
|
||||
// CleanupExpired removes all expired sessions
|
||||
func (sm *SessionManager) CleanupExpired() int {
|
||||
removed := 0
|
||||
now := time.Now()
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
for id, session := range sm.sessions {
|
||||
if now.After(session.Expiry) {
|
||||
delete(sm.sessions, id)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
// SetCookieOptions configures cookie parameters
|
||||
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.cookieName = name
|
||||
sm.cookiePath = path
|
||||
sm.cookieDomain = domain
|
||||
sm.cookieSecure = secure
|
||||
sm.cookieHTTPOnly = httpOnly
|
||||
sm.cookieMaxAge = maxAge
|
||||
}
|
||||
|
||||
// GetSessionFromRequest extracts the session from a request
|
||||
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
|
||||
cookie := ctx.Request.Header.Cookie(sm.cookieName)
|
||||
if len(cookie) == 0 {
|
||||
return sm.CreateSession()
|
||||
}
|
||||
|
||||
return sm.GetSession(string(cookie))
|
||||
}
|
||||
|
||||
// ApplySessionCookie adds the session cookie to the response
|
||||
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
|
||||
cookie.SetKey(sm.cookieName)
|
||||
cookie.SetValue(session.ID)
|
||||
cookie.SetPath(sm.cookiePath)
|
||||
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
|
||||
cookie.SetMaxAge(sm.cookieMaxAge)
|
||||
|
||||
if sm.cookieDomain != "" {
|
||||
cookie.SetDomain(sm.cookieDomain)
|
||||
}
|
||||
|
||||
cookie.SetSecure(sm.cookieSecure)
|
||||
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// CookieOptions returns the cookie options for this session manager
|
||||
|
@ -129,52 +203,5 @@ func (sm *SessionManager) CookieOptions() map[string]any {
|
|||
}
|
||||
}
|
||||
|
||||
// SetCookieOptions configures cookie parameters
|
||||
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.cookieName = name
|
||||
sm.cookiePath = path
|
||||
sm.cookieDomain = domain
|
||||
sm.cookieSecure = secure
|
||||
sm.cookieHTTPOnly = httpOnly
|
||||
sm.cookieMaxAge = maxAge
|
||||
}
|
||||
|
||||
// GetSessionFromRequest extracts the session from a request context
|
||||
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
|
||||
cookie := ctx.Request.Header.Cookie(sm.cookieName)
|
||||
if len(cookie) == 0 {
|
||||
// No session cookie, create a new session
|
||||
return sm.CreateSession()
|
||||
}
|
||||
|
||||
// Session cookie exists, get the session
|
||||
return sm.GetSession(string(cookie))
|
||||
}
|
||||
|
||||
// SaveSessionToResponse adds the session cookie to an HTTP response
|
||||
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
|
||||
sm.mu.RLock()
|
||||
cookie.SetKey(sm.cookieName)
|
||||
cookie.SetValue(session.ID)
|
||||
cookie.SetPath(sm.cookiePath)
|
||||
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
|
||||
cookie.SetMaxAge(sm.cookieMaxAge)
|
||||
|
||||
if sm.cookieDomain != "" {
|
||||
cookie.SetDomain(sm.cookieDomain)
|
||||
}
|
||||
|
||||
cookie.SetSecure(sm.cookieSecure)
|
||||
sm.mu.RUnlock()
|
||||
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// GlobalSessionManager is the default session manager instance
|
||||
var GlobalSessionManager = NewSessionManager()
|
||||
var GlobalSessionManager = NewSessionManager(DefaultMaxSessions)
|
||||
|
|
|
@ -1,41 +1,31 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxValueSize = 256 * 1024 // 256KB per value
|
||||
)
|
||||
|
||||
var (
|
||||
ErrValueTooLarge = errors.New("session value exceeds size limit")
|
||||
)
|
||||
|
||||
// Session stores data for a single user session
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
mu sync.RWMutex `json:"-"`
|
||||
maxValueSize int `json:"max_value_size"`
|
||||
totalDataSize int `json:"total_data_size"`
|
||||
ID string
|
||||
Data map[string]any
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
LastUsed time.Time
|
||||
Expiry time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSession creates a new session with the given ID
|
||||
func NewSession(id string) *Session {
|
||||
func NewSession(id string, maxAge int) *Session {
|
||||
now := time.Now()
|
||||
return &Session{
|
||||
ID: id,
|
||||
Data: make(map[string]any),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
maxValueSize: DefaultMaxValueSize,
|
||||
LastUsed: now,
|
||||
Expiry: now.Add(time.Duration(maxAge) * time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -47,65 +37,17 @@ func (s *Session) Get(key string) any {
|
|||
}
|
||||
|
||||
// Set stores a value in the session
|
||||
func (s *Session) Set(key string, value any) error {
|
||||
// Estimate value size
|
||||
size, err := estimateSize(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check against limit
|
||||
if size > s.maxValueSize {
|
||||
return ErrValueTooLarge
|
||||
}
|
||||
|
||||
func (s *Session) Set(key string, value any) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// If replacing, subtract old value size
|
||||
if oldVal, exists := s.Data[key]; exists {
|
||||
oldSize, _ := estimateSize(oldVal)
|
||||
s.totalDataSize -= oldSize
|
||||
}
|
||||
|
||||
s.Data[key] = value
|
||||
s.totalDataSize += size
|
||||
s.UpdatedAt = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMaxValueSize changes the maximum allowed value size
|
||||
func (s *Session) SetMaxValueSize(bytes int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.maxValueSize = bytes
|
||||
}
|
||||
|
||||
// GetMaxValueSize returns the current max value size
|
||||
func (s *Session) GetMaxValueSize() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.maxValueSize
|
||||
}
|
||||
|
||||
// GetTotalSize returns the estimated total size of all session data
|
||||
func (s *Session) GetTotalSize() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.totalDataSize
|
||||
}
|
||||
|
||||
// Delete removes a value from the session
|
||||
func (s *Session) Delete(key string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Update size tracking
|
||||
if oldVal, exists := s.Data[key]; exists {
|
||||
oldSize, _ := estimateSize(oldVal)
|
||||
s.totalDataSize -= oldSize
|
||||
}
|
||||
|
||||
delete(s.Data, key)
|
||||
s.UpdatedAt = time.Now()
|
||||
}
|
||||
|
@ -115,7 +57,6 @@ func (s *Session) Clear() {
|
|||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Data = make(map[string]any)
|
||||
s.totalDataSize = 0
|
||||
s.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
|
@ -124,7 +65,6 @@ func (s *Session) GetAll() map[string]any {
|
|||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Create a copy to avoid concurrent map access issues
|
||||
copy := make(map[string]any, len(s.Data))
|
||||
for k, v := range s.Data {
|
||||
copy[k] = v
|
||||
|
@ -133,20 +73,14 @@ func (s *Session) GetAll() map[string]any {
|
|||
return copy
|
||||
}
|
||||
|
||||
// estimateSize approximates the memory footprint of a value
|
||||
func estimateSize(v any) (int, error) {
|
||||
// Fast path for common types
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return len(val), nil
|
||||
case []byte:
|
||||
return len(val), nil
|
||||
}
|
||||
|
||||
// For other types, use JSON serialization as approximation
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
// IsExpired checks if the session has expired
|
||||
func (s *Session) IsExpired() bool {
|
||||
return time.Now().After(s.Expiry)
|
||||
}
|
||||
|
||||
// UpdateLastUsed updates the last used time
|
||||
func (s *Session) UpdateLastUsed() {
|
||||
s.mu.Lock()
|
||||
s.LastUsed = time.Now()
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user