package sessions import ( "sync" "time" "github.com/deneonet/benc" bstd "github.com/deneonet/benc/std" ) // Session stores data for a single user session type Session struct { ID string Data map[string]any CreatedAt time.Time UpdatedAt time.Time LastUsed time.Time Expiry time.Time dirty bool // Tracks if session has changes, not serialized } // Session pool to reduce allocations var sessionPool = sync.Pool{ New: func() any { return &Session{ Data: make(map[string]any, 8), } }, } // BufPool for reusing serialization buffers var bufPool = benc.NewBufPool(benc.WithBufferSize(4096)) // GetFromPool retrieves a session from the pool func GetFromPool() *Session { return sessionPool.Get().(*Session) } // ReturnToPool returns a session to the pool after cleaning it func ReturnToPool(s *Session) { if s == nil { return } // Clean the session for reuse s.ID = "" for k := range s.Data { delete(s.Data, k) } s.CreatedAt = time.Time{} s.UpdatedAt = time.Time{} s.LastUsed = time.Time{} s.Expiry = time.Time{} s.dirty = false sessionPool.Put(s) } // NewSession creates a new session with the given ID func NewSession(id string, maxAge int) *Session { now := time.Now() // Get from pool or create new session := GetFromPool() // Initialize session.ID = id session.CreatedAt = now session.UpdatedAt = now session.LastUsed = now session.Expiry = now.Add(time.Duration(maxAge) * time.Second) session.dirty = false return session } // Get retrieves a value from the session func (s *Session) Get(key string) any { return s.Data[key] } // Set stores a value in the session func (s *Session) Set(key string, value any) { s.Data[key] = value s.UpdatedAt = time.Now() s.dirty = true } // Delete removes a value from the session func (s *Session) Delete(key string) { delete(s.Data, key) s.UpdatedAt = time.Now() s.dirty = true } // Clear removes all data from the session func (s *Session) Clear() { s.Data = make(map[string]any, 8) s.UpdatedAt = time.Now() s.dirty = true } // GetAll returns a copy of all session data func (s *Session) GetAll() map[string]any { copy := make(map[string]any, len(s.Data)) for k, v := range s.Data { copy[k] = v } return copy } // IsExpired checks if the session has expired func (s *Session) IsExpired() bool { return time.Now().After(s.Expiry) } // UpdateLastUsed updates the last used time // Only updates if at least 5 seconds have passed since last update func (s *Session) UpdateLastUsed() { now := time.Now() if now.Sub(s.LastUsed) > 5*time.Second { s.LastUsed = now // Not marking dirty for LastUsed updates to reduce writes } } // IsDirty returns if the session has unsaved changes func (s *Session) IsDirty() bool { return s.dirty } // ResetDirty marks the session as clean after saving func (s *Session) ResetDirty() { s.dirty = false } // SizePlain calculates the size needed to marshal the session func (s *Session) SizePlain() (size int) { // ID size += bstd.SizeString(s.ID) // Data map size += bstd.SizeMap(s.Data, bstd.SizeString, func(v any) int { return sizeAny(v) }) // Time fields stored as int64 Unix timestamps size += bstd.SizeInt64() * 4 return size } // MarshalPlain serializes the session to binary func (s *Session) MarshalPlain(n int, b []byte) int { // ID n = bstd.MarshalString(n, b, s.ID) // Data map n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, func(n int, b []byte, v any) int { return marshalAny(n, b, v) }) // Time fields as Unix timestamps n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.LastUsed.Unix()) n = bstd.MarshalInt64(n, b, s.Expiry.Unix()) return n } // UnmarshalPlain deserializes the session from binary func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) { var err error // ID n, s.ID, err = bstd.UnmarshalString(n, b) if err != nil { return n, err } // Data map n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, func(n int, b []byte) (int, any, error) { return unmarshalAny(n, b) }) if err != nil { return n, err } // Time fields as Unix timestamps var timestamp int64 n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.CreatedAt = time.Unix(timestamp, 0) n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.UpdatedAt = time.Unix(timestamp, 0) n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.LastUsed = time.Unix(timestamp, 0) n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.Expiry = time.Unix(timestamp, 0) return n, nil } // Marshal serializes the session using benc func (s *Session) Marshal() ([]byte, error) { size := s.SizePlain() data, err := bufPool.Marshal(size, func(b []byte) (n int) { return s.MarshalPlain(0, b) }) if err != nil { return nil, err } return data, nil } // Unmarshal deserializes a session using benc func Unmarshal(data []byte) (*Session, error) { session := GetFromPool() _, err := session.UnmarshalPlain(0, data) if err != nil { ReturnToPool(session) return nil, err } return session, nil } // Type identifiers for any values const ( typeNull byte = 0 typeString byte = 1 typeInt byte = 2 typeFloat byte = 3 typeBool byte = 4 typeBytes byte = 5 ) // sizeAny calculates the size needed for any value func sizeAny(v any) int { if v == nil { return 1 // Just the type byte } // 1 byte for type + size of the value switch val := v.(type) { case string: return 1 + bstd.SizeString(val) case int: return 1 + bstd.SizeInt64() case int64: return 1 + bstd.SizeInt64() case float64: return 1 + bstd.SizeFloat64() case bool: return 1 + bstd.SizeBool() case []byte: return 1 + bstd.SizeBytes(val) default: // Convert unhandled types to string return 1 + bstd.SizeString("unknown") } } // marshalAny serializes any value func marshalAny(n int, b []byte, v any) int { if v == nil { b[n] = typeNull return n + 1 } switch val := v.(type) { case string: b[n] = typeString return bstd.MarshalString(n+1, b, val) case int: b[n] = typeInt return bstd.MarshalInt64(n+1, b, int64(val)) case int64: b[n] = typeInt return bstd.MarshalInt64(n+1, b, val) case float64: b[n] = typeFloat return bstd.MarshalFloat64(n+1, b, val) case bool: b[n] = typeBool return bstd.MarshalBool(n+1, b, val) case []byte: b[n] = typeBytes return bstd.MarshalBytes(n+1, b, val) default: // Convert unhandled types to string b[n] = typeString return bstd.MarshalString(n+1, b, "unknown") } } // unmarshalAny deserializes any value func unmarshalAny(n int, b []byte) (int, any, error) { if len(b) <= n { return n, nil, benc.ErrBufTooSmall } typeId := b[n] n++ switch typeId { case typeNull: return n, nil, nil case typeString: return bstd.UnmarshalString(n, b) case typeInt: var val int64 var err error n, val, err = bstd.UnmarshalInt64(n, b) return n, val, err case typeFloat: var val float64 var err error n, val, err = bstd.UnmarshalFloat64(n, b) return n, val, err case typeBool: var val bool var err error n, val, err = bstd.UnmarshalBool(n, b) return n, val, err case typeBytes: return bstd.UnmarshalBytesCopied(n, b) default: // Unknown type, return nil return n, nil, nil } }