package sessions import ( "fmt" "sync" "time" "github.com/deneonet/benc" bstd "github.com/deneonet/benc/std" ) // Session stores data for a single user session type Session struct { ID string Data map[string]any CreatedAt time.Time UpdatedAt time.Time LastUsed time.Time Expiry time.Time dirty bool } var ( sessionPool = sync.Pool{ New: func() any { return &Session{Data: make(map[string]any, 8)} }, } bufPool = benc.NewBufPool(benc.WithBufferSize(4096)) ) // NewSession creates a new session with the given ID func NewSession(id string, maxAge int) *Session { s := sessionPool.Get().(*Session) now := time.Now() *s = Session{ ID: id, Data: s.Data, // Reuse map CreatedAt: now, UpdatedAt: now, LastUsed: now, Expiry: now.Add(time.Duration(maxAge) * time.Second), } return s } // Release returns the session to the pool func (s *Session) Release() { for k := range s.Data { delete(s.Data, k) } sessionPool.Put(s) } // Get returns a deep copy of a value func (s *Session) Get(key string) any { if v, ok := s.Data[key]; ok { return deepCopy(v) } return nil } // GetTable returns a value as a table func (s *Session) GetTable(key string) map[string]any { if v := s.Get(key); v != nil { if t, ok := v.(map[string]any); ok { return t } } return nil } // GetAll returns a deep copy of all session data func (s *Session) GetAll() map[string]any { copy := make(map[string]any, len(s.Data)) for k, v := range s.Data { copy[k] = deepCopy(v) } return copy } // Set stores a value in the session func (s *Session) Set(key string, value any) { if existing, ok := s.Data[key]; ok && deepEqual(existing, value) { return // No change } s.Data[key] = value s.UpdatedAt = time.Now() s.dirty = true } // SetSafe stores a value with validation func (s *Session) SetSafe(key string, value any) error { if err := validate(value); err != nil { return fmt.Errorf("session.SetSafe: %w", err) } s.Set(key, value) return nil } // SetTable is a convenience method for setting table data func (s *Session) SetTable(key string, table map[string]any) error { return s.SetSafe(key, table) } // Delete removes a value from the session func (s *Session) Delete(key string) { delete(s.Data, key) s.UpdatedAt = time.Now() s.dirty = true } // Clear removes all data from the session func (s *Session) Clear() { s.Data = make(map[string]any, 8) s.UpdatedAt = time.Now() s.dirty = true } // IsExpired checks if the session has expired func (s *Session) IsExpired() bool { return time.Now().After(s.Expiry) } // UpdateLastUsed updates the last used time func (s *Session) UpdateLastUsed() { now := time.Now() if now.Sub(s.LastUsed) > 5*time.Second { s.LastUsed = now } } // IsDirty returns if the session has unsaved changes func (s *Session) IsDirty() bool { return s.dirty } // ResetDirty marks the session as clean after saving func (s *Session) ResetDirty() { s.dirty = false } // SizePlain calculates the size needed to marshal the session func (s *Session) SizePlain() int { return bstd.SizeString(s.ID) + bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) + bstd.SizeInt64()*4 } // MarshalPlain serializes the session to binary func (s *Session) MarshalPlain(n int, b []byte) int { n = bstd.MarshalString(n, b, s.ID) n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny) n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.LastUsed.Unix()) return bstd.MarshalInt64(n, b, s.Expiry.Unix()) } // UnmarshalPlain deserializes the session from binary func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) { var err error n, s.ID, err = bstd.UnmarshalString(n, b) if err != nil { return n, err } n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny) if err != nil { return n, err } var ts int64 for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} { n, ts, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } *t = time.Unix(ts, 0) } return n, nil } // Marshal serializes the session using benc func (s *Session) Marshal() ([]byte, error) { return bufPool.Marshal(s.SizePlain(), func(b []byte) int { return s.MarshalPlain(0, b) }) } // Unmarshal deserializes a session using benc func Unmarshal(data []byte) (*Session, error) { s := sessionPool.Get().(*Session) if _, err := s.UnmarshalPlain(0, data); err != nil { s.Release() return nil, err } return s, nil } // Type identifiers const ( typeNull byte = 0 typeString byte = 1 typeInt byte = 2 typeFloat byte = 3 typeBool byte = 4 typeBytes byte = 5 typeTable byte = 6 typeArray byte = 7 ) // sizeAny calculates the size needed for any value func sizeAny(v any) int { if v == nil { return 1 } size := 1 // type byte switch v := v.(type) { case string: size += bstd.SizeString(v) case int: size += bstd.SizeInt64() case int64: size += bstd.SizeInt64() case float64: size += bstd.SizeFloat64() case bool: size += bstd.SizeBool() case []byte: size += bstd.SizeBytes(v) case map[string]any: size += bstd.SizeMap(v, bstd.SizeString, sizeAny) case []any: size += bstd.SizeSlice(v, sizeAny) default: size += bstd.SizeString("unknown") } return size } // marshalAny serializes any value func marshalAny(n int, b []byte, v any) int { if v == nil { b[n] = typeNull return n + 1 } switch v := v.(type) { case string: b[n] = typeString return bstd.MarshalString(n+1, b, v) case int: b[n] = typeInt return bstd.MarshalInt64(n+1, b, int64(v)) case int64: b[n] = typeInt return bstd.MarshalInt64(n+1, b, v) case float64: b[n] = typeFloat return bstd.MarshalFloat64(n+1, b, v) case bool: b[n] = typeBool return bstd.MarshalBool(n+1, b, v) case []byte: b[n] = typeBytes return bstd.MarshalBytes(n+1, b, v) case map[string]any: b[n] = typeTable return bstd.MarshalMap(n+1, b, v, bstd.MarshalString, marshalAny) case []any: b[n] = typeArray return bstd.MarshalSlice(n+1, b, v, marshalAny) default: b[n] = typeString return bstd.MarshalString(n+1, b, "unknown") } } // unmarshalAny deserializes any value func unmarshalAny(n int, b []byte) (int, any, error) { if len(b) <= n { return n, nil, benc.ErrBufTooSmall } switch b[n] { case typeNull: return n + 1, nil, nil case typeString: return bstd.UnmarshalString(n+1, b) case typeInt: n, v, err := bstd.UnmarshalInt64(n+1, b) return n, v, err case typeFloat: return bstd.UnmarshalFloat64(n+1, b) case typeBool: return bstd.UnmarshalBool(n+1, b) case typeBytes: return bstd.UnmarshalBytesCopied(n+1, b) case typeTable: return bstd.UnmarshalMap[string, any](n+1, b, bstd.UnmarshalString, unmarshalAny) case typeArray: return bstd.UnmarshalSlice[any](n+1, b, unmarshalAny) default: return n + 1, nil, nil } } // deepCopy creates a deep copy of any value func deepCopy(v any) any { switch v := v.(type) { case map[string]any: cp := make(map[string]any, len(v)) for k, val := range v { cp[k] = deepCopy(val) } return cp case []any: cp := make([]any, len(v)) for i, val := range v { cp[i] = deepCopy(val) } return cp default: return v } } // validate ensures a value can be safely serialized func validate(v any) error { switch v := v.(type) { case nil, string, int, int64, float64, bool, []byte: return nil case map[string]any: for k, val := range v { if err := validate(val); err != nil { return fmt.Errorf("invalid value for key %q: %w", k, err) } } case []any: for i, val := range v { if err := validate(val); err != nil { return fmt.Errorf("invalid value at index %d: %w", i, err) } } default: return fmt.Errorf("unsupported type: %T", v) } return nil } // deepEqual efficiently compares two values for deep equality func deepEqual(a, b any) bool { if a == b { return true } if a == nil || b == nil { return false } switch va := a.(type) { case string: if vb, ok := b.(string); ok { return va == vb } case int: if vb, ok := b.(int); ok { return va == vb } if vb, ok := b.(int64); ok { return int64(va) == vb } case int64: if vb, ok := b.(int64); ok { return va == vb } if vb, ok := b.(int); ok { return va == int64(vb) } case float64: if vb, ok := b.(float64); ok { return va == vb } case bool: if vb, ok := b.(bool); ok { return va == vb } case []byte: if vb, ok := b.([]byte); ok { if len(va) != len(vb) { return false } for i, v := range va { if v != vb[i] { return false } } return true } case map[string]any: if vb, ok := b.(map[string]any); ok { if len(va) != len(vb) { return false } for k, v := range va { if bv, exists := vb[k]; !exists || !deepEqual(v, bv) { return false } } return true } case []any: if vb, ok := b.([]any); ok { if len(va) != len(vb) { return false } for i, v := range va { if !deepEqual(v, vb[i]) { return false } } return true } } return false } // IsEmpty returns true if the session has no data func (s *Session) IsEmpty() bool { return len(s.Data) == 0 }