From e4cd490f0f0af4868d7ae37fe9b0cac7da1e8a5a Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Mon, 26 May 2025 12:56:20 -0500 Subject: [PATCH] add table support to sessions, fix root lua path, optimize sesison manager --- http/server.go | 6 +- routers/luaRouter.go | 9 ++ sessions/manager.go | 143 ++++++------------ sessions/session.go | 344 +++++++++++++++++++++---------------------- 4 files changed, 228 insertions(+), 274 deletions(-) diff --git a/http/server.go b/http/server.go index a4a0c68..443ddd7 100644 --- a/http/server.go +++ b/http/server.go @@ -164,7 +164,8 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip session := s.sessionManager.GetSessionFromRequest(ctx) sessionMap["id"] = session.ID - sessionMap["data"] = session.Data + + sessionMap["data"] = session.GetAll() // This now returns a deep copy luaCtx.Set("method", method) luaCtx.Set("path", path) @@ -209,11 +210,12 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip delete(response.SessionData, "__clear_all") } + // Apply session changes - now supports nested tables for k, v := range response.SessionData { if v == "__SESSION_DELETE_MARKER__" { session.Delete(k) } else { - session.Set(k, v) + session.Set(k, v) // This will handle tables through marshalling } } diff --git a/routers/luaRouter.go b/routers/luaRouter.go index f5c2e49..96cb7be 100644 --- a/routers/luaRouter.go +++ b/routers/luaRouter.go @@ -485,6 +485,15 @@ func (r *LuaRouter) Match(method, path string, params *Params) (*node, bool) { // matchPath recursively matches a path against the routing tree func (r *LuaRouter) matchPath(current *node, segments []string, params *Params, depth int) (*node, bool) { + // Filter empty segments + filteredSegments := segments[:0] + for _, segment := range segments { + if segment != "" { + filteredSegments = append(filteredSegments, segment) + } + } + segments = filteredSegments + if len(segments) == 0 { if current.handler != "" { return current, true diff --git a/sessions/manager.go b/sessions/manager.go index a84018e..0e4cdf5 100644 --- a/sessions/manager.go +++ b/sessions/manager.go @@ -20,40 +20,25 @@ const ( // SessionManager handles multiple sessions type SessionManager struct { cache *fastcache.Cache - maxSessions int cookieName string cookiePath string cookieDomain string cookieSecure bool cookieHTTPOnly bool cookieMaxAge int - cookieMu sync.RWMutex // Only cookie options need a mutex + cookieMu sync.RWMutex cleanupTicker *time.Ticker cleanupDone chan struct{} } -// InitializeSessionPool pre-allocates session objects -func InitializeSessionPool(size int) { - for range size { - session := &Session{ - Data: make(map[string]any, 8), - } - ReturnToPool(session) - } -} - // NewSessionManager creates a new session manager func NewSessionManager(maxSessions int) *SessionManager { if maxSessions <= 0 { maxSessions = DefaultMaxSessions } - // Estimate max memory: ~4KB per session × maxSessions - maxBytes := maxSessions * 4096 - sm := &SessionManager{ - cache: fastcache.New(maxBytes), - maxSessions: maxSessions, + cache: fastcache.New(maxSessions * 4096), cookieName: DefaultCookieName, cookiePath: DefaultCookiePath, cookieHTTPOnly: true, @@ -61,10 +46,12 @@ func NewSessionManager(maxSessions int) *SessionManager { cleanupDone: make(chan struct{}), } - // Pre-allocate session objects for common pool size - InitializeSessionPool(100) // Adjust based on expected concurrent requests + // Pre-populate session pool + for i := 0; i < 100; i++ { + s := NewSession("", 0) + s.Release() + } - // Start periodic cleanup sm.cleanupTicker = time.NewTicker(CleanupInterval) go sm.cleanupRoutine() @@ -76,7 +63,6 @@ func (sm *SessionManager) Stop() { close(sm.cleanupDone) } -// cleanupRoutine periodically removes expired sessions func (sm *SessionManager) cleanupRoutine() { for { select { @@ -89,124 +75,80 @@ func (sm *SessionManager) cleanupRoutine() { } } -// generateSessionID creates a random session ID -func generateSessionID() string { - id, _ := gonanoid.New() - return id -} - // GetSession retrieves a session by ID, or creates a new one if it doesn't exist func (sm *SessionManager) GetSession(id string) *Session { - // Try to get an existing session if id != "" { - data := sm.cache.Get(nil, []byte(id)) - if len(data) > 0 { - session, err := Unmarshal(data) - if err == nil && !session.IsExpired() { - session.UpdateLastUsed() - session.ResetDirty() // Start clean - return session + if data := sm.cache.Get(nil, []byte(id)); len(data) > 0 { + if s, err := Unmarshal(data); err == nil && !s.IsExpired() { + s.UpdateLastUsed() + s.ResetDirty() + return s } - // Session expired or corrupt, remove it sm.cache.Del([]byte(id)) } } - - // Create a new session return sm.CreateSession() } // CreateSession generates a new session with a unique ID func (sm *SessionManager) CreateSession() *Session { - id := generateSessionID() + id, _ := gonanoid.New() - // Ensure ID uniqueness - attempts := 0 - for attempts < 3 { - if sm.cache.Has([]byte(id)) { - id = generateSessionID() - attempts++ - } else { - break - } + // Ensure uniqueness (max 3 attempts) + for i := 0; i < 3 && sm.cache.Has([]byte(id)); i++ { + id, _ = gonanoid.New() } - session := NewSession(id, sm.cookieMaxAge) - - // Serialize and store the session - if data, err := session.Marshal(); err == nil { + s := NewSession(id, sm.cookieMaxAge) + if data, err := s.Marshal(); err == nil { sm.cache.Set([]byte(id), data) } - - session.ResetDirty() // Start clean - return session + s.ResetDirty() + return s } // DestroySession removes a session func (sm *SessionManager) DestroySession(id string) { - // Get and clean session from cache before deleting - data := sm.cache.Get(nil, []byte(id)) - if len(data) > 0 { - if session, err := Unmarshal(data); err == nil { - ReturnToPool(session) + if data := sm.cache.Get(nil, []byte(id)); len(data) > 0 { + if s, err := Unmarshal(data); err == nil { + s.Release() } } - sm.cache.Del([]byte(id)) } // CleanupExpired removes all expired sessions -// Note: fastcache doesn't provide iteration, so we can't clean all expired sessions -// This is a limitation of this implementation func (sm *SessionManager) CleanupExpired() int { - // No way to iterate through all keys in fastcache - // We'd need to track expiring sessions separately + // fastcache doesn't support iteration return 0 } // SetCookieOptions configures cookie parameters func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) { sm.cookieMu.Lock() - defer sm.cookieMu.Unlock() - sm.cookieName = name sm.cookiePath = path sm.cookieDomain = domain sm.cookieSecure = secure sm.cookieHTTPOnly = httpOnly sm.cookieMaxAge = maxAge + sm.cookieMu.Unlock() } // GetSessionFromRequest extracts the session from a request func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session { sm.cookieMu.RLock() - cookieName := sm.cookieName + name := sm.cookieName sm.cookieMu.RUnlock() - cookie := ctx.Request.Header.Cookie(cookieName) - if len(cookie) == 0 { - return sm.CreateSession() + if cookie := ctx.Request.Header.Cookie(name); len(cookie) > 0 { + return sm.GetSession(string(cookie)) } - - return sm.GetSession(string(cookie)) + return sm.CreateSession() } // ApplySessionCookie adds the session cookie to the response func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) { - cookie := fasthttp.AcquireCookie() - defer fasthttp.ReleaseCookie(cookie) - - // Get cookie options with minimal lock time - sm.cookieMu.RLock() - cookieName := sm.cookieName - cookiePath := sm.cookiePath - cookieDomain := sm.cookieDomain - cookieSecure := sm.cookieSecure - cookieHTTPOnly := sm.cookieHTTPOnly - cookieMaxAge := sm.cookieMaxAge - sm.cookieMu.RUnlock() - - // Store updated session only if it has changes if session.IsDirty() { if data, err := session.Marshal(); err == nil { sm.cache.Set([]byte(session.ID), data) @@ -214,18 +156,21 @@ func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session * session.ResetDirty() } - cookie.SetKey(cookieName) - cookie.SetValue(session.ID) - cookie.SetPath(cookiePath) - cookie.SetHTTPOnly(cookieHTTPOnly) - cookie.SetMaxAge(cookieMaxAge) + cookie := fasthttp.AcquireCookie() + defer fasthttp.ReleaseCookie(cookie) - if cookieDomain != "" { - cookie.SetDomain(cookieDomain) + sm.cookieMu.RLock() + cookie.SetKey(sm.cookieName) + cookie.SetPath(sm.cookiePath) + cookie.SetHTTPOnly(sm.cookieHTTPOnly) + cookie.SetMaxAge(sm.cookieMaxAge) + if sm.cookieDomain != "" { + cookie.SetDomain(sm.cookieDomain) } + cookie.SetSecure(sm.cookieSecure) + sm.cookieMu.RUnlock() - cookie.SetSecure(cookieSecure) - + cookie.SetValue(session.ID) ctx.Response.Header.SetCookie(cookie) } @@ -244,9 +189,6 @@ func (sm *SessionManager) CookieOptions() map[string]any { } } -// GlobalSessionManager is the default session manager instance -var GlobalSessionManager = NewSessionManager(DefaultMaxSessions) - // GetCacheStats returns statistics about the session cache func (sm *SessionManager) GetCacheStats() map[string]uint64 { if sm == nil || sm.cache == nil { @@ -265,3 +207,6 @@ func (sm *SessionManager) GetCacheStats() map[string]uint64 { "misses": stats.Misses, } } + +// GlobalSessionManager is the default session manager instance +var GlobalSessionManager = NewSessionManager(DefaultMaxSessions) diff --git a/sessions/session.go b/sessions/session.go index 83b92e8..c01fdf6 100644 --- a/sessions/session.go +++ b/sessions/session.go @@ -1,6 +1,7 @@ package sessions import ( + "fmt" "sync" "time" @@ -16,67 +17,66 @@ type Session struct { UpdatedAt time.Time LastUsed time.Time Expiry time.Time - dirty bool // Tracks if session has changes, not serialized + dirty bool } -// Session pool to reduce allocations -var sessionPool = sync.Pool{ - New: func() any { - return &Session{ - Data: make(map[string]any, 8), - } - }, -} - -// BufPool for reusing serialization buffers -var bufPool = benc.NewBufPool(benc.WithBufferSize(4096)) - -// GetFromPool retrieves a session from the pool -func GetFromPool() *Session { - return sessionPool.Get().(*Session) -} - -// ReturnToPool returns a session to the pool after cleaning it -func ReturnToPool(s *Session) { - if s == nil { - return +var ( + sessionPool = sync.Pool{ + New: func() any { + return &Session{Data: make(map[string]any, 8)} + }, } - - // Clean the session for reuse - s.ID = "" - for k := range s.Data { - delete(s.Data, k) - } - s.CreatedAt = time.Time{} - s.UpdatedAt = time.Time{} - s.LastUsed = time.Time{} - s.Expiry = time.Time{} - s.dirty = false - - sessionPool.Put(s) -} + bufPool = benc.NewBufPool(benc.WithBufferSize(4096)) +) // NewSession creates a new session with the given ID func NewSession(id string, maxAge int) *Session { + s := sessionPool.Get().(*Session) now := time.Now() - - // Get from pool or create new - session := GetFromPool() - - // Initialize - session.ID = id - session.CreatedAt = now - session.UpdatedAt = now - session.LastUsed = now - session.Expiry = now.Add(time.Duration(maxAge) * time.Second) - session.dirty = false - - return session + *s = Session{ + ID: id, + Data: s.Data, // Reuse map + CreatedAt: now, + UpdatedAt: now, + LastUsed: now, + Expiry: now.Add(time.Duration(maxAge) * time.Second), + } + return s } -// Get retrieves a value from the session +// Release returns the session to the pool +func (s *Session) Release() { + for k := range s.Data { + delete(s.Data, k) + } + sessionPool.Put(s) +} + +// Get returns a deep copy of a value func (s *Session) Get(key string) any { - return s.Data[key] + if v, ok := s.Data[key]; ok { + return deepCopy(v) + } + return nil +} + +// GetTable returns a value as a table +func (s *Session) GetTable(key string) map[string]any { + if v := s.Get(key); v != nil { + if t, ok := v.(map[string]any); ok { + return t + } + } + return nil +} + +// GetAll returns a deep copy of all session data +func (s *Session) GetAll() map[string]any { + copy := make(map[string]any, len(s.Data)) + for k, v := range s.Data { + copy[k] = deepCopy(v) + } + return copy } // Set stores a value in the session @@ -86,6 +86,20 @@ func (s *Session) Set(key string, value any) { s.dirty = true } +// SetSafe stores a value with validation +func (s *Session) SetSafe(key string, value any) error { + if err := validate(value); err != nil { + return fmt.Errorf("session.SetSafe: %w", err) + } + s.Set(key, value) + return nil +} + +// SetTable is a convenience method for setting table data +func (s *Session) SetTable(key string, table map[string]any) error { + return s.SetSafe(key, table) +} + // Delete removes a value from the session func (s *Session) Delete(key string) { delete(s.Data, key) @@ -100,27 +114,16 @@ func (s *Session) Clear() { s.dirty = true } -// GetAll returns a copy of all session data -func (s *Session) GetAll() map[string]any { - copy := make(map[string]any, len(s.Data)) - for k, v := range s.Data { - copy[k] = v - } - return copy -} - // IsExpired checks if the session has expired func (s *Session) IsExpired() bool { return time.Now().After(s.Expiry) } // UpdateLastUsed updates the last used time -// Only updates if at least 5 seconds have passed since last update func (s *Session) UpdateLastUsed() { now := time.Now() if now.Sub(s.LastUsed) > 5*time.Second { s.LastUsed = now - // Not marking dirty for LastUsed updates to reduce writes } } @@ -135,115 +138,64 @@ func (s *Session) ResetDirty() { } // SizePlain calculates the size needed to marshal the session -func (s *Session) SizePlain() (size int) { - // ID - size += bstd.SizeString(s.ID) - - // Data map - size += bstd.SizeMap(s.Data, bstd.SizeString, func(v any) int { - return sizeAny(v) - }) - - // Time fields stored as int64 Unix timestamps - size += bstd.SizeInt64() * 4 - - return size +func (s *Session) SizePlain() int { + return bstd.SizeString(s.ID) + + bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) + + bstd.SizeInt64()*4 } // MarshalPlain serializes the session to binary func (s *Session) MarshalPlain(n int, b []byte) int { - // ID n = bstd.MarshalString(n, b, s.ID) - - // Data map - n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, func(n int, b []byte, v any) int { - return marshalAny(n, b, v) - }) - - // Time fields as Unix timestamps + n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny) n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.LastUsed.Unix()) - n = bstd.MarshalInt64(n, b, s.Expiry.Unix()) - - return n + return bstd.MarshalInt64(n, b, s.Expiry.Unix()) } // UnmarshalPlain deserializes the session from binary func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) { var err error - - // ID n, s.ID, err = bstd.UnmarshalString(n, b) if err != nil { return n, err } - // Data map - n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, func(n int, b []byte) (int, any, error) { - return unmarshalAny(n, b) - }) + n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny) if err != nil { return n, err } - // Time fields as Unix timestamps - var timestamp int64 - - n, timestamp, err = bstd.UnmarshalInt64(n, b) - if err != nil { - return n, err + var ts int64 + for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} { + n, ts, err = bstd.UnmarshalInt64(n, b) + if err != nil { + return n, err + } + *t = time.Unix(ts, 0) } - s.CreatedAt = time.Unix(timestamp, 0) - - n, timestamp, err = bstd.UnmarshalInt64(n, b) - if err != nil { - return n, err - } - s.UpdatedAt = time.Unix(timestamp, 0) - - n, timestamp, err = bstd.UnmarshalInt64(n, b) - if err != nil { - return n, err - } - s.LastUsed = time.Unix(timestamp, 0) - - n, timestamp, err = bstd.UnmarshalInt64(n, b) - if err != nil { - return n, err - } - s.Expiry = time.Unix(timestamp, 0) - return n, nil } // Marshal serializes the session using benc func (s *Session) Marshal() ([]byte, error) { - size := s.SizePlain() - - data, err := bufPool.Marshal(size, func(b []byte) (n int) { + return bufPool.Marshal(s.SizePlain(), func(b []byte) int { return s.MarshalPlain(0, b) }) - - if err != nil { - return nil, err - } - - return data, nil } // Unmarshal deserializes a session using benc func Unmarshal(data []byte) (*Session, error) { - session := GetFromPool() - _, err := session.UnmarshalPlain(0, data) - if err != nil { - ReturnToPool(session) + s := sessionPool.Get().(*Session) + if _, err := s.UnmarshalPlain(0, data); err != nil { + s.Release() return nil, err } - return session, nil + return s, nil } -// Type identifiers for any values +// Type identifiers const ( typeNull byte = 0 typeString byte = 1 @@ -251,32 +203,38 @@ const ( typeFloat byte = 3 typeBool byte = 4 typeBytes byte = 5 + typeTable byte = 6 + typeArray byte = 7 ) // sizeAny calculates the size needed for any value func sizeAny(v any) int { if v == nil { - return 1 // Just the type byte + return 1 } - // 1 byte for type + size of the value - switch val := v.(type) { + size := 1 // type byte + switch v := v.(type) { case string: - return 1 + bstd.SizeString(val) + size += bstd.SizeString(v) case int: - return 1 + bstd.SizeInt64() + size += bstd.SizeInt64() case int64: - return 1 + bstd.SizeInt64() + size += bstd.SizeInt64() case float64: - return 1 + bstd.SizeFloat64() + size += bstd.SizeFloat64() case bool: - return 1 + bstd.SizeBool() + size += bstd.SizeBool() case []byte: - return 1 + bstd.SizeBytes(val) + size += bstd.SizeBytes(v) + case map[string]any: + size += bstd.SizeMap(v, bstd.SizeString, sizeAny) + case []any: + size += bstd.SizeSlice(v, sizeAny) default: - // Convert unhandled types to string - return 1 + bstd.SizeString("unknown") + size += bstd.SizeString("unknown") } + return size } // marshalAny serializes any value @@ -286,27 +244,32 @@ func marshalAny(n int, b []byte, v any) int { return n + 1 } - switch val := v.(type) { + switch v := v.(type) { case string: b[n] = typeString - return bstd.MarshalString(n+1, b, val) + return bstd.MarshalString(n+1, b, v) case int: b[n] = typeInt - return bstd.MarshalInt64(n+1, b, int64(val)) + return bstd.MarshalInt64(n+1, b, int64(v)) case int64: b[n] = typeInt - return bstd.MarshalInt64(n+1, b, val) + return bstd.MarshalInt64(n+1, b, v) case float64: b[n] = typeFloat - return bstd.MarshalFloat64(n+1, b, val) + return bstd.MarshalFloat64(n+1, b, v) case bool: b[n] = typeBool - return bstd.MarshalBool(n+1, b, val) + return bstd.MarshalBool(n+1, b, v) case []byte: b[n] = typeBytes - return bstd.MarshalBytes(n+1, b, val) + return bstd.MarshalBytes(n+1, b, v) + case map[string]any: + b[n] = typeTable + return bstd.MarshalMap(n+1, b, v, bstd.MarshalString, marshalAny) + case []any: + b[n] = typeArray + return bstd.MarshalSlice(n+1, b, v, marshalAny) default: - // Convert unhandled types to string b[n] = typeString return bstd.MarshalString(n+1, b, "unknown") } @@ -318,33 +281,68 @@ func unmarshalAny(n int, b []byte) (int, any, error) { return n, nil, benc.ErrBufTooSmall } - typeId := b[n] - n++ - - switch typeId { + switch b[n] { case typeNull: - return n, nil, nil + return n + 1, nil, nil case typeString: - return bstd.UnmarshalString(n, b) + return bstd.UnmarshalString(n+1, b) case typeInt: - var val int64 - var err error - n, val, err = bstd.UnmarshalInt64(n, b) - return n, val, err + n, v, err := bstd.UnmarshalInt64(n+1, b) + return n, v, err case typeFloat: - var val float64 - var err error - n, val, err = bstd.UnmarshalFloat64(n, b) - return n, val, err + return bstd.UnmarshalFloat64(n+1, b) case typeBool: - var val bool - var err error - n, val, err = bstd.UnmarshalBool(n, b) - return n, val, err + return bstd.UnmarshalBool(n+1, b) case typeBytes: - return bstd.UnmarshalBytesCopied(n, b) + return bstd.UnmarshalBytesCopied(n+1, b) + case typeTable: + return bstd.UnmarshalMap[string, any](n+1, b, bstd.UnmarshalString, unmarshalAny) + case typeArray: + return bstd.UnmarshalSlice[any](n+1, b, unmarshalAny) default: - // Unknown type, return nil - return n, nil, nil + return n + 1, nil, nil } } + +// deepCopy creates a deep copy of any value +func deepCopy(v any) any { + switch v := v.(type) { + case map[string]any: + cp := make(map[string]any, len(v)) + for k, val := range v { + cp[k] = deepCopy(val) + } + return cp + case []any: + cp := make([]any, len(v)) + for i, val := range v { + cp[i] = deepCopy(val) + } + return cp + default: + return v + } +} + +// validate ensures a value can be safely serialized +func validate(v any) error { + switch v := v.(type) { + case nil, string, int, int64, float64, bool, []byte: + return nil + case map[string]any: + for k, val := range v { + if err := validate(val); err != nil { + return fmt.Errorf("invalid value for key %q: %w", k, err) + } + } + case []any: + for i, val := range v { + if err := validate(val); err != nil { + return fmt.Errorf("invalid value at index %d: %w", i, err) + } + } + default: + return fmt.Errorf("unsupported type: %T", v) + } + return nil +}