package sessions import ( "sync" "time" "github.com/deneonet/benc" bstd "github.com/deneonet/benc/std" ) // Session stores data for a single user session type Session struct { ID string Data map[string]any CreatedAt time.Time UpdatedAt time.Time LastUsed time.Time Expiry time.Time dirty bool // Tracks if session has changes, not serialized } // Session pool to reduce allocations var sessionPool = sync.Pool{ New: func() any { return &Session{ Data: make(map[string]any, 8), } }, } // BufPool for reusing serialization buffers var bufPool = benc.NewBufPool(benc.WithBufferSize(4096)) // GetFromPool retrieves a session from the pool func GetFromPool() *Session { return sessionPool.Get().(*Session) } // ReturnToPool returns a session to the pool after cleaning it func ReturnToPool(s *Session) { if s == nil { return } // Clean the session for reuse s.ID = "" for k := range s.Data { delete(s.Data, k) } s.CreatedAt = time.Time{} s.UpdatedAt = time.Time{} s.LastUsed = time.Time{} s.Expiry = time.Time{} s.dirty = false sessionPool.Put(s) } // NewSession creates a new session with the given ID func NewSession(id string, maxAge int) *Session { now := time.Now() // Get from pool or create new session := GetFromPool() // Initialize session.ID = id session.CreatedAt = now session.UpdatedAt = now session.LastUsed = now session.Expiry = now.Add(time.Duration(maxAge) * time.Second) session.dirty = false return session } // Get retrieves a value from the session func (s *Session) Get(key string) any { return s.Data[key] } // Set stores a value in the session func (s *Session) Set(key string, value any) { s.Data[key] = value s.UpdatedAt = time.Now() s.dirty = true } // Delete removes a value from the session func (s *Session) Delete(key string) { delete(s.Data, key) s.UpdatedAt = time.Now() s.dirty = true } // Clear removes all data from the session func (s *Session) Clear() { s.Data = make(map[string]any, 8) s.UpdatedAt = time.Now() s.dirty = true } // GetAll returns a copy of all session data func (s *Session) GetAll() map[string]any { copy := make(map[string]any, len(s.Data)) for k, v := range s.Data { copy[k] = v } return copy } // IsExpired checks if the session has expired func (s *Session) IsExpired() bool { return time.Now().After(s.Expiry) } // UpdateLastUsed updates the last used time // Only updates if at least 5 seconds have passed since last update func (s *Session) UpdateLastUsed() { now := time.Now() if now.Sub(s.LastUsed) > 5*time.Second { s.LastUsed = now // Not marking dirty for LastUsed updates to reduce writes } } // IsDirty returns if the session has unsaved changes func (s *Session) IsDirty() bool { return s.dirty } // ResetDirty marks the session as clean after saving func (s *Session) ResetDirty() { s.dirty = false } // SizePlain calculates the size needed to marshal the session func (s *Session) SizePlain() (size int) { // ID size += bstd.SizeString(s.ID) // Data (map of string to any) // For simplicity, we store data as binary-encoded strings // This is a simplification, in a real-world scenario you would handle // different types differently dataAsStrings := make(map[string]string) for k, v := range s.Data { dataAsStrings[k] = toString(v) } size += bstd.SizeMap(dataAsStrings, bstd.SizeString, bstd.SizeString) // Time fields size += bstd.SizeInt64() * 4 // Store Unix timestamps for all time fields return size } // MarshalPlain serializes the session to binary func (s *Session) MarshalPlain(n int, b []byte) (int, error) { // ID n = bstd.MarshalString(n, b, s.ID) // Data dataAsStrings := make(map[string]string) for k, v := range s.Data { dataAsStrings[k] = toString(v) } n = bstd.MarshalMap(n, b, dataAsStrings, bstd.MarshalString, bstd.MarshalString) // Time fields as Unix timestamps n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.LastUsed.Unix()) n = bstd.MarshalInt64(n, b, s.Expiry.Unix()) return n, nil } // UnmarshalPlain deserializes the session from binary func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) { var err error // ID n, s.ID, err = bstd.UnmarshalString(n, b) if err != nil { return n, err } // Data var dataAsStrings map[string]string n, dataAsStrings, err = bstd.UnmarshalMap[string, string](n, b, bstd.UnmarshalString, bstd.UnmarshalString) if err != nil { return n, err } // Convert string data back to original types s.Data = make(map[string]any, len(dataAsStrings)) for k, v := range dataAsStrings { s.Data[k] = fromString(v) } // Time fields var timestamp int64 // CreatedAt n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.CreatedAt = time.Unix(timestamp, 0) // UpdatedAt n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.UpdatedAt = time.Unix(timestamp, 0) // LastUsed n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.LastUsed = time.Unix(timestamp, 0) // Expiry n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.Expiry = time.Unix(timestamp, 0) return n, nil } // Marshal serializes the session using benc func (s *Session) Marshal() ([]byte, error) { size := s.SizePlain() data, err := bufPool.Marshal(size, func(b []byte) (n int) { n, _ = s.MarshalPlain(0, b) return n }) if err != nil { return nil, err } return data, nil } // Unmarshal deserializes a session using benc func Unmarshal(data []byte) (*Session, error) { session := GetFromPool() _, err := session.UnmarshalPlain(0, data) if err != nil { ReturnToPool(session) return nil, err } return session, nil } // Helper functions to convert between any and string // In a production environment, you would use a more robust serialization method for the map values func toString(v any) string { if v == nil { return "" } switch t := v.(type) { case string: return t case []byte: return string(t) case int: return "i:" + string(rune(t)) case bool: if t { return "b:t" } return "b:f" default: return "u:" // unknown type } } func fromString(s string) any { if s == "" { return nil } if len(s) < 2 { return s } prefix := s[:2] switch prefix { case "i:": if len(s) > 2 { return int(rune(s[2])) } return 0 case "b:": if len(s) > 2 && s[2] == 't' { return true } return false case "u:": return nil default: return s } }